# PrimeKG
> ```bash
> wget -O ./data/kg.csv https://dataverse.harvard.edu/api/access/datafile/6180620
> ```
- Goal: Given drug node and disease node, score whether an indication edge exists.
- Dataset:
  - 20 biomedical sources
  - 100k+ nodes, 4M+ edges, 29 edge types
  - drug–disease indications/contraindications/off-label edges
  - fetch kg.csv from Harvard Dataverse

In [5]:
# Libraries: torch torch-geometric scipy sklearn gdl

import os
import math
import random
import pandas as pd
import numpy as np
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import roc_auc_score, average_precision_score

from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv
from torch_geometric.transforms import ToUndirected, RandomLinkSplit
from torch_geometric.loader import LinkNeighborLoader

# ---- Reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---- Configuration
CFG = {
    "KG_CSV": "./data/kg.csv",                 # <-- path to PrimeKG kg.csv
    "EMBED_DIM": 128,
    "HIDDEN": 128,
    "OUT_DIM": 128,
    "LR": 1e-3,
    "WEIGHT_DECAY": 1e-4,
    "EPOCHS": 50,
    "PATIENCE": 8,                      # early stopping on val AP
    "SCHED_STEP": 20,                   # step LR after this many epochs
    "SCHED_GAMMA": 0.5,
    "BATCH_SIZE": 4096,                 # for mini-batch loader
    "NUM_NEIGHBORS": [15, 10],          # neighbor sampling fanouts
    "NEG_SAMPLING_RATIO": 1.0,          # negatives per positive (mini-batch)
}

# our supervision target edge type:
TARGET_ET = ('drug','indication','disease')
REVERSE_ET = ('disease','rev_indication','drug')


## Load PrimeKG & Build Hetero Graph

In [6]:
# ---- Load edges
kg = pd.read_csv(CFG["KG_CSV"], low_memory=False)

# normalize columns for robust matching
for col in ["x_type","y_type","relation","display_relation"]:
    if col in kg.columns:
        kg[col] = kg[col].astype(str).str.lower()

# --- Select positive "indication" edges --- #
pos = kg[
    (kg['x_type']=='drug') &
    (kg['y_type']=='disease') &
    (kg['display_relation'].str.contains('indication', na=False))
][['x_id','y_id']].drop_duplicates()

# Build initial node sets from positives
drug_ids = pd.Index(sorted(set(pos['x_id'])))
disease_ids = pd.Index(sorted(set(pos['y_id'])))

# Optionally enrich message passing with extra relations (if available):
EXTRA_RELATIONS = [
    # (src_type, display_relation_substring, dst_type, canonical_edge_type_name)
    ('drug',   None,       'protein', 'acts_on'),         # accept all drug-protein relations
    ('gene',   None,       'disease', 'gene_assoc'),      # accept all gene-disease relations
    # add more if desired (e.g., pathway membership)
]

# Build per-type ID maps (we will extend when adding extra relations)
type2ids = {
    'drug': set(drug_ids),
    'disease': set(disease_ids),
    'protein': set(),
    'gene': set(),
}
edges_by_type = defaultdict(list)  # maps (stype, etype, dtype) -> list of (src, dst)

# Add indication edges:
for s,d in pos.itertuples(index=False):
    type2ids['drug'].add(s)
    type2ids['disease'].add(d)
    edges_by_type[('drug','indication','disease')].append((s,d))

# Add extra relations if present
for stype, rel_substr, dtype, canonical in EXTRA_RELATIONS:
    sub = kg[(kg['x_type']==stype) & (kg['y_type']==dtype)].copy()
    if rel_substr is not None:
        sub = sub[sub['display_relation'].str.contains(rel_substr, na=False)]
    # if no filter, we include all relations between stype and dtype
    if len(sub)==0:
        continue
    # Track IDs and edges
    for s,d in sub[['x_id','y_id']].drop_duplicates().itertuples(index=False):
        type2ids[stype].add(s)
        type2ids[dtype].add(d)
        edges_by_type[(stype, canonical, dtype)].append((s,d))

