In [21]:
train_df = pd.read_csv("./COGS/train.tsv", sep="\t", names=['sentence', 'LF', 'type'])
dev_df = pd.read_csv("./COGS/dev.tsv", sep="\t", names=['sentence', 'LF', 'type'])
dev_tiny_df = pd.read_csv("./COGS/dev_tiny.tsv", sep="\t", names=['sentence', 'LF', 'type'])
test_df = pd.read_csv("./COGS/test.tsv", sep="\t", names=['sentence', 'LF', 'type'])
gen_df = pd.read_csv("./COGS/gen.tsv", sep="\t", names=['sentence', 'LF', 'type'])
gen_lexical_df = pd.read_csv("./COGS/gen_lexical.tsv", sep="\t", names=['sentence', 'LF', 'type'])
gen_structural_df = pd.read_csv("./COGS/gen_structural.tsv", sep="\t", names=['sentence', 'LF', 'type'])
train_df_original = pd.read_csv("./COGS/train.tsv", sep="\t", names=['sentence', 'LF', 'type'])


In [22]:
from utils.second_looks_utils import *
from utils.train_utils import *
import spacy
import json

existing_digit_pool = set([])
# loading target vocab to random sample our variable names
for k, v in load_vocab("./data/tgt_vocab.txt").items():
    if k.isnumeric():
        existing_digit_pool.add(k)
existing_digit_pool = list(existing_digit_pool)

def translate(text, phi):
    
    if len(phi.split()) == 1:
        return text, f"LAMBDA a . {phi} ( a )"
    elif "LAMBDA" in phi:
        if len(phi.split()) == 7:
            return text, phi
        phi_split = phi.split(text)
        cleaned_phi = []
        for chunk in phi_split:
            if "LAMBDA" in chunk:
                cleaned_phi += [chunk.strip()]
            else:
                verb_args = chunk.strip(" .").split()[2]
                cleaned_phi += [chunk.strip(" .")]
        return text, " ".join(cleaned_phi[:1] + [f"{text} ( {verb_args} ) AND"] + cleaned_phi[1:])
    
    # parse
    text_split = text.split()
    data = []    
    conjs = re.split(r"\s*(?:AND|;)\s*", phi)
    for conj in conjs: 
        if np_re.search(conj):
            d = parse_np(conj)
        elif pred_re.search(conj):
            d = parse_pred(conj)
            if "x _" not in d['entvar']:
                d['entvar_name'] = d['entvar']
                assert text_split.count(d['entvar']) == 1
                name_idx = text_split.index(d['entvar'])
                d['entvar'] = f"x _ {name_idx}"
        elif mod_re.search(conj):
            d = parse_mod(conj)
        else:
            raise Exception(f"Conjunct could not be parsed: {conj}")
        data.append(d)
    
    # collect
    def_terms = []
    role_terms = []
    for d in data:
        if d['type'] == 'np':
            if d['definiteness'] == '*':
                def_terms += [f"* {d['pred']} ( {d['entvar']} )"]
            else:
                def_terms += [f"{d['pred']} ( {d['entvar']} )"]
        if d['type'] == 'role':
            if f"{d['pred']} ( {d['eventvar']} )" not in role_terms:
                role_terms += [f"{d['pred']} ( {d['eventvar']} )"]
            role_terms += [f"{d['role']} ( {d['eventvar']} , {d['entvar']} )"]
            if "entvar_name" in d:
                def_terms += [f"{d['entvar_name']} ( {d['entvar']} )"]
        elif d['type'] == 'mod':
            role_terms += [f"nmod . {d['pred']} ( {d['e1']} , {d['e2']} )"]
            
    # sort def_terms
    def_terms = [*set(def_terms)]
    def_terms.sort(key = lambda x: int(x.split()[-2]))

    rest_terms = role_terms
    
    # combine
    def_terms = " ; ".join(def_terms)
    if def_terms == "":
        terms = " AND ".join(rest_terms)
    elif " AND ".join(rest_terms) == "":
        terms = def_terms
    else:
        terms = def_terms + " ; " + " AND ".join(rest_terms)
    
    # final step, remove biases
    current_digit_pool = set([])
    for t in terms.split():
        if t.isnumeric():
            current_digit_pool.add(t)
    current_digit_pool = list(current_digit_pool)
    random.shuffle(current_digit_pool)
    sample_random_digit = random.sample(existing_digit_pool, k=len(current_digit_pool))
    digit_mapping = dict(zip(current_digit_pool, sample_random_digit))

    new_terms = []
    for t in terms.split():
        if t == "_" or t == "x":
            continue
        if t.isnumeric():
            new_terms += [digit_mapping[t]]
        else:
            new_terms += [t]

    terms = " ".join(new_terms)
    return text, terms

