In [5]:
import random
import argparse

def load_edges(path):
    """
    Load edges from a file. Each line should be: head rel tail (space- or tab-separated).
    Returns a list of (head, rel, tail) triples and a set of all nodes.
    """
    edges = []
    nodes = set()
    with open(path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 3:
                continue
            head, rel, tail = parts[0], parts[1], parts[2]
            edges.append((head, rel, tail))
            nodes.update([head, tail])
    return edges, nodes

def split_edges(edges, nodes, train_frac, dev_frac, test_frac, seed=None):
    """
    Split edges into train/dev/test for a transductive KG completion task,
    ensuring every node appears in at least one training edge.
    """
    # random.seed(seed)

    # 1) Pick one 'safety' edge per node
    safety_edges = set()
    incident = {n: [] for n in nodes}
    for e in edges:
        h, _, t = e
        incident[h].append(e)
        incident[t].append(e)
    for n in nodes:
        if incident[n]:
            e = random.choice(incident[n])
            safety_edges.add(e)

    # 2) Initialize train-degree counts from safety edges
    train_degree = {n: 0 for n in nodes}
    for e in safety_edges:
        h, _, t = e
        train_degree[h] += 1
        train_degree[t] += 1

    # 3) Collect remaining candidates and shuffle
    remaining = [e for e in edges if e not in safety_edges]
    random.shuffle(remaining)

    E = len(edges)
    dev_target = int(E * dev_frac)
    test_target = int(E * test_frac)

    dev, test, train_extra = [], [], []

    # 4) Assign to dev
    for e in remaining:
        if len(dev) < dev_target:
            h, _, t = e
            # only remove if both endpoints still have >= 1 in train
            if train_degree[h] > 1 and train_degree[t] > 1:
                dev.append(e)
                train_degree[h] -= 1
                train_degree[t] -= 1
            else:
                train_extra.append(e)
        else:
            train_extra.append(e)

    # 5) Assign to test
    rem2 = train_extra[:]
    train_extra = []
    for e in rem2:
        if len(test) < test_target:
            h, _, t = e
            if train_degree[h] > 1 and train_degree[t] > 1:
                test.append(e)
                train_degree[h] -= 1
                train_degree[t] -= 1
            else:
                train_extra.append(e)
        else:
            train_extra.append(e)

    # 6) Final train is safety + leftovers
    train = list(safety_edges) + train_extra
    return train, dev, test

def save_edges(edges, path):
    """
    Write edges to a file in tab-separated format: head\trel\ttail.
    """
    with open(path, 'w') as f:
        for h, r, t in edges:
            f.write(f"{h}\t{r}\t{t}\n")





In [6]:
edges, nodes = load_edges("../graph_data/triples.tsv")

train, dev, test = split_edges(
    edges, nodes,
    0.8, 0.01, 0.01,
    # seed = 201
)

save_edges(train, "../graph_data/train.tsv")
save_edges(dev, "../graph_data/dev.tsv")
save_edges(test, "../graph_data/test.tsv")