# Freeze ID maps to list->index mappings
ntype2index = {}
for ntype, idset in type2ids.items():
    ids = sorted(idset)
    ntype2index[ntype] = {k:i for i,k in enumerate(ids)}
    type2ids[ntype] = pd.Index(ids)  # store as index for ordering

# Build HeteroData
data = HeteroData()
for ntype, ids in type2ids.items():
    data[ntype].num_nodes = len(ids)

# Fill edges
def build_edge_index(pairs, src_map, dst_map):
    src = torch.tensor([src_map[s] for s,_ in pairs], dtype=torch.long)
    dst = torch.tensor([dst_map[d] for _,d in pairs], dtype=torch.long)
    return torch.stack([src, dst], dim=0)

for (stype, etype, dtype), pairs in edges_by_type.items():
    if len(pairs)==0: 
        continue
    ei = build_edge_index(pairs, ntype2index[stype], ntype2index[dtype])
    data[(stype, etype, dtype)].edge_index = ei

# Add reverse edges so messages flow both ways
data = ToUndirected(merge=False)(data)

print("Node counts:", {nt: data[nt].num_nodes for nt in data.node_types})
print("Edge types:", data.edge_types)

Node counts: {'drug': 2068, 'disease': 1937, 'protein': 0, 'gene': 0}
Edge types: [('drug', 'indication', 'disease'), ('disease', 'rev_indication', 'drug')]


## Features

