In [1]:
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F


In [1]:
from torch_geometric.nn import GATConv, RGCNConv

# RGATConv ممکنه در بعضی نسخه‌ها نباشه
try:
    from torch_geometric.nn import RGATConv
    HAS_RGAT = True
except Exception as e:
    HAS_RGAT = False
    print("⚠️ cant find rgat")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def load_yaml(path: Path):
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists():
            return p
    return start

project_root = find_project_root()
project_root


WindowsPath('D:/Shiraz University/HomeWorks/Ostad Moosavi/LinkPrediction')

In [4]:
config_path = project_root / "code" / "config.yaml"
cfg = load_yaml(config_path)

proc_dir = project_root / cfg["data"]["processed_dir"]
out_dir  = project_root / cfg["output"]["dir"]

metrics_dir = out_dir / "metrics"
ensure_dir(metrics_dir)

print("proc_dir:", proc_dir)
print("out_dir :", out_dir)
print("metrics_dir:", metrics_dir)


proc_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed
out_dir : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output
metrics_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\metrics


In [5]:
g = torch.load(proc_dir / "graph_edges.pt")
edge_index = g["edge_index"]
edge_type  = g["edge_type"]
num_nodes = int(g["num_nodes"])
num_relations = int(g["num_relations"])

keep_idx = np.load(proc_dir / "train_graph_edge_idx.npy")

splits = np.load(proc_dir / "split_target_edges.npz")
test_pos = splits["test_pos"]

negs = np.load(proc_dir / "negatives.npz", allow_pickle=True)
test_neg = negs["test_neg"]
K = int(negs["K"])

print("edge_index:", edge_index.shape)
print("edge_type :", edge_type.shape)
print("num_nodes:", num_nodes, "num_relations:", num_relations)
print("test_pos shape:", test_pos.shape)
print("test_neg shape:", test_neg.shape, "K:", K)


edge_index: torch.Size([2, 118308])
edge_type : torch.Size([118308])
num_nodes: 37614 num_relations: 107
test_pos shape: (2, 38)
test_neg shape: (2, 1900) K: 50


In [6]:
def filter_edges_by_idx(edge_index, edge_type, keep_idx):
    keep_idx_t = torch.tensor(keep_idx, dtype=torch.long)
    ei = edge_index[:, keep_idx_t]
    et = edge_type[keep_idx_t]
    return ei, et


In [7]:
def binary_metrics(y_true, y_score):
    y_true = np.asarray(y_true).astype(int)
    y_score = np.asarray(y_score).astype(float)

    out = {}
    try:
        from sklearn.metrics import roc_auc_score, average_precision_score
        out["roc_auc"] = float(roc_auc_score(y_true, y_score))
        out["pr_auc"]  = float(average_precision_score(y_true, y_score))
    except Exception:
        out["roc_auc"] = None
        out["pr_auc"]  = None

    # threshold=0 روی logits
    y_pred = (y_score >= 0).astype(int)
    out["accuracy@0"] = float((y_pred == y_true).mean())
    return out

def sampled_ranking_metrics(pos_scores, neg_scores, ks=(1,3,10)):
    """
    pos_scores: [N]
    neg_scores: [N, K]
    rank = 1 + count(neg > pos)
    """
    pos_scores = np.asarray(pos_scores).reshape(-1)
    neg_scores = np.asarray(neg_scores)
    N = len(pos_scores)
    K = neg_scores.shape[1]

    ranks = []
    for i in range(N):
        rank = 1 + int((neg_scores[i] > pos_scores[i]).sum())
        ranks.append(rank)

    ranks = np.array(ranks)
    out = {}
    out["mrr"] = float((1.0 / ranks).mean())
    for k in ks:
        out[f"hits@{k}"] = float((ranks <= k).mean())
    out["mean_rank"] = float(ranks.mean())
    out["num_test"] = int(N)
    out["num_negs_per_pos"] = int(K)
    return out


In [8]:
class GATEncoder(nn.Module):
    def __init__(self, num_nodes, dim, heads=4, dropout=0.2):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.conv1 = GATConv(dim, dim, heads=heads, dropout=dropout, concat=True)
        self.conv2 = GATConv(dim * heads, dim, heads=1, dropout=dropout, concat=False)
        self.dropout = dropout

    def forward(self, edge_index):
        x = self.emb.weight
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class RGCNEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, dim=128, dropout=0.2, num_bases=None):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.dropout = dropout
        self.conv1 = RGCNConv(dim, dim, num_relations, num_bases=num_bases)
        self.conv2 = RGCNConv(dim, dim, num_relations, num_bases=num_bases)

    def forward(self, edge_index, edge_type):
        x = self.emb.weight
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x

class RGATEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, dim=32, heads=2, dropout=0.2, num_bases=8):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.dropout = dropout

        self.conv1 = RGATConv(
            in_channels=dim,
            out_channels=dim,
            num_relations=num_relations,
            heads=heads,
            concat=True,
            dropout=dropout,
            num_bases=num_bases,   # ✅ خیلی مهم
        )

        self.conv2 = RGATConv(
            in_channels=dim * heads,
            out_channels=dim,
            num_relations=num_relations,
            heads=1,
            concat=False,
            dropout=dropout,
            num_bases=num_bases,   # ✅ خیلی مهم
        )

    def forward(self, edge_index, edge_type):
        x = self.emb.weight
        x = self.conv1(x, edge_index, edge_type)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x

class MLPLinkScorer(nn.Module):
    def __init__(self, dim, hidden=128, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim * 3, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )

    def forward(self, h_u, h_v):
        x = torch.cat([h_u, h_v, h_u * h_v], dim=-1)
        return self.net(x).squeeze(-1)

def batch_score(node_emb, heads, tails, scorer):
    return scorer(node_emb[heads], node_emb[tails])


In [9]:
import torch

def _infer_gat_heads_from_state_dict(sd, dim):
    # PyG GATConv معمولاً lin_src.weight با شکل [heads*out_channels, in_channels] داره
    # در encoder ما out_channels = dim و in_channels = dim
    for key in ["conv1.lin_src.weight", "conv1.lin.weight", "conv1.lin_l.weight"]:
        if key in sd:
            h = sd[key].shape[0] // dim
            return int(h)
    # fallback
    return 2

def _infer_rgcn_num_bases_from_state_dict(sd):
    # ✅ RGCNConv وقتی num_bases داشته باشه، کلید comp ایجاد میشه
    if "conv1.comp" in sd and "conv1.weight" in sd:
        return int(sd["conv1.weight"].shape[0])  # مثلا 30
    return None

def _infer_rgat_num_bases_from_state_dict(sd):
    return int(sd["conv1.basis"].shape[0]) if "conv1.basis" in sd else 8

def load_model(model_name, cfg, g, out_dir):
    ckpt_path = out_dir / "models" / f"{model_name}.pt"
    if not ckpt_path.exists():
        return None, None

    ckpt = torch.load(ckpt_path, map_location="cpu")
    enc_sd = ckpt["encoder"]
    sc_sd  = ckpt["scorer"]

    # dim را همیشه از emb.weight می‌گیریم (قابل اعتمادترین)
    if "emb.weight" not in enc_sd:
        print(f"[{model_name}] emb.weight not found in checkpoint -> skip")
        return None, None
    dim_ckpt = int(enc_sd["emb.weight"].shape[1])

    dropout = float(cfg["model"].get("dropout", 0.2))
    hidden  = int(ckpt.get("cfg", cfg).get("model", {}).get("mlp_hidden", cfg["model"].get("mlp_hidden", 128)))

    # ---------------- GAT ----------------
    if model_name == "gat":
        heads_ckpt = _infer_gat_heads_from_state_dict(enc_sd, dim_ckpt)

        encoder = GATEncoder(
            num_nodes=int(g["num_nodes"]),
            dim=dim_ckpt,
            heads=heads_ckpt,
            dropout=dropout
        )
        encoder.load_state_dict(enc_sd, strict=True)

        scorer = MLPLinkScorer(dim=dim_ckpt, hidden=hidden, dropout=dropout)
        scorer.load_state_dict(sc_sd, strict=True)
        return encoder, scorer

    # ---------------- RGCN ----------------
    if model_name == "rgcn":
        num_bases_ckpt = _infer_rgcn_num_bases_from_state_dict(enc_sd)

        # ✅ اگر چک‌پوینت decomposition داشته ولی گراف فعلی num_relations متفاوت باشه، همسان نیست
        if "conv1.comp" in enc_sd:
            if enc_sd["conv1.comp"].shape[0] != int(g["num_relations"]):
                print(f"[rgcn] relation mismatch: ckpt={enc_sd['conv1.comp'].shape[0]} graph={int(g['num_relations'])} -> skip")
                return None, None

        encoder = RGCNEncoder(
            num_nodes=int(g["num_nodes"]),
            num_relations=int(g["num_relations"]),
            dim=dim_ckpt,
            num_bases=num_bases_ckpt,   # ✅ این خط کلید حل مشکل است
            dropout=dropout
        )
        encoder.load_state_dict(enc_sd, strict=True)

        scorer = MLPLinkScorer(dim=dim_ckpt, hidden=hidden, dropout=dropout)
        scorer.load_state_dict(sc_sd, strict=True)
        return encoder, scorer
    # ---------------- RGAT ----------------
    if model_name == "rgat":
        # heads را از conv1.q می‌گیریم (در چک‌پوینت‌های RGAT ما جواب می‌دهد)
        if "conv1.q" not in enc_sd:
            print("[rgat] conv1.q not found -> skip")
            return None, None
        heads_ckpt = int(enc_sd["conv1.q"].shape[1])

        num_bases_ckpt = _infer_rgat_num_bases_from_state_dict(enc_sd)

        encoder = RGATEncoder(
            num_nodes=int(g["num_nodes"]),
            num_relations=int(g["num_relations"]),
            dim=dim_ckpt,
            heads=heads_ckpt,
            dropout=dropout,
            num_bases=num_bases_ckpt
        )
        encoder.load_state_dict(enc_sd, strict=True)

        scorer = MLPLinkScorer(dim=dim_ckpt, hidden=hidden, dropout=dropout)
        scorer.load_state_dict(sc_sd, strict=True)
        return encoder, scorer

    return None, None


