In [None]:
import json
import numpy as np
from tqdm.auto import tqdm
import os
from copy import deepcopy
import matplotlib.pyplot as plt
from utils import choose, rand_matching, kagebunshin, split
from collections import defaultdict
import math

In [None]:
P = 10007                      # prime in modular addition: all values are from 0 ~ P-1 and mod P
core_v_multiplicity = 10
num_core_entities = P * core_v_multiplicity

all_values = ["<{}>".format(i) for i in range(P)]
core_entities = ["<c_{}>".format(i) for i in range(num_core_entities)]
# assign values to core entities
temp = all_values * core_v_multiplicity
core_entity2values = dict()
for i in range(len(core_entities)):
    core_entity2values[core_entities[i]] = temp[i]

In [None]:
def build_token_dict(l, num_token=1, comb=None):
    """
    l: a list of distinct entity names
    num_token: number of tokens for each entity
    comb: multiplicity of f/l names (if not None)

    returns: a dict mapping each entitiy in l to its token list
    """
    if num_token == 1:
        # identity map
        assert comb is None
        return {e: [e] for e in l}
    
    assert not (comb is None)
    # add multiplicity for first two tokens, then fill the tails with unique tokens
    # comb: number of distinct first/second tokens
    assert num_token >= 2
    num_tail_tokens = num_token - 2
    assert type(comb) == int
    token_dict = dict()

    tmp = math.ceil(len(l)/comb)
    
    first, last = ["<f_{}>".format(i) for i in range(tmp)], ["<l_{}>".format(i) for i in range(tmp)]
    first, last = kagebunshin(first, comb), kagebunshin(last, comb)
    if comb == 1:
        tokens = [(first[jj], last[jj]) for jj in range(len(first))]
    else:
        tokens = rand_matching(first, last)
    tokens = tokens[:len(l)]

    for j in range(len(l)):
        e = l[j]
        # sanity check
        assert e.count("<") == e.count(">") == 1
        assert e[0] == "<" and e[-1] == ">"
        tail = ["<{}_{}>".format(e.strip("><"), i) for i in range(num_tail_tokens)]
        token_dict[e] = list(tokens[j]) + tail
    return token_dict

def mod_add(a_, b_, P_=P):
    if type(a_) == type(b_) == int:
        c = a_ + b_
        if c < 0:
            c += P_
        elif c >= P_:
            c -= P_
        assert 0 <= c <= P_-1
        return c

    assert a_.startswith('<') and a_.endswith(">") and b_.startswith('<') and b_.endswith(">")
    a, b = int(a_[1:-1]), int(b_[1:-1])
    assert 0 <= a <= P_-1 and 0 <= b <= P_-1
    c = a + b
    if c < 0:
        c += P_
    elif c >= P_:
        c -= P_
    assert 0 <= c <= P_-1
    return '<'+str(c)+'>'

def mod_subtract(a_, b_, P_=P):
    if type(a_) == type(b_) == int:
        c = a_ - b_
        if c < 0:
            c += P_
        elif c >= P_:
            c -= P_
        assert 0 <= c <= P_-1
        return c

    assert a_.startswith('<') and a_.endswith(">") and b_.startswith('<') and b_.endswith(">")
    a, b = int(a_[1:-1]), int(b_[1:-1])
    assert 0 <= a <= P_-1 and 0 <= b <= P_-1
    c = a - b
    if c < 0:
        c += P_
    elif c >= P_:
        c -= P_
    assert 0 <= c <= P_-1
    return '<'+str(c)+'>'

In [None]:
secondarg_values_numerical = choose(list(range(10, 50)), 2)   # choose two small values (excluding 0)
X_value_lower_bound, X_value_upper_bound = 200, 800

values_banned_source = set(secondarg_values_numerical)   
max_ = 4
for _ in range(max_):
    for v1 in deepcopy(values_banned_source):
        for v2 in secondarg_values_numerical:
            values_banned_source.add(mod_subtract(v1, v2))

values_banned_target = deepcopy(values_banned_source)
for _ in range(max_):
    for v1 in deepcopy(values_banned_target):
        for v2 in secondarg_values_numerical:
            values_banned_target.add(mod_add(v1, v2))
            values_banned_target.add(mod_subtract(v1, v2))
print(secondarg_values_numerical, len(values_banned_source), len(values_banned_target))
values_available_target = list(set(range(P)) - values_banned_target)

values_available_target = [i for i in values_available_target if X_value_lower_bound<=i<=X_value_upper_bound]

In [None]:
branching_factor = 10
num_X_entities = 20           # number of problem instances
X_entities = ["<x_{}>".format(i) for i in range(num_X_entities)]
# assign random distinct values for testing. small ones to make calculations easy
X_values = choose(values_available_target, num_X_entities)
X_values = [f'<{i}>' for i in X_values]

In [None]:
num_bridge1_entities = branching_factor * num_X_entities
bridge1_entities = ["<br1_{}>".format(i) for i in range(num_bridge1_entities)]

num_bridge2_entities = branching_factor**2 * num_X_entities
bridge2_entities = ["<br2_{}>".format(i) for i in range(num_bridge2_entities)]

