In [67]:
from pathlib import Path
import csv
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [68]:
from torch_geometric.nn import RGATConv

In [69]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def load_yaml(path: Path):
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists():
            return p
    return start

project_root = find_project_root()
project_root


WindowsPath('D:/Shiraz University/HomeWorks/Ostad Moosavi/LinkPrediction')

In [70]:
def set_seed(seed: int = 42):
    import random, os
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [71]:
config_path = project_root / "code" / "config.yaml"
cfg = load_yaml(config_path)

set_seed(cfg.get("seed", 42))

proc_dir = project_root / cfg["data"]["processed_dir"]
out_dir  = project_root / cfg["output"]["dir"]

ensure_dir(out_dir / "models")
ensure_dir(out_dir / "logs")

print("proc_dir:", proc_dir)
print("out_dir :", out_dir)


proc_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed
out_dir : D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output


In [72]:
g = torch.load(proc_dir / "graph_edges.pt")
edge_index = g["edge_index"]
edge_type  = g["edge_type"]

num_nodes = int(g["num_nodes"])
num_relations = int(g["num_relations"])

keep_idx = np.load(proc_dir / "train_graph_edge_idx.npy")

splits = np.load(proc_dir / "split_target_edges.npz")
train_pos = splits["train_pos"]
val_pos   = splits["val_pos"]

negs = np.load(proc_dir / "negatives.npz", allow_pickle=True)
val_neg = negs["val_neg"]

print("edge_index:", edge_index.shape)
print("edge_type :", edge_type.shape)
print("num_nodes:", num_nodes, "num_relations:", num_relations)
print("train_pos:", train_pos.shape, "val_pos:", val_pos.shape)
print("val_neg :", val_neg.shape)


edge_index: torch.Size([2, 118308])
edge_type : torch.Size([118308])
num_nodes: 37614 num_relations: 107
train_pos: (2, 299) val_pos: (2, 37)
val_neg : (2, 1850)


In [73]:
keep_idx_t = torch.tensor(keep_idx, dtype=torch.long)

ei_train = edge_index[:, keep_idx_t]
et_train = edge_type[keep_idx_t]

print("ei_train:", ei_train.shape)
print("et_train:", et_train.shape)


ei_train: torch.Size([2, 118233])
et_train: torch.Size([118233])


In [74]:
def build_pair_set(heads, tails):
    return set(zip(heads, tails))

def sample_negatives_on_the_fly(batch_heads, tail_candidates, existing_pairs_set,
                               num_negs_per_pos=1, seed=0):
    import random
    rng = random.Random(seed)
    tail_candidates = list(tail_candidates)

    neg_h, neg_t = [], []
    for h in batch_heads:
        for _ in range(num_negs_per_pos):
            for _try in range(30):
                t2 = rng.choice(tail_candidates)
                if (h, t2) not in existing_pairs_set:
                    neg_h.append(h)
                    neg_t.append(t2)
                    break
            else:
                t2 = rng.choice(tail_candidates)
                neg_h.append(h)
                neg_t.append(t2)
    return neg_h, neg_t

tail_candidates = set(val_neg[1].tolist())
existing_pairs  = build_pair_set(train_pos[0], train_pos[1])

print("num tail candidates:", len(tail_candidates))
print("num existing train pairs:", len(existing_pairs))


num tail candidates: 1175
num existing train pairs: 299


In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import RGATConv

class RGATEncoder(nn.Module):
    def __init__(self, num_nodes, num_relations, dim=32, heads=1, dropout=0.2,
                 num_bases=8, num_blocks=None):
        super().__init__()
        self.emb = nn.Embedding(num_nodes, dim)
        self.dropout = dropout

        self.conv1 = RGATConv(
            in_channels=dim,
            out_channels=dim,
            num_relations=num_relations,
            heads=heads,
            concat=True,
            dropout=dropout,
            num_bases=num_bases,     # ✅ مهم
            num_blocks=num_blocks,   # اختیاری
        )

        self.conv2 = RGATConv(
            in_channels=dim * heads,
            out_channels=dim,
            num_relations=num_relations,
            heads=1,
            concat=False,
            dropout=dropout,
            num_bases=num_bases,     # ✅ مهم
            num_blocks=num_blocks,
        )

    def forward(self, edge_index, edge_type):
        x = self.emb.weight
        x = self.conv1(x, edge_index, edge_type)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x

