In [1]:
from pathlib import Path
import json
import gzip
from collections import Counter

import torch
import numpy as np


In [2]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def save_json(obj, path: Path):
    ensure_dir(path.parent)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def load_yaml(path: Path):
    import yaml
    with open(path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)

def open_maybe_gz(path: Path, mode="rt"):
    if str(path).endswith(".gz"):
        return gzip.open(path, mode, encoding="utf-8", errors="ignore")
    return open(path, mode, encoding="utf-8", errors="ignore")

def read_tsv_edges(path: Path, max_edges: int = None):

    with open_maybe_gz(path, "rt") as f:
        for i, line in enumerate(f):
            if max_edges is not None and i >= max_edges:
                break
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t")
            if len(parts) < 3:
                continue
            yield parts[0], parts[1], parts[2]


In [3]:
def parse_entity_type(entity: str) -> str:

    if "::" in entity:
        return entity.split("::", 1)[0]
    if ":" in entity:
        return entity.split(":", 1)[0]
    return "UNK"

def build_mappings_and_edges(edge_iter, relations_keep=None, entity_types_keep=None):
    entity2id = {}
    relation2id = {}
    id2entity = []
    id2relation = []
    id2etype = []

    edges_src, edges_rel, edges_dst = [], [], []

    def get_ent_id(ent: str):
        if ent in entity2id:
            return entity2id[ent]
        idx = len(id2entity)
        entity2id[ent] = idx
        id2entity.append(ent)
        id2etype.append(parse_entity_type(ent))
        return idx

    def get_rel_id(rel: str):
        if rel in relation2id:
            return relation2id[rel]
        idx = len(id2relation)
        relation2id[rel] = idx
        id2relation.append(rel)
        return idx

    for h, r, t in edge_iter:
        # filter relations by substring if provided
        if relations_keep:
            if not any(k in r for k in relations_keep):
                continue

        # filter entity types if provided
        if entity_types_keep:
            ht = parse_entity_type(h)
            tt = parse_entity_type(t)
            if ht not in entity_types_keep or tt not in entity_types_keep:
                continue

        hid = get_ent_id(h)
        rid = get_rel_id(r)
        tid = get_ent_id(t)

        edges_src.append(hid)
        edges_rel.append(rid)
        edges_dst.append(tid)

    return entity2id, relation2id, id2entity, id2relation, id2etype, edges_src, edges_rel, edges_dst

def make_edge_tensors(edges_src, edges_dst, edges_rel):
    edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
    edge_type = torch.tensor(edges_rel, dtype=torch.long)
    return edge_index, edge_type

def relation_counts(edge_type: torch.Tensor):
    c = Counter(edge_type.tolist())
    return dict(c)

def degree_stats(edge_index: torch.Tensor, num_nodes: int):
    deg = torch.zeros(num_nodes, dtype=torch.long)
    src = edge_index[0]
    dst = edge_index[1]
    deg.scatter_add_(0, src, torch.ones_like(src))
    deg.scatter_add_(0, dst, torch.ones_like(dst))
    return deg


In [2]:
def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()

    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data").exists():
            return p
    return start

project_root = find_project_root()
project_root


WindowsPath('D:/Shiraz University/HomeWorks/Ostad Moosavi/LinkPrediction')

In [5]:
config_path = project_root / "code" / "config.yaml"
cfg = load_yaml(config_path)
cfg

{'seed': 42,
 'data': {'raw_edges_path': 'data/raw/drkg_subgraph_120k.tsv',
  'processed_dir': 'data/processed'},
 'output': {'dir': 'output'},
 'preprocess': {'max_edges': 2000000,
  'relations_keep': None,
  'entity_types_keep': None},
 'task': {'target_relation_substrings': ['treats']},
 'splits': {'train_ratio': 0.8,
  'val_ratio': 0.1,
  'use_type_filter': True,
  'head_type': 'Compound',
  'tail_type': 'Disease',
  'num_negs_per_pos_eval': 50},
 'train': {'use_cuda': True,
  'lr': 0.001,
  'epochs': 30,
  'batch_size': 512,
  'num_negs_per_pos_train': 1},
 'model': {'dim': 128,
  'dropout': 0.2,
  'mlp_hidden': 128,
  'gat_heads': 4,
  'rgat_heads': 4,
  'rgcn_num_bases': 30},
 'attention': {'topk_relations_plot': 20, 'num_cases': 3, 'k_hops': 2}}