In [None]:
all_entities = core_entities + bridge1_entities + bridge2_entities + X_entities

comb = None
ent_token_dict = build_token_dict(all_entities)
# comb = 10
# ent_token_dict = build_token_dict(all_entities, num_token=2, comb=comb)

ent_token_set = []
for v in ent_token_dict.values():
    ent_token_set.extend(v)

if len(ent_token_set) != len(set(ent_token_set)):
    ent_token_set = list(set(ent_token_set))

vocab = dict()    # map token to id
for tok in all_values + ent_token_set + ["<=>", "<pad>"]:  # numbers always come first.       
    assert tok not in vocab
    vocab[tok] = len(vocab)
id2token = {val:key for key, val in vocab.items()}
len(vocab)

In [None]:
def form_attr_item(arr, ty, entity_token_dict, vocab):
    entity, value = arr
    input_text = [[vocab[tok] for tok in entity_token_dict[entity]], [vocab['<=>']]]
    target_text = [[vocab[value]]]
    item = {
        "input_text": input_text,
        "target_text": target_text,
        "type": ty,
    }
    return item

def form_rel_item(arr, ty, entity_token_dict, vocab):
    e1, e2, e3 = arr
    if np.random.uniform() < 0.5:
        input_text = [[vocab[tok] for tok in entity_token_dict[e1]], [vocab[tok] for tok in entity_token_dict[e2]]]
    else:
        input_text = [[vocab[tok] for tok in entity_token_dict[e2]], [vocab[tok] for tok in entity_token_dict[e1]]]
    target_text = [[vocab[tok] for tok in entity_token_dict[e3]]]
    item = {
        "input_text": input_text,
        "target_text": target_text,
        "type": ty,
    }
    return item

In [None]:
X_attr = []       # test
for ent, val in zip(X_entities, X_values):
    X_attr.append(form_attr_item([ent, val], 'X_attr', entity_token_dict=ent_token_dict, vocab=vocab))

In [None]:
# map each numerical value or its string version to a list of core entities with that value 
v2group = dict()
for v_added in range(P):
    group = []
    for k in range(core_v_multiplicity):
        group.append(core_entities[k*P + v_added])
        assert core_entity2values[group[-1]] == f'<{v_added}>'
    v2group[v_added] = group
    v2group[f'<{v_added}>'] = group

In [None]:
# training data for teaching the rules
downsampling_ratio = 0.3
core_rel_full = []
# rel facts between ID core entities (through available 2nd arg values)
for i in tqdm(range(P)):
    if i in values_banned_source:
        continue
    for v_added in secondarg_values_numerical:
        j = mod_add(i, v_added)
        group1, group2, group3 = v2group[i], v2group[v_added], v2group[j]
        for e1 in group1:
            for e3 in group3:
                # randomly pick a group2 ent.
                e2 = choose(group2, 1)[0]
                core_rel_full.append(form_rel_item([e1, e2, e3], 'core_rel', entity_token_dict=ent_token_dict, vocab=vocab))
print(len(core_rel_full))
core_rel, core_rel_test = split(core_rel_full, downsampling_ratio)

In [None]:
core_attr = []
for ent in core_entities:
    core_attr.append(form_attr_item([ent, core_entity2values[ent]], 'core_attr', entity_token_dict=ent_token_dict, vocab=vocab))
len(core_rel), len(core_attr)

In [None]:
drop_prob = 0.6     # chance of poisoning
drop_backward_prob = 0.5
num_core_connect_to_leaf = 6

br1_attr = []
br2_attr = []

edge2fact = [defaultdict(set) for _ in range(len(X_entities))]

