In [1]:
import numpy as np
import torch
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.nn import Sequential, Linear
from torch.nn import ReLU
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.loader import NeighborLoader, HGTLoader


In [None]:
# Random data just to show how store values of nodes work
authors = torch.rand((10,8))
papers = torch.rand((20,4))
authors_y = torch.rand(10)

# Random data just to show how store values of edges work
write_from = torch.tensor(np.random.choice(10, 50, replace = True))
write_to = torch.tensor(np.random.choice(20, 50, replace=True))
write = torch.concat((write_from, write_to)).reshape(-1,50).long()

# Random dat justo to show how store values of edges work
cite_from = torch.tensor(np.random.choice(20, 15, replace=True))
cite_to = torch.tensor(np.random.choice(20, 15, replace=True))
cite = torch.concat((cite_from, cite_to)).reshape(-1,15).long()

In [None]:
#-------------------------Register HeteroData
# Pattern to declare all as one dictionary as argument of class HeteroData
data = HeteroData({'author': {'x':authors, 'y':authors_y}, 'paper':{'x':papers}},
                 author__write__paper={'edge_index':write}, paper__cite__paper={'edge_index': cite})

data.metadata()

# Transforms from many types of nodes and edges to just one type of each
homogeneus_data = data.to_homogeneous()


In [None]:

# If you want to store the data
data.to_dict()

#-------------------------Example of model with HeteroData
transform = T.RandomNodeSplit()
data = transform(data)

#---------------------Model 1 
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1,-1), hidden_channels)
        self.conv2 = SAGEConv((-1,-1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=64, out_channels=2)
model= to_hetero(model, data.metadata(), aggr='sum')

##---------------------Model 2
model = Sequential('x, edge_index', [
    (SAGEConv((-1,1),64), 'x, edge_index ->x'),
    ReLU(inplace = True),
    (SAGEConv((-1,1),64), 'x, edge_index ->x'),
    ReLU(inplace = True),
    (Linear(-1,2), 'x -> x'),
])

model = to_hetero(model, data.metadata(), aggr='sum')

#-------------------------Train Data

dataset = OGB_MAG(root='.data', preprocess='metapath2vec', transform=T.ToUndirected())
data = dataset[0]

data.metadata()

train_input_nodes = ('paper', data['paper'].train_mask)
train_loader = NeighborLoader(data, num_neighbors=[10] *2, shuffle=True, input_nodes=train_input_nodes)

for t in train_loader:
    print(t)
    break