In [20]:
torch_version = str(torch.__version__)
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src
!pip install torch-sparse -f $sparse_src
!pip install torch-geometric
!pip install -q git+https://github.com/snap-stanford/deepsnap.git

Looking in links: https://pytorch-geometric.com/whl/torch-2.8.0+cu126.html
Looking in links: https://pytorch-geometric.com/whl/torch-2.8.0+cu126.html
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [21]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero
import tqdm
import torch.nn.functional as F
import torch_geometric.transforms as T

In [22]:
# Lets start by loading the data

data = torch.load("hetero_data_no_coauthor.pt", weights_only=False)
data

HeteroData(
  author={ node_id=[90941] },
  paper={
    node_id=[63854],
    x=[63854, 256],
  },
  (author, writes, paper)={ edge_index=[2, 320187] },
  (paper, rev_writes, author)={ edge_index=[2, 320187] }
)

In [23]:
# Do the Train, Val, Test Split
# training (80%), validation (10%), and testing edges (10%).
# Across the training edges, we use 70% of edges for message passing,
# and 30% of edges for supervision. (This is from a tutorial by PyG, we can change this later)
# We further want to generate fixed negative edges for evaluation with a ratio of 2:1. (Again a Hyperparameter we can tune later)
# Negative edges during training will be generated on-the-fly (How?, again this is from the tutorial, need to check later)
transform = T.RandomLinkSplit(
    num_val=0.1, # Validation set percentage
    num_test=0.1, # Test set percentage
    disjoint_train_ratio=0.3, # Percentage of training edges used for supervision, these will not be used for message passing
    neg_sampling_ratio=2.0, # Ratio of negative to positive edges for validation and testing, dont know how this is related to `add_negative_train_samples`, need to check later
    add_negative_train_samples=False, # AYYY NO idea, why this set to False, but somehow it works worse with True ???, Need it investigate later, Prolly because we do LinkNeighborLoader which samples neg edges for us?
    edge_types=("author", "writes", "paper"), # Any ways, these are the edge types we want to predict
    rev_edge_types=("paper", "rev_writes", "author"), # Reverse edge types, so we dont accidentally bleed information into validation/test set
)

train_data, val_data, test_data = transform(data)
train_data

HeteroData(
  author={ node_id=[90941] },
  paper={
    node_id=[63854],
    x=[63854, 256],
  },
  (author, writes, paper)={
    edge_index=[2, 179306],
    edge_label=[76845],
    edge_label_index=[2, 76845],
  },
  (paper, rev_writes, author)={ edge_index=[2, 179306] }
)

In [24]:
# In the first hop, we sample at most 20 neighbors.
# In the second hop, we sample at most 10 neighbors.
# In addition, during training, we want to sample negative edges on-the-fly with
# a ratio of 2:1.
# We can make use of the `loader.LinkNeighborLoader` from PyG:

# This loader is actually SAMPLING the full graph, by first sampling 64 random nodes then 32 neighbors of each node previously sampled node to create a sparse subgraph etc...
# We should be able to load the graph fully into memory, but how would one train that?
# We could probably use the previous random link split to do full batch training, but somehow we would not sample random negative edges then?
# Need to check different loaders which sample the full graph and then do negative sampling on-the-fly
edge_label_index = train_data["author", "writes", "paper"].edge_label_index
edge_label = train_data["author", "writes", "paper"].edge_label

