In [5]:
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


In [6]:
# torch_geometric
from torch_geometric.nn import RGATConv
from torch_geometric.utils import k_hop_subgraph

import networkx as nx


In [7]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def load_yaml(path: Path):
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def load_json(path: Path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def save_json(obj, path: Path):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists() and (p / "output").exists():
            return p
    return start


In [8]:
project_root = find_project_root()
project_root


WindowsPath('D:/Shiraz University/HomeWorks/Ostad Moosavi/LinkPrediction')

In [9]:
config_path = project_root / "code" / "config.yaml"
cfg = load_yaml(config_path)

proc_dir = project_root / cfg["data"]["processed_dir"]
out_dir  = project_root / cfg["output"]["dir"]

fig_dir = out_dir / "figures"
sub_dir = out_dir / "subgraphs"
met_dir = out_dir / "metrics"

ensure_dir(fig_dir)
ensure_dir(sub_dir)
ensure_dir(met_dir)

print("proc_dir:", proc_dir)
print("out_dir :", out_dir)


proc_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed
out_dir : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output


In [10]:
# load graph + maps
g = torch.load(proc_dir / "graph_edges.pt")

edge_index = g["edge_index"]
edge_type  = g["edge_type"]
num_nodes  = int(g["num_nodes"])
num_relations = int(g["num_relations"])

id2entity = load_json(proc_dir / "id2entity.json")
id2relation = load_json(proc_dir / "id2relation.json")

print("num_nodes:", num_nodes)
print("num_relations:", num_relations)
print("edges:", edge_index.size(1))

# train-graph edges (remove val/test target edges)
keep_idx = np.load(proc_dir / "train_graph_edge_idx.npy")

def filter_edges_by_idx(edge_index, edge_type, keep_idx):
    keep_idx_t = torch.tensor(keep_idx, dtype=torch.long)
    return edge_index[:, keep_idx_t], edge_type[keep_idx_t]

ei_train, et_train = filter_edges_by_idx(edge_index, edge_type, keep_idx)
print("train edges:", ei_train.size(1))


num_nodes: 37614
num_relations: 107
edges: 118308
train edges: 118233


In [11]:
# models
class MLPLinkScorer(nn.Module):
    def __init__(self, dim: int, hidden: int = 128, dropout: float = 0.2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim * 2, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )

    def forward(self, z, heads, tails):
        x = torch.cat([z[heads], z[tails]], dim=1)
        return self.mlp(x).view(-1)

class RGATEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, dim=32, heads=2, dropout=0.2, num_bases=8):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.dropout = dropout

        self.conv1 = RGATConv(
            in_channels=dim,
            out_channels=dim,
            num_relations=num_relations,
            heads=heads,
            concat=True,
            dropout=dropout,
            num_bases=num_bases,
        )
        self.conv2 = RGATConv(
            in_channels=dim * heads,
            out_channels=dim,
            num_relations=num_relations,
            heads=1,
            concat=False,
            dropout=dropout,
            num_bases=num_bases,
        )

    def forward(self, edge_index, edge_type, return_attention=False):
        x = self.emb.weight
        if return_attention:
            x, att1 = self.conv1(x, edge_index, edge_type, return_attention_weights=True)
        else:
            x = self.conv1(x, edge_index, edge_type)
            att1 = None

        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        if return_attention:
            x, att2 = self.conv2(x, edge_index, edge_type, return_attention_weights=True)
            return x, att1, att2

        x = self.conv2(x, edge_index, edge_type)
        return x


In [14]:
# load rgat checkpoint (safe: infer dim/heads/bases from state_dict)
ckpt_path = out_dir / "models" / "rgat.pt"
if not ckpt_path.exists():
    raise FileNotFoundError(f"rgat checkpoint not found: {ckpt_path}")

ckpt = torch.load(ckpt_path, map_location="cpu")
sd_enc = ckpt["encoder"]

dim_ckpt = int(sd_enc["emb.weight"].shape[1])
heads_ckpt = int(sd_enc["conv1.q"].shape[1]) if "conv1.q" in sd_enc else int(cfg["model"].get("rgat_heads", 2))
num_bases_ckpt = int(sd_enc["conv1.basis"].shape[0]) if "conv1.basis" in sd_enc else int(cfg["model"].get("rgat_num_bases", 8))

