In [1]:
!pip install torch



In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.3/1.3 MB[0m [31m72.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successf

In [3]:
"""
Enhanced R-GCN Link Prediction Pipeline
Features:
- Comprehensive training metrics (Loss, AUC, AP, Hits@K, MRR)
- Top-10 disease predictions per compound with positive/negative edge labels
- Clean formatted output
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv
from torch_geometric.data import HeteroData
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score

# ================================================================
# Reproducibility
# ================================================================
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# ================================================================
# Load HeteroData graph
# ================================================================
print("=" * 80)
print("LOADING GRAPH DATA")
print("=" * 80)

data = torch.load('graph.pt', weights_only=False)
print(f"\nLoaded HeteroData with:")
print(f"  - Node types: {list(data.node_types)}")
print(f"  - Edge types: {len(data.edge_types)}")

num_compounds = data['Compound'].num_nodes
num_diseases = data['Disease'].num_nodes
print(f"\n  - Compounds: {num_compounds}")
print(f"  - Diseases: {num_diseases}")

assert data['Compound'].num_nodes == num_compounds
assert data['Disease'].num_nodes == num_diseases

# ================================================================
# Split treats relation
# ================================================================
print("\n" + "=" * 80)
print("SPLITTING EDGES")
print("=" * 80)

def split_edges(edge_index, train_ratio=0.7, val_ratio=0.15, seed=SEED):
    torch.manual_seed(seed)
    num_edges = edge_index.size(1)
    perm = torch.randperm(num_edges)
    train_end = int(train_ratio * num_edges)
    val_end = train_end + int(val_ratio * num_edges)
    return {
        "train": edge_index[:, perm[:train_end]],
        "val": edge_index[:, perm[train_end:val_end]],
        "test": edge_index[:, perm[val_end:]]
    }

rel_treats = ('Compound', 'treats', 'Disease')
orig_treats = data[rel_treats].edge_index
splits_treats = split_edges(orig_treats)

print(f"\n'treats' relation split:")
print(f"  - Train: {splits_treats['train'].size(1)} edges")
print(f"  - Val:   {splits_treats['val'].size(1)} edges")
print(f"  - Test:  {splits_treats['test'].size(1)} edges")

# ================================================================
# Build train_data (only treats edges reduced)
# ================================================================
train_data = data.clone()

for rel in data.edge_types:
    if rel == rel_treats:
        train_data[rel].edge_index = splits_treats['train']
    else:
        train_data[rel].edge_index = data[rel].edge_index

    src, name, dst = rel
    rev = (dst, name + "_rev", src)
    if rev in data.edge_types:
        train_data[rev].edge_index = train_data[rel].edge_index.flip(0)

# ================================================================
# Heterogeneous ‚Üí Homogeneous conversion
# ================================================================
print("\n" + "=" * 80)
print("CONVERTING TO HOMOGENEOUS GRAPH")
print("=" * 80)

def hetero_to_homo(hetero_data, default_dim=128):
    node_offset = {}
    offset = 0
    feat_dim = None
    node_features = []

    for ntype in hetero_data.node_types:
        node_offset[ntype] = offset
        n = hetero_data[ntype].num_nodes

        if hasattr(hetero_data[ntype], "x") and hetero_data[ntype].x is not None:
            ft = hetero_data[ntype].x
            if feat_dim is None:
                feat_dim = ft.size(1)
            else:
                assert ft.size(1) == feat_dim
        else:
            if feat_dim is None:
                feat_dim = default_dim
            ft = torch.zeros((n, feat_dim))
        offset += n
        node_features.append(ft)

    x = torch.cat(node_features, dim=0)
    x = F.normalize(x, p=2, dim=1)

    edge_index_list = []
    edge_type_list = []
    relation_names = []
    relname2id = {}

    for rel_id, edge_type in enumerate(hetero_data.edge_types):
        src, relname, dst = edge_type
        eidx = hetero_data[edge_type].edge_index.clone()
        eidx[0] += node_offset[src]
        eidx[1] += node_offset[dst]
        edge_index_list.append(eidx)
        edge_type_list.append(torch.full((eidx.size(1),), rel_id, dtype=torch.long))
        relation_names.append(edge_type)
        relname2id[edge_type] = rel_id

    edge_index = torch.cat(edge_index_list, dim=1)
    edge_type = torch.cat(edge_type_list, dim=0)

    # Add self-loop relation
    N = x.size(0)
    loop_edges = torch.arange(N).unsqueeze(0).repeat(2, 1)
    self_rel_id = len(relation_names)

    edge_index = torch.cat([edge_index, loop_edges], dim=1)
    edge_type = torch.cat([edge_type, torch.full((N,), self_rel_id)], dim=0)

    relation_names.append(('SELF', 'SELF', 'SELF'))
    relname2id[('SELF', 'SELF', 'SELF')] = self_rel_id

    return x, edge_index, edge_type, node_offset, relation_names, relname2id

x, edge_index, edge_type, node_offset, relation_names, relname2id = hetero_to_homo(train_data)
treats_rel_id = relname2id[rel_treats]

print(f"\nHomogeneous graph:")
print(f"  - Total nodes: {x.size(0)}")
print(f"  - Total edges: {edge_index.size(1)}")
print(f"  - Relations: {len(relation_names)}")
print(f"  - 'treats' relation ID: {treats_rel_id}")

# ================================================================
# Negative Sampling
# ================================================================
print("\n" + "=" * 80)
print("NEGATIVE SAMPLING")
print("=" * 80)

pos_mat = np.zeros((num_compounds, num_diseases), dtype=bool)
all_pos = torch.cat([splits_treats['train'], splits_treats['val'], splits_treats['test']], dim=1)
pos_mat[all_pos[0].numpy(), all_pos[1].numpy()] = True

def sample_neg(k):
    negs = []
    while len(negs) < k:
        need = k - len(negs)
        samp = max(need * 5, 1000)
        c = np.random.randint(0, num_compounds, samp)
        d = np.random.randint(0, num_diseases, samp)
        mask = ~pos_mat[c, d]
        gc = c[mask][:need]
        gd = d[mask][:need]
        for ci, di in zip(gc, gd):
            negs.append([ci, di])
    return torch.tensor(negs[:k]).t()

neg_train = sample_neg(splits_treats['train'].size(1))
neg_val = sample_neg(splits_treats['val'].size(1))
neg_test = sample_neg(splits_treats['test'].size(1))

print(f"\nNegative samples generated:")
print(f"  - Train: {neg_train.size(1)}")
print(f"  - Val:   {neg_val.size(1)}")
print(f"  - Test:  {neg_test.size(1)}")

# ================================================================
# R-GCN Model
# ================================================================
class RGCNLinkPredictor(nn.Module):
    def __init__(self, in_dim, hid, num_rel, num_bases=8, dropout=0.2, temperature=1.0):
        super().__init__()
        self.conv1 = RGCNConv(in_dim, hid, num_rel, num_bases=num_bases)
        self.conv2 = RGCNConv(hid, hid, num_rel, num_bases=num_bases)
        self.norm1 = nn.LayerNorm(hid)
        self.norm2 = nn.LayerNorm(hid)
        self.relation_emb = nn.Parameter(torch.randn(num_rel, hid) * 0.1)
        self.dropout = dropout
        self.temperature = temperature

    def encode(self, x, ei, et):
        h = self.conv1(x, ei, et)
        h = self.norm1(h)
        h = F.leaky_relu(h, 0.2)
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = self.conv2(h, ei, et)
        h = self.norm2(h)
        h = F.leaky_relu(h, 0.2)
        return F.normalize(h, dim=1)

    def decode(self, z, edge_index, rel_ids):
        z_src = z[edge_index[0]]
        z_dst = z[edge_index[1]]
        r = self.relation_emb[rel_ids]
        return (z_src * r * z_dst).sum(dim=1) / self.temperature

    def forward(self, x, ei, et, pred_edges, rel_ids):
        z = self.encode(x, ei, et)
        return self.decode(z, pred_edges, rel_ids)

# ================================================================
# Training Setup
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n" + "=" * 80)
print(f"DEVICE: {device}")
print("=" * 80)

model = RGCNLinkPredictor(x.size(1), 128, len(relation_names)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

x = x.to(device)
edge_index = edge_index.to(device)
edge_type = edge_type.to(device)

# ================================================================
# Evaluation Function
# ================================================================
def evaluate(model, x, ei, et, pos_edges, neg_edges, node_offset, rel_id, k=10):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, ei, et)

        pos = pos_edges.clone()
        pos[0] += node_offset['Compound']
        pos[1] += node_offset['Disease']
        neg = neg_edges.clone()
        neg[0] += node_offset['Compound']
        neg[1] += node_offset['Disease']

        pos = pos.to(device)
        neg = neg.to(device)

        rel_pos = torch.full((pos.size(1),), rel_id, dtype=torch.long, device=device)
        rel_neg = torch.full((neg.size(1),), rel_id, dtype=torch.long, device=device)

        pos_scores = torch.sigmoid(model.decode(z, pos, rel_pos)).cpu().numpy()
        neg_scores = torch.sigmoid(model.decode(z, neg, rel_neg)).cpu().numpy()

        y_true = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
        y_score = np.concatenate([pos_scores, neg_scores])

        auc = roc_auc_score(y_true, y_score)
        ap = average_precision_score(y_true, y_score)

        # Compute Hits@K and MRR
        hits_at_k = []
        reciprocal_ranks = []
        for i in range(len(pos_scores)):
            pos_s = pos_scores[i]
            rank = (neg_scores > pos_s).sum() + 1
            hits_at_k.append(1.0 if rank <= k else 0.0)
            reciprocal_ranks.append(1.0 / rank)

        return {
            "auc": auc,
            "ap": ap,
            f"hits@{k}": np.mean(hits_at_k),
            "mrr": np.mean(reciprocal_ranks),
            "pos_prob": pos_scores.mean(),
            "neg_prob": neg_scores.mean()
        }

# ================================================================
# Training Loop
# ================================================================
def train_epoch():
    model.train()

    pos = splits_treats['train'].clone()
    pos[0] += node_offset['Compound']
    pos[1] += node_offset['Disease']

    neg = neg_train.clone()
    neg[0] += node_offset['Compound']
    neg[1] += node_offset['Disease']

    pos = pos.to(device)
    neg = neg.to(device)

    rel_pos = torch.full((pos.size(1),), treats_rel_id, dtype=torch.long, device=device)
    rel_neg = torch.full((neg.size(1),), treats_rel_id, dtype=torch.long, device=device)

    pos_scores = model(x, edge_index, edge_type, pos, rel_pos)
    neg_scores = model(x, edge_index, edge_type, neg, rel_neg)

    logits = torch.cat([pos_scores, neg_scores])
    labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])

    loss = F.binary_cross_entropy_with_logits(logits, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        pos_prob = torch.sigmoid(pos_scores).mean().item()
        neg_prob = torch.sigmoid(neg_scores).mean().item()

    return loss.item(), pos_prob, neg_prob

print("\n" + "=" * 80)
print("TRAINING")
print("=" * 80)
print(f"\n{'Epoch':<8} {'Loss':<10} {'Pos Prob':<10} {'Neg Prob':<10} {'ROC-AUC':<10} {'AP':<10} {'Hits@10':<10} {'MRR':<10}")
print("-" * 88)

best_auc = 0
for epoch in range(1, 201):
    loss, pos_prob, neg_prob = train_epoch()

    if epoch % 10 == 0:
        val_metrics = evaluate(model, x, edge_index, edge_type, splits_treats['val'],
                               neg_val, node_offset, treats_rel_id)

        print(f"{epoch:<8} {loss:<10.4f} {pos_prob:<10.4f} {neg_prob:<10.4f} "
              f"{val_metrics['auc']:<10.4f} {val_metrics['ap']:<10.4f} "
              f"{val_metrics['hits@10']:<10.4f} {val_metrics['mrr']:<10.4f}")

        if val_metrics["auc"] > best_auc:
            best_auc = val_metrics["auc"]
            torch.save(model.state_dict(), "best_rgcn_model.pt")

# ================================================================
# Test Evaluation
# ================================================================
print("\n" + "=" * 80)
print("TEST EVALUATION")
print("=" * 80)

model.load_state_dict(torch.load("best_rgcn_model.pt"))
test_metrics = evaluate(model, x, edge_index, edge_type, splits_treats['test'],
                        neg_test, node_offset, treats_rel_id)

print(f"\nTest Results:")
print(f"  - ROC-AUC:  {test_metrics['auc']:.4f}")
print(f"  - AP:       {test_metrics['ap']:.4f}")
print(f"  - Hits@10:  {test_metrics['hits@10']:.4f}")
print(f"  - MRR:      {test_metrics['mrr']:.4f}")

# ================================================================
# Predict all compound-disease pairs
# ================================================================
print("\n" + "=" * 80)
print("GENERATING PREDICTIONS")
print("=" * 80)

def predict_all_pairs():
    model.eval()
    with torch.no_grad():
        z = model.encode(x, edge_index, edge_type)
        comp_start = node_offset['Compound']
        dis_start = node_offset['Disease']

        zc = z[comp_start : comp_start + num_compounds]
        zd = z[dis_start : dis_start + num_diseases]

        r = model.relation_emb[treats_rel_id]
        scores = torch.sigmoid((zc * r) @ zd.T)
        return scores.cpu().numpy()

scores = predict_all_pairs()
np.save("compound_disease_predictions.npy", scores)
print(f"\nSaved full prediction matrix: {scores.shape}")

# ================================================================
# Create CSV with Top-10 predictions per compound
# ================================================================
print("\nCreating top-10 predictions CSV...")

# Build set of existing positive edges
existing_edges = set()
for i in range(all_pos.size(1)):
    existing_edges.add((all_pos[0, i].item(), all_pos[1, i].item()))

rows = []
for compound_id in range(num_compounds):
    # Get all scores for this compound
    compound_scores = scores[compound_id]

    # Get top-10 disease indices
    top10_indices = np.argsort(compound_scores)[-10:][::-1]

    for rank, disease_id in enumerate(top10_indices, 1):
        score = float(compound_scores[disease_id])

        # Check if this is an existing positive edge
        is_positive = (compound_id, disease_id) in existing_edges
        edge_type = "Positive" if is_positive else "Negative (New Prediction)"

        rows.append({
            'Compound_ID': compound_id,
            'Disease_ID': disease_id,
            'Rank': rank,
            'Score': score,
            'Edge_Type': edge_type
        })

df = pd.DataFrame(rows)
df.to_csv('top10_disease_predictions_per_compound.csv', index=False)

print(f"Saved: top10_disease_predictions_per_compound.csv")
print(f"  - Total predictions: {len(df)}")
print(f"  - Positive edges: {(df['Edge_Type'] == 'Positive').sum()}")
print(f"  - Negative edges (new): {(df['Edge_Type'] == 'Negative (New Prediction)').sum()}")

# ================================================================
# Query Functions for Specific Compounds
# ================================================================

def predict_for_compound(compound_id, top_k=10):
    """
    Get top-K disease predictions for a specific compound.

    Args:
        compound_id: Integer ID of the compound
        top_k: Number of top predictions to return

    Returns:
        DataFrame with predictions
    """
    if compound_id < 0 or compound_id >= num_compounds:
        raise ValueError(f"Invalid compound_id. Must be between 0 and {num_compounds-1}")

    compound_scores = scores[compound_id]
    top_indices = np.argsort(compound_scores)[-top_k:][::-1]

    results = []
    for rank, disease_id in enumerate(top_indices, 1):
        score = float(compound_scores[disease_id])
        is_positive = (compound_id, disease_id) in existing_edges
        edge_type = "Positive" if is_positive else "Negative (New Prediction)"

        results.append({
            'Compound_ID': compound_id,
            'Disease_ID': disease_id,
            'Rank': rank,
            'Score': score,
            'Edge_Type': edge_type
        })

    return pd.DataFrame(results)

def predict_for_compounds(compound_ids, top_k=10):
    """
    Get top-K disease predictions for multiple compounds.

    Args:
        compound_ids: List of compound IDs
        top_k: Number of top predictions per compound

    Returns:
        DataFrame with all predictions
    """
    all_results = []
    for cid in compound_ids:
        df_temp = predict_for_compound(cid, top_k)
        all_results.append(df_temp)

    return pd.concat(all_results, ignore_index=True)

def predict_compound_disease_pair(compound_id, disease_id):
    """
    Get prediction score for a specific compound-disease pair.

    Args:
        compound_id: Integer ID of the compound
        disease_id: Integer ID of the disease

    Returns:
        Dictionary with prediction details
    """
    if compound_id < 0 or compound_id >= num_compounds:
        raise ValueError(f"Invalid compound_id. Must be between 0 and {num_compounds-1}")
    if disease_id < 0 or disease_id >= num_diseases:
        raise ValueError(f"Invalid disease_id. Must be between 0 and {num_diseases-1}")

    score = float(scores[compound_id, disease_id])
    is_positive = (compound_id, disease_id) in existing_edges

    return {
        'Compound_ID': compound_id,
        'Disease_ID': disease_id,
        'Score': score,
        'Edge_Type': "Positive" if is_positive else "Negative (New Prediction)"
    }

# ================================================================
# Example Usage
# ================================================================
print("\n" + "=" * 80)
print("QUERY EXAMPLES")
print("=" * 80)

# Example 1: Single compound
print("\n[1] Predictions for Compound ID = 0:")
result1 = predict_for_compound(compound_id=0, top_k=5)
print(result1.to_string(index=False))

# Example 2: Multiple compounds
print("\n[2] Predictions for Compounds [5, 10, 15]:")
result2 = predict_for_compounds(compound_ids=[5, 10, 15], top_k=3)
print(result2.to_string(index=False))

# Example 3: Specific pair
print("\n[3] Prediction for Compound 0 - Disease 10:")
result3 = predict_compound_disease_pair(compound_id=0, disease_id=10)
print(f"  Score: {result3['Score']:.4f}")
print(f"  Edge Type: {result3['Edge_Type']}")

print("\n" + "=" * 80)
print("COMPLETED SUCCESSFULLY")
print("=" * 80)
print("\nüìã AVAILABLE FUNCTIONS:")
print("  ‚Ä¢ predict_for_compound(compound_id, top_k=10)")
print("  ‚Ä¢ predict_for_compounds(compound_ids, top_k=10)")
print("  ‚Ä¢ predict_compound_disease_pair(compound_id, disease_id)")
print("\nüíæ SAVED FILES:")
print("  ‚Ä¢ compound_disease_predictions.npy - Full prediction matrix")
print("  ‚Ä¢ top10_disease_predictions_per_compound.csv - All compounds top-10")
print("  ‚Ä¢ best_rgcn_model.pt - Trained model weights")

LOADING GRAPH DATA

Loaded HeteroData with:
  - Node types: ['Anatomy', 'Biological_Process', 'Cellular_Component', 'Compound', 'Disease', 'Gene', 'Molecular_Function', 'Pathway', 'Pharmacologic_Class', 'Side_Effect', 'Symptom']
  - Edge types: 38

  - Compounds: 1552
  - Diseases: 137

SPLITTING EDGES

'treats' relation split:
  - Train: 528 edges
  - Val:   113 edges
  - Test:  114 edges

CONVERTING TO HOMOGENEOUS GRAPH

Homogeneous graph:
  - Total nodes: 47031
  - Total edges: 4498632
  - Relations: 39
  - 'treats' relation ID: 14

NEGATIVE SAMPLING

Negative samples generated:
  - Train: 528
  - Val:   113
  - Test:  114

DEVICE: cuda

TRAINING

Epoch    Loss       Pos Prob   Neg Prob   ROC-AUC    AP         Hits@10    MRR       
----------------------------------------------------------------------------------------
10       0.6837     0.5078     0.4982     0.8826     0.8776     0.5841     0.2990    
20       0.6651     0.5226     0.4938     0.9288     0.9361     0.7345     0.628