In [9]:
from pathlib import Path
from collections import deque
import os, json, difflib, time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F


In [10]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def load_yaml(path: Path):
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def load_json(path: Path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists() and (p / "output").exists():
            return p
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists():
            return p
    return start


In [11]:
project_root = find_project_root()
os.chdir(project_root)
print("CWD:", Path.cwd())

config_path = project_root / "code" / "config.yaml"
cfg = load_yaml(config_path)

proc_dir = project_root / cfg["data"]["processed_dir"]
out_dir  = project_root / cfg["output"]["dir"]

models_dir  = out_dir / "models"
queries_dir = out_dir / "queries"
ensure_dir(queries_dir)

print("proc_dir  :", proc_dir)
print("out_dir   :", out_dir)
print("models_dir:", models_dir, "| exists:", models_dir.exists())
print("queries_dir:", queries_dir)


CWD: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction
proc_dir  : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed
out_dir   : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output
models_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\models | exists: True
queries_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\queries


In [12]:
g_path = proc_dir / "graph_edges.pt"
if not g_path.exists():
    raise FileNotFoundError(f"graph_edges.pt not found: {g_path}")

g = torch.load(g_path, map_location="cpu")
edge_index = g["edge_index"]   # [2, E]
edge_type  = g["edge_type"]    # [E]
num_nodes  = int(g["num_nodes"])
num_relations = int(g["num_relations"])

print("num_nodes:", num_nodes, " | num_relations:", num_relations, " | E:", edge_index.shape[1])

id2entity   = load_json(proc_dir / "id2entity.json")
id2relation = load_json(proc_dir / "id2relation.json")

entity2id = {v: int(k) for k, v in id2entity.items()}

print("sample entity:", list(entity2id.items())[:3])
print("sample relation:", list(id2relation.items())[:3])


num_nodes: 37614  | num_relations: 107  | E: 118308
sample entity: [('Gene::3313', 0), ('Gene::5521', 1), ('Compound::DB11767', 2)]
sample relation: [('0', 'Hetionet::GiG::Gene:Gene'), ('1', 'DRUGBANK::ddi-interactor-in::Compound:Compound'), ('2', 'GNBR::Ud::Gene:Disease')]


In [13]:
keep_path = proc_dir / "train_graph_edge_idx.npy"
if keep_path.exists():
    keep_idx = np.load(keep_path).astype(np.int64)
    keep_idx = keep_idx[keep_idx < edge_index.shape[1]]
    ei_train = edge_index[:, keep_idx]
    et_train = edge_type[keep_idx]
    print("Using TRAIN graph:", ei_train.shape, et_train.shape)
else:
    ei_train = edge_index
    et_train = edge_type
    print("Using FULL graph:", ei_train.shape, et_train.shape)

ei_np = ei_train.numpy()
et_np = et_train.numpy()
adj = [[] for _ in range(num_nodes)]
for i in range(ei_np.shape[1]):
    u = int(ei_np[0, i]); v = int(ei_np[1, i]); r = int(et_np[i])
    adj[u].append((v, r))

print("Adj built.")


Using TRAIN graph: torch.Size([2, 118233]) torch.Size([118233])
Adj built.


In [14]:
try:
    from torch_geometric.nn import GATConv, RGCNConv, RGATConv
except Exception as e:
    raise ImportError("torch_geometric is not installed. Install it first, then rerun.") from e


class MLPLinkScorer(nn.Module):
    """Flexible MLP scorer. in_mult is inferred from checkpoint (2 or 3 or 4)."""
    def __init__(self, dim: int, hidden: int = 128, dropout: float = 0.2, in_mult: int = 2):
        super().__init__()
        self.dim = int(dim)
        self.in_mult = int(in_mult)
        self.net = nn.Sequential(
            nn.Linear(self.dim * self.in_mult, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )

    def forward(self, z: torch.Tensor, heads: torch.Tensor, tails: torch.Tensor) -> torch.Tensor:
        h = z[heads]
        t = z[tails]
        if self.in_mult == 2:
            x = torch.cat([h, t], dim=1)
        elif self.in_mult == 3:
            x = torch.cat([h, t, h * t], dim=1)
        elif self.in_mult == 4:
            x = torch.cat([h, t, torch.abs(h - t), h * t], dim=1)
        else:
            x = torch.cat([h, t], dim=1)
        return self.net(x).view(-1)


class GATEncoder(nn.Module):
    def __init__(self, num_nodes, dim=64, heads=2, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.dropout = dropout
        self.conv1 = GATConv(dim, dim, heads=heads, dropout=dropout, concat=True)
        self.conv2 = GATConv(dim * heads, dim, heads=1, dropout=dropout, concat=False)

    def forward(self, edge_index):
        x = self.emb.weight
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x


class RGCNEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, dim=128, dropout=0.2, num_bases=None):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.dropout = dropout
        self.conv1 = RGCNConv(dim, dim, num_relations, num_bases=num_bases)
        self.conv2 = RGCNConv(dim, dim, num_relations, num_bases=num_bases)

    def forward(self, edge_index, edge_type):
        x = self.emb.weight
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x


class RGATEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, dim=32, heads=2, dropout=0.2, num_bases=8):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.dropout = dropout
        self.conv1 = RGATConv(dim, dim, num_relations=num_relations, heads=heads, concat=True,
                              dropout=dropout, num_bases=num_bases)
        self.conv2 = RGATConv(dim * heads, dim, num_relations=num_relations, heads=1, concat=False,
                              dropout=dropout, num_bases=num_bases)

    def forward(self, edge_index, edge_type, return_attention=False):
        x = self.emb.weight
        if return_attention:
            x, att1 = self.conv1(x, edge_index, edge_type, return_attention_weights=True)
        else:
            x = self.conv1(x, edge_index, edge_type)
            att1 = None

        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        if return_attention:
            x, att2 = self.conv2(x, edge_index, edge_type, return_attention_weights=True)
            return x, att1, att2

        x = self.conv2(x, edge_index, edge_type)
        return x


In [15]:
def _rename_prefix(sd: dict, src: str, dst: str) -> dict:
    out = {}
    for k, v in sd.items():
        if k.startswith(src + "."):
            out[dst + k[len(src):]] = v
        else:
            out[k] = v
    return out

def _infer_scorer_in_mult(sd_sc: dict, dim: int):
    key = None
    for cand in ["net.0.weight", "mlp.0.weight"]:
        if cand in sd_sc:
            key = cand
            break
    if key is None:
        for k in sd_sc.keys():
            if k.endswith(".0.weight"):
                key = k
                break
    if key is None:
        return 2, 128
    in_features = int(sd_sc[key].shape[1])
    hidden = int(sd_sc[key].shape[0])
    in_mult = max(1, in_features // int(dim))
    return in_mult, hidden

def _infer_gat_heads(sd_enc: dict, dim: int):
    # PyG: conv1.att_src shape = [1, heads, out_channels]
    for k in ["conv1.att_src", "conv1.att_dst", "conv1.att_l", "conv1.att"]:
        if k in sd_enc:
            t = sd_enc[k]
            return int(t.shape[1])  # âœ… heads

    # fallback: bias length = heads*out_channels (concat=True)
    if "conv1.bias" in sd_enc:
        b = sd_enc["conv1.bias"]
        if b.numel() % int(dim) == 0:
            return int(b.numel() // int(dim))

    # fallback: lin rows = heads*out_channels
    for k in ["conv1.lin_src.weight", "conv1.lin.weight", "conv1.lin_l.weight"]:
        if k in sd_enc:
            w = sd_enc[k]
            if w.shape[0] % int(dim) == 0:
                return int(w.shape[0] // int(dim))

    return 2

def _infer_rgcn_num_bases(sd_enc: dict):
    if "conv1.comp" in sd_enc and "conv1.weight" in sd_enc:
        return int(sd_enc["conv1.weight"].shape[0])
    return None

def _infer_rgat_heads_bases(sd_enc: dict):
    heads = None
    if "conv1.q" in sd_enc:
        heads = int(sd_enc["conv1.q"].shape[1])
    num_bases = int(sd_enc["conv1.basis"].shape[0]) if "conv1.basis" in sd_enc else 8
    return heads, num_bases

def find_ckpt(models_dir: Path, model_name: str):
    p = models_dir / f"{model_name}.pt"
    if p.exists():
        return p
    cands = sorted(models_dir.glob(f"*{model_name}*.pt"))
    return cands[0] if cands else None

def load_model(model_name: str, cfg: dict, g: dict, models_dir: Path):
    ckpt_path = find_ckpt(models_dir, model_name)
    if ckpt_path is None:
        return None, None, None, None

    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd_enc = ckpt["encoder"]
    sd_sc  = ckpt["scorer"]

    dim = int(sd_enc["emb.weight"].shape[1])
    dropout = float(cfg["model"].get("dropout", 0.2))

    if any(k.startswith("mlp.") for k in sd_sc.keys()) and not any(k.startswith("net.") for k in sd_sc.keys()):
        sd_sc = _rename_prefix(sd_sc, "mlp", "net")

    in_mult, hidden = _infer_scorer_in_mult(sd_sc, dim)

    scorer = MLPLinkScorer(dim=dim, hidden=hidden, dropout=dropout, in_mult=in_mult)
    scorer.load_state_dict(sd_sc, strict=True)

    if model_name == "gat":
        heads = _infer_gat_heads(sd_enc, dim)
        encoder = GATEncoder(num_nodes=int(g["num_nodes"]), dim=dim, heads=heads, dropout=dropout)
        encoder.load_state_dict(sd_enc, strict=True)
        return encoder, scorer, ckpt_path, {"dim": dim, "heads": heads, "in_mult": in_mult}

    if model_name == "rgcn":
        num_bases = _infer_rgcn_num_bases(sd_enc)
        encoder = RGCNEncoder(num_nodes=int(g["num_nodes"]), num_relations=int(g["num_relations"]),
                              dim=dim, dropout=dropout, num_bases=num_bases)
        encoder.load_state_dict(sd_enc, strict=True)
        return encoder, scorer, ckpt_path, {"dim": dim, "num_bases": num_bases, "in_mult": in_mult}

    if model_name == "rgat":
        heads, num_bases = _infer_rgat_heads_bases(sd_enc)
        if heads is None:
            heads = int(cfg["model"].get("rgat_heads", 2))
        encoder = RGATEncoder(num_nodes=int(g["num_nodes"]), num_relations=int(g["num_relations"]),
                              dim=dim, heads=heads, dropout=dropout, num_bases=num_bases)
        encoder.load_state_dict(sd_enc, strict=True)
        return encoder, scorer, ckpt_path, {"dim": dim, "heads": heads, "num_bases": num_bases, "in_mult": in_mult}

    return None, None, None, None


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() and cfg["train"].get("use_cuda", True) else "cpu")
print("device:", device)

models_to_try = ["gat", "rgcn", "rgat"]
models = {}

for name in models_to_try:
    enc, sc, path, info = load_model(name, cfg, g, models_dir)
    if enc is None:
        print(f"[{name}] not found / load failed -> skip")
        continue
    enc.to(device).eval()
    sc.to(device).eval()
    models[name] = {"encoder": enc, "scorer": sc, "path": path, "info": info}
    print(f"[{name}] loaded:", path.name, "| info:", info)

if len(models) == 0:
    raise RuntimeError("No models loaded. Check output/models/*.pt")

ei_train_d = ei_train.to(device)
et_train_d = et_train.to(device)


device: cpu
[gat] loaded: gat.pt | info: {'dim': 128, 'heads': 4, 'in_mult': 3}
[rgcn] loaded: rgcn.pt | info: {'dim': 128, 'num_bases': 30, 'in_mult': 3}
[rgat] loaded: rgat.pt | info: {'dim': 32, 'heads': 2, 'num_bases': 8, 'in_mult': 3}


In [17]:
@torch.no_grad()
def compute_z(model_name: str, encoder, ei, et):
    if model_name == "gat":
        return encoder(ei)
    if model_name == "rgcn":
        return encoder(ei, et)
    if model_name == "rgat":
        return encoder(ei, et, return_attention=False)
    return None

Z = {}
for name, m in models.items():
    print("Computing embeddings for", name, "...")
    z = compute_z(name, m["encoder"], ei_train_d, et_train_d)
    Z[name] = z
    print(" z shape:", tuple(z.shape))
print("Embeddings ready.")


Computing embeddings for gat ...
 z shape: (37614, 128)
Computing embeddings for rgcn ...
 z shape: (37614, 128)
Computing embeddings for rgat ...
 z shape: (37614, 32)
Embeddings ready.


In [18]:
def resolve_entity(query: str, prefix: str = None, topn: int = 10):
    q = (query or "").strip()
    if not q:
        return []
    q_low = q.lower()

    if q in entity2id:
        return [q]

    if prefix is not None and not q.startswith(prefix):
        cand = prefix + q
        if cand in entity2id:
            return [cand]

    hits = []
    for e in entity2id.keys():
        if prefix is not None and not e.startswith(prefix):
            continue
        if q_low in e.lower():
            hits.append(e)
            if len(hits) >= topn:
                break
    if hits:
        return hits

    pool = [e for e in entity2id.keys() if (prefix is None or e.startswith(prefix))]
    close = difflib.get_close_matches(q, pool, n=topn, cutoff=0.6)
    return close

def pick_first_id(query: str, prefix: str):
    cands = resolve_entity(query, prefix=prefix, topn=10)
    if not cands:
        return None, []
    return entity2id[cands[0]], cands

COMPOUND_PREFIX = "Compound::"
DISEASE_PREFIX  = "Disease::"


In [19]:
compound_ids = [eid for e, eid in entity2id.items() if e.startswith(COMPOUND_PREFIX)]
disease_ids  = [eid for e, eid in entity2id.items() if e.startswith(DISEASE_PREFIX)]
print("num compounds:", len(compound_ids))
print("num diseases :", len(disease_ids))


num compounds: 7974
num diseases : 1867


In [20]:
@torch.no_grad()
def score_pairs(model_name: str, head_ids: np.ndarray, tail_ids: np.ndarray, batch: int = 4096):
    z = Z[model_name]
    scorer = models[model_name]["scorer"]

    head_ids = np.asarray(head_ids, dtype=np.int64)
    tail_ids = np.asarray(tail_ids, dtype=np.int64)

    out = []
    for s in range(0, len(head_ids), batch):
        h = torch.tensor(head_ids[s:s+batch], dtype=torch.long, device=device)
        t = torch.tensor(tail_ids[s:s+batch], dtype=torch.long, device=device)
        logits = scorer(z, h, t)
        out.append(logits.detach().cpu().numpy())
    logits = np.concatenate(out, axis=0)
    probs = 1.0 / (1.0 + np.exp(-logits))
    return logits, probs

def pretty_entity(eid: int):
    return id2entity.get(str(int(eid)), f"ent_{eid}")

def pretty_relation(rid: int):
    return id2relation.get(str(int(rid)), f"rel_{rid}")


In [21]:
def shortest_path_with_rel(src: int, dst: int, max_hops: int = 3):
    src = int(src); dst = int(dst)
    if src == dst:
        return []
    prev = {src: None}
    q = deque([(src, 0)])
    while q:
        u, d = q.popleft()
        if d >= max_hops:
            continue
        for v, r in adj[u]:
            if v in prev:
                continue
            prev[v] = (u, r)
            if v == dst:
                path = []
                cur = dst
                while cur != src:
                    pu, pr = prev[cur]
                    path.append((pu, pr, cur))
                    cur = pu
                path.reverse()
                return path
            q.append((v, d+1))
    return None

def explain_pair(src: int, dst: int, max_hops: int = 3):
    path = shortest_path_with_rel(src, dst, max_hops=max_hops)
    if path is None:
        print(f"No directed path found within {max_hops} hops.")
        return
    if len(path) == 0:
        print("Same node.")
        return
    print("Path:")
    for u, r, v in path:
        print(f"  {pretty_entity(u)}  --[{pretty_relation(r)}]-->  {pretty_entity(v)}")


In [22]:
def recommend_drugs_for_disease(disease_query: str, top_k: int = 20, model_name: str = "rgat"):
    if model_name not in models:
        raise ValueError(f"Model '{model_name}' not loaded. Available: {list(models.keys())}")

    disease_id, cands = pick_first_id(disease_query, DISEASE_PREFIX)
    if disease_id is None:
        print("Disease not found. suggestions:", resolve_entity(disease_query, DISEASE_PREFIX, topn=10))
        return None

    heads = np.array(compound_ids, dtype=np.int64)
    tails = np.full_like(heads, disease_id)

    _, probs = score_pairs(model_name, heads, tails, batch=4096)

    idx = np.argsort(-probs)[:top_k]
    top_heads = heads[idx]
    top_probs = probs[idx]

    df = pd.DataFrame({
        "compound_id": top_heads,
        "compound": [pretty_entity(i) for i in top_heads],
        "disease_id": int(disease_id),
        "disease": pretty_entity(disease_id),
        f"prob_{model_name}": top_probs
    })

    stamp = time.strftime("%Y%m%d_%H%M%S")
    out_csv = queries_dir / f"recommend_{model_name}_{stamp}.csv"
    df.to_csv(out_csv, index=False)

    print("Disease resolved to:", pretty_entity(disease_id))
    print("Saved:", out_csv)
    return df


In [23]:
def test_drug_disease(drug_query: str, disease_query: str, max_hops: int = 3):
    drug_id, drug_cands = pick_first_id(drug_query, COMPOUND_PREFIX)
    if drug_id is None:
        print("Drug/Compound not found. suggestions:", resolve_entity(drug_query, COMPOUND_PREFIX, topn=10))
        return None

    disease_id, dis_cands = pick_first_id(disease_query, DISEASE_PREFIX)
    if disease_id is None:
        print("Disease not found. suggestions:", resolve_entity(disease_query, DISEASE_PREFIX, topn=10))
        return None

    rows = []
    for name in models.keys():
        head = np.array([drug_id], dtype=np.int64)
        tail = np.array([disease_id], dtype=np.int64)
        logits, probs = score_pairs(name, head, tail, batch=1)
        rows.append({
            "model": name,
            "logit": float(logits[0]),
            "prob": float(probs[0]),
        })

    df = pd.DataFrame(rows).sort_values("prob", ascending=False).reset_index(drop=True)

    print("Drug   :", pretty_entity(drug_id))
    print("Disease:", pretty_entity(disease_id))
    display(df)

    print("\n--- Path explanation (graph) ---")
    explain_pair(drug_id, disease_id, max_hops=max_hops)

    stamp = time.strftime("%Y%m%d_%H%M%S")
    out_json = queries_dir / f"pair_{stamp}.json"
    payload = {
        "drug": {"query": drug_query, "id": int(drug_id), "entity": pretty_entity(drug_id), "candidates": drug_cands},
        "disease": {"query": disease_query, "id": int(disease_id), "entity": pretty_entity(disease_id), "candidates": dis_cands},
        "scores": rows,
        "max_hops": int(max_hops)
    }
    with open(out_json, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print("Saved:", out_json)
    return df


In [24]:
# --- Example 1: Recommend drugs for a disease ---
df = recommend_drugs_for_disease("Disease::MESH:D014774", top_k=20, model_name="rgat")
df

# --- Example 2: Test one drug-disease pair on all available models + show a path ---
res = test_drug_disease("Compound::DB01234", "Disease::MESH:D014774", max_hops=3)
res


Disease resolved to: Disease::MESH:D017674
Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\queries\recommend_rgat_20260218_115149.csv
Drug   : Compound::DB01234
Disease: Disease::MESH:D017674


Unnamed: 0,model,logit,prob
0,gat,-3.338781,0.03426448
1,rgat,-7.606811,0.0004968077
2,rgcn,-18.566849,8.640125e-09



--- Path explanation (graph) ---
No directed path found within 3 hops.
Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\queries\pair_20260218_115149.json


Unnamed: 0,model,logit,prob
0,gat,-3.338781,0.03426448
1,rgat,-7.606811,0.0004968077
2,rgcn,-18.566849,8.640125e-09