dropout = float(cfg["model"].get("dropout", 0.2))
hidden = int(ckpt.get("cfg", cfg).get("model", {}).get("mlp_hidden", cfg["model"].get("mlp_hidden", 128)))

encoder = RGATEncoder(num_nodes=num_nodes,
                      num_relations=num_relations,
                      dim=dim_ckpt,
                      heads=heads_ckpt,
                      dropout=dropout,
                      num_bases=num_bases_ckpt)
scorer = MLPLinkScorer(dim=dim_ckpt, hidden=hidden, dropout=dropout)

encoder.load_state_dict(ckpt["encoder"], strict=True)


device = torch.device("cuda" if torch.cuda.is_available() and cfg["train"].get("use_cuda", True) else "cpu")
encoder.to(device)
scorer.to(device)

ei_train = ei_train.to(device)
et_train = et_train.to(device)

encoder.eval()
print("loaded rgat:", ckpt_path.name, "| dim:", dim_ckpt, "| heads:", heads_ckpt, "| bases:", num_bases_ckpt, "| device:", device)


loaded rgat: rgat.pt | dim: 32 | heads: 2 | bases: 8 | device: cpu


In [15]:
# extract attention (layer1)
@torch.no_grad()
def get_edge_attention(encoder, edge_index, edge_type):
    z, att1, att2 = encoder(edge_index, edge_type, return_attention=True)

    # att1, att2: (edge_index_used, alpha)
    eidx_used, alpha = att1
    if alpha.dim() == 2:
        w = alpha.mean(dim=1)
    else:
        w = alpha.view(-1)

    return w.detach().cpu().numpy()

w_all = get_edge_attention(encoder, ei_train, et_train)

E = int(ei_train.size(1))
if len(w_all) >= E:
    w = w_all[:E]
else:
    # fallback: pad
    w = np.pad(w_all, (0, E - len(w_all)), constant_values=np.mean(w_all) if len(w_all) > 0 else 0.0)

et_cpu = et_train.detach().cpu().numpy()
print("w:", w.shape, "| et:", et_cpu.shape)
print("w stats:", float(np.min(w)), float(np.mean(w)), float(np.max(w)))


w: (118233,) | et: (118233,)
w stats: 0.0011619306169450283 0.25422683358192444 1.0


In [16]:
# relation-level attention stats + plot
topk = int(cfg.get("attention", {}).get("topk_relations_plot", 20))

rel_sum = {}
rel_cnt = {}
for rid, wi in zip(et_cpu.tolist(), w.tolist()):
    rel_sum[rid] = rel_sum.get(rid, 0.0) + float(wi)
    rel_cnt[rid] = rel_cnt.get(rid, 0) + 1

rows = []
for rid in rel_sum:
    avg = rel_sum[rid] / max(1, rel_cnt[rid])
    name = id2relation.get(str(rid), f"rel_{rid}")
    rows.append((rid, name, avg, rel_cnt[rid]))

df_rel = pd.DataFrame(rows, columns=["relation_id", "relation_name", "avg_attention", "count"])
df_rel = df_rel.sort_values("avg_attention", ascending=False).reset_index(drop=True)

csv_path = met_dir / "attention_by_relation.csv"
df_rel.to_csv(csv_path, index=False, encoding="utf-8")
print("saved:", csv_path)

# plot topk
top = df_rel.head(topk).copy()
labels = top["relation_name"].tolist()
vals = top["avg_attention"].tolist()
labels = [s if len(s) <= 45 else (s[:45] + "...") for s in labels]

plt.figure(figsize=(10, 5))
plt.bar(range(len(vals)), vals)
plt.xticks(range(len(vals)), labels, rotation=70, ha="right")
plt.title("Average Attention by Relation (Top)")
plt.tight_layout()

fig_path = fig_dir / "attention_by_relation.png"
plt.savefig(fig_path, dpi=200)
plt.close()
print("saved:", fig_path)


saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\metrics\attention_by_relation.csv
saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\attention_by_relation.png


In [17]:
# case studies: k-hop subgraphs (edge width ~ attention)
spl_path = proc_dir / "split_target_edges.npz"
if not spl_path.exists():
    print("split_target_edges.npz not found -> skip subgraphs")
