In [None]:
import torch
import torch_geometric
import networkx as nx

print(torch.__version__)
print(nx.__version__)
print(torch_geometric.__version__)
print(torch.cuda.is_available())

2.0.1+cu118
2.8.8
2.6.1
True


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv, Linear

In [12]:
class HeteroGNN(nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = HeteroConv({
            ('paper', 'written_by', 'author'): SAGEConv((-1, -1), hidden_channels),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
            ('paper', 'cites', 'paper'): SAGEConv((-1, -1), hidden_channels),
        }, aggr='sum')
        self.lin_dict = nn.ModuleDict()
        for node_type in ['paper', 'author']:
            self.lin_dict[node_type] = Linear(hidden_channels, out_channels)

    def forward(self, data: HeteroData):
        x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
        x_dict = self.conv1(x_dict, edge_index_dict)
        for node_type, x in x_dict.items():
            x_dict[node_type] = self.lin_dict[node_type](x)
        return x_dict

In [8]:
data = HeteroData()

# 100 paper nodes & 30 author nodes
data['paper'].x = torch.randn((100, 16))
data['author'].x = torch.randn((30, 16))

# Edges
data['paper', 'written_by', 'author'].edge_index = torch.tensor([
    [0, 1, 2, 3], # paper
    [0, 1, 2, 3]  # author
], dtype=torch.long)

data['author', 'writes', 'paper'].edge_index = torch.tensor([
    [0, 1, 2, 3], # author
    [0, 1, 2, 3]  # paper
], dtype=torch.long)

data['paper', 'cites', 'paper'].edge_index = torch.tensor([
    [0, 2, 5, 7], # source paper
    [1, 3, 8, 10] # target paper
], dtype=torch.long)

In [9]:
data.x_dict

{'paper': tensor([[ 0.6657,  0.5537,  1.9262,  ..., -0.4909, -0.9882,  1.3422],
         [-0.6322,  0.4012, -1.5001,  ..., -1.4602,  1.3033, -1.8400],
         [-1.8812, -0.4963,  1.3473,  ..., -0.3284, -0.9288, -0.7713],
         ...,
         [ 0.6736,  1.4270, -0.6056,  ...,  1.2398, -1.3212, -0.1970],
         [-0.3299, -0.6561,  0.8177,  ...,  0.0757, -0.7885, -1.0609],
         [-0.4785,  0.4440,  1.3176,  ..., -0.8703, -2.1342, -1.0828]]),
 'author': tensor([[ 1.5064e-02,  6.2708e-01,  5.1050e-01,  1.7278e+00, -1.4251e-02,
           4.3664e-01,  1.1219e-02,  7.1715e-01, -7.9840e-01,  1.1662e+00,
           1.1805e-01,  7.3316e-01, -2.0145e+00, -4.4689e-01,  1.0381e+00,
           5.9129e-01],
         [ 4.9385e-01,  5.8562e-01, -5.1970e-01, -2.2129e+00, -8.2794e-02,
           6.9124e-01,  4.7222e-01, -9.6065e-01,  7.3267e-01, -1.4560e+00,
          -4.6076e-01, -8.1869e-01, -2.8971e-02, -4.2630e-01,  5.3083e-01,
           1.8312e-01],
         [ 1.0721e+00,  8.4299e-02, -1.33

In [10]:
data.edge_index_dict

{('paper',
  'written_by',
  'author'): tensor([[0, 1, 2, 3],
         [0, 1, 2, 3]]),
 ('author',
  'writes',
  'paper'): tensor([[0, 1, 2, 3],
         [0, 1, 2, 3]]),
 ('paper',
  'cites',
  'paper'): tensor([[ 0,  2,  5,  7],
         [ 1,  3,  8, 10]])}

In [13]:
model = HeteroGNN(hidden_channels=32, out_channels=16)

In [47]:
out = model(data)

In [46]:
print("Input node features by type:")
for i in data.node_items():
    print(f"{i[0]}: {i[1].x.shape}")

Input node features by type:
paper: torch.Size([100, 16])
author: torch.Size([30, 16])


In [22]:
print("Output node features by type:")
for node_type, features in out.items():
    print(f"{node_type}: {features.shape}")

Output node features by type:
author: torch.Size([30, 16])
paper: torch.Size([100, 16])