train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[64, 32, 16],
    neg_sampling_ratio=2.0,
    edge_label_index=(("author", "writes", "paper"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

  neighbor_sampler = NeighborSampler(


In [25]:

# Simple 3 hop GNN
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(
            hidden_channels,
            hidden_channels,
            aggr="mean",
            project=False,
        )
        self.conv2 = SAGEConv(
            hidden_channels,
            hidden_channels,
            aggr="mean",
            project=False,
        )
        self.conv3 = SAGEConv(
            hidden_channels,
            hidden_channels,
            aggr="mean",
            project=False,
        )

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x


# Our final classifier applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Classifier(torch.nn.Module):
    def forward(
        self,
        x_user: torch.Tensor,
        x_movie: torch.Tensor,
        edge_label_index: torch.Tensor,
    ) -> torch.Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels: int, data: HeteroData):
        super().__init__()

        self.hidden_channels = hidden_channels

        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)

        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())

        # Instantiate link classifier:
        self.classifier = Classifier()

    def forward(self, data: HeteroData) -> torch.Tensor:

        # Set the initial user embeddings to all ones for all authors
        # This makes sure the graph can generalize to unseen authors during inference
        author_embedding = torch.ones(
            (data["author"].num_nodes, self.hidden_channels),
            device=data["paper"].x.device,
        )

        # Extract paper embeddings from the data object
        paper_embedding = data["paper"].x

        # Noew we can create the x_dict required for the GNN
        x_dict = {
            "author": author_embedding,
            "paper": paper_embedding,
        }

        # "x_dict" now holds feature matrices of all node types
        # "edge_index_dict" holds all edge indices, i.e. the connections between users and movies
        # The GNN will predict new embeddings for all node types, we can even check how the user embeddings change
        gnn_pred = self.gnn(x_dict, data.edge_index_dict)

        # Finally we can use the classifier to get the final link predictions
        # This can be done either with the dot product of the updated embeddings
        # or more involved with a linear projection head or smth similar
        cls_pred = self.classifier(
            gnn_pred["author"],
            gnn_pred["paper"],
            data["author", "writes", "paper"].edge_label_index,
        )

        return cls_pred

In [26]:
from torch_scatter import scatter_mean
class BaselineNoGraphModel(torch.nn.Module):
    def __init__(self, hidden_channels: int, data: HeteroData):
        super().__init__()
        self.hidden_channels = hidden_channels

        # Use the correct per-type input sizes
        paper_in = data["paper"].num_features
        # author_in = data["author"].num_features  # not used in this baseline

        # Project paper features to hidden size
        self.lin_paper = torch.nn.Linear(paper_in, hidden_channels, bias=True)

        # Optional extra transform on the aggregated author representation
        self.lin_author = torch.nn.Linear(hidden_channels, hidden_channels, bias=True)

        self.classifier = Classifier()  # assumes signature: (author_emb, paper_emb, edge_label_index) -> scores

    def forward(self, data: HeteroData) -> torch.Tensor:
        edge_type = ("author", "writes", "paper")
        edge_index = data[edge_type].edge_index
        author_ids, paper_ids = edge_index[0], edge_index[1]

        # 1) Paper embeddings
        paper_x = data["paper"].x  # [num_papers, paper_in]
        paper_h = self.lin_paper(paper_x)  # [num_papers, hidden]

        # 2) Build author embeddings by averaging their authored papers' embeddings
        num_authors = data["author"].num_nodes
        # paper_h[paper_ids] picks each written paper's embedding; scatter to author_ids
        author_h = scatter_mean(
            paper_h[paper_ids],
            author_ids,
            dim=0,
            dim_size=num_authors,  # ensures we get a row for every author (zeros for authors with no papers)
        )
        author_h = self.lin_author(author_h)  # [num_authors, hidden]

        # 3) Score candidate pairs (author, paper) at edge_label_index
        cls_pred = self.classifier(
            author_h,
            paper_h,
            data[edge_type].edge_label_index,
        )
        return cls_pred

In [27]:
LR = 0.001
EPOCHS = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model(hidden_channels=256, data=data)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

model = model.to(device)

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):

        optimizer.zero_grad()
        sampled_data.to(device)

        y_pred = model(sampled_data)
        y_true = sampled_data["author", "writes", "paper"].edge_label

        loss = F.binary_cross_entropy_with_logits(y_pred, y_true)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * y_pred.numel()
        total_examples += y_pred.numel()

    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

100%|██████████| 601/601 [00:44<00:00, 13.44it/s]


Epoch: 000, Loss: 0.6041


100%|██████████| 601/601 [00:35<00:00, 16.82it/s]


Epoch: 001, Loss: 0.5136


100%|██████████| 601/601 [00:36<00:00, 16.35it/s]


Epoch: 002, Loss: 0.4832


100%|██████████| 601/601 [00:35<00:00, 16.74it/s]


Epoch: 003, Loss: 0.4603


100%|██████████| 601/601 [00:35<00:00, 16.82it/s]


Epoch: 004, Loss: 0.4493


100%|██████████| 601/601 [00:35<00:00, 16.80it/s]


Epoch: 005, Loss: 0.4396


100%|██████████| 601/601 [00:35<00:00, 16.83it/s]


Epoch: 006, Loss: 0.4275


100%|██████████| 601/601 [00:35<00:00, 16.74it/s]