sampled_n = 5
append_k = 3072

train_dfs = []
for i in range(sampled_n):
    train_df_i = train_df.copy()
    train_df_i[['sentence', 'LF']] = train_df_i[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')
    train_dfs += [train_df_i]

train_df_original[['sentence', 'LF']] = train_df_original[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')
dev_df[['sentence', 'LF']] = dev_df[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')
test_df[['sentence', 'LF']] = test_df[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')
gen_df[['sentence', 'LF']] = gen_df[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')
gen_lexical_df[['sentence', 'LF']] = gen_lexical_df[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')
gen_structural_df[['sentence', 'LF']] = gen_structural_df[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')
dev_tiny_df[['sentence', 'LF']] = dev_tiny_df[['sentence', 'LF']].apply(lambda x: translate(*x), axis=1, result_type='expand')

def reindex(LFs, existing_digit_pool):
    curr_digit = set([])
    for i in range(len(LFs)):
        for item in LFs[i].split():
            if item.isnumeric():
                curr_digit.add((i, int(item)))
    sampled_digits = random.sample(existing_digit_pool, k=len(curr_digit))
    digit_map = {}
    idx = 0
    for d in list(curr_digit):
        digit_map[d] = sampled_digits[idx]
        idx += 1
    
    reindex_LFs = []
    for i in range(len(LFs)):
        new_LFs = []
        for item in LFs[i].split():
            if item.isnumeric():
                new_LFs += [digit_map[(i, int(item))]]
            else:
                new_LFs += [item]
        reindex_LFs += [" ".join(new_LFs)]
        
    new_LF_prefix = []
    new_LF_body_role = []
        
    for i in range(len(reindex_LFs)):
        new_LF_prefix.extend(reindex_LFs[i].split(" ; ")[:-1])
        for term in reindex_LFs[i].split(" ; ")[-1].split(" AND "):
            new_LF_body_role += [term]
                
    new_LF_body = new_LF_body_role
        
    return " ; ".join(new_LF_prefix) + " ; " + " AND ".join(new_LF_body)

start_indexes = [i*6 for i in range(append_k)]
append_data = []

for i in range(sampled_n):
    train_df_sorted = train_dfs[i].sort_values(by="sentence", key=lambda x: x.str.len())
    for start_index in start_indexes:
        conj_1 = train_df_sorted.iloc[-2-start_index].sentence
        if conj_1.split()[0] in {'The', 'A'}:
            conj_1_first = conj_1[0].lower()
        else:
            conj_1_first = conj_1[0]
            
        conj_2 = train_df_sorted.iloc[-3-start_index].sentence
        if conj_2.split()[0] in {'The', 'A'}:
            conj_2_first = conj_2[0].lower()
        else:
            conj_2_first = conj_2[0]
            
        append_data += [
            [train_df_sorted.iloc[-1-start_index].sentence[:-1]+\
            conj_1_first+\
            train_df_sorted.iloc[-2-start_index].sentence[1:-1]+\
            conj_2_first+\
            train_df_sorted.iloc[-3-start_index].sentence[1:],
            reindex(
                [
                    train_df_sorted.iloc[-1-start_index].LF,
                    train_df_sorted.iloc[-2-start_index].LF,
                    train_df_sorted.iloc[-3-start_index].LF
                ], existing_digit_pool
            ),
            'length_ood']
        ]
append_data = pd.DataFrame(append_data, columns =['sentence', 'LF', 'type'])

In [None]:
train_df = pd.concat(train_dfs)
train_df = pd.concat([train_df, append_data])
train_df = train_df.drop_duplicates()

dataset_postfix = "COGS"
train_df_original.to_csv(f'./{dataset_postfix}/RECOGStrain.tsv', sep='\t', index=False, header=False)
dev_df.to_csv(f'./{dataset_postfix}/RECOGSdev.tsv', sep='\t', index=False, header=False)
test_df.to_csv(f'./{dataset_postfix}/RECOGStest.tsv', sep='\t', index=False, header=False)
gen_df.to_csv(f'./{dataset_postfix}/RECOGSgen.tsv', sep='\t', index=False, header=False)
gen_lexical_df.to_csv(f'./{dataset_postfix}/RECOGSgen_lexical.tsv', sep='\t', index=False, header=False)
gen_structural_df.to_csv(f'./{dataset_postfix}/RECOGSgen_structural.tsv', sep='\t', index=False, header=False)
dev_tiny_df.to_csv(f'./{dataset_postfix}/RECOGSdev_tiny.tsv', sep='\t', index=False, header=False)