# Graph Neural Networks for Link Prediction

Refer to this [blog post](https://medium.com/@tanishjain/224w-final-project-46c1054f2aa4) for more details!

In [None]:
# Wipe conflicting installs
!pip -q uninstall -y torch torchvision torchaudio torch_geometric torch_scatter torch_sparse torch_cluster torch_spline_conv

# Install PyTorch built for CUDA 12.1 (fits Colab GPU)
!pip install -q --index-url https://download.pytorch.org/whl/cu121 \
  torch==2.4.0+cu121 torchvision==0.19.0+cu121 torchaudio==2.4.0+cu121

# Install PyG and its compiled extensions matching that exact Torch/CUDA
!pip install -q torch_geometric==2.5.3 \
  -f https://data.pyg.org/whl/torch-2.4.0+cu121.html

# Optional packages
!pip install -q git+https://github.com/snap-stanford/deepsnap.git
!pip install -q PyDrive
!pip install ogb

[0m  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
# -*- coding: utf-8 -*-
"""
OGB-ddi Link Prediction with Enh4SAGEConv (Torch 2.6 compatible)
================================================================

This single-file script integrates the fixes & improvements we discussed:

Core fixes (most impactful):
  1) NEGATIVE SAMPLING WITHOUT LEAKAGE:
     - Build two graphs:
         - edge_index_train : train-only (undirected, deduped) for message passing
         - edge_index_all   : train ∪ valid ∪ test (undirected, deduped) for negative sampling
     - This avoids sampling true edges (from valid/test) as negatives.

  2) EVALUATION ROBUST TO GROUPED NEGATIVES:
     - Supports both flat and grouped shapes of `edge_neg` in OGB.

Training tweaks:
  - Cosine decay (no warm restarts) with warmup
  - Slightly lower dropout & edge_drop for DDI
  - Fewer negatives per positive (neg_ratio=2)
  - Freeze node embeddings for the first 10 epochs
  - Early stopping by Val@20 (patience=40)
  - AMP + grad clipping + AdamW remain
  - Proximal regularizer to initial embeddings preserved

Preserves:
  - Google Drive artifact/result paths & file names
  - 512-dim external node embeddings
  - Enh4SAGEConv + Enh4SAGEStack w/ residuals, LayerNorm, DropEdge, JK-Max
  - Plotting training loss and Hits@20 curves

Tested with: Torch 2.6.0+cu124, PyG 2.6.1, ogb 1.3.x in Colab (Aug 2025).

Changes made since running out of memory
1) Set PYTORCH_CUDA_ALLOC_CONF at the very top;
2) Call before eval:
   if torch.cuda.is_available():
      torch.cuda.empty_cache()
      torch.cuda.reset_peak_memory_stats()

"""

import os
# os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:256")

import sys, math, random, json
from pathlib import Path

# ----------------------------- Minimal deps ----------------------------------
def ensure_pkg(pkg: str, pip_name: str = None):
    try:
        __import__(pkg)
    except Exception:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pip_name or pkg])

# We do not pin torch/pyg here; your Colab already has 2.6.0/2.6.1.
ensure_pkg("ogb")
ensure_pkg("networkx", "networkx>=3.0")
ensure_pkg("pydrive", "PyDrive")

import torch
# ---- Torch 2.6 fix: default weights_only=True breaks OGB/PyG processed files.
#      Force weights_only=False for ALL torch.load calls unless explicitly set.
if "weights_only" in torch.load.__code__.co_varnames:
    _orig_load = torch.load
    def _load_compat(*args, **kwargs):
        kwargs.setdefault("weights_only", False)
        return _orig_load(*args, **kwargs)
    torch.load = _load_compat  # monkey-patch early

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

import networkx as nx
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

import torch_geometric as pyg
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import negative_sampling, to_networkx
from torch_geometric.data import Data
from torch.utils.data import DataLoader  # safer generic loader

# ------------------------------ Colab Drive ----------------------------------
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

# ------------------------------ Paths ----------------------------------------
BASE_DIR = "/content/drive/MyDrive/CS145/neurips/FINAL-CODE/"
ART_EMB_DIR = f"{BASE_DIR}/artifacts"
ART_DIR  = f"{BASE_DIR}/artifacts_seed7"
RES_DIR  = f"{BASE_DIR}/results_seed7"
os.makedirs(ART_DIR, exist_ok=True)
os.makedirs(RES_DIR, exist_ok=True)