Epoch: 007, Loss: 0.4210


100%|██████████| 601/601 [00:35<00:00, 16.83it/s]


Epoch: 008, Loss: 0.4167


100%|██████████| 601/601 [00:35<00:00, 16.86it/s]


Epoch: 009, Loss: 0.4119


100%|██████████| 601/601 [00:38<00:00, 15.77it/s]


Epoch: 010, Loss: 0.4100


100%|██████████| 601/601 [00:35<00:00, 16.87it/s]


Epoch: 011, Loss: 0.4071


100%|██████████| 601/601 [00:36<00:00, 16.60it/s]


Epoch: 012, Loss: 0.4053


100%|██████████| 601/601 [00:35<00:00, 16.84it/s]


Epoch: 013, Loss: 0.4012


100%|██████████| 601/601 [00:35<00:00, 16.94it/s]


Epoch: 014, Loss: 0.3979


100%|██████████| 601/601 [00:35<00:00, 16.93it/s]


Epoch: 015, Loss: 0.3964


100%|██████████| 601/601 [00:35<00:00, 16.92it/s]


Epoch: 016, Loss: 0.3929


100%|██████████| 601/601 [00:35<00:00, 16.91it/s]


Epoch: 017, Loss: 0.3939


100%|██████████| 601/601 [00:35<00:00, 16.80it/s]


Epoch: 018, Loss: 0.3894


100%|██████████| 601/601 [00:35<00:00, 16.73it/s]

Epoch: 019, Loss: 0.3877