else:
    spl = np.load(spl_path)
    test_pos = spl["test_pos"]  # [2, N]
    num_cases = int(cfg.get("attention", {}).get("num_cases", 3))
    k_hops = int(cfg.get("attention", {}).get("k_hops", 2))

    set_seed(int(cfg.get("seed", 42)))
    rng = np.random.default_rng(int(cfg.get("seed", 42)))

    if test_pos.shape[1] == 0:
        print("no test positives -> skip")
    else:
        picks = rng.choice(test_pos.shape[1], size=min(num_cases, test_pos.shape[1]), replace=False)

        # for k_hop_subgraph edge_mask, we need CPU edge_index
        ei_cpu = ei_train.detach().cpu()
        et_cpu_t = et_train.detach().cpu()

        for ci, j in enumerate(picks, start=1):
            src = int(test_pos[0, j])
            dst = int(test_pos[1, j])

            subset, sub_ei, mapping, edge_mask = k_hop_subgraph(
                [src, dst],
                num_hops=k_hops,
                edge_index=ei_cpu,
                relabel_nodes=True,
                num_nodes=num_nodes
            )

            # attention for subgraph edges (aligned with input edge list)
            mask_np = edge_mask.detach().cpu().numpy()
            w_sub = w[mask_np]
            et_sub = et_cpu_t.detach().cpu().numpy()[mask_np]

            # build nx graph
            Gnx = nx.DiGraph()
            subset_np = subset.detach().cpu().numpy().tolist()
            sub_ei_np = sub_ei.detach().cpu().numpy()

            for new_id, old_id in enumerate(subset_np):
                lab = id2entity.get(str(old_id), str(old_id))
                if len(lab) > 35:
                    lab = lab[:35] + "..."
                Gnx.add_node(new_id, label=lab)

            widths = []
            for e in range(sub_ei_np.shape[1]):
                u = int(sub_ei_np[0, e]); v = int(sub_ei_np[1, e])
                rid = int(et_sub[e]) if e < len(et_sub) else -1
                rel = id2relation.get(str(rid), f"r{rid}")
                Gnx.add_edge(u, v, weight=float(w_sub[e]), relation=rel)
                widths.append(float(w_sub[e]))

            widths = np.array(widths, dtype=float) if len(widths) > 0 else np.array([0.0])
            if len(widths) > 0:
                wmin, wmax = float(widths.min()), float(widths.max())
                widths = 0.7 + 4.0 * (widths - wmin) / (wmax - wmin + 1e-9)
            widths = widths.tolist()

            plt.figure(figsize=(10, 7))
            pos = nx.spring_layout(Gnx, seed=1)

            nx.draw_networkx_nodes(Gnx, pos, node_size=400)
            nx.draw_networkx_edges(Gnx, pos, width=widths, arrows=True, alpha=0.8)
            nx.draw_networkx_labels(Gnx, pos, labels={n: Gnx.nodes[n]["label"] for n in Gnx.nodes()}, font_size=7)

            plt.title(f"Case {ci}: {k_hops}-hop subgraph (edge width ~ attention)")
            plt.axis("off")
            plt.tight_layout()

            out_path = sub_dir / f"case_{ci:03d}.png"
            plt.savefig(out_path, dpi=200)
            plt.close()
            print("saved subgraph:", out_path)


saved subgraph: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\subgraphs\case_001.png
saved subgraph: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\subgraphs\case_002.png
saved subgraph: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\subgraphs\case_003.png


In [18]:
# save a short summary json for report
summary = {
    "rgat_checkpoint": str((out_dir / "models" / "rgat.pt").as_posix()),
    "dim": dim_ckpt,
    "heads": heads_ckpt,
    "num_bases": num_bases_ckpt,
    "topk": int(cfg.get("attention", {}).get("topk_relations_plot", 20)),
    "generated_files": {
        "figure_attention_by_relation": str((fig_dir / "attention_by_relation.png").as_posix()),
        "csv_attention_by_relation": str((met_dir / "attention_by_relation.csv").as_posix()),
    },
}

save_json(summary, met_dir / "attention_summary.json")
print("saved:", met_dir / "attention_summary.json")


saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\metrics\attention_summary.json
