In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.metrics import roc_auc_score, average_precision_score

from models_sage import HeteroGraphSAGE
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE


'cpu'

In [2]:
# Reproducibility control
import torch, random, numpy as np

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [3]:
SEEDS = [42, 43, 44, 45, 46]


In [4]:
# Notebook-safe paths
ROOT = Path.cwd().parent

DATA_DIR = ROOT / "data" / "data_cleaned"
GRAPH_PATH = ROOT / "outputs" / "data.pt"
OUT_DIR = ROOT / "outputs"
OUT_DIR.mkdir(exist_ok=True)

print("ROOT:", ROOT)
print("DATA_DIR:", DATA_DIR)
print("GRAPH_PATH:", GRAPH_PATH)


ROOT: C:\Users\ayish\OneDrive\Documents\circRNA-disease-gnn
DATA_DIR: C:\Users\ayish\OneDrive\Documents\circRNA-disease-gnn\data\data_cleaned
GRAPH_PATH: C:\Users\ayish\OneDrive\Documents\circRNA-disease-gnn\outputs\data.pt


### Load node preserving labelled splits

In [5]:
def load_split(name, le_circ, le_dis):
    df = pd.read_csv(DATA_DIR / name)

    # Encode node names → integer IDs
    circ_ids = le_circ.transform(df["circRNA"].astype(str))
    dis_ids  = le_dis.transform(df["disease"].astype(str))

    # Build edge index: [2, num_edges]
    edges = torch.from_numpy(np.vstack([circ_ids, dis_ids])).long()

    # Labels (already numeric)
    labels = torch.tensor(df["label"].values, dtype=torch.float)

    return edges, labels


In [6]:
encoders = torch.load(
    OUT_DIR / "label_encoders.pt",
    weights_only=False
)

le_circ = encoders["circRNA"]
le_dis  = encoders["disease"]


train_edges, train_labels = load_split("circRNA_disease_train.csv", le_circ, le_dis)
val_edges, val_labels     = load_split("circRNA_disease_val.csv", le_circ, le_dis)
test_edges, test_labels   = load_split("circRNA_disease_test.csv", le_circ, le_dis)

train_edges.shape, train_labels.shape


(torch.Size([2, 929]), torch.Size([929]))

### Move Splits to Device

In [7]:
train_edges, train_labels = train_edges.to(DEVICE), train_labels.to(DEVICE)
val_edges, val_labels     = val_edges.to(DEVICE), val_labels.to(DEVICE)
test_edges, test_labels   = test_edges.to(DEVICE), test_labels.to(DEVICE)


### Load HeteroGraph

In [8]:
print("Loading heterogeneous graph...")
data = torch.load(
    GRAPH_PATH,
    map_location=DEVICE,
    weights_only=False
)

data


Loading heterogeneous graph...


HeteroData(
  circRNA={ x=[828, 6] },
  miRNA={ x=[521, 6] },
  disease={ x=[122, 6] },
  (circRNA, interacts, miRNA)={ edge_index=[2, 896] },
  (miRNA, interacts, disease)={ edge_index=[2, 828] },
  (circRNA, associated, disease)={ edge_index=[2, 985] },
  (circRNA, gip_sim, circRNA)={
    edge_index=[2, 685584],
    edge_weight=[685584],
  },
  (miRNA, gip_sim, miRNA)={
    edge_index=[2, 271441],
    edge_weight=[271441],
  },
  (miRNA, rev_interacts, circRNA)={ edge_index=[2, 896] },
  (disease, rev_interacts, miRNA)={ edge_index=[2, 828] },
  (disease, rev_associated, circRNA)={ edge_index=[2, 985] }
)

### Initialize GraphSAGE model

In [9]:
class LinkPredictor(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim * 3, dim),
            nn.ReLU(),
            nn.Linear(dim, 1)
        )

    def forward(self, z_c, z_d):
        x = torch.cat([z_c, z_d, z_c * z_d], dim=1)
        return self.mlp(x).squeeze(-1)

In [10]:
results = []

