In [27]:
%matplotlib inline
import os
import random
import numpy as np
from collections import defaultdict as ddict
from tqdm import tqdm
import pickle

## load dataset

In [28]:
# data_path = '../WN18RR'
# data_path = '../FB15K237'
data_path = '../NELL-995'

In [29]:
ent2id_file = open(os.path.join(data_path, 'entity2id.txt'))
ent2id = dict()
num_ent = int(ent2id_file.readline())
for line in ent2id_file.readlines():
    ent, idx = line.split()
    ent2id[ent] = int(idx)
id2ent = {v: k for k, v in ent2id.items()}
    
rel2id = dict()
rel2id_file = open(os.path.join(data_path, 'relation2id.txt'))
num_rel = int(rel2id_file.readline())
for line in rel2id_file.readlines():
    rel, idx = line.split()
    rel2id[rel] = int(idx)
id2rel = {v: k for k, v in rel2id.items()}

In [11]:
triples = []

train2id_file = open(os.path.join(data_path, 'train2id.txt'))
num_train = int(train2id_file.readline())
train_triples = []
for line in train2id_file.readlines():
    line = map(lambda x: int(x), line.split())
    h, t, r = line
    triples.append([h, r, t])
    train_triples.append([h, r, t])

valid2id_file = open(os.path.join(data_path, 'valid2id.txt'))
num_valid = int(valid2id_file.readline())
valid_triples = []
for line in valid2id_file.readlines():
    line = map(lambda x: int(x), line.split())
    h, t, r = line
    triples.append([h, r, t])
    valid_triples.append([h, r, t])

test2id_file = open(os.path.join(data_path, 'test2id.txt'))
num_test = int(test2id_file.readline())
test_triples = []
for line in test2id_file.readlines():
    line = map(lambda x: int(x), line.split())
    h, t, r = line
    triples.append([h, r, t])
    test_triples.append([h, r, t])

In [16]:
triples = np.array(triples)

## random split relations

In [17]:
num_client = 3
rel_pool = np.unique(triples[:,1])

client_rel = []
num_client_rel = round(len(rel_pool) / num_client)

for i in range(num_client):
    client_rel.append([])
    if i != num_client - 1:
        client_rel[i] = (np.random.choice(rel_pool, num_client_rel, replace=False))
        rel_pool = np.setdiff1d(rel_pool, client_rel, assume_unique=True)
    else:
        client_rel[i] = rel_pool

## split triples into client by relation

In [18]:
client_triples = [[] for i in range(num_client)]

for tri in triples.tolist():
    h, r, t = tri
    for i in range(num_client):
        if r in client_rel[i]:
            client_triples[i].append(tri)
            break

## split train/valid/test in client

In [25]:
client_data = []

for client_idx in tqdm(range(num_client)):
    all_triples = client_triples[client_idx]

    triples_reidx = []
    ent_reidx = dict()
    rel_reidx = dict()
    entidx = 0
    relidx = 0

    ent_freq = ddict(int)
    rel_freq = ddict(int)

    for tri in all_triples:
        h, r, t = tri
        ent_freq[h] += 1
        ent_freq[t] += 1
        rel_freq[r] += 1
        if h not in ent_reidx.keys():
            ent_reidx[h] = entidx
            entidx += 1
        if t not in ent_reidx.keys():
            ent_reidx[t] = entidx
            entidx += 1
        if r not in rel_reidx.keys():
            rel_reidx[r] = relidx
            relidx += 1
        triples_reidx.append([h, r, t, ent_reidx[h], rel_reidx[r], ent_reidx[t]])

    client_train_triples = []
    client_valid_triples = []
    client_test_triples = []

    random.shuffle(triples_reidx)
    for idx, tri in enumerate(triples_reidx):
        h, r, t, _, _, _ = tri
        if ent_freq[h] > 2 and ent_freq[t] > 2 and rel_freq[r] > 2:
            client_test_triples.append(tri)
            ent_freq[h] -= 1
            ent_freq[t] -= 1
            rel_freq[r] -= 1
        else:
            client_train_triples.append(tri)
        if len(client_test_triples) > int(len(triples_reidx) * 0.2):
            break
    client_train_triples.extend(triples_reidx[idx+1:])

    random.shuffle(client_test_triples)
    test_len = len(client_test_triples)
    client_valid_triples = client_test_triples[:int(test_len/2)]
    client_test_triples = client_test_triples[int(test_len/2):] 

    train_edge_index_ori = np.array(client_train_triples)[:, [0, 2]].T
    train_edge_type_ori = np.array(client_train_triples)[:, 1].T
    train_edge_index = np.array(client_train_triples)[:, [3, 5]].T
    train_edge_type = np.array(client_train_triples)[:, 4].T

    valid_edge_index_ori = np.array(client_valid_triples)[:, [0, 2]].T
    valid_edge_type_ori = np.array(client_valid_triples)[:, 1].T
    valid_edge_index = np.array(client_valid_triples)[:, [3, 5]].T
    valid_edge_type = np.array(client_valid_triples)[:, 4].T

    test_edge_index_ori = np.array(client_test_triples)[:, [0, 2]].T
    test_edge_type_ori = np.array(client_test_triples)[:, 1].T
    test_edge_index = np.array(client_test_triples)[:, [3, 5]].T
    test_edge_type = np.array(client_test_triples)[:, 4].T

    client_data_dict = {'train': {'edge_index': train_edge_index, 'edge_type': train_edge_type, 
                          'edge_index_ori': train_edge_index_ori, 'edge_type_ori': train_edge_type_ori},
                'test': {'edge_index': test_edge_index, 'edge_type': test_edge_type, 
                         'edge_index_ori': test_edge_index_ori, 'edge_type_ori': test_edge_type_ori},
                'valid': {'edge_index': valid_edge_index, 'edge_type': valid_edge_type, 
                      'edge_index_ori': valid_edge_index_ori, 'edge_type_ori': valid_edge_type_ori}}

    client_data.append(client_data_dict)

100%|██████████| 3/3 [00:01<00:00,  1.90it/s]


## save dataset

In [30]:
pickle.dump(client_data, open('test.pkl', 'wb'))