In [9]:
class NodeEmbeddings(nn.Module):
    """
    Featureless baseline: one learnable embedding per node.
    You can later replace per-type entries with real features (same dim) if available.
    """
    def __init__(self, data: HeteroData, embed_dim: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.emb = nn.ModuleDict()
        for ntype in data.node_types:
            num_nodes = data[ntype].num_nodes
            self.emb[ntype] = nn.Embedding(num_nodes, embed_dim)
            nn.init.xavier_uniform_(self.emb[ntype].weight)

    def forward_full(self):
        # Used in full-batch training (no loaders)
        return {ntype: self.emb[ntype].weight for ntype in self.emb.keys()}

    def forward_on_batch(self, batch):
        # Used with LinkNeighborLoader; select by local->global mapping (n_id)
        x = {}
        for ntype in batch.node_types:
            n_id = batch[ntype].n_id  # global ids of nodes inside the batch
            x[ntype] = self.emb[ntype].weight[n_id]
        return x

## Train, Validation, Test Split

In [10]:
splitter = RandomLinkSplit(
    num_val=0.1, num_test=0.1,
    edge_types=[TARGET_ET],
    rev_edge_types=[REVERSE_ET],
    add_negative_train_samples=True,  # negatives for train (full-batch path)
    is_undirected=True
)
train_data, val_data, test_data = splitter(data)

# Move to device for full-batch path; mini-batch loaders handle device per-batch.
train_data = train_data.to(device)
val_data   = val_data.to(device)
test_data  = test_data.to(device)

## Encoder & Heads

In [11]:
class HeteroSAGE(nn.Module):
    def __init__(self, hidden: int, out_dim: int, edge_types, dropout: float = 0.2, use_bn: bool = False):
        super().__init__()
        self.dropout = dropout
        self.use_bn = use_bn

        self.conv1 = HeteroConv(
            {et: SAGEConv((-1, -1), hidden) for et in edge_types},
            aggr='sum'
        )
        self.conv2 = HeteroConv(
            {et: SAGEConv((-1, -1), out_dim) for et in edge_types},
            aggr='sum'
        )
        if use_bn:
            self.bn1 = nn.ModuleDict({nt: nn.BatchNorm1d(hidden) for nt in data.node_types})
            self.bn2 = nn.ModuleDict({nt: nn.BatchNorm1d(out_dim) for nt in data.node_types})

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        if self.use_bn:
            x_dict = {k: self.bn1[k](v) for k,v in x_dict.items()}
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        x_dict = {k: F.dropout(v, p=self.dropout, training=self.training) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        if self.use_bn:
            x_dict = {k: self.bn2[k](v) for k,v in x_dict.items()}
        return x_dict  # dict: ntype -> [num_nodes, out_dim]

class EdgePredictor(nn.Module):
    """Concatenate embeddings of (u,v) then MLP -> logits."""
    def __init__(self, dim: int):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.ReLU(),
            nn.Linear(dim, 1)
        )

    def forward(self, z_src, z_dst):
        return self.mlp(torch.cat([z_src, z_dst], dim=-1)).view(-1)

# Optional node-level heads (if you add node-level tasks later)
class NodeClassifier(nn.Module):
    def __init__(self, dim, num_classes):
        super().__init__()
        self.head = nn.Linear(dim, num_classes)
    def forward(self, z):  # z: [N, dim]
        return self.head(z)

class NodeRegressor(nn.Module):
    def __init__(self, dim, out_dim=1):
        super().__init__()
        self.head = nn.Linear(dim, out_dim)
    def forward(self, z):  # z: [N, dim]
        return self.head(z)

## Full‑Batch Link Prediction Training

In [12]:
# Instantiate modules
embeds   = NodeEmbeddings(train_data, CFG["EMBED_DIM"]).to(device)
encoder  = HeteroSAGE(hidden=CFG["HIDDEN"], out_dim=CFG["OUT_DIM"],
                      edge_types=list(train_data.edge_types), dropout=0.2, use_bn=False).to(device)
predictor= EdgePredictor(dim=CFG["OUT_DIM"]).to(device)

params = list(embeds.parameters()) + list(encoder.parameters()) + list(predictor.parameters())
opt = torch.optim.Adam(params, lr=CFG["LR"], weight_decay=CFG["WEIGHT_DECAY"])
sched = torch.optim.lr_scheduler.StepLR(opt, step_size=CFG["SCHED_STEP"], gamma=CFG["SCHED_GAMMA"])
loss_fn = nn.BCEWithLogitsLoss()

def _get_edge_batch(split_data, et):
    ei = split_data[et].edge_label_index
    y  = split_data[et].edge_label.float()
    return ei, y

def forward_once(split_data):
    x_dict = embeds.forward_full()
    z_dict = encoder(x_dict, split_data.edge_index_dict)
    ei, y  = _get_edge_batch(split_data, TARGET_ET)
    src, dst = ei[0], ei[1]
    z_src = z_dict[TARGET_ET[0]][src]   # 'drug'
    z_dst = z_dict[TARGET_ET[2]][dst]   # 'disease'
    logits = predictor(z_src, z_dst)
    return logits, y

def evaluate(split_data):
    encoder.eval(); embeds.eval(); predictor.eval()
    with torch.no_grad():
        logits, y = forward_once(split_data)
        prob = torch.sigmoid(logits).detach().cpu().numpy()
        y_np = y.detach().cpu().numpy()
        return roc_auc_score(y_np, prob), average_precision_score(y_np, prob)

best_val_ap, patience, best = -1.0, CFG["PATIENCE"], None

# ---- Training loop (full-batch)
for epoch in range(1, CFG["EPOCHS"] + 1):
    encoder.train(); embeds.train(); predictor.train()
    opt.zero_grad()

    logits, y = forward_once(train_data)
    loss = loss_fn(logits, y)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(params, 5.0)
    opt.step(); sched.step()

    # Evaluate
    val_auroc, val_ap = evaluate(val_data)
    if val_ap > best_val_ap:
        best_val_ap = val_ap
        patience = CFG["PATIENCE"]
        best = {
            'encoder': encoder.state_dict(),
            'predictor': predictor.state_dict(),
        }
    else:
        patience -= 1

    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | loss {loss.item():.4f} | val AUROC {val_auroc:.3f} | val AP {val_ap:.3f} | lr {sched.get_last_lr()[0]:.1e}")

    if patience == 0:
        print("Early stopping triggered.")
        break

# Test with best
if best is not None:
    encoder.load_state_dict(best['encoder'])
    predictor.load_state_dict(best['predictor'])

test_auroc, test_ap = evaluate(test_data)
print(f"TEST AUROC {test_auroc:.3f} | AP {test_ap:.3f}")

# Export a sanity check of embedding shapes
with torch.no_grad():
    z_dict = encoder(embeds.forward_full(), train_data.edge_index_dict)
    print({k: v.shape for k,v in z_dict.items()})  # e.g., {'drug': [N_d, 128], 'disease': [N_dis, 128], ...}


Epoch 001 | loss 0.6932 | val AUROC 0.606 | val AP 0.555 | lr 1.0e-03
Epoch 005 | loss 0.6919 | val AUROC 0.769 | val AP 0.756 | lr 1.0e-03
Epoch 010 | loss 0.6878 | val AUROC 0.837 | val AP 0.839 | lr 1.0e-03
Epoch 015 | loss 0.6666 | val AUROC 0.851 | val AP 0.855 | lr 1.0e-03
Epoch 020 | loss 0.5964 | val AUROC 0.854 | val AP 0.855 | lr 5.0e-04
Epoch 025 | loss 0.5276 | val AUROC 0.858 | val AP 0.855 | lr 5.0e-04
Early stopping triggered.
TEST AUROC 0.861 | AP 0.861
{'disease': torch.Size([1937, 128]), 'drug': torch.Size([2068, 128])}


## Mini‑Batch Link Prediction

In [17]:
# ---- Build loaders only if you want mini-batch path
# We create loaders for train/val/test on the TARGET_ET.
# For train, we let the loader create negatives (neg_sampling_ratio). For val/test, do the same for consistent metrics.

def make_loader(split_data, batch_size, num_neighbors, shuffle, neg_sampling_ratio):
    return LinkNeighborLoader(
        data=split_data,
        num_neighbors=num_neighbors,
        batch_size=batch_size,
        shuffle=shuffle,
        edge_label_index=(TARGET_ET, split_data[TARGET_ET].edge_label_index),
        edge_label=split_data[TARGET_ET].edge_label if split_data is not train_data else None,
        neg_sampling_ratio=neg_sampling_ratio
    )

train_loader = make_loader(train_data.cpu(), CFG["BATCH_SIZE"], CFG["NUM_NEIGHBORS"], True,  CFG["NEG_SAMPLING_RATIO"])
val_loader   = make_loader(val_data.cpu(),   CFG["BATCH_SIZE"], CFG["NUM_NEIGHBORS"], False, CFG["NEG_SAMPLING_RATIO"])
test_loader  = make_loader(test_data.cpu(),  CFG["BATCH_SIZE"], CFG["NUM_NEIGHBORS"], False, CFG["NEG_SAMPLING_RATIO"])

# Fresh models for mini-batch experiment (to compare apples-to-apples)
embeds_mb   = NodeEmbeddings(data, CFG["EMBED_DIM"]).to(device)
encoder_mb  = HeteroSAGE(CFG["HIDDEN"], CFG["OUT_DIM"], list(data.edge_types), dropout=0.2, use_bn=False).to(device)
predictor_mb= EdgePredictor(CFG["OUT_DIM"]).to(device)

params_mb = list(embeds_mb.parameters()) + list(encoder_mb.parameters()) + list(predictor_mb.parameters())
opt_mb = torch.optim.Adam(params_mb, lr=CFG["LR"], weight_decay=CFG["WEIGHT_DECAY"])
sched_mb = torch.optim.lr_scheduler.StepLR(opt_mb, step_size=CFG["SCHED_STEP"], gamma=CFG["SCHED_GAMMA"])
loss_fn = nn.BCEWithLogitsLoss()

def run_epoch_loader(loader, train_mode=True):
    if train_mode:
        encoder_mb.train(); embeds_mb.train(); predictor_mb.train()
    else:
        encoder_mb.eval(); embeds_mb.eval(); predictor_mb.eval()

    all_probs, all_labels = [], []
    total_loss = 0.0
    for batch in loader:
        batch = batch.to(device)
        if train_mode:
            opt_mb.zero_grad()

        # Build mini-batch features from global embeddings via n_id
        x_dict = embeds_mb.forward_on_batch(batch)
        z_dict = encoder_mb(x_dict, batch.edge_index_dict)

        # edge labels are on the batch object (pos + sampled negs)
        ei = batch[TARGET_ET].edge_label_index
        y  = batch[TARGET_ET].edge_label.float()

        src, dst = ei[0], ei[1]
        z_src = z_dict[TARGET_ET[0]][src]
        z_dst = z_dict[TARGET_ET[2]][dst]
        logits = predictor_mb(z_src, z_dst)
        loss = loss_fn(logits, y)

        if train_mode:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params_mb, 5.0)
            opt_mb.step()

        total_loss += loss.item() * y.numel()
        prob = torch.sigmoid(logits).detach().cpu().numpy()
        lab  = y.detach().cpu().numpy()
        all_probs.append(prob); all_labels.append(lab)

    if not train_mode:
        # compute metrics across all mini-batches
        probs = np.concatenate(all_probs, axis=0)
        labs  = np.concatenate(all_labels, axis=0)
        auroc = roc_auc_score(labs, probs)
        ap    = average_precision_score(labs, probs)
    else:
        auroc = ap = float('nan')
    mean_loss = total_loss / max(1, sum(len(a) for a in all_labels))
    return mean_loss, auroc, ap