EMB_PATH = f"{ART_EMB_DIR}/projected_embeddings_512.pt"  # [N,512]
SPD_PATH = f"{ART_DIR}/shortest_paths.pt"
EA_PATH  = f"{ART_DIR}/edge_attr.pt"

# ------------------------------ Device/seed ----------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("Torch:", torch.__version__, "CUDA:", torch.version.cuda if torch.cuda.is_available() else "CPU")
print("PyG:", pyg.__version__)

def set_all_seeds(seed=7):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_all_seeds(7)
os.environ.setdefault("PYTORCH_DISABLE_DYNAMO", "1")
os.environ.setdefault("TORCH_COMPILE_DISABLE", "1")

# ------------------------------ Dataset --------------------------------------
dataset = PygLinkPropPredDataset(name="ogbl-ddi", root='./dataset/')
data_obj = dataset[0]
split_edge = dataset.get_edge_split()
evaluator = Evaluator(name='ogbl-ddi')

num_nodes = data_obj.num_nodes
edge_index_raw = data_obj.edge_index.to(device)
print(f"Loaded raw graph: {num_nodes} nodes, {edge_index_raw.size(1)} edges")

# ---------------------------- Build Train/All graphs -------------------------
def to_undirected_coalesce(ei: torch.Tensor) -> torch.Tensor:
    # Make undirected and dedupe edges.
    ei_ud = torch.cat([ei, ei.flip(0)], dim=1)
    ei_ud = torch.unique(ei_ud.t(), dim=0).t().contiguous()
    return ei_ud

train_ei = split_edge['train']['edge'].to(device).t().contiguous()
valid_ei = split_edge['valid']['edge'].to(device).t().contiguous()
test_ei  = split_edge['test']['edge'].to(device).t().contiguous()

edge_index_train = to_undirected_coalesce(train_ei)
edge_index_all   = to_undirected_coalesce(torch.cat([train_ei, valid_ei, test_ei], dim=1))

print(f"edge_index_train: {edge_index_train.size(1)} undirected edges")
print(f"edge_index_all  : {edge_index_all.size(1)} undirected edges (for neg sampling only)")

# ---------------------- Shortest paths & edge attributes ---------------------
def get_spd_matrix(G: nx.Graph, anchors, max_spd=5):
    spd = np.zeros((G.number_of_nodes(), len(anchors)), dtype=np.float32)
    for i, a in enumerate(anchors):
        for node, L in nx.shortest_path_length(G, source=int(a)).items():
            spd[int(node), i] = min(L, max_spd)
    return spd

def compute_anchor_distances(num_nodes, edge_index, num_anchors=500, max_path_length=5, device='cpu'):
    """
    Compute SPD to random anchors on the TRAIN graph (avoid leakage).
    """
    np.random.seed(7)
    # Build a minimal PyG Data for the train graph
    d = Data(num_nodes=num_nodes, edge_index=edge_index)
    G = to_networkx(d, to_undirected=True)
    anchors = np.random.choice(G.number_of_nodes(), size=min(num_anchors, num_nodes), replace=False)
    spd = get_spd_matrix(G, anchors, max_spd=max_path_length)
    return torch.tensor(spd, dtype=torch.float32, device=device)  # [N, A]

def prepare_edge_attributes(shortest_paths_to_anchors, edge_index, num_samples=5):
    """
    Matches prior construction, but built on TRAIN edges:
      - SPD rows for endpoints -> mean -> [E, A]
      - For each of S samples, pick 200 anchors, mean -> [E, S]
      - Per-column min-max normalize to [0,1]
    """
    E = edge_index.size(1)
    N, A = shortest_paths_to_anchors.shape

    spa = shortest_paths_to_anchors[edge_index, :].mean(dim=0)  # [E, A]

    rng = np.random.default_rng(42)
    pick = min(200, A)
    masks = np.stack([rng.choice(A, size=pick, replace=False) for _ in range(num_samples)], axis=0)  # [S, pick]
    masks_t = torch.tensor(masks, device=spa.device, dtype=torch.long)

    ea = spa[:, masks_t].mean(dim=2)  # [E, S]
    a_max = ea.max(dim=0, keepdim=True).values
    a_min = ea.min(dim=0, keepdim=True).values
    ea = (ea - a_min) / (a_max - a_min + 1e-6)
    return ea

