## Dependencies

In [1]:
import torch
import torch.nn as nn
#!pip install torch_geometric
#!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
#!pip install pandas
import torch_geometric
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import HGTLoader
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.conv import HeteroConv
from torch_geometric.transforms import ToUndirected
from torch.nn.functional import cross_entropy
from torch.nn import Linear, ReLU, Softmax
from tqdm import tqdm

print(torch.__version__)
print(torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


2.8.0+cpu
False


In [2]:
dataset = OGB_MAG(root='GNN-SSL-Project-for-Deep-Learning/data/',
                  transform=ToUndirected())[0]
print(dataset)

HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389],
  },
  author={ num_nodes=1134649 },
  institution={ num_nodes=8740 },
  field_of_study={ num_nodes=59965 },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 10792672] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] },
  (institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] },
  (paper, rev_writes, author)={ edge_index=[2, 7145660] },
  (field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] }
)


In [3]:
num_neighbors = [15, 10, 5]
batch_size = 64

train_batch = NeighborLoader(dataset, 
                        num_neighbors=num_neighbors, 
                        input_nodes=('paper', dataset['paper'].train_mask),
                        batch_size=batch_size, 
                        shuffle=True, 
                        num_workers=0)
test_batch = NeighborLoader(dataset, 
                        num_neighbors=num_neighbors, 
                        input_nodes=('paper', dataset['paper'].test_mask),
                        batch_size=batch_size, 
                        shuffle=True, 
                        num_workers=0)
val_batch = NeighborLoader(dataset, 
                        num_neighbors=num_neighbors, 
                        input_nodes=('paper', dataset['paper'].val_mask),
                        batch_size=batch_size, 
                        shuffle=True, 
                        num_workers=0)
i=0
for batch in train_batch:
    print(batch)
    if i>5:
        break
    i+=1
    
    

