In [6]:
from pathlib import Path
from collections import Counter, defaultdict
import numpy as np
import gzip, json, time


In [7]:
from pathlib import Path
import os

def find_project_root(start: Path = None) -> Path:
    if start is None:
        start = Path.cwd()
    for p in [start] + list(start.parents):
        if (p / "code").exists() and (p / "data" / "raw").exists():
            return p
    return start

project_root = find_project_root()
os.chdir(project_root)
print("CWD set to:", Path.cwd())
print("Has data/raw:", (Path.cwd() / "data" / "raw").exists())


CWD set to: D:\Shiraz University\HomeWorks\Ostad Moosavi\LinkPrediction
Has data/raw: True


In [8]:
raw_dir = Path.cwd() / "data" / "raw"


In [9]:
def ensure_dir(path: Path):
    path.mkdir(parents=True, exist_ok=True)

def open_maybe_gz(path: Path):
    return gzip.open(path, "rt", encoding="utf-8", errors="ignore") if str(path).endswith(".gz") else open(path, "rt", encoding="utf-8", errors="ignore")

def detect_delimiter(path: Path):
    with open_maybe_gz(path) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            return "\t" if "\t" in line else ","
    return "\t"

def iter_edges(path: Path, delim: str):
    with open_maybe_gz(path) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split(delim)
            if len(parts) < 3:
                continue
            yield parts[0].strip(), parts[1].strip(), parts[2].strip()

