In [None]:
import json
import numpy as np
import random
from tqdm.auto import tqdm
from utils import choose, rand_matching, kagebunshin, random_pairing
import matplotlib.pyplot as plt
import os

num_GA_entities = 50000           # number of group A entities

# multiplicity
comb = None     # abstract concept level
# comb = 10      # attaching surface-form names

num_GB_entities = int(num_GA_entities * 0.6)
GA_entities = ["<s_{}>".format(i) for i in range(num_GA_entities)]
GB_entities = ["<t_{}>".format(i) for i in range(num_GB_entities)]
all_entities = GA_entities + GB_entities

In [None]:
def build_token_dict(l, num_token=1, comb=None):
    """
    l: a list of distinct entity names
    num_token: number of tokens for each entity
    comb: multiplicity of f/l names (if not None)

    returns: a dict mapping each entitiy in l to its token list
    """
    if num_token == 1:
        # identity map
        assert comb is None
        return {e: [e] for e in l}
    
    assert not (comb is None)
    # add multiplicity for first two tokens, then fill the tails with unique tokens
    assert num_token >= 2
    num_tail_tokens = num_token - 2
    assert type(comb) == int and len(l) % comb == 0
    token_dict = dict()
    
    first, last = ["<f_{}>".format(i) for i in range(len(l)//comb)], ["<l_{}>".format(i) for i in range(len(l)//comb)]
    first, last = kagebunshin(first, comb), kagebunshin(last, comb)
    if comb == 1:
        tokens = [(first[jj], last[jj]) for jj in range(len(first))]
    else:
        tokens = rand_matching(first, last)
    assert len(tokens) == len(l)

    for j in range(len(l)):
        e = l[j]
        # sanity check
        assert e.count("<") == e.count(">") == 1
        assert e[0] == "<" and e[-1] == ">"
        tail = ["<{}_{}>".format(e.strip("><"), i) for i in range(num_tail_tokens)]
        token_dict[e] = list(tokens[j]) + tail
    return token_dict
    
def form_items(arr, ty, entity_token_dict, vocab):
    h,r,t = arr

    item_list = []
    input_text = [[vocab[tok] for tok in entity_token_dict[h]], [vocab[r]], [vocab["<mask>"]]]
    target_text = [[vocab[tok] for tok in entity_token_dict[t]]]
    
    item_list.append({
        "input_text": input_text,
        "target_text": target_text,
        "type": ty,
    })
    return item_list

In [None]:
if comb is None:
    ent_token_dict = build_token_dict(all_entities)
else:
    assert type(comb) == int
    ent_token_dict = build_token_dict(all_entities, num_token=2, comb=comb)

ent_token_set = set()
for v in ent_token_dict.values():
    ent_token_set |= set(v)

vocab = dict()
for tok in list(ent_token_set) + ["<mask>"]:
    assert tok not in vocab
    vocab[tok] = len(vocab)

In [None]:
num_rel_pairs = 6
train, atomic, test = [], [], []

for i in tqdm(range(num_rel_pairs)):
    rel, rel_inv = "<r_{}>".format(i), "<r_{}_inv>".format(i)     # add _inv if want r != r^{-1}
    for tok in [rel, rel_inv]:
        assert tok not in vocab
        vocab[tok] = len(vocab)

    for (ent1, ent2) in random_pairing(GA_entities):
        # add both directions into training set for learning the rules
        train += form_items([ent1, rel, ent2], 'train', entity_token_dict=ent_token_dict, vocab=vocab)
        train += form_items([ent2, rel_inv, ent1], 'train', entity_token_dict=ent_token_dict, vocab=vocab)

    for (ent1, ent2) in random_pairing(GB_entities):
        # add one direction into training set, and the other into test set
        if random.uniform(0,1) <= 0.5:
            atomic += form_items([ent1, rel, ent2], 'atomic', entity_token_dict=ent_token_dict, vocab=vocab)
            test += form_items([ent2, rel_inv, ent1], 'test', entity_token_dict=ent_token_dict, vocab=vocab)
        else:
            atomic += form_items([ent2, rel_inv, ent1], 'atomic', entity_token_dict=ent_token_dict, vocab=vocab)
            test += form_items([ent1, rel, ent2], 'test', entity_token_dict=ent_token_dict, vocab=vocab)

print("train/atomic/test:", len(train), len(atomic), len(test))

In [None]:
# pad token at the end
for tok in ["<pad>"]:
    assert tok not in vocab
    vocab[tok] = len(vocab)

assert len(vocab) == len(set(vocab.keys())) == len(set(vocab.values()))
print(len(vocab))

In [None]:
test_size = 3000

if comb is None:
    dataset = "inversionid.{}.{}".format(num_GA_entities, num_GB_entities)
else:
    dataset = "inversionidcomb{}.{}.{}".format(comb, num_GA_entities, num_GB_entities)

os.makedirs("data/{}".format(dataset), exist_ok=True)

with open("data/{}/train.json".format(dataset), "w", encoding='utf-8') as f:
    json.dump(train + atomic, f)

probes = {
    "train": choose(train, test_size),
    "atomic": choose(atomic, test_size),
    "test": choose(test, test_size),
}

with open("data/{}/valid.json".format(dataset), "w", encoding='utf-8') as f:
    json.dump(probes, f)
with open("data/{}/vocab.json".format(dataset), "w", encoding='utf-8') as f:
    json.dump(vocab, f)