In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.metrics import roc_auc_score, average_precision_score

from models_sage import HeteroGraphSAGE
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE


'cpu'

In [2]:
# Notebook-safe paths
ROOT = Path.cwd().parent

DATA_DIR = ROOT / "data" / "data_cleaned"
GRAPH_PATH = ROOT / "outputs" / "data.pt"
OUT_DIR = ROOT / "outputs"
OUT_DIR.mkdir(exist_ok=True)

print("ROOT:", ROOT)
print("DATA_DIR:", DATA_DIR)
print("GRAPH_PATH:", GRAPH_PATH)


ROOT: C:\Users\ayish\OneDrive\Documents\GNN
DATA_DIR: C:\Users\ayish\OneDrive\Documents\GNN\data\data_cleaned
GRAPH_PATH: C:\Users\ayish\OneDrive\Documents\GNN\outputs\data.pt


### Load node preserving labelled splits

In [3]:
def load_split(name, le_circ, le_dis):
    df = pd.read_csv(DATA_DIR / name)

    # Encode node names → integer IDs
    circ_ids = le_circ.transform(df["circRNA"].astype(str))
    dis_ids  = le_dis.transform(df["disease"].astype(str))

    # Build edge index: [2, num_edges]
    edges = torch.from_numpy(np.vstack([circ_ids, dis_ids])).long()

    # Labels (already numeric)
    labels = torch.tensor(df["label"].values, dtype=torch.float)

    return edges, labels


In [4]:
encoders = torch.load(OUT_DIR / "label_encoders.pt")
le_circ = encoders["circRNA"]
le_dis  = encoders["disease"]

train_edges, train_labels = load_split("circRNA_disease_train.csv", le_circ, le_dis)
val_edges, val_labels     = load_split("circRNA_disease_val.csv", le_circ, le_dis)
test_edges, test_labels   = load_split("circRNA_disease_test.csv", le_circ, le_dis)

train_edges.shape, train_labels.shape


(torch.Size([2, 950]), torch.Size([950]))

### Move Splits to Device

In [5]:
train_edges, train_labels = train_edges.to(DEVICE), train_labels.to(DEVICE)
val_edges, val_labels     = val_edges.to(DEVICE), val_labels.to(DEVICE)
test_edges, test_labels   = test_edges.to(DEVICE), test_labels.to(DEVICE)


### Load HeteroGraph

In [6]:
print("Loading heterogeneous graph...")
data = torch.load(GRAPH_PATH, map_location=DEVICE)
data


Loading heterogeneous graph...


HeteroData(
  circRNA={ x=[828, 4] },
  miRNA={ x=[521, 4] },
  disease={ x=[122, 4] },
  (circRNA, interacts, miRNA)={ edge_index=[2, 896] },
  (miRNA, interacts, disease)={ edge_index=[2, 828] },
  (circRNA, associated, disease)={ edge_index=[2, 985] },
  (miRNA, rev_interacts, circRNA)={ edge_index=[2, 896] },
  (disease, rev_interacts, miRNA)={ edge_index=[2, 828] },
  (disease, rev_associated, circRNA)={ edge_index=[2, 985] }
)

### Initialize GraphSAGE model

In [7]:
model = HeteroGraphSAGE(
    in_channels=data["circRNA"].x.size(1),
    hidden_channels=64,
    out_channels=64,
    dropout=0.2
).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

model


HeteroGraphSAGE(
  (conv1): HeteroConv(num_relations=6)
  (conv2): HeteroConv(num_relations=6)
  (act): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
  (res_lin_circ): Linear(in_features=4, out_features=64, bias=True)
  (res_lin_mir): Linear(in_features=4, out_features=64, bias=True)
  (res_lin_dis): Linear(in_features=4, out_features=64, bias=True)
)

In [9]:
best_val_aupr = 0.0

print("\nStarting GraphSAGE training...\n")