# Always recompute on TRAIN graph to avoid leakage and overwrite previous files.
print("Computing shortest_paths and edge_attr on TRAIN graph (overwriting any existing files to avoid leakage)...")
shortest_paths = compute_anchor_distances(num_nodes, edge_index_train, num_anchors=500, max_path_length=5, device=device)
edge_attr_full = prepare_edge_attributes(shortest_paths, edge_index_train, num_samples=5)
torch.save(shortest_paths, SPD_PATH)
torch.save(edge_attr_full, EA_PATH)
print(f"Saved shortest_paths -> {SPD_PATH}")
print(f"Saved edge_attr     -> {EA_PATH}")
print("edge_attr shape:", tuple(edge_attr_full.shape))  # [E_train, S]
assert edge_attr_full.dim() == 2

# ---------------------------- 512-dim embeddings -----------------------------
Z = torch.load(EMB_PATH, map_location=device).float()
assert Z.ndim == 2 and Z.shape[1] == 512 and Z.shape[0] == num_nodes, f"Expected [{num_nodes},512], got {Z.shape}"

emb = nn.Embedding.from_pretrained(Z, freeze=False).to(device)
E0  = nn.Embedding(num_nodes, Z.size(1)).to(device)
E0.weight.data.copy_(emb.weight.data)   # snapshot of initial features

# ----------------------------- Enh4SAGEConv ----------------------------------
from typing import Union, Tuple
from torch import Tensor
from torch.nn import Linear
from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size

class Enh4SAGEConv(MessagePassing):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, edge_attr_dim: int, normalize: bool = False,
                 root_weight: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'mean')
        super().__init__(**kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight
        self.edge_attr_dim = edge_attr_dim

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)

        self.lin_edge = Linear(edge_attr_dim, in_channels[0], bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()
        self.lin_edge.reset_parameters()

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
        out = self.lin_l(out)
        if self.root_weight and x[1] is not None:
            out = out + self.lin_r(x[1])
        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)
        return out

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        embedded_edge_attr = self.lin_edge(edge_attr)
        return F.relu(x_j + embedded_edge_attr)

    def __repr__(self):
        return f'{self.__class__.__name__}({self.in_channels}, {self.out_channels})'