In [28]:
def evaluate_model(model, data):
    model.eval()
    with torch.no_grad():
        y_pred = model(data)

    y_pred = y_pred.cpu().numpy()
    y_true = data["author", "writes", "paper"].edge_label.cpu().numpy()

    # binary thresholding at 0.5
    y_pred = (y_pred >= 0.5)

    FP = ((y_true == 0) & (y_pred == 1)).sum().item()
    TP = ((y_true == 1) & (y_pred == 1)).sum().item()
    FN = ((y_true == 1) & (y_pred == 0)).sum().item()
    TN = ((y_true == 0) & (y_pred == 0)).sum().item()

    precision = TP / (TP + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
    accuracy = (TP + TN) / (TP + TN + FP + FN + 1e-8)

    return precision, recall, f1_score, accuracy


test_data.to(device)
precision, recall, f1_score, accuracy = evaluate_model(model, test_data)
print("Evaluating on Test set...")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print("--------------------------------------------------")
val_data.to(device)
precision, recall, f1_score, accuracy = evaluate_model(model, val_data)
print("Evaluating on validation set...")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Evaluating on Test set...
Precision: 0.8251
Recall: 0.6633
F1 Score: 0.7354
Accuracy: 0.8409
--------------------------------------------------
Evaluating on validation set...
Precision: 0.8316
Recall: 0.6448
F1 Score: 0.7264
Accuracy: 0.8381


In [29]:
from sklearn.metrics import roc_auc_score

with torch.no_grad():
    y_pred = model(val_data)

y_pred = y_pred.cpu().numpy()
y_true = val_data["author", "writes", "paper"].edge_label.cpu().numpy()

auc = roc_auc_score(y_true, y_pred)
print(f"Validation AUC: {auc:.4f}")

Validation AUC: 0.9062


In [30]:

@torch.no_grad()
def evaluate_ranking_metrics(
    model,
    data,
    edge_type=("author", "writes", "paper"),
    ks=(1, 3, 5, 10),
    reduce="macro",  # 'macro' = average over heads (recommended)
    device=None,
):
    """
    Compute ranking-style metrics for link prediction / recommendation:
      - Hits@K:   fraction of heads with >=1 positive in top-K
      - Recall@K: average over heads of (positives in top-K / total positives)
      - Precision@K: average over heads of (positives in top-K / K)
      - MRR:      mean reciprocal rank of the first positive per head
      - MAP:      mean average precision over heads
      - NDCG@K:   average normalized DCG at K over heads

    Assumptions:
      - model(data) -> scores aligned with edge_label (1D)
      - data[edge_type].edge_label in {0,1}
      - data[edge_type].edge_label_index[0] are the "head" IDs to group by

    Notes:
      - Heads with zero positives are skipped for metrics that require a positive
        (MRR, MAP, Recall@K, NDCG@K). For Precision@K and Hits@K we include all heads.
      - Set `device` if you want to force inference on a specific device.
    """
    if device is not None:
        data = data.to(device)
        model = model.to(device)
    model.eval()

    scores = model(data).detach()
    labels = data[edge_type].edge_label
    head_ids = data[edge_type].edge_label_index[0]

    # move to cpu numpy
    scores = scores.cpu().numpy().astype(np.float64)
    labels = labels.cpu().numpy().astype(np.int64)
    head_ids = head_ids.cpu().numpy().astype(np.int64)

    # group indices by head
    # heads_idx_map: head_id -> np.array(indices of edges for that head)
    # This is robust if head_ids are not contiguous or sorted.
    heads_idx_map = {}
    for i, h in enumerate(head_ids):
        heads_idx_map.setdefault(int(h), []).append(i)

    # containers
    hits_at_k = {k: [] for k in ks}
    prec_at_k = {k: [] for k in ks}
    rec_at_k  = {k: [] for k in ks}
    ndcg_at_k = {k: [] for k in ks}
    mrr_vals = []
    ap_vals  = []

    # helper: DCG with binary relevance
    def dcg_at_k(y_true_sorted, k):
        # y_true_sorted: binary labels sorted by descending score
        rel = y_true_sorted[:k]
        if rel.size == 0:
            return 0.0
        # log2 positions start at 2 for rank 1
        discounts = 1.0 / np.log2(np.arange(2, rel.size + 2))
        return np.sum(rel * discounts)

    for h, idxs in heads_idx_map.items():
        idxs = np.array(idxs, dtype=np.int64)
        y = labels[idxs]
        s = scores[idxs]

        # sort by score desc
        order = np.argsort(-s)
        y_sorted = y[order]

        num_pos = int(y.sum())

        # Precision/Recall/Hits/NDCG@K
        for k in ks:
            topk = y_sorted[:k]
            hits_at_k[k].append(1.0 if topk.sum() > 0 else 0.0)
            prec_at_k[k].append(float(topk.sum()) / max(k, 1))

            if num_pos > 0:
                rec_at_k[k].append(float(topk.sum()) / num_pos)
                # NDCG
                dcg = dcg_at_k(y_sorted, k)
                ideal_sorted = np.sort(y)[::-1]  # best-case ranking
                idcg = dcg_at_k(ideal_sorted, k)
                ndcg_at_k[k].append(dcg / idcg if idcg > 0 else 0.0)

        # MRR + MAP only defined if there is at least one positive
        if num_pos > 0:
            # MRR
            pos_ranks = np.where(y_sorted == 1)[0]  # 0-based ranks
            first_rank = pos_ranks[0] + 1  # 1-based
            mrr_vals.append(1.0 / first_rank)

            # AP
            # Precision at each position of a relevant item, averaged over #relevant
            cum_pos = 0
            prec_sum = 0.0
            for rank, rel in enumerate(y_sorted, start=1):
                if rel == 1:
                    cum_pos += 1
                    prec_sum += cum_pos / rank
            ap_vals.append(prec_sum / num_pos)

    # aggregate
    def avg(lst):
        return float(np.mean(lst)) if len(lst) > 0 else 0.0

    results = {
        "num_heads": len(heads_idx_map),
        "MRR": avg(mrr_vals),
        "MAP": avg(ap_vals),
    }

    for k in ks:
        results[f"Hits@{k}"] = avg(hits_at_k[k])
        results[f"Precision@{k}"] = avg(prec_at_k[k])
        # Recall@K & NDCG@K are averaged over heads with positives only
        results[f"Recall@{k}"] = avg(rec_at_k[k])
        results[f"NDCG@{k}"] = avg(ndcg_at_k[k])

    return results


In [31]:
import numpy as np

# Example usage:
test_data = test_data.to(device)
metrics = evaluate_ranking_metrics(model, test_data,
    edge_type=("author","writes","paper"), ks=(1,3,10), device=device)
print("Ranking metrics on Test:")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")


Ranking metrics on Test:
num_heads: 56505.0000
MRR: 0.9565
MAP: 0.9542
Hits@1: 0.3447
Precision@1: 0.3447
Recall@1: 0.7717
NDCG@1: 0.9170
Hits@3: 0.3753
Precision@3: 0.1678
Recall@3: 0.9725
NDCG@3: 0.9642
Hits@10: 0.3759
Precision@10: 0.0561
Recall@10: 0.9990
NDCG@10: 0.9669


In [32]:
LR = 0.001
EPOCHS = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BaselineNoGraphModel(hidden_channels=256, data=data)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

model = model.to(device)

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):

        optimizer.zero_grad()
        sampled_data.to(device)

        y_pred = model(sampled_data)
        y_true = sampled_data["author", "writes", "paper"].edge_label

        loss = F.binary_cross_entropy_with_logits(y_pred, y_true)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * y_pred.numel()
        total_examples += y_pred.numel()

    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

