In [1]:
from pathlib import Path
import json
import numpy as np
import torch

In [2]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def load_json(path: Path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def load_yaml(path: Path):
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists():
            return p
    return start

project_root = find_project_root()
project_root


WindowsPath('D:/Shiraz University/HomeWorks/Ostad Moosavi/LinkPrediction')

In [3]:
config_path = project_root / "code" / "config.yaml"
cfg = load_yaml(config_path)

proc_dir = project_root / cfg["data"]["processed_dir"]
ensure_dir(proc_dir)

print("proc_dir:", proc_dir)


proc_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed


In [4]:
g = torch.load(proc_dir / "graph_edges.pt")
edge_index = g["edge_index"]
edge_type  = g["edge_type"]
num_nodes  = int(g["num_nodes"])

relation2id = load_json(proc_dir / "relation2id.json")
id2etype    = load_json(proc_dir / "id2etype.json")

print("edge_index shape:", edge_index.shape)
print("edge_type shape:", edge_type.shape)
print("num_nodes:", num_nodes)
print("num_relations:", int(g["num_relations"]))


edge_index shape: torch.Size([2, 118308])
edge_type shape: torch.Size([118308])
num_nodes: 37614
num_relations: 107


In [5]:
def build_pair_set(heads, tails):
    return set(zip(heads, tails))

def sample_negatives_for_pairs(pos_heads, pos_tails, tail_candidates, existing_pairs_set,
                               num_negs_per_pos=1, seed=42):
    import random
    rng = random.Random(seed)

    tail_candidates = list(tail_candidates)
    if len(tail_candidates) == 0:
        raise ValueError("tail_candidates is empty!")

    neg_h, neg_t = [], []
    for h, t in zip(pos_heads, pos_tails):
        for _ in range(num_negs_per_pos):
            for _try in range(50):
                t2 = rng.choice(tail_candidates)
                if (h, t2) not in existing_pairs_set:
                    neg_h.append(h)
                    neg_t.append(t2)
                    break
            else:

                t2 = rng.choice(tail_candidates)
                neg_h.append(h)
                neg_t.append(t2)

    return neg_h, neg_t


In [6]:
def pick_target_relation_ids(relation2id: dict, target_substrings):
    ids = []
    for rel, rid in relation2id.items():
        if any(s in rel for s in target_substrings):
            ids.append(rid)
    return sorted(list(set(ids)))

target_substrings = cfg["task"]["target_relation_substrings"]
target_rel_ids = pick_target_relation_ids(relation2id, target_substrings)

print("target_substrings:", target_substrings)
print("target_rel_ids:", target_rel_ids)

if len(target_rel_ids) == 0:
    raise ValueError("هیچ relation مطابق target_relation_substrings پیدا نشد!")


target_substrings: ['treats']
target_rel_ids: [73]


In [7]:
target_mask = torch.zeros(edge_type.size(0), dtype=torch.bool)
for rid in target_rel_ids:
    target_mask |= (edge_type == rid)

target_idx = torch.where(target_mask)[0].cpu().numpy()

print("num target edges (before type filter):", len(target_idx))

if len(target_idx) == 0:
    raise ValueError("هیچ یال هدفی توی گراف پیدا نشد! احتمالاً فیلتر preprocess سخت بوده.")

heads = edge_index[0, target_idx].cpu().numpy().tolist()
tails = edge_index[1, target_idx].cpu().numpy().tolist()

print("sample head/tail ids:", heads[:5], tails[:5])


num target edges (before type filter): 374
sample head/tail ids: [747, 897, 1009, 1754, 2581] [748, 898, 164, 2563, 2582]


In [8]:
use_type_filter = cfg["splits"].get("use_type_filter", True)
head_type = cfg["splits"].get("head_type", None)
tail_type = cfg["splits"].get("tail_type", None)

if use_type_filter and head_type and tail_type:
    keep = []
    for i, (h, t) in enumerate(zip(heads, tails)):
        ht = id2etype.get(str(h), "UNK")
        tt = id2etype.get(str(t), "UNK")
        if (ht == head_type) and (tt == tail_type):
            keep.append(i)

    heads = [heads[i] for i in keep]
    tails = [tails[i] for i in keep]
    target_idx = target_idx[keep]
    print(f"After type filter ({head_type}->{tail_type}): {len(target_idx)} positives")
else:
    print("Type filter is OFF or head_type/tail_type not set in config.")


After type filter (Compound->Disease): 374 positives


In [9]:
n = len(target_idx)
seed = cfg.get("seed", 42)
rng = np.random.default_rng(seed)
perm = rng.permutation(n)

train_ratio = cfg["splits"].get("train_ratio", 0.8)
val_ratio   = cfg["splits"].get("val_ratio", 0.1)

n_train = int(n * train_ratio)
n_val   = int(n * val_ratio)
n_test  = n - n_train - n_val

idx_train = perm[:n_train]
idx_val   = perm[n_train:n_train+n_val]
idx_test  = perm[n_train+n_val:]

heads_np = np.array(heads)
tails_np = np.array(tails)

train_pos = np.array([heads_np[idx_train], tails_np[idx_train]], dtype=np.int64)
val_pos   = np.array([heads_np[idx_val],   tails_np[idx_val]], dtype=np.int64)
test_pos  = np.array([heads_np[idx_test],  tails_np[idx_test]], dtype=np.int64)

print("train_pos:", train_pos.shape, "val_pos:", val_pos.shape, "test_pos:", test_pos.shape)


train_pos: (2, 299) val_pos: (2, 37) test_pos: (2, 38)


In [10]:
remove_target_idx = np.concatenate([target_idx[idx_val], target_idx[idx_test]])
remove_set = set(remove_target_idx.tolist())

all_idx = np.arange(edge_type.size(0))
keep_idx = np.array([i for i in all_idx if i not in remove_set], dtype=np.int64)

print("train_graph edges kept:", len(keep_idx), "/", edge_type.size(0))


train_graph edges kept: 118233 / 118308


In [11]:

if tail_type:
    tail_candidates = [i for i in range(num_nodes) if id2etype.get(str(i), "UNK") == tail_type]
else:
    tail_candidates = list(range(num_nodes))

print("num tail candidates:", len(tail_candidates))


num tail candidates: 1867


In [12]:

all_pos_pairs = build_pair_set(heads_np.tolist(), tails_np.tolist())

K = cfg["splits"].get("num_negs_per_pos_eval", 50)

val_neg_h, val_neg_t = sample_negatives_for_pairs(
    val_pos[0].tolist(), val_pos[1].tolist(),
    tail_candidates, all_pos_pairs,
    num_negs_per_pos=K, seed=seed + 1
)

test_neg_h, test_neg_t = sample_negatives_for_pairs(
    test_pos[0].tolist(), test_pos[1].tolist(),
    tail_candidates, all_pos_pairs,
    num_negs_per_pos=K, seed=seed + 2
)

val_neg = np.array([val_neg_h, val_neg_t], dtype=np.int64)
test_neg = np.array([test_neg_h, test_neg_t], dtype=np.int64)

print("val_neg:", val_neg.shape, "test_neg:", test_neg.shape, "K:", K)


val_neg: (2, 1850) test_neg: (2, 1900) K: 50


In [13]:
np.savez(proc_dir / "split_target_edges.npz",
         train_pos=train_pos, val_pos=val_pos, test_pos=test_pos)

np.savez(proc_dir / "negatives.npz",
         val_neg=val_neg, test_neg=test_neg, K=K)

np.save(proc_dir / "train_graph_edge_idx.npy", keep_idx)

print("Saved:")
print(" - split_target_edges.npz")
print(" - negatives.npz")
print(" - train_graph_edge_idx.npy")


Saved:
 - split_target_edges.npz
 - negatives.npz
 - train_graph_edge_idx.npy


In [14]:
for p in ["split_target_edges.npz", "negatives.npz", "train_graph_edge_idx.npy"]:
    print(p, "exists?", (proc_dir / p).exists())


split_target_edges.npz exists? True
negatives.npz exists? True
train_graph_edge_idx.npy exists? True