# -------------------------- Enhanced SAGE Stack ------------------------------
class Enh4SAGEStack(nn.Module):
    """
    Residuals + LayerNorm + DropEdge + JK-Max around Enh4SAGEConv.
    """
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers, dropout, edge_attr_dim, edge_drop=0.05, jk="max"):
        super().__init__()
        assert num_layers >= 2
        self.edge_drop = edge_drop
        self.dropout   = dropout
        self.jk        = jk

        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.convs.append(Enh4SAGEConv(in_channels, hidden_channels, edge_attr_dim))
        self.norms.append(nn.LayerNorm(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(Enh4SAGEConv(hidden_channels, hidden_channels, edge_attr_dim))
            self.norms.append(nn.LayerNorm(hidden_channels))
        self.convs.append(Enh4SAGEConv(hidden_channels, hidden_channels, edge_attr_dim))
        self.norms.append(nn.LayerNorm(hidden_channels))

        self.proj = nn.Identity() if hidden_channels == out_channels else nn.Linear(hidden_channels, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        for c in self.convs:
            c.reset_parameters()
        for n in self.norms:
            if hasattr(n, "reset_parameters"):
                n.reset_parameters()
        if isinstance(self.proj, nn.Linear):
            nn.init.xavier_uniform_(self.proj.weight); nn.init.zeros_(self.proj.bias)

    def forward(self, x, edge_index, edge_attr, *_):
        xs = []
        E = edge_index.size(1)
        for i, conv in enumerate(self.convs):
            if self.training and self.edge_drop > 0:
                mask = torch.rand(E, device=edge_index.device) > self.edge_drop
                ei = edge_index[:, mask]
                ea = edge_attr[mask]
            else:
                ei, ea = edge_index, edge_attr
            h = conv(x, ei, ea)
            h = self.norms[i](h)
            h = F.relu(h)
            if h.shape == x.shape:
                h = h + x
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = h
            xs.append(h)
        if self.jk == "max":
            h_out = torch.stack(xs, dim=0).max(dim=0).values
        else:
            h_out = xs[-1]
        return self.proj(h_out)

# ------------------------------ Predictor ------------------------------------
class LinkPredictor(nn.Module):
    """Returns logits."""
    def __init__(self, in_channels, hidden_channels=512, num_layers=3, dropout=0.2):
        super().__init__()
        f_in = 4 * in_channels  # [hi*hj, |hi-hj|, hi, hj]
        layers = []
        dim = f_in
        for _ in range(num_layers - 1):
            layers += [nn.Linear(dim, hidden_channels), nn.ReLU(), nn.Dropout(dropout)]
            dim = hidden_channels
        layers += [nn.Linear(dim, 1)]
        self.mlp = nn.Sequential(*layers)

    def forward(self, hi, hj):
        x = torch.cat([hi * hj, torch.abs(hi - hj), hi, hj], dim=-1)
        return self.mlp(x).view(-1)  # logits

# ------------------------- Eval (Hits@K via OGB) -----------------------------
@torch.no_grad()
def evaluate(model, predictor, edge_index_msg, edge_attr_msg, x, batch_size, split_edge, evaluator):
    model.eval(); predictor.eval()

    use_amp = (x.is_cuda)  # or reuse H["use_amp"]
    with torch.amp.autocast("cuda", enabled=use_amp):
      h = model(x, edge_index_msg, edge_attr_msg, None)

    def score_pairs(edge_pairs):
        out = []
        for perm in DataLoader(range(edge_pairs.size(0)), batch_size=batch_size, shuffle=False):
            e = edge_pairs[perm].t()
            with torch.amp.autocast("cuda", enabled=use_amp):
              out.append(torch.sigmoid(predictor(h[e[0]], h[e[1]])).cpu())
        return torch.cat(out, dim=0)  # 1D

    def score_maybe_grouped(neg):
        # neg: either [M,2] or [G, R, 2]
        if neg.dim() == 2:
            return score_pairs(neg)                         # [M]
        else:
            G, R, _ = neg.shape
            neg = neg.view(G*R, 2)
            s = score_pairs(neg).view(G, R)                # [G, R]
            return s

    pos_valid = split_edge['valid']['edge'].to(h.device)
    pos_test  = split_edge['test']['edge'].to(h.device)
    neg_valid = split_edge['valid']['edge_neg'].to(h.device)
    neg_test  = split_edge['test']['edge_neg'].to(h.device)

    pos_valid_pred = score_pairs(pos_valid)
    pos_test_pred  = score_pairs(pos_test)
    neg_valid_pred = score_maybe_grouped(neg_valid)
    neg_test_pred  = score_maybe_grouped(neg_test)

    results = {}
    for K in [20, 50, 100]:
        evaluator.K = K
        v = evaluator.eval({'y_pred_pos': pos_valid_pred, 'y_pred_neg': neg_valid_pred})[f'hits@{K}']
        t = evaluator.eval({'y_pred_pos': pos_test_pred,  'y_pred_neg': neg_test_pred})[f'hits@{K}']
        results[f'Hits@{K}'] = (v, t)
    return results

# --------------------------- Training primitives -----------------------------
def bpr_loss(pos_logits: torch.Tensor, neg_logits: torch.Tensor):
    """
    Pairwise ranking loss. We align negatives into shape [B, R] (truncate/pad if needed).
    """
    B = pos_logits.numel()
    N = neg_logits.numel()
    if B == 0 or N == 0:
        return torch.tensor(0.0, device=pos_logits.device)

    R = max(1, N // B)
    needed = B * R

    if N >= needed:
        neg_use = neg_logits[:needed]
    else:
        reps = needed - N
        pad = neg_logits[-1:].repeat(reps)
        neg_use = torch.cat([neg_logits, pad], dim=0)

    neg_mat = neg_use.view(B, R)                 # [B, R]
    pos_mat = pos_logits.view(B, 1)              # [B, 1]
    return F.softplus(-(pos_mat - neg_mat)).mean()

def maybe_negative_sampling(edge_index, num_nodes, num_neg):
    # Prefer 'sparse'; fallback to 'dense' if needed for compatibility.
    try:
        return negative_sampling(edge_index=edge_index, num_nodes=num_nodes,
                                 num_neg_samples=num_neg, method='sparse')
    except Exception:
        return negative_sampling(edge_index=edge_index, num_nodes=num_nodes,
                                 num_neg_samples=num_neg, method='dense')

def train_one_epoch(model, predictor, x, E0, edge_index_msg, edge_attr_msg, edge_index_for_negs,
                    pos_train_edges, optimizer, scaler, batch_size,
                    neg_ratio=2, lam=5e-4, hard_frac=0.5, hard_mul=1.5, use_amp=True):
    model.train(); predictor.train()
    total_loss, total_examples = 0.0, 0
    bce = nn.BCEWithLogitsLoss()

    n = pos_train_edges.size(0)
    for perm in DataLoader(range(n), batch_size=batch_size, shuffle=True):
        idx = torch.as_tensor(perm, dtype=torch.long, device=device)
        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=(use_amp and device.type=='cuda')):
            h = model(x, edge_index_msg, edge_attr_msg, None)

            pos_edge = pos_train_edges[idx].t()     # [2, B]
            pos_logits = predictor(h[pos_edge[0]], h[pos_edge[1]])  # [B]

            neg_samples = maybe_negative_sampling(
                edge_index=edge_index_for_negs,     # <-- ALL positives, to avoid false negatives
                num_nodes=x.size(0),
                num_neg=pos_logits.numel() * neg_ratio
            )
            neg_logits = predictor(h[neg_samples[0]], h[neg_samples[1]])  # [B*R]

            # Hard-negative mixing
            if hard_frac > 0.0 and neg_logits.numel() > 0:
                k = max(1, int(hard_frac * neg_logits.numel()))
                hard_vals, _ = torch.topk(neg_logits, k=k, largest=True, sorted=False)
                extra = hard_vals.repeat_interleave(int(math.ceil(hard_mul)))
                neg_logits_eff = torch.cat([neg_logits, extra], dim=0)
            else:
                neg_logits_eff = neg_logits

            loss_rank = bpr_loss(pos_logits, neg_logits_eff)

            labels = torch.cat([torch.ones_like(pos_logits), torch.zeros_like(neg_logits)], dim=0)
            logits = torch.cat([pos_logits, neg_logits], dim=0)
            loss_bce = bce(logits, labels)

            touched = torch.unique(torch.cat([pos_edge.reshape(-1), neg_samples.reshape(-1)], dim=0))
            prior = (E0(touched) if hasattr(E0, "__call__") else E0[touched]).detach()
            prox = lam * (x[touched] - prior).pow(2).mean()

            loss = loss_rank + 0.15 * loss_bce + prox

        if scaler is not None and device.type == "cuda" and use_amp:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)
            scaler.step(optimizer); scaler.update()
        else:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)
            optimizer.step()

        total_loss += loss.item() * pos_logits.size(0)
        total_examples += pos_logits.size(0)

    return total_loss / max(1, total_examples)

# ------------------------------- Hyperparams ---------------------------------
H = {
    "epochs":          500,
    "hidden_dim":      512,
    "dropout":         0.2,      # slightly lower for DDI
    "num_layers":      3,
    "lr_main":         1e-3,
    "lr_emb":          5e-4,
    "weight_decay":    0.01,
    "batch_size":      64 * 1024,
    "neg_ratio":       2,        # fewer, cleaner negatives
    "lam_prox":        5e-4,
    "edge_drop":       0.05,     # gentler drop edge
    "use_amp":         True,
    "warmup_epochs":   10,
    "patience":        40,       # early stopping by Val@20
}

# ----------------------------- Build + Optimizer -----------------------------
edge_attr_full = edge_attr_full.to(device)

model = Enh4SAGEStack(
    in_channels=emb.embedding_dim,
    hidden_channels=H["hidden_dim"],
    out_channels=H["hidden_dim"],
    num_layers=H["num_layers"],
    dropout=H["dropout"],
    edge_attr_dim=edge_attr_full.size(1),
    edge_drop=H["edge_drop"],
    jk="max"
).to(device)

predictor = LinkPredictor(
    in_channels=H["hidden_dim"],
    hidden_channels=H["hidden_dim"],
    num_layers=3,
    dropout=H["dropout"]
).to(device)

def count_params(m): return sum(p.numel() for p in m.parameters())
print("Parameters:")
print("  GNN       :", count_params(model))
print("  Predictor :", count_params(predictor))
print("  Embedding :", count_params(emb))

param_groups = [
    {"params": model.parameters()},
    {"params": predictor.parameters()},
    {"params": emb.parameters(), "lr": H["lr_emb"]},
]
optimizer = torch.optim.AdamW(param_groups, lr=H["lr_main"], weight_decay=H["weight_decay"])

def lr_lambda_warmup(epoch):
    if epoch < H["warmup_epochs"]:
        return float(epoch + 1) / float(max(1, H["warmup_epochs"]))
    return 1.0

warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_warmup)
# Cosine decay (no warm restarts)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=max(1, H["epochs"] - H["warmup_epochs"])
)

scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and H["use_amp"]))

# --------------------------------- Train ------------------------------------
pos_train_edges = split_edge['train']['edge'].to(device)

best = {"val": 0.0, "test": 0.0, "epoch": -1}
best_test = {"val": 0.0, "test": 0.0, "epoch": -1}
train_loss_hist, val_hist, test_hist = [], [], []
epochs_no_improve = 0

# Optional: quick sanity check that false-negative rate is ~0 with ALL graph
with torch.no_grad():
    # Sample a bunch of negatives against ALL positives (should avoid true edges)
    neg_tmp = maybe_negative_sampling(edge_index_all, num_nodes, 200000).t()
    # Build a set of ALL positives for membership check
    all_pos = torch.unique(edge_index_all.t(), dim=0)
    # Check overlap in both directions (undirected)
    pos_set = { (int(a), int(b)) for a,b in all_pos.tolist() }
    pos_set |= { (b, a) for a,b in all_pos.tolist() }
    leak = sum((int(a), int(b)) in pos_set for a,b in neg_tmp.tolist()) / max(1, neg_tmp.size(0))
    print(f"[Sanity] Approx false-negative rate vs ALL positives: {leak:.6f}")

# Freeze node embeddings for warmup (stabilize)
emb.requires_grad_(False)

for epoch in range(1, H["epochs"] + 1):
    x_feats = emb.weight  # trainable features (grad may be disabled initially)

    loss = train_one_epoch(
        model, predictor, x_feats, E0,
        edge_index_msg=edge_index_train, edge_attr_msg=edge_attr_full,
        edge_index_for_negs=edge_index_all,
        pos_train_edges=pos_train_edges,
        optimizer=optimizer, scaler=scaler,
        batch_size=H["batch_size"], neg_ratio=H["neg_ratio"], lam=H["lam_prox"],
        hard_frac=0.5, hard_mul=1.5, use_amp=H["use_amp"]
    )
    train_loss_hist.append(loss)

    # Unfreeze embeddings after warmup phase
    if epoch == H["warmup_epochs"] + 1:
        emb.requires_grad_(True)

    if epoch <= H["warmup_epochs"]:
        warmup.step()
    else:
        cosine.step()

    if torch.cuda.is_available():
      torch.cuda.empty_cache()
      torch.cuda.reset_peak_memory_stats()

    results = evaluate(
        model, predictor,
        edge_index_msg=edge_index_train, edge_attr_msg=edge_attr_full, x=x_feats,
        batch_size=H["batch_size"], split_edge=split_edge, evaluator=evaluator
    )
    val20, test20 = results["Hits@20"]
    val_hist.append(val20); test_hist.append(test20)

    lr_str = ", ".join([f"{pg['lr']:.6f}" for pg in optimizer.param_groups[:2]])
    print(f"Epoch {epoch:03d} | loss {loss:.5f} | lr [{lr_str}] | Val@20 {val20:.4f}  Test@20 {test20:.4f}")

    improved = val20 > best["val"]
    improved_test = test20 > best_test["test"]
    if improved:
        best.update({"val": val20, "test": test20, "epoch": epoch})
        print(f"Best: epoch {best['epoch']} | Val@20 {best['val']:.4f} | Test@20 {best['test']:.4f}")
        # Save best checkpoints
        torch.save(model.state_dict(), f"{ART_DIR}/best_model.pt")
        torch.save(predictor.state_dict(), f"{ART_DIR}/best_predictor.pt")
        torch.save(emb.state_dict(), f"{ART_DIR}/best_emb.pt")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if improved_test:
        best_test.update({"val": val20, "test": test20, "epoch": epoch})
        print(f"Best Test: epoch {best_test['epoch']} | Val@20 {best_test['val']:.4f} | Test@20 {best_test['test']:.4f}")

    if epochs_no_improve >= H["patience"]:
        print(f"Early stopping at epoch {epoch} (no improvement for {H['patience']} epochs).")
        break