class MLPLinkScorer(nn.Module):
    def __init__(self, dim, hidden=128, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim*3, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1),
        )
    def forward(self, h_u, h_v):
        x = torch.cat([h_u, h_v, h_u*h_v], dim=-1)
        return self.net(x).squeeze(-1)

def batch_score(z, heads, tails, scorer):
    return scorer(z[heads], z[tails])




In [76]:
def binary_metrics(y_true, y_score):
    y_true = np.asarray(y_true).astype(int)
    y_score = np.asarray(y_score).astype(float)

    out = {}
    try:
        from sklearn.metrics import roc_auc_score, average_precision_score
        out["roc_auc"] = float(roc_auc_score(y_true, y_score))
        out["pr_auc"]  = float(average_precision_score(y_true, y_score))
    except Exception:
        out["roc_auc"] = None
        out["pr_auc"]  = None

    y_pred = (y_score >= 0).astype(int)  # threshold روی 0 برای logits
    out["accuracy@0"] = float((y_pred == y_true).mean())
    return out


def eval_binary(encoder, scorer, edge_index, edge_type, pos, neg, device):
    encoder.eval()
    scorer.eval()
    with torch.no_grad():
        z = encoder(edge_index, edge_type)

        pos_h = torch.tensor(pos[0], dtype=torch.long, device=device)
        pos_t = torch.tensor(pos[1], dtype=torch.long, device=device)
        neg_h = torch.tensor(neg[0], dtype=torch.long, device=device)
        neg_t = torch.tensor(neg[1], dtype=torch.long, device=device)

        pos_logits = batch_score(z, pos_h, pos_t, scorer).detach().cpu().numpy()
        neg_logits = batch_score(z, neg_h, neg_t, scorer).detach().cpu().numpy()

        y_true = np.concatenate([np.ones_like(pos_logits), np.zeros_like(neg_logits)])
        y_score = np.concatenate([pos_logits, neg_logits])
        return binary_metrics(y_true, y_score)


In [77]:
dim = 32
heads = 2
dropout = cfg["model"].get("dropout", 0.2)

encoder = RGATEncoder(num_nodes, num_relations, dim=dim, heads=heads, dropout=dropout)
scorer  = MLPLinkScorer(dim=dim, hidden=cfg["model"].get("mlp_hidden", 128), dropout=dropout)

device = torch.device("cuda" if torch.cuda.is_available() and cfg["train"].get("use_cuda", True) else "cpu")
encoder.to(device)
scorer.to(device)

ei_train = ei_train.to(device)
et_train = et_train.to(device)

print("device:", device)


device: cpu


In [78]:
lr = cfg["train"].get("lr", 1e-3)
epochs = cfg["train"].get("epochs", 20)
batch_size = cfg["train"].get("batch_size", 2048)
neg_per_pos = cfg["train"].get("num_negs_per_pos_train", 1)

opt = torch.optim.Adam(list(encoder.parameters()) + list(scorer.parameters()), lr=lr)

log_path = out_dir / "logs" / "rgat_train.csv"
with open(log_path, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "loss", "val_roc_auc", "val_pr_auc", "val_acc@0"])

best_auc = -1.0
best_path = out_dir / "models" / "rgat.pt"

print("log_path:", log_path)
print("best_path:", best_path)


log_path: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\logs\rgat_train.csv
best_path: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\output\models\rgat.pt