HeteroData(
  paper={
    x=[54617, 128],
    year=[54617],
    y=[54617],
    train_mask=[54617],
    val_mask=[54617],
    test_mask=[54617],
    n_id=[54617],
    input_id=[64],
    batch_size=64,
  },
  author={
    num_nodes=29814,
    n_id=[29814],
  },
  institution={
    num_nodes=877,
    n_id=[877],
  },
  field_of_study={
    num_nodes=7326,
    n_id=[7326],
  },
  (author, affiliated_with, institution)={
    edge_index=[2, 995],
    e_id=[995],
  },
  (author, writes, paper)={
    edge_index=[2, 38063],
    e_id=[38063],
  },
  (paper, cites, paper)={
    edge_index=[2, 48720],
    e_id=[48720],
  },
  (paper, has_topic, field_of_study)={
    edge_index=[2, 12394],
    e_id=[12394],
  },
  (institution, rev_affiliated_with, author)={
    edge_index=[2, 3809],
    e_id=[3809],
  },
  (paper, rev_writes, author)={
    edge_index=[2, 10199],
    e_id=[10199],
  },
  (field_of_study, rev_has_topic, paper)={
    edge_index=[2, 53351],
    e_id=[53351],
  }
)
HeteroData(
  paper=

## Working on single batch

In [4]:
batch = next(iter(train_batch))
class graphSAGESINGLE(nn.Module):
    def __init__(self,edge_types,hidden_dim,output_dim):
        super().__init__()
        self.conv1 = HeteroConv({edge_type : SAGEConv((-1,-1),hidden_dim) for edge_type in edge_types},aggr='sum')
        self.conv2 = HeteroConv({edge_type : SAGEConv((-1,-1),output_dim) for edge_type in edge_types},aggr='sum')

    def forward(self,x_dict,edge_index_dict):
        x_dict = self.conv1(x_dict,edge_index_dict)
        x_dict = {k:ReLU()(v) for k,v in x_dict.items()}
        x_dict = self.conv2(x_dict,edge_index_dict)
        return x_dict['paper']

hidden_dim = 16
out_channels = max(batch['paper'].y).item() + 1

model = graphSAGESINGLE(batch.edge_types,hidden_dim,out_channels)

featless = [t for t in batch.node_types if 'x' not in batch[t]]
print(featless)
emb = nn.ModuleDict({
    t: nn.Embedding(dataset[t].num_nodes, 128)
    for t in featless
})

def build_x_dict(batch):
    x_dict = {}
    for node_type in batch.node_types:
        if node_type in featless:
            x_dict[node_type] = emb[node_type](batch[node_type].n_id)
        else:
            x_dict[node_type] = batch[node_type].x
    return x_dict

x_dict = build_x_dict(batch)
edge_index_dict = {edge_type : batch[edge_type].edge_index for edge_type in batch.edge_types}

opt = torch.optim.Adam(list(model.parameters())+list(emb.parameters()), lr=0.01)


epochs = 50
for epoch in range(epochs):
    model.train()
    total_loss = 0
    x_dict = build_x_dict(batch)
    edge_index_dict = {edge_type : batch[edge_type].edge_index for edge_type in batch.edge_types}
    out = model(x_dict,edge_index_dict)
    loss = cross_entropy(out, batch['paper'].y)
    prediction = torch.argmax(Softmax(dim=1)(out),dim=-1)
    acc = (prediction == batch['paper'].y).sum()/batch['paper'].y.shape[0]
    opt.zero_grad()
    loss.backward()
    opt.step()
    total_loss += loss.item() 
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, accuracy: {acc}")

['author', 'institution', 'field_of_study']
Epoch 1/50, Loss: 5.8836, accuracy: 0.0047023785300552845
Epoch 2/50, Loss: 5.6676, accuracy: 0.03981347382068634
Epoch 3/50, Loss: 5.4660, accuracy: 0.03454288840293884
Epoch 4/50, Loss: 5.2665, accuracy: 0.02999725751578808
Epoch 5/50, Loss: 5.0587, accuracy: 0.05533132329583168
Epoch 6/50, Loss: 4.8655, accuracy: 0.07543399184942245
Epoch 7/50, Loss: 4.7122, accuracy: 0.0817234218120575
Epoch 8/50, Loss: 4.5992, accuracy: 0.11172068119049072
Epoch 9/50, Loss: 4.5081, accuracy: 0.11730474978685379
Epoch 10/50, Loss: 4.4298, accuracy: 0.12668991088867188
Epoch 11/50, Loss: 4.3601, accuracy: 0.13881812989711761
Epoch 12/50, Loss: 4.2939, accuracy: 0.15417924523353577
Epoch 13/50, Loss: 4.2296, accuracy: 0.16689525544643402
Epoch 14/50, Loss: 4.1672, accuracy: 0.17432110011577606
Epoch 15/50, Loss: 4.1031, accuracy: 0.18139424920082092
Epoch 16/50, Loss: 4.0363, accuracy: 0.18805596232414246


KeyboardInterrupt: 

In [None]:

x_dict = build_x_dict(batch)
edge_index_dict = {edge_type : batch[edge_type].edge_index for edge_type in batch.edge_types}
out = model(x_dict,edge_index_dict)
prediction = torch.argmax(Softmax(dim=1)(out),dim=-1)

acc = (prediction == batch['paper'].y).sum()/batch['paper'].y.shape[0]
print(acc)

51789
tensor(0.0316)


In [50]:
hidden_channels = 64
out_channels = max(dataset['paper'].y).item() + 1


featless = [t for t in dataset.node_types if 'x' not in dataset[t]]
emb = nn.ModuleDict({
    t: nn.Embedding(dataset[t].num_nodes, 128)
    for t in featless
})
emb = emb.to(device)

class graphSAGEmodel(nn.Module):
    def __init__(self, edge_types, hidden, out_channels):
        super().__init__()
        edge_types = edge_types
        # layer 1
        self.conv1 = HeteroConv(
            {et: SAGEConv((-1, -1), hidden) for et in edge_types},
            aggr='sum'
        )
        # layer 2
        self.conv2 = HeteroConv(
            {et: SAGEConv((-1, -1), hidden) for et in edge_types},
            aggr='sum'
        )
        self.head = Linear(hidden, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        #x_dict = {k: ReLU(v) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        # return logits for papers only
        return self.head(x_dict['paper'])
    
model = graphSAGEmodel(dataset.edge_types, hidden_channels, out_channels)
model = model.to(device)
opt = torch.optim.Adam(list(model.parameters())+list(emb.parameters()), lr=0.01)
print(model)

graphSAGEmodel(
  (conv1): HeteroConv(num_relations=7)
  (conv2): HeteroConv(num_relations=7)
  (head): Linear(in_features=64, out_features=349, bias=True)
)


In [46]:
def build_x_dict(batch):
    x_dict = {}
    for node_type in batch.node_types:
        if node_type in featless:
            x_dict[node_type] = emb[node_type](batch[node_type].n_id)
        else:
            x_dict[node_type] = batch[node_type].x
    return x_dict



In [56]:
@torch.no_grad()
def infer(loader):
    model.eval()
    preds, ys = [], []
    for batch in loader:
        batch = batch.to(device)
        bs = batch['paper'].batch_size
        x_dict = build_x_dict(batch)
        edge_index_dict = {edge_type : batch[edge_type].edge_index for edge_type in batch.edge_types}
        logits = model(x_dict, edge_index_dict)[:bs]
        preds.append(logits.argmax(-1).cpu())
        ys.append(batch['paper'].y[:bs].view(-1).cpu())
    return torch.cat(preds), torch.cat(ys)

num_epochs = 5
for epoch in range(1, num_epochs):
    model.train()
    tr_loss = 0
    for batch in tqdm(train_batch, desc=f"Epoch {epoch}/{num_epochs}"):
        batch = batch.to(device)
        x_dict = build_x_dict(batch)
        edge_index_dict = {edge_type : batch[edge_type].edge_index for edge_type in batch.edge_types}
        logits = model(x_dict, edge_index_dict)
        loss = cross_entropy(logits, batch['paper'].y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        tr_loss += loss.item()
    val_pred, val_true = infer(val_batch)
    val_metric = (val_pred == val_true).sum().item() / val_true.size(0)
    print(f"Epoch {epoch:02d} | loss {tr_loss:.4f} | val {val_metric:.4f}")


Epoch 1/5: 100%|██████████| 9838/9838 [13:30<00:00, 12.14it/s]


Epoch 01 | loss 18033.1235 | val 0.8608


Epoch 2/5:  11%|█         | 1050/9838 [01:20<11:11, 13.08it/s]


KeyboardInterrupt: 