pointer_br1, pointer_br2 = 0, 0
for iii in range(len(X_entities)):

    core_ents_used = set()
    all_values_used = set()

    x_ent, x_val = X_entities[iii], X_values[iii]
    all_values_used.add(vocab[x_val])

    nodes_dropall = choose(list(range(branching_factor)), round(drop_prob * branching_factor))
    for node_id in range(branching_factor):
        # grab a new bridge1 entity
        e_br1 = bridge1_entities[pointer_br1]
        # decide whether to nullify all children of the node
        doomify = node_id in nodes_dropall
        # randomly decide the 2nd arg and the entity from available ones
        v_added = id2token[choose(secondarg_values_numerical, 1)[0]]
        e2 = choose(v2group[v_added], 1)[0]
        ent_pair = (x_ent, e_br1)
        if np.random.uniform() < 0.5:
            v_br1 = mod_subtract(x_val, v_added)
            br1_attr.append(form_attr_item([e_br1, v_br1], 'br1_attr', entity_token_dict=ent_token_dict, vocab=vocab))
            edge2fact[iii][ent_pair].add((e_br1, e2, x_ent))
        else:
            v_br1 = mod_add(x_val, v_added)
            br1_attr.append(form_attr_item([e_br1, v_br1], 'br1_attr', entity_token_dict=ent_token_dict, vocab=vocab))
            edge2fact[iii][ent_pair].add((x_ent, e2, e_br1))
        core_ents_used |= {e2}
        all_values_used.add(vocab[v_br1])
        assert vocab[v_br1] not in values_banned_source

        nodes_dropped_2 = choose(list(range(branching_factor)), round(drop_prob * branching_factor))
        for node_id_2 in range(branching_factor):
            # whether nullify this node
            nullify = doomify or (node_id_2 in nodes_dropped_2)
            # cut backward or forward (if nullify)
            if np.random.uniform() < drop_backward_prob:
                drop_backward = True
            else:
                drop_backward = False

            e_br2 = bridge2_entities[pointer_br2]
            v_added = id2token[choose(secondarg_values_numerical, 1)[0]]
            e2 = choose(v2group[v_added], 1)[0]
            ent_pair = (e_br1, e_br2)
            if np.random.uniform() < 0.5:
                v_br2 = mod_subtract(v_br1, v_added)
                br2_attr.append(form_attr_item([e_br2, v_br2], 'br2_attr', entity_token_dict=ent_token_dict, vocab=vocab))
                if nullify and drop_backward:
                    pass
                else:
                    edge2fact[iii][ent_pair].add((e_br2, e2, e_br1))
            else:
                v_br2 = mod_add(v_br1, v_added)
                br2_attr.append(form_attr_item([e_br2, v_br2], 'br2_attr', entity_token_dict=ent_token_dict, vocab=vocab))
                if nullify and drop_backward:
                    pass
                else:
                    edge2fact[iii][ent_pair].add((e_br1, e2, e_br2))
            core_ents_used |= {e2}
            all_values_used.add(vocab[v_br2])
            assert vocab[v_br2] not in values_banned_source

            if nullify and not drop_backward:
                # drop forward: no connection to nodes with known values
                pass
            else:
                # link to known entities. constant branching factor here.
                for _ in range(num_core_connect_to_leaf):
                    # randomly decide the value to take and randomly assign the source entity from core entities
                    v_added = id2token[choose(secondarg_values_numerical, 1)[0]]
                    e2 = choose(v2group[v_added], 1)[0]
                    if np.random.uniform() < 0.5:
                        v_taken = mod_subtract(v_br2, v_added)
                        e1 = choose(v2group[v_taken], 1)[0]
                        edge2fact[iii][(e_br2, e1)].add((e1, e2, e_br2))
                    else:
                        v_taken = mod_add(v_br2, v_added)
                        e1 = choose(v2group[v_taken], 1)[0]
                        edge2fact[iii][(e_br2, e1)].add((e_br2, e2, e1))
                    core_ents_used |= {e1, e2}
                    all_values_used.add(vocab[v_taken])
                    assert vocab[v_taken] not in values_banned_source

            pointer_br2 += 1
        pointer_br1 += 1
    print(f"{len(core_ents_used)} core entities are used. {len(all_values_used)} distinct values for nodes; max: {max(all_values_used)}, min: {min(all_values_used)}")

assert pointer_br1 == num_bridge1_entities and pointer_br2 == num_bridge2_entities

In [None]:
# """
# dump the facts for synthesizing data for LLM testing 
# """
# all_facts = dict()
# for iii in range(len(edge2fact)):
#     all_facts[iii] = set()
#     count = 0
#     for _, val in edge2fact[iii].items():
#         count += len(val)
#         for temp in val:
#             all_facts[iii].add(temp)
#     assert count == len(all_facts[iii])
#     all_facts[iii] = list(all_facts[iii])
# with open(f"saved_files/all_facts_{branching_factor}.json", "w", encoding='utf-8') as f:
#     json.dump(all_facts, f)
# with open(f"saved_files/X_values_{branching_factor}.json", "w", encoding='utf-8') as f:
#     json.dump(X_values, f)

In [None]:
X_rel = []      # in train
for d in edge2fact:
    for _, val in d.items():
        for (e1, e2, e3) in val:
            X_rel.append(form_rel_item([e1, e2, e3], 'X_rel', entity_token_dict=ent_token_dict, vocab=vocab))
print(len(core_rel), len(core_attr), len(X_rel), len(br1_attr), len(br2_attr), len(X_attr))

In [None]:
test_size = 512 * 19

if comb is None:
    dataset = f"gsmfinal.{P}.{branching_factor}.{drop_prob}"
else:
    dataset = f"gsmfinalcomb{comb}.{P}.{branching_factor}.{drop_prob}"
os.makedirs("data/{}".format(dataset), exist_ok=True)

with open("data/{}/train.json".format(dataset), "w", encoding='utf-8') as f:
    json.dump(core_rel + core_attr + X_rel, f)

probes = {
    "core_rel": choose(core_rel, test_size),
    "core_rel_test": choose(core_rel_test, test_size),
    "core_attr": choose(core_attr, test_size),
    "br1_attr": choose(br1_attr, test_size),
    "br2_attr": choose(br2_attr, test_size),
    "X_rel": choose(X_rel, test_size),
    "X_attr": choose(X_attr, test_size),
}

with open("data/{}/valid.json".format(dataset), "w", encoding='utf-8') as f:
    json.dump(probes, f)
with open("data/{}/vocab.json".format(dataset), "w", encoding='utf-8') as f:
    json.dump(vocab, f)