In [None]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
import itertools
import os
from copy import deepcopy

In [None]:
def compare(a,b):
    a = int(a.strip("<>"))
    b = int(b.strip("<>"))
    if a<b:
        return 0
    if a==b:
        return 1
    if a>b:
        return 2
    assert False
    
def build_dicts(entities):
    entity2ind = dict()
    ind2entity = []
    for i in range(len(entities)):
        entity = entities[i]
        if not (entity in ind2entity):
            ind2entity.append(entity)
            entity2ind[entity] = len(ind2entity) - 1
    return ind2entity, entity2ind

In [None]:
vocab = []

num_entities = 1000
entities = ["<e_{}>".format(i) for i in range(num_entities)]
vocab = vocab + entities
ind2entity, entity2ind = build_dicts(entities)

num_attributes = 20
attributes = ["<attr_{}>".format(i) for i in range(num_attributes)]
vocab = vocab + attributes
ind2attribute, attribute2ind = build_dicts(attributes)

num_vals_per_attr = 20  # values range from [0, num_vals_per_attr-1]
values = ["<{}>".format(i) for i in range(num_vals_per_attr)]
vocab = vocab + values

# randomly assign values to people's attributes
atomic_KB = np.random.randint(low=0, high=num_vals_per_attr, size=(num_entities, num_attributes))     #  [entity id, attribute id] -> value

In [None]:
def rand_flip(tup):
    tup_l = list(tup)
    random.shuffle(tup_l)
    return tuple(tup_l)
    
def choose(arr, ratio_or_count):
    if type(ratio_or_count) == float:
        num = round(ratio_or_count*len(arr))
    elif type(ratio_or_count) == int:
        num = ratio_or_count
    else:
         assert False
    if num >= len(arr):
        return arr
    rand_inds = np.random.choice(len(arr), num, replace=False).tolist()
    return [arr[i] for i in rand_inds]
    
def split(arr, ratio):
    train, test = [], []
    rand_inds = np.random.choice(len(arr), round(ratio*len(arr)), replace=False).tolist()
    for i in range(len(arr)):
        if i in rand_inds:
            train.append(arr[i])
        else:
            test.append(arr[i])
    return [train, test]

In [None]:
# special tokens
vocab = vocab + ["<mask>", "<sep>", "<a>", "</a>", "<q>", "</q>"]

comp_q_tokens = attributes
comp2labels = dict()
for comp_q_token in comp_q_tokens:
    comp2labels[comp_q_token] = ["<"+comp_q_token.strip("<>")+"_{}>".format(i) for i in range(3)]
    vocab = vocab + comp2labels[comp_q_token]

In [None]:
assert len(vocab) == len(set(vocab))
print("vocab size:", len(vocab))

In [None]:
def format_atomic(entity, attr, val, t):
    val = "<{}>".format(val)
    input_text = "".join([entity, attr])
    target_text = input_text + "".join([val, "</a>"])
    return {
        "input_text": input_text,
        "target_text": target_text,
        "type": t,
    }

def format_comp(comp_q_token, ent_1, ent_2, label, t):
    input_text = "".join([comp_q_token, "<q>", ent_1, "<mask>", ent_2])
    target_text = input_text + "".join([label, "</a>"])
    return {
        "input_text": input_text,
        "target_text": target_text,
        "type": t,
    }

num_id_entities_ratio = 0.9

id_atomic_facts, ood_atomic_facts = [], []
train_inferred, test_inferred_iid, test_inferred_ood = [], [], []

def compare_ent(ent_1, ent_2, attr):
    val_1, val_2 = atomic_KB[entity2ind[ent_1], attribute2ind[attr]], atomic_KB[entity2ind[ent_2], attribute2ind[attr]]
    return compare("<{}>".format(val_1), "<{}>".format(val_2))

for comp_q_token in tqdm(comp_q_tokens):
    id_entities, ood_entities = split(entities, num_id_entities_ratio)
    id_entities, ood_entities = set(id_entities), set(ood_entities)

    for entity in id_entities:
        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]
        id_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='id_atomic'))

    for entity in ood_entities:
        val = atomic_KB[entity2ind[entity], attribute2ind[comp_q_token]]
        ood_atomic_facts.append(format_atomic(entity, comp_q_token, val, t='ood_atomic'))
    
    all_pairs = list(itertools.combinations(entities, 2))
    for (ent_1, ent_2) in all_pairs:
        if ent_1 in ood_entities and ent_2 in ood_entities:
            ty = 'test_inferred_ood'
            label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
            test_inferred_ood.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
            # flip
            label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
            test_inferred_ood.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
        elif ent_1 in id_entities and ent_2 in id_entities:
            if np.random.uniform() < 0.1:
                ty = 'test_inferred_iid'
                label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                test_inferred_iid.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                # flip
                label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                test_inferred_iid.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
            else:
                ty = 'train_inferred'
                label = comp2labels[comp_q_token][compare_ent(ent_1, ent_2, comp_q_token)]
                train_inferred.append(format_comp(comp_q_token, ent_1, ent_2, label, t=ty))
                # flip
                label = comp2labels[comp_q_token][compare_ent(ent_2, ent_1, comp_q_token)]
                train_inferred.append(format_comp(comp_q_token, ent_2, ent_1, label, t=ty))
        else:
            pass

print(len(id_atomic_facts), len(ood_atomic_facts), "|", len(train_inferred), len(test_inferred_iid)), len(test_inferred_ood)

In [None]:
test_size = 3000
comp_facts_test_ds = choose(test_inferred_ood, test_size)

probes = []
probes = probes + comp_facts_test_ds
probes = probes + choose(id_atomic_facts, test_size)
probes = probes + choose(ood_atomic_facts, test_size)
probes = probes + choose(test_inferred_iid, test_size)

In [None]:
# downsampling inferred facts included in training
for inf_atom_ratio in [12.6,9.0,7.2,3.6]:
    dataset_name = "comparison.{}.{}".format(num_entities, inf_atom_ratio)
    os.makedirs("data/{}".format(dataset_name), exist_ok=True)

    train_inferred_ds = choose(train_inferred, round(inf_atom_ratio*len(id_atomic_facts)))

    probes_ = probes + choose(train_inferred_ds, test_size)

    print("train/test atomic, # train inferred:", len(id_atomic_facts), len(ood_atomic_facts), len(train_inferred_ds))
    with open("data/{}/train.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(id_atomic_facts + ood_atomic_facts + train_inferred_ds, f)
    with open("data/{}/valid.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(comp_facts_test_ds, f)
    with open("data/{}/test.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(probes_, f)
    # add vocab
    with open("data/{}/vocab.json".format(dataset_name), "w", encoding='utf-8') as f:
        json.dump(vocab, f)