print(f"\nBest: epoch {best['epoch']} | Val@20 {best['val']:.4f} | Test@20 {best['test']:.4f}")
print(f"\nBest Test: epoch {best_test['epoch']} | Val@20 {best_test['val']:.4f} | Test@20 {best_test['test']:.4f}")

# --------------------------------- Plot -------------------------------------
plt.figure(figsize=(10,6))
plt.title('Link Prediction on OGB-ddi with Enh4SAGEConv (Torch 2.6 fixed)')
plt.plot(train_loss_hist, label='train loss')
plt.plot(val_hist, label='Hits@20 val')
plt.plot(test_hist, label='Hits@20 test')
plt.xlabel('Epoch'); plt.ylabel('Metric')
plt.grid(True); plt.legend()
plot_path = f"{RES_DIR}/{num_nodes}_Enh4Sage_T26_plot.png"
plt.savefig(plot_path); plt.close()
print(f"Plot saved to {plot_path}")


Mounted at /content/drive
Device: cuda
Torch: 2.4.0+cu121 CUDA: 12.1
PyG: 2.5.3
Loaded raw graph: 4267 nodes, 2135822 edges
edge_index_train: 2135822 undirected edges
edge_index_all  : 2669778 undirected edges (for neg sampling only)
Computing shortest_paths and edge_attr on TRAIN graph (overwriting any existing files to avoid leakage)...
Saved shortest_paths -> /content/drive/MyDrive/CS145/Week11-GCN_SciBert_Ini//artifacts_seed7/shortest_paths.pt
Saved edge_attr     -> /content/drive/MyDrive/CS145/Week11-GCN_SciBert_Ini//artifacts_seed7/edge_attr.pt
edge_attr shape: (2135822, 5)
Parameters:
  GNN       : 1585152
  Predictor : 1312257
  Embedding : 2184704


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and H["use_amp"]))