100%|██████████| 601/601 [00:26<00:00, 22.86it/s]


Epoch: 000, Loss: 0.5402


100%|██████████| 601/601 [00:25<00:00, 23.55it/s]


Epoch: 001, Loss: 0.4958


100%|██████████| 601/601 [00:25<00:00, 23.95it/s]


Epoch: 002, Loss: 0.4853


100%|██████████| 601/601 [00:25<00:00, 23.60it/s]


Epoch: 003, Loss: 0.4835


100%|██████████| 601/601 [00:26<00:00, 22.83it/s]


Epoch: 004, Loss: 0.4795


100%|██████████| 601/601 [00:25<00:00, 23.26it/s]


Epoch: 005, Loss: 0.4775


100%|██████████| 601/601 [00:25<00:00, 23.91it/s]


Epoch: 006, Loss: 0.4764


100%|██████████| 601/601 [00:24<00:00, 24.08it/s]


Epoch: 007, Loss: 0.4744


100%|██████████| 601/601 [00:25<00:00, 23.91it/s]


Epoch: 008, Loss: 0.4729


100%|██████████| 601/601 [00:25<00:00, 23.65it/s]


Epoch: 009, Loss: 0.4711


100%|██████████| 601/601 [00:25<00:00, 23.69it/s]


Epoch: 010, Loss: 0.4707


100%|██████████| 601/601 [00:25<00:00, 23.71it/s]


Epoch: 011, Loss: 0.4698


100%|██████████| 601/601 [00:25<00:00, 23.28it/s]


Epoch: 012, Loss: 0.4674


100%|██████████| 601/601 [00:25<00:00, 23.59it/s]


Epoch: 013, Loss: 0.4674


100%|██████████| 601/601 [00:25<00:00, 23.70it/s]


Epoch: 014, Loss: 0.4659


100%|██████████| 601/601 [00:25<00:00, 23.64it/s]


Epoch: 015, Loss: 0.4649


100%|██████████| 601/601 [00:25<00:00, 23.72it/s]


Epoch: 016, Loss: 0.4655


100%|██████████| 601/601 [00:25<00:00, 23.69it/s]


Epoch: 017, Loss: 0.4632


100%|██████████| 601/601 [00:25<00:00, 23.67it/s]


Epoch: 018, Loss: 0.4629


100%|██████████| 601/601 [00:25<00:00, 23.96it/s]

Epoch: 019, Loss: 0.4622





In [33]:
import numpy as np

# Example usage:
test_data = test_data.to(device)
metrics = evaluate_ranking_metrics(model, test_data,
    edge_type=("author","writes","paper"), ks=(1,3,10), device=device)
print("Ranking metrics on Test:")
for k, v in metrics.items():
    print(f"{k}: {v:.4f}")


Ranking metrics on Test:
num_heads: 56505.0000
MRR: 0.9511
MAP: 0.9482
Hits@1: 0.3409
Precision@1: 0.3409
Recall@1: 0.7620
NDCG@1: 0.9071
Hits@3: 0.3752
Precision@3: 0.1675
Recall@3: 0.9713
NDCG@3: 0.9593
Hits@10: 0.3759
Precision@10: 0.0561
Recall@10: 0.9990
NDCG@10: 0.9626


In [34]:

test_data.to(device)
precision, recall, f1_score, accuracy = evaluate_model(model, test_data)
print("Evaluating on Test set...")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print("--------------------------------------------------")
val_data.to(device)
precision, recall, f1_score, accuracy = evaluate_model(model, val_data)
print("Evaluating on validation set...")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Evaluating on Test set...
Precision: 0.7739
Recall: 0.3906
F1 Score: 0.5191
Accuracy: 0.7588
--------------------------------------------------
Evaluating on validation set...
Precision: 0.7860
Recall: 0.3873
F1 Score: 0.5190
Accuracy: 0.7606
