In [6]:
import torch
import torch.nn.functional as F
import torch_geometric.transforms
from torch_geometric.nn import GATv2Conv, Linear, to_hetero
from src.model.GAT.gat_encoder import GATv2Encoder
from torch_geometric.datasets import OGB_MAG

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = OGB_MAG(root='./data/ogb_mag', preprocess='metapath2vec', transform=torch_geometric.transforms.ToUndirected())
data = dataset[0]

In [8]:
from torch_geometric.nn import GATConv, Linear, to_hetero
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
                ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
                ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels, add_self_loops=False),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        return self.lin(x_dict['author'])

model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes,
                  num_layers=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

with torch.no_grad():  # Initialize lazy modules.
     out = model(data.x_dict, data.edge_index_dict)

In [9]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['paper'].train_mask
    loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

In [10]:
train()

IndexError: too many indices for tensor of dimension 2