In [1]:
from collections import Counter, defaultdict
import random

In [2]:
data_dir = "./knowledge_graph_completion/data/OpenBG500/OpenBG500_train.tsv"

In [3]:
# 1. 读取原始 OpenBG500 训练集文件，统计每种关系出现次数，统计每个实体出现次数
relation_counter = Counter()
entity_counter   = Counter()
rel2triples = defaultdict(list)
ent2triples = defaultdict(list) 

with open(data_dir, 'r', encoding='utf-8') as f:
    for line in f:
        h, r, t = line.strip().split('\t')
        relation_counter[r] += 1
        entity_counter[h]   += 1
        entity_counter[t]   += 1

        triple = (h, r, t)
        rel2triples[r].append(triple)
        ent2triples[h].append(triple)
        ent2triples[t].append(triple)

In [4]:
# 2. 根据频率将关系分类（低频、中频、高频）
low_freq_relations = [r for r, cnt in relation_counter.items() if cnt <= 100]       # 低频关系阈值示例：<=100
high_freq_relations = [r for r, cnt in relation_counter.items() if cnt >= 5000]     # 高频关系阈值示例：>=5000
mid_freq_relations = [r for r in relation_counter if r not in set(low_freq_relations) | set(high_freq_relations)]

In [5]:
len(low_freq_relations),len(mid_freq_relations),len(high_freq_relations)

(97, 377, 26)

In [6]:
# 3. 初始化样本训练集三元组集合
sample_triples = []

In [7]:
# 3.1 全量纳入低频关系
for r in low_freq_relations:
    sample_triples.extend(rel2triples[r])

In [8]:
len(sample_triples)

5843

In [9]:
# 3.2 【修改处】按中频关系频次从低到高，依次加入它们的所有三元组，直到接近容量上限
remaining_capacity = 10000 - len(sample_triples)
mid_freq_relations.sort(key=lambda r: relation_counter[r])  # 频次升序

for relation in mid_freq_relations:
    cnt = relation_counter[relation]
    if cnt <= remaining_capacity:
        # 用 rel2triples 索引取出该关系的所有三元组
        sample_triples.extend(rel2triples[relation])
        remaining_capacity -= cnt
    else:
        # 加入下一个关系将超过容量，退出循环
        break

In [10]:
remaining_capacity

106

In [11]:
len(sample_triples)

9894

## ent2text

In [12]:
import csv

# —— 1. 读取实体和关系描述 —— #
ent2text = {}
with open('./knowledge_graph_completion/data/OpenBG500/OpenBG500_entity2text.tsv', 'r', encoding='utf-8') as f:
    for line in f:
        eid, desc = line.strip().split('\t', 1)
        ent2text[eid] = desc

rel2text = {}
with open('./knowledge_graph_completion/data/OpenBG500/OpenBG500_relation2text.tsv', 'r', encoding='utf-8') as f:
    for line in f:
        rid, desc = line.strip().split('\t', 1)
        rel2text[rid] = desc

In [13]:
train_mapped = []
for head, rel, tail in sample_triples:
    head_text = ent2text.get(head, head)      # 若无描述，则保留原 ID
    rel_text  = rel2text.get(rel, rel)
    tail_text = ent2text.get(tail, tail)
    train_mapped.append((head_text, rel_text, tail_text))

In [14]:
len(train_mapped)

9894

## split

In [15]:
import random
from collections import defaultdict


In [26]:
random.seed(42)

# 1. 参数
test_ratio = 0.12
total = len(train_mapped)
n_test = int(total * test_ratio)

In [27]:
n_test

1187

In [28]:
# 1. 按关系把样本索引分组
rel2indices = defaultdict(list)
for idx, (_, rel, _) in enumerate(train_mapped):
    rel2indices[rel].append(idx)

# 2. 在每个关系组内按比例选索引
test_indices = set()
for rel, idxs in rel2indices.items():
    k = max(1, int(len(idxs) * test_ratio))
    # 如果本组样本少于 k，就全拿，否则随机抽 k 个
    test_indices.update(random.sample(idxs, min(k, len(idxs))))

In [29]:
# 3. 调整到精确的 n_test
if len(test_indices) > n_test:
    test_indices = set(random.sample(test_indices, n_test))
elif len(test_indices) < n_test:
    all_indices = set(range(total))
    needed = n_test - len(test_indices)
    test_indices.update(random.sample(list(all_indices - test_indices), needed))

In [30]:
# 4. 根据索引构建测试集与训练集
test_triples  = [train_mapped[i] for i in sorted(test_indices)]
train_triples = [train_mapped[i] for i in range(total) if i not in test_indices]


In [31]:
# 6. 输出并校验
print(f"总样本: {total} 条")
print(f"测试集: {len(test_triples)} 条")
print(f"训练集: {len(train_triples)} 条")

总样本: 9894 条
测试集: 1187 条
训练集: 8707 条


## save

In [32]:
import csv

In [33]:
def save_triples_to_csv(triples, path):
    """
    将三元组列表保存为 CSV，带表头 head,relation,tail
    """
    with open(path, 'w', newline='', encoding='utf-8') as fp:
        writer = csv.writer(fp)
        writer.writerow(['head', 'relation', 'tail'])
        writer.writerows(triples)

In [34]:
# 保存训练/验证/测试集
save_triples_to_csv(train_triples, 'train.csv')
save_triples_to_csv(test_triples,   'test.csv')