In [79]:
for epoch in range(1, epochs + 1):
    encoder.train()
    scorer.train()

    idx = np.random.permutation(train_pos.shape[1])
    total_loss = 0.0
    n_batches = 0

    for start in range(0, len(idx), batch_size):
        batch_idx = idx[start:start+batch_size]

        bh = train_pos[0][batch_idx].tolist()
        bt = train_pos[1][batch_idx].tolist()

        # negatives (corrupt tail)
        nh, nt = sample_negatives_on_the_fly(
            bh, tail_candidates, existing_pairs,
            num_negs_per_pos=neg_per_pos,
            seed=cfg.get("seed", 42) + epoch + start
        )

        bh_t = torch.tensor(bh, dtype=torch.long, device=device)
        bt_t = torch.tensor(bt, dtype=torch.long, device=device)
        nh_t = torch.tensor(nh, dtype=torch.long, device=device)
        nt_t = torch.tensor(nt, dtype=torch.long, device=device)

        opt.zero_grad()

        z = encoder(ei_train, et_train)

        pos_logits = batch_score(z, bh_t, bt_t, scorer)
        neg_logits = batch_score(z, nh_t, nt_t, scorer)

        y = torch.cat([torch.ones_like(pos_logits), torch.zeros_like(neg_logits)])
        logits = torch.cat([pos_logits, neg_logits])

        loss = F.binary_cross_entropy_with_logits(logits, y)
        loss.backward()
        opt.step()

        total_loss += float(loss.item())
        n_batches += 1

    avg_loss = total_loss / max(1, n_batches)

    # validation
    val_m = eval_binary(encoder, scorer, ei_train, et_train, val_pos, val_neg, device)
    val_auc = val_m.get("roc_auc") if val_m.get("roc_auc") is not None else -1.0

    with open(log_path, "a", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow([epoch, avg_loss, val_m.get("roc_auc"), val_m.get("pr_auc"), val_m.get("accuracy@0")])

    print(f"[R-GAT] epoch={epoch} loss={avg_loss:.4f} val_roc_auc={val_auc}")

    # save best
    if val_auc > best_auc:
        best_auc = val_auc
        torch.save(
            {"encoder": encoder.state_dict(), "scorer": scorer.state_dict(), "cfg": cfg},
            best_path
        )

print(f"[R-GAT] Done. Best val roc_auc={best_auc:.4f}, saved: {best_path}")


[R-GAT] epoch=1 loss=0.6938 val_roc_auc=0.48293644996347695
[R-GAT] epoch=2 loss=0.6932 val_roc_auc=0.4944777209642075
[R-GAT] epoch=3 loss=0.6929 val_roc_auc=0.5118115412710007
[R-GAT] epoch=4 loss=0.6925 val_roc_auc=0.5212783053323594
[R-GAT] epoch=5 loss=0.6923 val_roc_auc=0.5276113951789627
[R-GAT] epoch=6 loss=0.6918 val_roc_auc=0.5347479912344777
[R-GAT] epoch=7 loss=0.6919 val_roc_auc=0.5440832724616509
[R-GAT] epoch=8 loss=0.6909 val_roc_auc=0.552527392257122
[R-GAT] epoch=9 loss=0.6903 val_roc_auc=0.5604382761139518
[R-GAT] epoch=10 loss=0.6890 val_roc_auc=0.5704894083272462
[R-GAT] epoch=11 loss=0.6893 val_roc_auc=0.5830971512052594
[R-GAT] epoch=12 loss=0.6884 val_roc_auc=0.5939663988312637
[R-GAT] epoch=13 loss=0.6874 val_roc_auc=0.6035208181154128
[R-GAT] epoch=14 loss=0.6858 val_roc_auc=0.612483564645727
[R-GAT] epoch=15 loss=0.6847 val_roc_auc=0.6192768444119796
[R-GAT] epoch=16 loss=0.6829 val_roc_auc=0.6245215485756027
[R-GAT] epoch=17 loss=0.6792 val_roc_auc=0.6303798

In [80]:
print("saved model exists?", best_path.exists())
print("log exists?", log_path.exists())


saved model exists? True
log exists? True
