In [None]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy
import networkx as nx

In [None]:
def compare(a,b):
    a = int(a.strip("<>"))
    b = int(b.strip("<>"))
    if a<b:
        return 0
    if a==b:
        return 1
    if a>b:
        return 2
    assert False
    
def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

In [None]:
vocab = []

num_entities = 1000
entities = ["<e_{}>".format(i) for i in range(num_entities)]
vocab = vocab + entities
ind2entity, entity2ind = build_dicts(entities)

num_attributes = 20
attributes = ["<attr_{}>".format(i) for i in range(num_attributes)]
vocab = vocab + attributes
ind2attribute, attribute2ind = build_dicts(attributes)

num_vals_per_attr = 20  # values range from [0, num_vals_per_attr-1]
values = ["<{}>".format(i) for i in range(num_vals_per_attr)]
vocab = vocab + values

# randomly assign values to people's attributes
atomic_KB = np.random.randint(low=0, high=num_vals_per_attr, size=(num_entities, num_attributes))     #  [entity id, attribute id] -> value

In [None]:
def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]
    
def split(arr, ratio):
    train, test = [], []
    rand_inds = np.random.choice(len(arr), round(ratio*len(arr)), replace=False).tolist()
    for i in range(len(arr)):
        if i in rand_inds:
            train.append(arr[i])
        else:
            test.append(arr[i])
    return [train, test]

In [None]:
# special tokens
vocab = vocab + ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]

comp_q_tokens = attributes
comp2labels = dict()
for comp_q_token in comp_q_tokens:
    comp2labels[comp_q_token] = ["<"+comp_q_token.strip("<>")+"_{}>".format(i) for i in range(3)]
    vocab = vocab + comp2labels[comp_q_token]

In [None]:
# ensure uniqueness of new vocab
assert len(vocab) == len(set(vocab))
print("vocab size:", len(vocab))

In [None]:
def format_atomic(entity, attr, val, t):
    val = "<{}>".format(val)
    input_text = "".join(["<q>", entity, attr, "<mask>", "</q>"])
    target_text = input_text + "".join(["<a>", val, "</a>"])
    return {
        "input_text": input_text,
        "target_text": target_text,
        "type": t,
    }

def format_comp(comp_q_token, ent_1, ent_2, label, t):
    input_text = "".join([comp_q_token, "<q>", ent_1, "<mask>", ent_2, "</q>"])
    target_text = input_text + "".join(["<a>", label, "</a>"])
    return {
        "input_text": input_text,
        "target_text": target_text,
        "type": t,
    }

num_id_entities_ratio = 0.9

id_atomic_facts, ood_atomic_facts = [], []
train_inferred_topleft, train_inferred_cross, test_inferred_ood, test_inferred_iid = [], [], [], []

def compare_ent(ent_1, ent_2, attr):
    val_1, val_2 = atomic_KB[entity2ind[ent_1], attribute2ind[attr]], atomic_KB[entity2ind[ent_2], attribute2ind[attr]]
    return compare("<{}>".format(val_1), "<{}>".format(val_2))

for comp_q_token in tqdm(comp_q_tokens):
    # randomly partition the entities
    id_entities, ood_entities = split(entities, num_id_entities_ratio)

    # all attribute values
    for entity in id_entities:
        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]
        id_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='id_atomic'))

    for entity in ood_entities:
        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]
        ood_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='ood_atomic'))
    
    # add all pairs of entities for all
    all_pairs = list(itertools.combinations(entities, 2))
    for (ent_1, ent_2) in all_pairs:
        if ent_1 in ood_entities and ent_2 in ood_entities:
            ty = 'test_inferred_ood'
            label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
            test_inferred_ood.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
            # flip
            label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
            test_inferred_ood.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
        elif ent_1 in id_entities and ent_2 in id_entities:
            if np.random.uniform() < 0.1:
                ty = 'test_inferred_iid'
                label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                test_inferred_iid.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                # flip
                label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                test_inferred_iid.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
            else:
                ty = 'train_inferred_topleft'
                label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                train_inferred_topleft.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                # flip
                label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                train_inferred_topleft.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
        else:
            ty = 'train_inferred_cross'
            label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
            train_inferred_cross.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
            # flip
            label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
            train_inferred_cross.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))

print(len(id_atomic_facts), len(ood_atomic_facts), "|", len(train_inferred_topleft), len(train_inferred_cross), len(test_inferred_ood), len(test_inferred_iid))

In [None]:
probes = []

test_size = 3000
comp_facts_test_ds = choose(test_inferred_ood, test_size)
probes = probes + comp_facts_test_ds

probes = probes + choose(id_atomic_facts, test_size)
probes = probes + choose(test_inferred_iid, test_size)
probes = probes + ood_atomic_facts

In [None]:
# downsampling inferred facts included in training
train_downsampling = 0.03
dataset_name = "cplx_reasoning"     # 
os.makedirs("data/{}".format(dataset_name), exist_ok=True)

train_inferred_topleft_ds = choose(train_inferred_topleft, train_downsampling)
train_inferred_cross_ds = choose(train_inferred_cross, train_downsampling)