for epoch in range(1, 51):
    # -------- Training --------
    model.train()
    optimizer.zero_grad()

    emb = model(data.x_dict, data.edge_index_dict)
    circ_emb, dis_emb = emb["circRNA"], emb["disease"]

    logits = (circ_emb[train_edges[0]] * dis_emb[train_edges[1]]).sum(dim=1)
    loss = loss_fn(logits, train_labels)

    loss.backward()
    optimizer.step()

    # -------- Validation --------
    model.eval()
    with torch.no_grad():
        emb = model(data.x_dict, data.edge_index_dict)
        circ_emb, dis_emb = emb["circRNA"], emb["disease"]

        val_logits = (circ_emb[val_edges[0]] * dis_emb[val_edges[1]]).sum(dim=1)
        val_scores = torch.sigmoid(val_logits).cpu().numpy()
        val_true   = val_labels.cpu().numpy()

        auc = roc_auc_score(val_true, val_scores)
        aupr = average_precision_score(val_true, val_scores)

    print(
        f"Epoch {epoch:03d} | "
        f"Loss {loss.item():.4f} | "
        f"Val AUC {auc:.4f} | "
        f"Val AUPR {aupr:.4f}"
    )

    if aupr > best_val_aupr:
        best_val_aupr = aupr
        torch.save(model.state_dict(), OUT_DIR / "sage_best_model.pth")
        print("   → Saved best model")



Starting GraphSAGE training...

Epoch 001 | Loss 0.7278 | Val AUC 0.5580 | Val AUPR 0.1196
   → Saved best model
Epoch 002 | Loss 0.6713 | Val AUC 0.5929 | Val AUPR 0.1287
   → Saved best model
Epoch 003 | Loss 0.6178 | Val AUC 0.6081 | Val AUPR 0.1349
   → Saved best model
Epoch 004 | Loss 0.5744 | Val AUC 0.6150 | Val AUPR 0.1388
   → Saved best model
Epoch 005 | Loss 0.5333 | Val AUC 0.6321 | Val AUPR 0.1465
   → Saved best model
Epoch 006 | Loss 0.5050 | Val AUC 0.6474 | Val AUPR 0.1533
   → Saved best model
Epoch 007 | Loss 0.4801 | Val AUC 0.6605 | Val AUPR 0.1594
   → Saved best model
Epoch 008 | Loss 0.4569 | Val AUC 0.6729 | Val AUPR 0.1656
   → Saved best model
Epoch 009 | Loss 0.4387 | Val AUC 0.6868 | Val AUPR 0.1744
   → Saved best model
Epoch 010 | Loss 0.4227 | Val AUC 0.7032 | Val AUPR 0.1856
   → Saved best model
Epoch 011 | Loss 0.4105 | Val AUC 0.7167 | Val AUPR 0.1967
   → Saved best model
Epoch 012 | Loss 0.4009 | Val AUC 0.7323 | Val AUPR 0.2167
   → Saved best m

In [10]:
print("\nEvaluating on test set...")

model.load_state_dict(torch.load(OUT_DIR / "sage_best_model.pth", map_location=DEVICE))
model.eval()

with torch.no_grad():
    emb = model(data.x_dict, data.edge_index_dict)
    circ_emb, dis_emb = emb["circRNA"], emb["disease"]

    test_logits = (circ_emb[test_edges[0]] * dis_emb[test_edges[1]]).sum(dim=1)
    test_scores = torch.sigmoid(test_logits).cpu().numpy()
    test_true   = test_labels.cpu().numpy()

    test_auc = roc_auc_score(test_true, test_scores)
    test_aupr = average_precision_score(test_true, test_scores)

print("\nFINAL TEST RESULTS")
print(f"AUC  = {test_auc:.4f}")
print(f"AUPR = {test_aupr:.4f}")



Evaluating on test set...

FINAL TEST RESULTS
AUC  = 0.7852
AUPR = 0.4298
