In [69]:
import numpy as np
import torch
from torch_geometric.data import HeteroData

In [70]:
authors = torch.rand((10,8))
papers = torch.rand((20,4))
authors_y = torch.rand(10).round()

write_from = torch.tensor(np.random.choice(10,50, replace = True))
write_to = torch.tensor(np.random.choice(20, 50, replace = True))
write = torch.concat((write_from, write_to)).reshape(-1,50).long()

cite_from = torch.tensor(np.random.choice(20,15, replace = True))
cite_to = torch.tensor(np.random.choice(20,15, replace=True))
cite = torch.concat((cite_from, cite_to)).reshape(-1,15).long()

In [71]:
data = HeteroData({'author':{'x':authors,'y':authors_y}, 'papers':{'x':papers}},
                    author__write__paper={'edge_index':write},paper__cite__paper = {'edge_index':cite})


In [72]:
import torch_geometric.transforms as T
transform = T.RandomNodeSplit(num_val=3, num_test=3)
data = transform(data)

In [73]:
data['author'].train_mask

tensor([False,  True, False, False, False, False, False,  True,  True,  True])

In [74]:
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero

class GNN(torch.nn.Module):
    def __init__(self,hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1,-1), hidden_channels)
        self.conv2 = SAGEConv((-1,-1), out_channels)

    def forward(self, x, edge_index):
        print(x)
        x = torch.relu(self.conv1(x, edge_index))
        print(x)
        x = self.conv2(x, edge_index)
        return x

model = GNN(64,1)
model = to_hetero(model, data.metadata(), aggr='sum')
model

Proxy(relu)


GraphModule(
  (conv1): ModuleDict(
    (author): SAGEConv((-1, -1), 64)
    (papers): SAGEConv((-1, -1), 64)
  )
  (conv2): ModuleDict(
    (author): SAGEConv((-1, -1), 1)
    (papers): SAGEConv((-1, -1), 1)
  )
)

In [75]:
import torch.nn as nn
def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x_dict, data.edge_index_dict)
    mask = data['author'].train_mask
    print(out)
    print(data['author'].y[mask])
    loss = nn.MSELoss(out[mask], data["author"].y[mask])
    loss.backward()
    optimizer.step()
    return model,float(loss)

In [76]:
@torch.no_grad()
def test(model, optimzizer, data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=1)

    losses = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['author'][split]
        loss = nn.MSELoss(pred[mask], data['author'].y[mask])/mask.sum()
        losses.append(float(loss))
    return losses

In [77]:
import torch.optim as optim
optimizer = optim.Adam(model.parameters())

for epoch in range(1,10):
    #print(data.x_dict)
    #print(data.edge_index_dict)
    model, loss = train(model,optimizer, data)
    train_loss, val_loss, test_loss = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_loss:.4f}, '
        f'Val: {val_loss:.4f}, Test: {test_loss:.4f}')

tensor([0., 1., 1., 1.])


KeyError: tensor([False,  True, False, False, False, False, False,  True,  True,  True])