best_val_ap_mb, patience_mb, best_mb = -1.0, CFG["PATIENCE"], None

for epoch in range(1, CFG["EPOCHS"] + 1):
    train_loss, _, _ = run_epoch_loader(train_loader, train_mode=True)
    val_loss, val_auroc, val_ap = run_epoch_loader(val_loader, train_mode=False)
    sched_mb.step()

    if val_ap > best_val_ap_mb:
        best_val_ap_mb = val_ap
        patience_mb = CFG["PATIENCE"]
        best_mb = {
            'encoder': encoder_mb.state_dict(),
            'predictor': predictor_mb.state_dict(),
            'embeds': embeds_mb.state_dict(),  # keep if you want to resume MB training later
        }
    else:
        patience_mb -= 1

    if epoch % 5 == 0 or epoch == 1:
        print(f"[MB] Epoch {epoch:03d} | train loss {train_loss:.4f} | val AUROC {val_auroc:.3f} | val AP {val_ap:.3f} | lr {sched_mb.get_last_lr()[0]:.1e}")

    if patience_mb == 0:
        print("[MB] Early stopping.")
        break

# Test with best
if best_mb is not None:
    encoder_mb.load_state_dict(best_mb['encoder'])
    predictor_mb.load_state_dict(best_mb['predictor'])

