# 1. Definition of GraphSAGE

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero

import random, numpy as np

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader

In [2]:

class HomoSAGE(nn.Module):
    """Homogeneous GraphSAGE backbone (to be replicated per relation by to_hetero)."""
    def __init__(self, hidden: int, out_dim: int, layers: int = 2,
                    dropout: float = 0.5, aggr: str = "mean", bn: bool = False):
        super().__init__()
        self.layers = layers
        self.dropout = dropout
        self.bns = nn.ModuleList()
        self.convs = nn.ModuleList()
        if layers == 1:
            self.convs.append(SAGEConv((-1, -1), out_dim, aggr=aggr))
        else:
            self.convs.append(SAGEConv((-1, -1), hidden, aggr=aggr))
            for _ in range(layers - 2):
                self.convs.append(SAGEConv((-1, -1), hidden, aggr=aggr))
            self.convs.append(SAGEConv((-1, -1), out_dim, aggr=aggr))
            if bn:
                for _ in range(layers - 1):
                    self.bns.append(nn.BatchNorm1d(hidden))
        self.use_bn = bn

    def forward(self, x, edge_index):
        h = x
        for i, conv in enumerate(self.convs[:-1]):
            h = conv(h, edge_index)
            if self.use_bn:
                h = self.bns[i](h)
            h = torch.relu(h)
            h = torch.dropout(h, p=self.dropout, train=self.training)
        h = self.convs[-1](h, edge_index)
        return h

def build_hetero_sage(metadata, hidden: int, out_dim: int, layers: int = 2,
                        dropout: float = 0.5, aggr: str = "mean",
                        bn: bool = False, aggr_rel: str = "sum"):
    """Wrap HomoSAGE into a heterogeneous operator via to_hetero."""
    backbone = HomoSAGE(hidden=hidden, out_dim=out_dim, layers=layers,
                        dropout=dropout, aggr=aggr, bn=bn)
    model = to_hetero(backbone, metadata=metadata, aggr=aggr_rel)
    return model

# 2. Utils

In [3]:
def resolve_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    return torch.device("cpu")

def set_seed(seed: int = 42, deterministic: bool = False) -> None:
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = deterministic
    if deterministic:
        torch.use_deterministic_algorithms(True, warn_only=True)

@torch.no_grad()
def masked_accuracy(logits: torch.Tensor, y: torch.Tensor, idx: torch.Tensor) -> float:
    if idx.numel() == 0: return float("nan")
    pred = logits[idx].argmax(-1)
    return (pred == y[idx]).float().mean().item()

In [4]:
device = resolve_device()

In [5]:
device

device(type='mps')

# 3. Load Data

In [9]:
def load_mag_hetero(root: str = "../data/ogb"):
    """
    Returns (data, y, split_idx, train_idx, valid_idx, test_idx)
    - data: object with x_dict, edge_index_dict, y_dict
    - y:    labels tensor for paper nodes (shape [N_paper])
    - split_idx: raw OGB split dict (for reference)
    """
    transform = T.ToUndirected(merge=True)
    ds = OGB_MAG(root, preprocess='metapath2vec', transform=transform)

    data = ds[0]
    y = data["paper"].y.view(-1)

    train_input_nodes = ('paper', data['paper'].train_mask)
    test_input_nodes = ('paper', data['paper'].test_mask)
    val_input_nodes = ('paper', data['paper'].val_mask)
    return ds, y, train_input_nodes, val_input_nodes, test_input_nodes

dataset, y, train_input_nodes, val_input_nodes, test_input_nodes = load_mag_hetero()
data = dataset[0]
num_classes = dataset.num_classes
data

HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389],
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 10792672] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] },
  (institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] },
  (paper, rev_writes, author)={ edge_index=[2, 7145660] },
  (field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] }
)

In [16]:
data.metadata()

(['paper', 'author', 'institution', 'field_of_study'],
 [('author', 'affiliated_with', 'institution'),
  ('author', 'writes', 'paper'),
  ('paper', 'cites', 'paper'),
  ('paper', 'has_topic', 'field_of_study'),
  ('institution', 'rev_affiliated_with', 'author'),
  ('paper', 'rev_writes', 'author'),
  ('field_of_study', 'rev_has_topic', 'paper')])

