In [1]:
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
import torch.optim as O
import torch.nn.functional as F
from torch_geometric.nn import GATConv, Linear, to_hetero
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = T.ToUndirected()  # Add reverse edge types.

dataset = OGB_MAG(root="./data", preprocess="metapath2vec", transform=T.ToUndirected())
data = dataset[0]


train_loader = NeighborLoader(
    data,
    # Sample 15 neighbors for each node and each edge type for 2 iterations:
    num_neighbors=[15] * 2,
    # Use a batch size of 128 for sampling training nodes of type "paper":
    batch_size=128,
    input_nodes=("paper", data["paper"].train_mask),
)

batch = next(iter(train_loader))

In [3]:
class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)
        self.lin1 = Linear(-1, hidden_channels)
        self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)
        self.lin2 = Linear(-1, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = x.relu()
        x = self.conv2(x, edge_index) + self.lin2(x)
        return x


model = GAT(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr="sum")
optimizer = O.Adam(
    model.parameters(),
    lr=0.001,
    betas=(0.9, 0.999),
    eps=1e-08,
    weight_decay=0,
    amsgrad=False,
)
model.to(DEVICE)

GraphModule(
  (conv1): ModuleDict(
    (author__affiliated_with__institution): GATConv((-1, -1), 64, heads=1)
    (author__writes__paper): GATConv((-1, -1), 64, heads=1)
    (paper__cites__paper): GATConv((-1, -1), 64, heads=1)
    (paper__has_topic__field_of_study): GATConv((-1, -1), 64, heads=1)
    (institution__rev_affiliated_with__author): GATConv((-1, -1), 64, heads=1)
    (paper__rev_writes__author): GATConv((-1, -1), 64, heads=1)
    (field_of_study__rev_has_topic__paper): GATConv((-1, -1), 64, heads=1)
  )
  (lin1): ModuleDict(
    (paper): Linear(-1, 64, bias=True)
    (author): Linear(-1, 64, bias=True)
    (institution): Linear(-1, 64, bias=True)
    (field_of_study): Linear(-1, 64, bias=True)
  )
  (conv2): ModuleDict(
    (author__affiliated_with__institution): GATConv((-1, -1), 349, heads=1)
    (author__writes__paper): GATConv((-1, -1), 349, heads=1)
    (paper__cites__paper): GATConv((-1, -1), 349, heads=1)
    (paper__has_topic__field_of_study): GATConv((-1, -1), 349

In [4]:
def train():
    model.train()

    total_examples = total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(DEVICE)
        batch_size = batch["paper"].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)
        loss = F.cross_entropy(out["paper"][:batch_size], batch["paper"].y[:batch_size])
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples

In [5]:
avg_loss = train()
avg_loss

2.153773753613016