test_loss, test_auroc_mb, test_ap_mb = run_epoch_loader(test_loader, train_mode=False)
print(f"[MB] TEST AUROC {test_auroc_mb:.3f} | AP {test_ap_mb:.3f}")


ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'

## Save & Load for Transfer Learning

In [None]:
# ---- Save (full-batch or mini-batch encoder; use the one you prefer)
torch.save(encoder.state_dict(), "./data/primekg_hetero_encoder.pt")

# ---- Load elsewhere
# enc_transfer = HeteroSAGE(
#     hidden=CFG["HIDDEN"], out_dim=CFG["OUT_DIM"],
#     edge_types=list(data.edge_types), dropout=0.2, use_bn=False
# ).to(device)
# enc_transfer.load_state_dict(torch.load("./data/primekg_hetero_encoder.pt", map_location=device))
# enc_transfer.eval()

# Example: export embeddings now (from current graph)
# with torch.no_grad():
#     z_dict = enc_transfer(embeds.forward_full(), data.to(device).edge_index_dict)
#     print({k: v.shape for k,v in z_dict.items()})


## Export Node Embeddings to CSV

In [None]:
drug_index = type2ids.get('drug', pd.Index([]))
disease_index = type2ids.get('disease', pd.Index([]))

with torch.no_grad():
    z_dict = encoder(embeds.forward_full(), data.edge_index_dict)

def save_embeddings(ntype, index, z):
    if len(index) == 0: return
    arr = z.detach().cpu().numpy()
    cols = [f"z{i}" for i in range(arr.shape[1])]
    df = pd.DataFrame(arr, index=list(index), columns=cols)
    df.to_csv(f"emb_{ntype}_{arr.shape[1]}.csv")

save_embeddings('drug', drug_index, z_dict['drug'])
save_embeddings('disease', disease_index, z_dict['disease'])

## Temp