# 3. Setup

In [9]:
seed = 42

kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}

hidden = 128
layers = 2
agg = "mean"
aggr_rel = "sum"
neighbors = "15,10"

epochs = 50
lr = 0.01
batch_size = 2048
weight_decay = 5e-4
dropout = 0.5


# 4. Train

In [None]:
def make_loaders(data, train_input_nodes, val_input_nodes, test_input_nodes,
                    num_neighbors=(15, 10)):
    train_loader = NeighborLoader(
        data, num_neighbors=list(num_neighbors), shuffle=True,
        input_nodes=train_input_nodes, **kwargs)
    val_loader = NeighborLoader(
        data, num_neighbors=list(num_neighbors),
        input_nodes=val_input_nodes, **kwargs)

    return train_loader, val_loader

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    all_logits = []
    all_idx = []
    for batch in loader:
        batch = batch.to(device)
        out_dict = model(batch.x_dict, batch.edge_index_dict)
        # first K entries in 'paper' correspond to seeds in this batch
        k = batch["paper"].batch_size
        logits = out_dict["paper"][:k]
        # map local seed nodes to global node indices:
        seed_nid = batch["paper"].n_id[:k]
        all_logits.append(logits.cpu())
        all_idx.append(seed_nid.cpu())
    return torch.cat(all_idx, dim=0), torch.cat(all_logits, dim=0)

In [None]:
def train(data, y, train_input_nodes, val_input_nodes, test_input_nodes):    

    model = build_hetero_sage(
        metadata=data.metadata(), hidden=hidden, out_dim=num_classes,
        layers=layers, dropout=dropout, aggr=agg, aggr_rel=aggr_rel
    ).to(device)

    # loaders
    num_neighbors = tuple(int(x) for x in neighbors.split(","))
    train_loader, val_loader = make_loaders(
        data, train_input_nodes, val_input_nodes, test_input_nodes, num_neighbors=num_neighbors
    )

    @torch.no_grad()
    def init_params():
        # Initialize lazy parameters via forwarding a single batch to the model:
        batch = next(iter(train_loader))
        batch = batch.to(device)
        model(batch.x_dict, batch.edge_index_dict)

    init_params()
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_val, best_state = 0.0, None
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        for batch in train_loader:
            batch = batch.to(device)
            out_dict = model(batch.x_dict, batch.edge_index_dict)
            k = batch["paper"].batch_size
            logits = out_dict["paper"][:k]
            y_seed = batch["paper"].y[:k]
            loss = F.cross_entropy(logits, y_seed)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            total_loss += loss.item()

        # Eval
        val_idx_all, val_logits = evaluate(model, val_loader, device)
        #test_idx_all, test_logits = evaluate(model, test_loader, device)
        val_acc  = masked_accuracy(val_logits, y[val_idx_all], torch.arange(val_idx_all.numel()))
        #test_acc = masked_accuracy(test_logits, y[test_idx_all], torch.arange(test_idx_all.numel()))

        if val_acc > best_val:
            best_val, best_state = val_acc, {k: v.detach().cpu() for k, v in model.state_dict().items()}

        print(f"Epoch {epoch:03d} | loss {total_loss:.3f} | val {val_acc:.4f} ")

    #if best_state is not None:
        #model.load_state_dict(best_state, strict=True)
        #test_idx_all, test_logits = evaluate(model, test_loader, device)
        #final_test = masked_accuracy(test_logits, y[test_idx_all], torch.arange(test_idx_all.numel()))
        #print(f"[Best on val]  Final test acc: {final_test:.4f}")

In [None]:
train(data, y, train_input_nodes, val_input_nodes, test_input_nodes)

Epoch 001 | loss 1534.478 | val 0.3755 
Epoch 002 | loss 1368.777 | val 0.3899 
Epoch 003 | loss 1341.780 | val 0.3873 


KeyboardInterrupt: 

libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipelibc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe

libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::system_error: Broken pipe
libc++abi: terminating due to uncaught exception of type std::__1::sys

: 