for seed in SEEDS:
    print(f"\n==============================")
    print(f"Running experiment with SEED = {seed}")
    print(f"==============================")

    set_seed(seed)

    # ---- model must be re-created AFTER seed ----
    model = HeteroGraphSAGE(
        in_channels=data["circRNA"].x.size(1),
        hidden_channels=64,
        out_channels=64,
        dropout=0.2
    ).to(DEVICE)
    predictor = LinkPredictor(dim=64).to(DEVICE)


    optimizer = torch.optim.Adam(
    list(model.parameters()) + list(predictor.parameters()),
    lr=1e-3
)

    loss_fn = nn.BCEWithLogitsLoss()

    best_val_aupr = 0.0
    best_val_loss = None


    # -------- Training loop --------
    for epoch in range(1, 51):
        model.train()
        optimizer.zero_grad()

        emb = model(data.x_dict, data.edge_index_dict)
        circ_emb, dis_emb = emb["circRNA"], emb["disease"]

        logits = predictor(circ_emb[train_edges[0]],dis_emb[train_edges[1]])
        loss = loss_fn(logits, train_labels)


        loss.backward()
        for name, p in predictor.named_parameters():
            print(name, p.grad.norm().item())

        optimizer.step()

        # -------- Validation --------
        model.eval()
        predictor.eval()

        with torch.no_grad():
            emb = model(data.x_dict, data.edge_index_dict)
            circ_emb, dis_emb = emb["circRNA"], emb["disease"]

            val_logits = predictor(circ_emb[val_edges[0]],dis_emb[val_edges[1]])
            val_scores = torch.sigmoid(val_logits).cpu().numpy()
            val_true   = val_labels.cpu().numpy()

            auc = roc_auc_score(val_true, val_scores)
            aupr = average_precision_score(val_true, val_scores)
        print(
        f"Epoch {epoch:03d} | "
        f"Loss {loss.item():.4f} | "
        f"Val AUC {auc:.4f} | "
        f"Val AUPR {aupr:.4f}"
        )

        if aupr > best_val_aupr:
            best_val_aupr = aupr
            best_val_loss = loss.item()
            torch.save(
                {
                    "model": model.state_dict(),
                    "predictor": predictor.state_dict()
                },
                OUT_DIR / f"sage_best_model_seed{seed}.pth"
            )

            print("   → Saved best model")

    # -------- Test evaluation --------
    ckpt = torch.load(
        OUT_DIR / f"sage_best_model_seed{seed}.pth",
        map_location=DEVICE
    )
    model.load_state_dict(ckpt["model"])
    predictor.load_state_dict(ckpt["predictor"])

    model.eval()
    predictor.eval()


    with torch.no_grad():
        emb = model(data.x_dict, data.edge_index_dict)
        circ_emb, dis_emb = emb["circRNA"], emb["disease"]

        test_logits = predictor(
            circ_emb[test_edges[0]],
            dis_emb[test_edges[1]]
        )
        test_scores = torch.sigmoid(test_logits).cpu().numpy()

        test_true   = test_labels.cpu().numpy()

        test_auc  = roc_auc_score(test_true, test_scores)
        test_aupr = average_precision_score(test_true, test_scores)

    print(f"SEED {seed} | Test AUC {test_auc:.4f} | Test AUPR {test_aupr:.4f}")

    results.append((seed, best_val_loss, test_auc, test_aupr))



Running experiment with SEED = 42
3Layer
mlp.0.weight 0.1529686003923416
mlp.0.bias 0.135014146566391
mlp.2.weight 0.15002262592315674
mlp.2.bias 0.4126432240009308
3Layer
Epoch 001 | Loss 0.6654 | Val AUC 0.8741 | Val AUPR 0.4865
   → Saved best model
3Layer
mlp.0.weight 0.13129237294197083
mlp.0.bias 0.11336798965930939
mlp.2.weight 0.15494771301746368
mlp.2.bias 0.39927923679351807
3Layer
Epoch 002 | Loss 0.6433 | Val AUC 0.8798 | Val AUPR 0.4997
   → Saved best model
3Layer
mlp.0.weight 0.12306120246648788
mlp.0.bias 0.1040099561214447
mlp.2.weight 0.1810811311006546
mlp.2.bias 0.3890114426612854
3Layer
Epoch 003 | Loss 0.6268 | Val AUC 0.8869 | Val AUPR 0.6146
   → Saved best model
3Layer
mlp.0.weight 0.12400288134813309
mlp.0.bias 0.10271754115819931
mlp.2.weight 0.20682412385940552
mlp.2.bias 0.3810746669769287
3Layer
Epoch 004 | Loss 0.6144 | Val AUC 0.8947 | Val AUPR 0.6526
   → Saved best model
3Layer
mlp.0.weight 0.13091687858104706
mlp.0.bias 0.10637342929840088
mlp.2.weig

In [11]:
results_df = pd.DataFrame(
    results,
    columns=["seed", "best_val_loss", "auc", "aupr"]
)

print("\n===== FINAL RESULTS =====")
print(results_df)
print("\nMean ± Std")
print(
    results_df[["best_val_loss", "auc", "aupr"]]
    .agg(["mean", "std"])
)



===== FINAL RESULTS =====
   seed  best_val_loss       auc      aupr
0    42       0.364543  0.898877  0.587469
1    43       0.464123  0.905987  0.633714
2    44       0.255183  0.883161  0.555311
3    45       0.256997  0.903421  0.573914
4    46       0.299408  0.893817  0.517792

Mean ± Std
      best_val_loss       auc      aupr
mean       0.328051  0.897053  0.573640
std        0.088071  0.009044  0.042588