In [6]:
raw_path = project_root / cfg["data"]["raw_edges_path"]
out_dir = project_root / cfg["data"]["processed_dir"]

max_edges = cfg["preprocess"].get("max_edges", None)
relations_keep = cfg["preprocess"].get("relations_keep", None)
entity_types_keep = cfg["preprocess"].get("entity_types_keep", None)

print("raw_path:", raw_path)
print("out_dir:", out_dir)
print("max_edges:", max_edges)
print("relations_keep:", relations_keep)
print("entity_types_keep:", entity_types_keep)


raw_path: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\raw\drkg_subgraph_120k.tsv
out_dir: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed
max_edges: 2000000
relations_keep: None
entity_types_keep: None


In [7]:
ensure_dir(out_dir)

print(f"[Preprocess] Reading: {raw_path}")
edge_iter = read_tsv_edges(raw_path, max_edges=max_edges)

entity2id, relation2id, id2entity, id2relation, id2etype, src, rel, dst = \
    build_mappings_and_edges(edge_iter, relations_keep, entity_types_keep)

edge_index, edge_type = make_edge_tensors(src, dst, rel)

num_nodes = len(id2entity)
num_rels = len(id2relation)

print(f"[Preprocess] Nodes={num_nodes}, Relations={num_rels}, Edges={edge_index.size(1)}")


[Preprocess] Reading: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\raw\drkg_subgraph_120k.tsv
[Preprocess] Nodes=37614, Relations=107, Edges=118308


In [8]:
meta = {
    "num_nodes": num_nodes,
    "num_relations": num_rels,
    "num_edges": int(edge_index.size(1)),
    "relations_keep": relations_keep,
    "entity_types_keep": entity_types_keep,
    "max_edges": max_edges,
}

# simple stats
rel_cnt = relation_counts(edge_type)
deg = degree_stats(edge_index, num_nodes=num_nodes)

meta["degree_max"] = int(deg.max().item()) if num_nodes > 0 else 0
meta["degree_mean"] = float(deg.float().mean().item()) if num_nodes > 0 else 0.0
meta["top_relations"] = sorted(rel_cnt.items(), key=lambda x: -x[1])[:20]

# save json maps
save_json(entity2id, out_dir / "entity2id.json")
save_json(relation2id, out_dir / "relation2id.json")
save_json({str(i): e for i, e in enumerate(id2entity)}, out_dir / "id2entity.json")
save_json({str(i): r for i, r in enumerate(id2relation)}, out_dir / "id2relation.json")
save_json({str(i): t for i, t in enumerate(id2etype)}, out_dir / "id2etype.json")
save_json(meta, out_dir / "graph_meta.json")

# save graph tensors
torch.save(
    {
        "edge_index": edge_index,
        "edge_type": edge_type,
        "num_nodes": num_nodes,
        "num_relations": num_rels,
    },
    out_dir / "graph_edges.pt",
)

print(f"[Preprocess] Saved all files to: {out_dir}")
print("[Preprocess] graph_meta top relations:", meta["top_relations"][:5])


[Preprocess] Saved all files to: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction\data\processed
[Preprocess] graph_meta top relations: [(1, 20939), (16, 8672), (12, 8177), (5, 6292), (24, 5441)]


In [9]:
print("Processed files:")
for p in sorted(out_dir.glob("*")):
    print(" -", p.name)

Processed files:
 - entity2id.json
 - graph_edges.pt
 - graph_meta.json
 - id2entity.json
 - id2etype.json
 - id2relation.json
 - relation2id.json