probes = probes + choose(train_inferred_topleft_ds, test_size)
probes = probes + choose(train_inferred_cross_ds, test_size)

with open("data/{}/train.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(id_atomic_facts + train_inferred_topleft_ds + train_inferred_cross_ds, f)
with open("data/{}/valid.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(comp_facts_test_ds, f)
with open("data/{}/vocab.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(vocab, f)

In [None]:
# prepare the test set
from itertools import permutations

with open("data/{}/train.json".format(dataset_name)) as f:
    train = json.load(f)

train_dict = dict()
all_atomics = dict()
for item in tqdm(train):
    t = item['type']
    attr = int(item['input_text'].split("attr_")[1].split(">")[0])
    assert 0 <= attr <= 19
    if t not in train_dict:
        train_dict[t] = dict()
    if attr not in train_dict[t]:
        train_dict[t][attr] = []
    train_dict[t][attr].append(item)

# (attr, e1, e2, label)
solvable_dict = set()   
for attr in tqdm(range(20)):

    train_inferred_cross = train_dict['train_inferred_cross'][attr]
    G = nx.DiGraph()   # a->b iff a>=b
    for item in train_inferred_cross:
        temp = item['target_text'].strip("><").split("><")
        e1, e2, label = temp[2], temp[4], int(temp[7][-1])
        if label == 2:
            G.add_edge(e1, e2)
        if label == 1:
            G.add_edge(e1, e2)
            G.add_edge(e2, e1)
        if label == 0:
            G.add_edge(e2, e1)

    train_atomic = train_dict['id_atomic'][attr]
    atomics = dict()
    for item in train_atomic:
        temp = item['target_text'].strip("><").split("><")
        e, val = temp[1], temp[6]
        atomics[e] = int(val)
        assert 0 <= atomics[e] <= 19

    ents = list(atomics.keys())
    for i in range(len(ents)-1):
        for j in range(i+1, len(ents)):
            e1, e2 = ents[i], ents[j]
            if atomics[e1] > atomics[e2]:
                G.add_edge(e1, e2)
            if atomics[e1] == atomics[e2]:
                G.add_edge(e1, e2)
                G.add_edge(e2, e1)
            if atomics[e1] < atomics[e2]:
                G.add_edge(e2, e1)
    all_atomics[attr] = atomics

    all_nodes = set(G.nodes)
    ood_nodes = list(all_nodes - set(atomics.keys()))
    assert len(ood_nodes) == 100
    # Testing
    for (e1, e2) in list(permutations(ood_nodes, 2)):
        greater_or_equal = int(nx.has_path(G, e1, e2))
        smaller_or_equal = int(nx.has_path(G, e2, e1))
        if greater_or_equal or smaller_or_equal:
            if greater_or_equal and smaller_or_equal:
                res = 1
            elif greater_or_equal:
                res = 2
            elif smaller_or_equal:
                res = 0
            solvable_dict.add((attr, e1, e2, res))

easy_dict = set()
for attr in tqdm(range(20)):

    train_inferred = train_dict['train_inferred_cross'][attr] + train_dict['train_inferred_topleft'][attr]

    G = nx.DiGraph()
    for item in train_inferred:
        temp = item['target_text'].strip("><").split("><")
        e1, e2, label = temp[2], temp[4], int(temp[7][-1])
        if label == 2:
            G.add_edge(e1, e2)
        if label == 1:
            G.add_edge(e1, e2)
            G.add_edge(e2, e1)
        if label == 0:
            G.add_edge(e2, e1)

    train_atomic = train_dict['id_atomic'][attr]
    atomics = dict()
    for item in train_atomic:
        temp = item['target_text'].strip("><").split("><")
        e, val = temp[1], temp[6]
        atomics[e] = int(val)
        assert 0 <= atomics[e] <= 19

    all_nodes = set(G.nodes)
    ood_nodes = list(all_nodes - set(atomics.keys()))
    assert len(ood_nodes) == 100
    # Testing
    for (e1, e2) in list(permutations(ood_nodes, 2)):
        greater_or_equal = int(nx.has_path(G, e1, e2))
        smaller_or_equal = int(nx.has_path(G, e2, e1))
        if greater_or_equal or smaller_or_equal:
            if greater_or_equal and smaller_or_equal:
                res = 1
            elif greater_or_equal:
                res = 2
            elif smaller_or_equal:
                res = 0
            easy_dict.add((attr, e1, e2, res))


solvable_hard_set = solvable_dict - easy_dict
print(len(solvable_hard_set))

group_by_label = dict()
for g in solvable_hard_set:
    key = g[-1]
    if key not in group_by_label:
        group_by_label[key] = []
    group_by_label[key].append(g)

test_hard = choose(group_by_label[0], 50) + choose(group_by_label[1], 50) + choose(group_by_label[2], 50)
test = []
for (attr, e1, e2, label) in test_hard:
    test.append(
        format_comp("<attr_{}>".format(attr), "<"+e1+">", "<"+e2+">", "<attr_{}_{}>".format(attr, label), 'test_hard'))


probes = probes + test
with open("data/{}/test.json".format(dataset_name), "w", encoding='utf-8') as f:
    json.dump(probes, f)