In [10]:
from pathlib import Path
import os

def find_project_root(start=None):
    start = Path.cwd() if start is None else Path(start)
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists() and (p / "output").exists():
            return p
    return start

project_root = find_project_root()
os.chdir(project_root)  # ✅ مهم: از این به بعد relative path ها درست می‌شوند

print("CWD:", Path.cwd())
print("rgcn exists?", (Path("output/models/rgcn.pt")).exists())
print("files:", list(Path("output/models").glob("*.pt")))


CWD: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction
rgcn exists? True
files: [WindowsPath('output/models/gat.pt'), WindowsPath('output/models/rgat.pt'), WindowsPath('output/models/rgcn.pt')]


In [11]:
def eval_one(model_name, encoder, scorer, g, keep_idx, test_pos, test_neg, K, device):
    # گراف پیام‌رسانی = فقط یال‌های keep_idx (بدون لیک test)
    if model_name == "gat":
        ei, _ = filter_edges_by_idx(g["edge_index"], torch.zeros(g["edge_index"].size(1)), keep_idx)
        ei = ei.to(device)
        et = None
    else:
        ei, et = filter_edges_by_idx(g["edge_index"], g["edge_type"], keep_idx)
        ei = ei.to(device)
        et = et.to(device)

    # reshape negatives: [2, N, K]
    N = test_pos.shape[1]
    test_neg_reshaped = test_neg.reshape(2, N, K)

    encoder.eval()
    scorer.eval()

    with torch.no_grad():
        if model_name == "gat":
            z = encoder(ei)
        else:
            z = encoder(ei, et)

        # positives
        pos_h = torch.tensor(test_pos[0], dtype=torch.long, device=device)
        pos_t = torch.tensor(test_pos[1], dtype=torch.long, device=device)
        pos_logits = batch_score(z, pos_h, pos_t, scorer).detach().cpu().numpy()

        # negatives (flatten)
        neg_h = torch.tensor(test_neg_reshaped[0].reshape(-1), dtype=torch.long, device=device)
        neg_t = torch.tensor(test_neg_reshaped[1].reshape(-1), dtype=torch.long, device=device)
        neg_logits = batch_score(z, neg_h, neg_t, scorer).detach().cpu().numpy()

        # binary metrics: pos vs all neg edges
        y_true = np.concatenate([np.ones_like(pos_logits), np.zeros_like(neg_logits)])
        y_score = np.concatenate([pos_logits, neg_logits])
        m_bin = binary_metrics(y_true, y_score)

        # ranking metrics: compare pos with its K negatives
        neg_logits_mat = neg_logits.reshape(N, K)
        m_rank = sampled_ranking_metrics(pos_logits, neg_logits_mat, ks=(1,3,10))

    out = {"model": model_name}
    out.update({f"bin_{k}": v for k, v in m_bin.items()})
    out.update(m_rank)
    return out


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() and cfg["train"].get("use_cuda", True) else "cpu")
print("device:", device)

results = []
model_names = ["gat", "rgcn", "rgat"]

for name in model_names:
    encoder, scorer = load_model(name, cfg, g, out_dir)
    if encoder is None:
        print(f"Skip {name} (not available or load failed).")
        continue

    encoder.to(device)
    scorer.to(device)

    res = eval_one(name, encoder, scorer, g, keep_idx, test_pos, test_neg, K, device)
    results.append(res)

    # save per-model json
    with open(metrics_dir / f"metrics_{name}.json", "w", encoding="utf-8") as f:
        json.dump(res, f, indent=2, ensure_ascii=False)

df = pd.DataFrame(results)
df.to_csv(metrics_dir / "comparison.csv", index=False)

df


device: cpu


Unnamed: 0,model,bin_roc_auc,bin_pr_auc,bin_accuracy@0,mrr,hits@1,hits@3,hits@10,mean_rank,num_test,num_negs_per_pos
0,gat,0.575748,0.041689,0.853457,0.159405,0.052632,0.157895,0.315789,22.684211,38,50
1,rgcn,0.639806,0.04226,0.889577,0.17938,0.078947,0.157895,0.421053,17.131579,38,50
2,rgat,0.73232,0.052569,0.849845,0.190849,0.052632,0.157895,0.605263,14.763158,38,50


In [13]:
cols = ["model", "bin_roc_auc", "bin_pr_auc", "mrr", "hits@10"]
display(df[cols])
print("Saved:", metrics_dir / "comparison.csv")


Unnamed: 0,model,bin_roc_auc,bin_pr_auc,mrr,hits@10
0,gat,0.575748,0.041689,0.159405,0.315789
1,rgcn,0.639806,0.04226,0.17938,0.421053
2,rgat,0.73232,0.052569,0.190849,0.605263


Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\metrics\comparison.csv
