In [10]:
from pathlib import Path
import json
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt


In [11]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def load_yaml(path: Path):
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def load_json(path: Path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists() and (p / "output").exists():
            return p
    # fallback: at least code+data
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists():
            return p
    return start


In [12]:
project_root = find_project_root()
project_root


WindowsPath('D:/Shiraz University/HomeWorks/Ostad Moosavi/LinkPrediction')

In [13]:
config_path = project_root / "code" / "config.yaml"
cfg = load_yaml(config_path)

proc_dir = project_root / cfg["data"]["processed_dir"]
out_dir  = project_root / cfg["output"]["dir"]

fig_dir = out_dir / "figures"
log_dir = out_dir / "logs"
metrics_dir = out_dir / "metrics"

ensure_dir(fig_dir)

print("proc_dir:", proc_dir)
print("out_dir :", out_dir)
print("fig_dir :", fig_dir)
print("log_dir :", log_dir, "| exists:", log_dir.exists())
print("metrics :", metrics_dir, "| exists:", metrics_dir.exists())


proc_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed
out_dir : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output
fig_dir : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures
log_dir : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\logs | exists: True
metrics : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\metrics | exists: True


In [14]:
# --- 1) Dataset: Top relations by count (from graph_meta.json) ---
meta_path = proc_dir / "graph_meta.json"
id2rel_path = proc_dir / "id2relation.json"

if not meta_path.exists():
    raise FileNotFoundError(f"graph_meta.json not found: {meta_path}")
if not id2rel_path.exists():
    raise FileNotFoundError(f"id2relation.json not found: {id2rel_path}")

meta = load_json(meta_path)
id2relation = load_json(id2rel_path)

top = meta.get("top_relations", [])
if len(top) == 0:
    print("No top_relations found in graph_meta.json -> skip plot.")
else:
    # top: list of [rid, count]
    labels = []
    counts = []
    for rid, c in top[:20]:
        rid_str = str(rid)
        lab = id2relation.get(rid_str, f"rel_{rid}")
        labels.append(lab if len(lab) <= 45 else (lab[:45] + "..."))
        counts.append(int(c))

    plt.figure(figsize=(10, 5))
    plt.bar(range(len(counts)), counts)
    plt.xticks(range(len(counts)), labels, rotation=70, ha="right")
    plt.title("Top Relations by Count")
    plt.tight_layout()
    out_path = fig_dir / "dataset_rel_counts.png"
    plt.savefig(out_path, dpi=200)
    plt.close()

    # also save a small csv (useful for report)
    df_rel = pd.DataFrame({"relation": labels, "count": counts})
    df_rel.to_csv(fig_dir / "dataset_rel_counts_top20.csv", index=False)

    print("Saved:", out_path)
    print("Saved:", fig_dir / "dataset_rel_counts_top20.csv")


Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\dataset_rel_counts.png
Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\dataset_rel_counts_top20.csv


In [15]:
# --- 2) Dataset: Degree histogram ---
g_path = proc_dir / "graph_edges.pt"
if not g_path.exists():
    raise FileNotFoundError(f"graph_edges.pt not found: {g_path}")

g = torch.load(g_path, map_location="cpu")
edge_index = g["edge_index"]
num_nodes = int(g["num_nodes"])

# degree (in+out) for directed graph
deg = torch.zeros(num_nodes, dtype=torch.long)
ones = torch.ones(edge_index.size(1), dtype=torch.long)
deg.scatter_add_(0, edge_index[0], ones)
deg.scatter_add_(0, edge_index[1], ones)

deg_np = deg.numpy()

plt.figure(figsize=(7, 4))
plt.hist(deg_np, bins=50)
plt.title("Degree Histogram (in+out)")
plt.xlabel("degree")
plt.ylabel("count")
plt.tight_layout()
out_path = fig_dir / "dataset_degree_hist.png"
plt.savefig(out_path, dpi=200)
plt.close()

# optional: save summary stats
stats = {
    "num_nodes": int(num_nodes),
    "num_edges": int(edge_index.size(1)),
    "deg_min": int(deg.min().item()),
    "deg_max": int(deg.max().item()),
    "deg_mean": float(deg.float().mean().item()),
    "deg_median": float(np.median(deg_np)),
    "deg_p95": float(np.percentile(deg_np, 95)),
}
with open(fig_dir / "dataset_degree_stats.json", "w", encoding="utf-8") as f:
    json.dump(stats, f, indent=2, ensure_ascii=False)

print("Saved:", out_path)
print("Saved:", fig_dir / "dataset_degree_stats.json")
print(stats)


Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\dataset_degree_hist.png
Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\dataset_degree_stats.json
{'num_nodes': 37614, 'num_edges': 118308, 'deg_min': 1, 'deg_max': 373, 'deg_mean': 6.2906365394592285, 'deg_median': 2.0, 'deg_p95': 24.0}


In [16]:
# --- 3) Training curves (loss) from output/logs/*_train.csv ---
def load_log(name: str):
    p = log_dir / f"{name}_train.csv"
    if p.exists():
        return pd.read_csv(p), p
    return None, p

logs = {}
paths = {}
for m in ["gat", "rgcn", "rgat"]:
    df, p = load_log(m)
    logs[m] = df
    paths[m] = p

for m in ["gat", "rgcn", "rgat"]:
    print(m, "->", "FOUND" if logs[m] is not None else "MISSING", "|", paths[m])

plt.figure(figsize=(7, 4))
has_any = False
for m, df in logs.items():
    if df is None:
        continue
    if "epoch" in df.columns and "loss" in df.columns:
        plt.plot(df["epoch"], df["loss"], label=m)
        has_any = True

if has_any:
    plt.legend()
    plt.title("Training Loss Curves")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.tight_layout()
    out_path = fig_dir / "train_loss_curves.png"
    plt.savefig(out_path, dpi=200)
    plt.close()
    print("Saved:", out_path)
else:
    plt.close()
    print("No training logs found (or missing columns epoch/loss). Skip loss plot.")


gat -> FOUND | D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\logs\gat_train.csv
rgcn -> FOUND | D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\logs\rgcn_train.csv
rgat -> FOUND | D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\logs\rgat_train.csv
Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\train_loss_curves.png


In [17]:
# --- 4) Model comparison plot (from output/metrics/comparison.csv) ---
comp_path = metrics_dir / "comparison.csv"
if not comp_path.exists():
    print("comparison.csv not found -> skip model comparison plot:", comp_path)
else:
    comp = pd.read_csv(comp_path)
    print("comparison columns:", list(comp.columns))

    # try to find model column
    model_col = "model" if "model" in comp.columns else (comp.columns[0] if len(comp.columns) else None)
    if model_col is None:
        print("Could not determine model column -> skip.")
    else:
        # pick metrics that exist
        candidate_metrics = [
            "bin_roc_auc", "roc_auc", "pr_auc",
            "mrr", "hits@10", "hits@3", "hits@1",
            "accuracy@0"
        ]
        metrics = [m for m in candidate_metrics if m in comp.columns]
        metrics = metrics[:3]  # keep plot readable

        if len(metrics) == 0:
            print("No known metric columns found in comparison.csv -> skip.")
        else:
            x = np.arange(len(comp[model_col]))
            width = 0.25 if len(metrics) >= 3 else (0.35 if len(metrics) == 2 else 0.5)

            plt.figure(figsize=(8, 4))
            for i, met in enumerate(metrics):
                vals = comp[met].values
                plt.bar(x + i * width, vals, width=width, label=met)

            plt.xticks(x + width * (len(metrics)-1) / 2.0, comp[model_col].values)
            plt.legend()
            plt.title("Model Comparison (Test)")
            plt.tight_layout()
            out_path = fig_dir / "model_comparison.png"
            plt.savefig(out_path, dpi=200)
            plt.close()
            print("Saved:", out_path)

            # also save a cleaned table for report
            comp[[model_col] + metrics].to_csv(fig_dir / "model_comparison_table.csv", index=False)
            print("Saved:", fig_dir / "model_comparison_table.csv")


comparison columns: ['model', 'bin_roc_auc', 'bin_pr_auc', 'bin_accuracy@0', 'mrr', 'hits@1', 'hits@3', 'hits@10', 'mean_rank', 'num_test', 'num_negs_per_pos']
Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\model_comparison.png
Saved: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures\model_comparison_table.csv


In [18]:
print("[Figures] Done. Check:", fig_dir)
print("Files in figures:")
for p in sorted(fig_dir.glob("*")):
    print(" -", p.name)


[Figures] Done. Check: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\figures
Files in figures:
 - attention_by_relation.png
 - dataset_degree_hist.png
 - dataset_degree_stats.json
 - dataset_rel_counts.png
 - dataset_rel_counts_top20.csv
 - model_comparison.png
 - model_comparison_table.csv
 - train_loss_curves.png