[Sanity] Approx false-negative rate vs ALL positives: 0.000000


  with torch.cuda.amp.autocast(enabled=(use_amp and device.type=='cuda')):


Epoch 001 | loss 0.58087 | lr [0.000200, 0.000200] | Val@20 0.0516  Test@20 0.0810
Best: epoch 1 | Val@20 0.0516 | Test@20 0.0810
Best Test: epoch 1 | Val@20 0.0516 | Test@20 0.0810
Epoch 002 | loss 0.26724 | lr [0.000300, 0.000300] | Val@20 0.4038  Test@20 0.3853
Best: epoch 2 | Val@20 0.4038 | Test@20 0.3853
Best Test: epoch 2 | Val@20 0.4038 | Test@20 0.3853
Epoch 003 | loss 0.10326 | lr [0.000400, 0.000400] | Val@20 0.5010  Test@20 0.5108
Best: epoch 3 | Val@20 0.5010 | Test@20 0.5108
Best Test: epoch 3 | Val@20 0.5010 | Test@20 0.5108
Epoch 004 | loss 0.06011 | lr [0.000500, 0.000500] | Val@20 0.5624  Test@20 0.5407
Best: epoch 4 | Val@20 0.5624 | Test@20 0.5407
Best Test: epoch 4 | Val@20 0.5624 | Test@20 0.5407
Epoch 005 | loss 0.04344 | lr [0.000600, 0.000600] | Val@20 0.5489  Test@20 0.6070
Best Test: epoch 5 | Val@20 0.5489 | Test@20 0.6070
Epoch 006 | loss 0.03491 | lr [0.000700, 0.000700] | Val@20 0.5870  Test@20 0.6347
Best: epoch 6 | Val@20 0.5870 | Test@20 0.6347
Best Te