## Feature Engineering.

Feature Engineering. Handle datasets in order to make sure test sets do have new relations/entities

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Inductive relation-based partitioning for Knowledge Graph datasets.
Generates NL-25 / NL-50 / NL-75 / NL-100 splits where relations are unseen in train.

Assumptions:
- Input datasets follow:
  data/DATASET/train.txt
  data/DATASET/valid.txt
  data/DATASET/test.txt
- Triplets are tab-separated: h \t r \t t
- Entities may appear in any split (allowed)
- Relations selected as "new" do NOT appear in train
- valid and test contain ONLY new relations
- Reproducible via fixed seed
- No external deps beyond Python stdlib (+ optional numpy, not required)

Author: you + ChatGPT
"""

import os
import random
from collections import defaultdict

# =========================
# CONFIG
# =========================
BASE_DATA_DIR = "data"
DATASET_NAME = "WN18RR"      # change if needed
SEED = 42

ALPHAS = {
    "NL-25": 0.25,
    "NL-50": 0.50,
    "NL-75": 0.75,
    "NL-100": 1.00,
}

# =========================
# UTILS
# =========================
def read_triples(path):
    triples = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            h, r, t = line.split("\t")
            triples.append((h, r, t))
    return triples


def write_triples(path, triples):
    with open(path, "w", encoding="utf-8") as f:
        for h, r, t in triples:
            f.write(f"{h}\t{r}\t{t}\n")


def ensure_dir(path):
    os.makedirs(path, exist_ok=True)


# =========================
# MAIN LOGIC
# =========================
def main():
    random.seed(SEED)

    dataset_dir = os.path.join(BASE_DATA_DIR, DATASET_NAME)
    train_path = os.path.join(dataset_dir, "train.txt")
    valid_path = os.path.join(dataset_dir, "valid.txt")
    test_path  = os.path.join(dataset_dir, "test.txt")

    print(f"\n[INFO] Loading base dataset: {DATASET_NAME}")

    train_triples = read_triples(train_path)
    valid_triples = read_triples(valid_path)
    test_triples  = read_triples(test_path)

    all_triples = train_triples + valid_triples + test_triples

    # Group triples by relation
    rel2triples = defaultdict(list)
    for h, r, t in all_triples:
        rel2triples[r].append((h, r, t))

    all_relations = sorted(rel2triples.keys())
    num_relations = len(all_relations)

    print(f"[STATS] Total triples      : {len(all_triples)}")
    print(f"[STATS] Total relations    : {num_relations}")

    for split_name, alpha in ALPHAS.items():
        print("\n" + "=" * 60)
        print(f"[SPLIT] Generating {split_name} (alpha={alpha})")

        num_new_rel = int(round(num_relations * alpha))

        if num_new_rel > num_relations:
            num_new_rel = num_relations

        shuffled_relations = all_relations[:]
        random.shuffle(shuffled_relations)

        new_relations = set(shuffled_relations[:num_new_rel])
        old_relations = set(shuffled_relations[num_new_rel:])

        # Build splits
        train_split = []
        valid_split = []
        test_split  = []

        for r in old_relations:
            train_split.extend(rel2triples[r])

        new_relation_triples = []
        for r in new_relations:
            new_relation_triples.extend(rel2triples[r])

        # Split new-relation triples into valid/test (50/50)
        random.shuffle(new_relation_triples)
        mid = len(new_relation_triples) // 2
        valid_split = new_relation_triples[:mid]
        test_split  = new_relation_triples[mid:]

        # Safety checks
        train_rels = {r for _, r, _ in train_split}
        valid_rels = {r for _, r, _ in valid_split}
        test_rels  = {r for _, r, _ in test_split}

        assert train_rels.isdisjoint(new_relations), "Leakage: new relations in train!"
        assert valid_rels.issubset(new_relations), "Invalid relation in valid!"
        assert test_rels.issubset(new_relations), "Invalid relation in test!"

        # Output directory
        out_dir = os.path.join(dataset_dir, split_name)
        ensure_dir(out_dir)

        write_triples(os.path.join(out_dir, "train.txt"), train_split)
        write_triples(os.path.join(out_dir, "valid.txt"), valid_split)
        write_triples(os.path.join(out_dir, "test.txt"),  test_split)

        # Report
        print(f"[STATS] #new relations     : {len(new_relations)}")
        print(f"[STATS] train triples     : {len(train_split)}")
        print(f"[STATS] valid triples     : {len(valid_split)}")
        print(f"[STATS] test  triples     : {len(test_split)}")
        print(f"[PATH ] Written to        : {out_dir}")

    print("\n[DONE] All NL-* splits generated successfully.")


if __name__ == "__main__":
    main()