def allocate_quotas(rel_counts: dict, total_edges: int, min_per_relation: int):
    rels = list(rel_counts.keys())
    R = len(rels)
    if R == 0:
        raise ValueError("هیچ relationای در فایل پیدا نشد!")

    min_per_relation = min(min_per_relation, max(1, total_edges // R))
    base = min_per_relation * R
    remaining = max(0, total_edges - base)

    total = sum(rel_counts.values())
    quotas = {r: min_per_relation for r in rels}

    if remaining > 0 and total > 0:
        frac = []
        for r in rels:
            share = remaining * (rel_counts[r] / total)
            add = int(np.floor(share))
            quotas[r] += add
            frac.append((share - add, r))

        diff = total_edges - sum(quotas.values())
        frac.sort(reverse=True)
        i = 0
        while diff > 0 and len(frac) > 0:
            quotas[frac[i % len(frac)][1]] += 1
            diff -= 1
            i += 1

    return quotas

def pick_largest_file(raw_dir: Path, include_substr=None, exclude_substr=None):
    files = [p for p in raw_dir.glob("*") if p.is_file()]
    if include_substr is not None:
        files = [p for p in files if include_substr.lower() in p.name.lower()]
    if exclude_substr is not None:
        files = [p for p in files if exclude_substr.lower() not in p.name.lower()]
    if not files:
        return None
    files.sort(key=lambda p: p.stat().st_size, reverse=True)
    return files[0]

def edge_key(h, r, t):
    return h + "\t" + r + "\t" + t


In [10]:
# ========= تنظیمات =========
raw_dir = Path("data/raw")
if not raw_dir.exists():
    raise FileNotFoundError("پوشه data/raw پیدا نشد. اول به ریشه پروژه cd کن.")

# 1) فایل train subgraph (اگر دارید) -> برای جلوگیری از overlap
train_subgraph_path = pick_largest_file(raw_dir, include_substr="subgraph")

# 2) فایل اصلی برای نمونه‌گیری (ترجیحاً drkg.tsv یا drkg.tsv.gz)
source_path = pick_largest_file(raw_dir, exclude_substr="subgraph") or train_subgraph_path

# اندازه‌ی تست
test_edges = 20_000
min_per_relation = 100
seed = 42

# برای اینکه تست قابل امتیازدهی توسط مدل‌های train شده باشد
require_entities_in_train = True
require_relations_in_train = True

print("source_path:", source_path)
print("train_subgraph_path:", train_subgraph_path)
print("test_edges:", test_edges, "| min_per_relation:", min_per_relation)


source_path: data\raw\drkg.tsv
train_subgraph_path: data\raw\drkg_subgraph_120k.tsv
test_edges: 20000 | min_per_relation: 100


In [11]:
train_edges = set()
train_entities = set()
train_relations = set()

if train_subgraph_path is not None and train_subgraph_path.exists():
    delim_train = detect_delimiter(train_subgraph_path)
    print("Reading train subgraph:", train_subgraph_path, "| delim:", "TAB" if delim_train=="\t" else "COMMA")
    n = 0
    for h, r, t in iter_edges(train_subgraph_path, delim_train):
        train_edges.add(edge_key(h, r, t))
        train_entities.add(h); train_entities.add(t)
        train_relations.add(r)
        n += 1
    print("train edges:", len(train_edges))
    print("train entities:", len(train_entities))
    print("train relations:", len(train_relations))
else:
    print("No train_subgraph found -> overlap prevention disabled.")


Reading train subgraph: data\raw\drkg_subgraph_120k.tsv | delim: TAB
train edges: 118305
train entities: 37614
train relations: 107


In [12]:
if source_path is None or not source_path.exists():
    raise FileNotFoundError("هیچ فایل source پیدا نشد. یک فایل drkg.tsv(.gz) یا subgraph داخل data/raw بگذار.")

delim = detect_delimiter(source_path)
print("Detected delimiter for source:", "TAB" if delim=="\t" else "COMMA")

rel_counts = Counter()
total_lines = 0
eligible_lines = 0

t0 = time.time()
for h, r, t in iter_edges(source_path, delim):
    total_lines += 1

    if train_edges and edge_key(h, r, t) in train_edges:
        continue
    if require_entities_in_train and train_entities:
        if (h not in train_entities) or (t not in train_entities):
            continue
    if require_relations_in_train and train_relations:
        if r not in train_relations:
            continue

    rel_counts[r] += 1
    eligible_lines += 1

    if total_lines % 1_000_000 == 0:
        print(f"scanned {total_lines:,} | eligible {eligible_lines:,} | rels {len(rel_counts)}")

dt = time.time() - t0
print("Done pass1.")
print("Total scanned:", f"{total_lines:,}")
print("Eligible edges:", f"{eligible_lines:,}")
print("Num relations (eligible):", len(rel_counts))
print("Top5:", rel_counts.most_common(5))
print("time(sec):", round(dt, 2))


Detected delimiter for source: TAB
scanned 1,000,000 | eligible 959,705 | rels 16
scanned 2,000,000 | eligible 1,794,160 | rels 57
scanned 4,000,000 | eligible 3,502,084 | rels 80
scanned 5,000,000 | eligible 4,444,454 | rels 95
Done pass1.
Total scanned: 5,874,261
Eligible edges: 5,283,458
Num relations (eligible): 95
Top5: [('DRUGBANK::ddi-interactor-in::Compound:Compound', 1348683), ('Hetionet::AeG::Anatomy:Gene', 481849), ('Hetionet::GpBP::Gene:Biological Process', 430843), ('STRING::REACTION::Gene:Gene', 392380), ('STRING::CATALYSIS::Gene:Gene', 336716)]
time(sec): 11.91


In [13]:
if sum(rel_counts.values()) < test_edges:
    print("WARNING: eligible edges کمتر از test_edges است. test_edges را کم می‌کنیم.")
    test_edges = int(sum(rel_counts.values()))

quotas = allocate_quotas(rel_counts, test_edges, min_per_relation)
print("Total quota:", sum(quotas.values()))

rng = np.random.default_rng(seed)
seen = defaultdict(int)
reservoirs = {r: [] for r in quotas.keys()}

t0 = time.time()
total_lines = 0
eligible_lines = 0

for h, r, t in iter_edges(source_path, delim):
    total_lines += 1

    if train_edges and edge_key(h, r, t) in train_edges:
        continue
    if require_entities_in_train and train_entities:
        if (h not in train_entities) or (t not in train_entities):
            continue
    if require_relations_in_train and train_relations:
        if r not in train_relations:
            continue

    q = quotas.get(r, 0)
    if q <= 0:
        continue

    eligible_lines += 1
    seen[r] += 1
    k = len(reservoirs[r])

    if k < q:
        reservoirs[r].append((h, r, t))
    else:
        j = rng.integers(0, seen[r])
        if j < q:
            reservoirs[r][j] = (h, r, t)

    if total_lines % 1_000_000 == 0:
        picked_now = sum(len(v) for v in reservoirs.values())
        print(f"scanned {total_lines:,} | eligible {eligible_lines:,} | picked {picked_now:,}")

dt = time.time() - t0
print("Done pass2. time(sec):", round(dt, 2))

all_edges = []
for r in reservoirs:
    all_edges.extend(reservoirs[r])

rng.shuffle(all_edges)

print("Picked edges:", len(all_edges))
print("Relations covered:", sum(1 for r in reservoirs if len(reservoirs[r])>0), "/", len(reservoirs))


Total quota: 20000
scanned 1,000,000 | eligible 959,705 | picked 4,060
scanned 2,000,000 | eligible 1,794,160 | picked 9,567
scanned 4,000,000 | eligible 3,502,084 | picked 14,808
scanned 5,000,000 | eligible 4,444,454 | picked 19,592
Done pass2. time(sec): 26.98
Picked edges: 19592
Relations covered: 95 / 95


In [14]:
out_path = raw_dir / f"drkg_test_holdout_{test_edges//1000}k.tsv"
with open(out_path, "w", encoding="utf-8") as f:
    for h, r, t in all_edges:
        f.write(f"{h}\t{r}\t{t}\n")

print("Saved TEST:", out_path)
print("Edges:", len(all_edges))

out_metrics = Path("output") / "metrics"
ensure_dir(out_metrics)

summary = {
    "source_path": str(source_path),
    "train_subgraph_path": str(train_subgraph_path) if train_subgraph_path else None,
    "test_path": str(out_path),
    "test_edges": int(len(all_edges)),
    "require_entities_in_train": bool(require_entities_in_train),
    "require_relations_in_train": bool(require_relations_in_train),
    "num_relations_covered": int(sum(1 for r in reservoirs if len(reservoirs[r])>0)),
    "num_relations_total_eligible": int(len(reservoirs)),
    "top_relations_in_test": Counter([r for _, r, _ in all_edges]).most_common(10),
}

with open(out_metrics / "test_holdout_summary.json", "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2, ensure_ascii=False)

print("Saved summary:", out_metrics / "test_holdout_summary.json")
summary


Saved TEST: data\raw\drkg_test_holdout_20k.tsv
Edges: 19592
Saved summary: output\metrics\test_holdout_summary.json


{'source_path': 'data\\raw\\drkg.tsv',
 'train_subgraph_path': 'data\\raw\\drkg_subgraph_120k.tsv',
 'test_path': 'data\\raw\\drkg_test_holdout_20k.tsv',
 'test_edges': 19592,
 'require_entities_in_train': True,
 'require_relations_in_train': True,
 'num_relations_covered': 95,
 'num_relations_total_eligible': 95,
 'top_relations_in_test': [('DRUGBANK::ddi-interactor-in::Compound:Compound',
   2780),
  ('Hetionet::AeG::Anatomy:Gene', 1058),
  ('Hetionet::GpBP::Gene:Biological Process', 956),
  ('STRING::REACTION::Gene:Gene', 880),
  ('STRING::CATALYSIS::Gene:Gene', 769),
  ('STRING::BINDING::Gene:Gene', 710),
  ('STRING::OTHER::Gene:Gene', 659),
  ('Hetionet::Gr>G::Gene:Gene', 609),
  ('Hetionet::GiG::Gene:Gene', 378),
  ('INTACT::PHYSICAL ASSOCIATION::Gene:Gene', 334)]}