In [None]:
import os
import openai
import time
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, classification_report

openai.api_key = os.getenv("OPENAI_API_KEY")

import json
import numpy as np

np.random.seed(0)
features_valid = json.load(open("../data/valid_subset_text.json"))
features_rest = json.load(open("../dataset_bias/valid_subset_rest_narrative.json"))

gen_examplars_res = np.load("../data/examplars_gda_rest.npy", allow_pickle=True)

def compare(d1, d2):
    return all(d1[key] == d2[key] for key in d1)

features_ori = [item for item in features_valid if all(not compare(d, item) for d in features_rest)]

In [None]:

features_valid_tense_all = json.load(open("dataset_bias/valid_subset_text_tense_bias_vague.json"))


In [50]:
features_valid_tense_all = json.load(open("../dataset_bias/valid_subset_text_tense_bias_vague.json"))
features_valid_erp = json.load(open("../dataset_bias/valid_subset_text_erp.json"))
features_valid_dep = json.load(open("../dataset_bias/valid_text_features_matres_dep_bias_subset.json"))
features_valid_narrative = json.load(open("../dataset_bias/valid_text_features_matres_narrative_bias_subset.json"))

In [42]:
features_valid_narrative[0].keys()

dict_keys(['input_ids', 'event_pos', 'event_pos_end', 'event_pair', 'labels'])

In [43]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-large")
features_valid_narrative_text = [{'text':tokenizer.decode(item['input_ids']), 
                                  'e1':tokenizer.decode(item['input_ids'][item['event_pos'][0]]),
                                  'e2':tokenizer.decode(item['input_ids'][item['event_pos'][1]]), 
                                  'labels':item['labels']} for item in features_valid_narrative]

In [45]:
with open("../dataset_bias/valid_subset_text_narrative.json", "w") as writer:
    json.dump(features_valid_narrative_text, writer)

In [24]:
erp_examplars = np.load("../dataset_bias/gpt3_data/valid_gpt2xl_hard_gpt3_examplars.npy", allow_pickle=True)
tense_examplars = np.load("../dataset_bias/gpt3_data/tense_gpt3_examplars_all.npy", allow_pickle=True)
narrative_examplars = np.load("../data/examplars_gda_narrative_bias.npy", allow_pickle=True)
dep_examplars = np.load("../data/examplars_gda_dep_bias.npy", allow_pickle=True)
rest_examplars = np.load("../data/examplars_gda_rest.npy", allow_pickle=True)

In [48]:
len(narrative_examplars)

477

In [49]:
len(features_valid_narrative)

3171

In [52]:
all_examplars = [[""] for i in range(len(features_valid))]
rest_2_all_idx = []

for i in range(len(features_rest)):
    rest_2_all_idx.append([])
    for j in range(len(features_valid)):
        if compare(features_valid[j], features_rest[i]):
            rest_2_all_idx[-1].append(j)
for i in range(len(features_rest)):
    for idx in rest_2_all_idx[i]:
        all_examplars[idx] = rest_examplars[i]
        
erp_2_all_idx = []
for i in range(len(features_valid_erp)):
    erp_2_all_idx.append([])
    for j in range(len(features_valid)):
        if compare(features_valid[j], features_valid_erp[i]):
            erp_2_all_idx[-1].append(j)
for i in range(len(features_valid_erp)):
    for idx in erp_2_all_idx[i]:
        all_examplars[idx] = erp_examplars[i]
        
tense_2_all_idx = []
for i in range(len(features_valid_tense_all)):
    tense_2_all_idx.append([])
    for j in range(len(features_valid)):
        if compare(features_valid[j], features_valid_tense_all[i]):
            tense_2_all_idx[-1].append(j)
for i in range(len(features_valid_erp)):
    for idx in tense_2_all_idx[i]:
        all_examplars[idx] = tense_examplars[i]

narrative_2_all_idx = []
for i in range(len(features_valid_narrative)):
    narrative_2_all_idx.append([])
    for j in range(len(features_valid)):
        if compare(features_valid[j], features_valid_narrative[i]):
            narrative_2_all_idx[-1].append(j)
for i in range(len(features_valid_narrative)):
    for idx in narrative_2_all_idx[i]:
        all_examplars[idx] = narrative_examplars[i]
        
dep_2_all_idx = []
for i in range(len(features_valid_dep)):
    dep_2_all_idx.append([])
    for j in range(len(features_valid)):
        if compare(features_valid[j], features_valid_dep[i]):
            dep_2_all_idx[-1].append(j)
for i in range(len(features_valid_dep)):
    for idx in dep_2_all_idx[i]:
        all_examplars[idx] = dep_examplars[i]
        

In [57]:
all_examplars[0][0].find('\n\nQ')

225

In [67]:
all_examplars[0][1][all_examplars[0][1].find('\n\nQ'):][len('\n\nQ: What\'s the temporal relation between the event "'):]

'waiting" and "said"? \nChoice A: waiting happens before said. \nChoice B: waiting happens after said. \nChoice C: waiting happens during said. \nChoice D: unknown. \nAnswer only with A, B, C, or D. \n\nA: Choice B'

In [99]:
def parse_examplar(text):
    prefix_context = 'Given the context:\n'
    suffix_context = '\n\nQ:'
    context = text[len(prefix_context):text.find(suffix_context)]

    rest = text[text.find(suffix_context):][len('\n\nQ: What\'s the temporal relation between the event "'):]

    e1 = rest.split('"')[0]
    e2 = rest.split('"')[2]
    labels = rest[-1]

    assert labels in ['A', 'B', 'C', 'D']
    
    return {'context':context, 'e1':e1, 'e2':e2, 'labels':labels}

In [100]:
all_examplars_raw= []
for items in all_examplars:
    tmp_examplars = []
    for item in items:
        tmp_examplars.append(parse_examplar(item))
    all_examplars_raw.append(tmp_examplars)

# generate examplars for the others:

In [None]:
label2letter = {0:"A", 1:"B", 2:"C", 3:"D"}

saved_gpt3_examplars_0 = []

for i, item in tqdm(enumerate(features_ori), total=len(features_ori)):

    context = item['text'].replace("[CLS]", "").replace("[SEP]", "").strip()
    e1 = item['e1']
    e2 = item['e2']
    
    counter_factual_prompt = [f"Generate a paragraph where event {e1} happens before {e2}:", 
                              f"Generate a paragraph where event {e1} happens after {e2}:", 
                              f"Generate a paragraph where event {e1} happens in the same time as {e2}:", 
                              f"Generate a paragraph where the temporal relation of {e1} and {e2} cannot be determined based on the context:", 
                             ]
    gpt3_context = []
    examplars = []
    for l, prompt in zip([0, 1, 2, 3], counter_factual_prompt):
        while True:
            try:
                gpt3_context.append(
                    openai.Completion.create(
                                model="text-davinci-003",
                                prompt=prompt,
                                max_tokens=40,
                                temperature=0
                    )["choices"][0]["text"].strip()
                )
                break
            except:
                time.sleep(20)
        time.sleep(2)
        examplars.append("Given the context:\n" + gpt3_context[-1] + f"\n\nQ: What's the temporal relation between the event \"{e1}\" and \"{e2}\"? \nChoice A: {e1} happens before {e2}. \nChoice B: {e1} happens after {e2}. \nChoice C: {e1} happens during {e2}. \nChoice D: unknown. \nAnswer only with A, B, C, or D. \n\nA: Choice {label2letter[l]}")
    
    saved_gpt3_examplars_0.append(examplars)
    

## use the examplars to do CDA

In [92]:
# 
zs_preds = json.load(open("./results/template_2_zeroshot_pred.json"))

In [104]:
choice_2_num = {'A':0,'B':1,'C':2,'D':3}
convert_dict_rev = {0:'BEFORE', 1:'AFTER', 2:'EQUAL', 3:'VAGUE'}
def construct_icl_examplar(item):
    context = item['context']
    e1 = item['e1']
    e2 = item['e2']
    label = choice_2_num[item['labels']]
    prompt = f"Determine the temporal order from \"{e1}\" to \"{e2}\" in the following sentence: \"{context}\". Only answer one word from AFTER, BEFORE, EQUAL, VAGUE. Answer: {convert_dict_rev[label]}"
    return prompt

In [105]:
all_examplars_raw[0][0]

{'context': 'The crowd was eagerly waiting for the event to begin. Everyone was filled with anticipation and excitement as they waited for the curtains to open. The atmosphere was electric as people chatted and laughed,',
 'e1': 'waiting',
 'e2': 'said',
 'labels': 'A'}

In [106]:
construct_icl_examplar(all_examplars_raw[0][0])

'Determine the temporal order from "waiting" to "said" in the following sentence: "The crowd was eagerly waiting for the event to begin. Everyone was filled with anticipation and excitement as they waited for the curtains to open. The atmosphere was electric as people chatted and laughed,". Only answer one word from AFTER, BEFORE, EQUAL, VAGUE. Answer: BEFORE'

In [112]:
all_gens_cda = []
for i, item in tqdm(enumerate(features_valid), total=len(features_valid)):
    context = item['text'].replace("[CLS]", "").replace("[SEP]", "").strip()
    e1 = item['e1']
    e2 = item['e2']
    

    examplars = [construct_icl_examplar(all_examplars_raw[i][_]) for _ in range(4) if _ != zs_preds[i]]
    examplar = "\n\n".join(examplars)
    
    prompt = f"Determine the temporal order from \"{e1}\" to \"{e2}\" in the following sentence: \"{context}\". Only answer one word from AFTER, BEFORE, EQUAL, VAGUE. "
#     print(prompt)
    while True:
        try:
            all_gens_cda.append(openai.Completion.create(
                        model="text-davinci-003",
                        prompt=examplar + "\n\n" + prompt,
                        max_tokens=20,
                        temperature=0
            ))
            break
        except:
            time.sleep(2)
    time.sleep(2)
    


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [53:32<00:00,  3.21s/it]


In [121]:
convert_dict = {'BEFORE':0, 'AFTER':1, 'EQUAL':2, 'VAGUE':3, 'EQU':2}
def parse_result(ans):
    return convert_dict.get(ans[len("Answer: "):].upper(), 3)

In [122]:
results_cda = [parse_result(gen["choices"][0]["text"].strip()) for gen in all_gens_cda]

In [115]:
from collections import Counter
Counter([gen["choices"][0]["text"].strip() for gen in all_gens_cda])

Counter({'Answer: BEFORE': 883,
         'Answer: AFTER': 93,
         'Answer: V': 21,
         'Answer: EQU': 3})

In [123]:
Counter(results_cda)

Counter({0: 883, 1: 93, 3: 21, 2: 3})

In [124]:

labels = [features_valid[i]['labels'][0] for i in range(len(features_valid))]
results_cda
print(f1_score(labels, results_cda, average='macro'), f1_score(labels, results_cda, average="micro"))



0.19850458861631082 0.472


In [None]:
# with open("results/template_1_threeshot_subsets_pred_2.json", "w") as writer:
#     json.dump(threeshot_preds_2, writer)

# newly generated

In [127]:
all_examplars_conflict = np.load("../data/examplars_gda_conflict.npy", allow_pickle=True)
all_examplars_rest = np.load("../data/examplars_gda_rest.npy", allow_pickle=True)

all_examplars_new = [[] for i in range(len(features_valid))]

rest_2_all_idx = []

for i in range(len(features_rest)):
    rest_2_all_idx.append([])
    for j in range(len(features_valid)):
        if compare(features_valid[j], features_rest[i]):
            rest_2_all_idx[-1].append(j)
for i in range(len(features_rest)):
    for idx in rest_2_all_idx[i]:
        all_examplars_new[idx] = all_examplars_rest[i]
        
conflict_2_all_idx = []

for i in range(len(features_ori)):
    conflict_2_all_idx.append([])
    for j in range(len(features_valid)):
        if compare(features_valid[j], features_ori[i]):
            conflict_2_all_idx[-1].append(j)
for i in range(len(features_ori)):
    for idx in conflict_2_all_idx[i]:
        all_examplars_new[idx] = all_examplars_conflict[i]

In [139]:
all_examplars_raw_new= []
for items in all_examplars_new:
    tmp_examplars = []
    for item in items:
        tmp_examplars.append(parse_examplar(item))
    all_examplars_raw_new.append(tmp_examplars)

In [149]:
all_examplars_raw_new == all_examplars_raw

cnt = 0

for item0, item1 in zip(all_examplars_raw_new, all_examplars_raw):
    for option0, option1 in zip(item0, item1):
        if not all(option0[key] == option1[key] for key in option0):
            # print(option0, option1)
            cnt += 1
            break

In [146]:
option0=all_examplars_raw_new[0][0] 
option1 = all_examplars_raw[0][0]

In [147]:
all(option0[key] == option1[key] for key in option0)

True

In [153]:
all_gens_cda_1 = []
for i, item in tqdm(enumerate(features_valid), total=len(features_valid)):
    context = item['text'].replace("[CLS]", "").replace("[SEP]", "").strip()
    e1 = item['e1']
    e2 = item['e2']
    

    examplars = [construct_icl_examplar(all_examplars_raw_new[i][_]) for _ in range(4) if _ != zs_preds[i]]
    examplar = "\n\n".join(examplars)
    
    prompt = f"Determine the temporal order from \"{e1}\" to \"{e2}\" in the following sentence: \"{context}\". Only answer one word from AFTER, BEFORE, EQUAL, VAGUE. "
#     print(prompt)
    while True:
        try:
            all_gens_cda_1.append(openai.Completion.create(
                        model="text-davinci-003",
                        prompt=examplar + "\n\n" + prompt,
                        max_tokens=20,
                        temperature=0
            ))
            break
        except:
            time.sleep(2)
    time.sleep(2)
    


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [47:03<00:00,  2.82s/it]


In [154]:
results_cda_1 = [parse_result(gen["choices"][0]["text"].strip()) for gen in all_gens_cda_1]

In [155]:
Counter(results_cda_1)

Counter({0: 910, 1: 73, 3: 14, 2: 3})

In [156]:
print(f1_score(labels, results_cda_1, average='macro'), f1_score(labels, results_cda_1, average="micro"))

0.19786903857278318 0.485


# template 1

In [158]:
all_examplars_new[0]

array(['Given the context:\nThe crowd was eagerly waiting for the event to begin. Everyone was filled with anticipation and excitement as they waited for the curtains to open. The atmosphere was electric as people chatted and laughed,\n\nQ: What\'s the temporal relation between the event "waiting" and "said"? \nChoice A: waiting happens before said. \nChoice B: waiting happens after said. \nChoice C: waiting happens during said. \nChoice D: unknown. \nAnswer only with A, B, C, or D. \n\nA: Choice A',
       'Given the context:\nThe event was finally here. Everyone had been waiting for weeks, and the anticipation was palpable. As the clock ticked closer to the start time, the crowd grew more and more excited.\n\nQ: What\'s the temporal relation between the event "waiting" and "said"? \nChoice A: waiting happens before said. \nChoice B: waiting happens after said. \nChoice C: waiting happens during said. \nChoice D: unknown. \nAnswer only with A, B, C, or D. \n\nA: Choice B',
       'Giv

In [159]:
all_gens_cda_template_1 = []
for i, item in tqdm(enumerate(features_valid), total=len(features_valid)):
    context = item['text'].replace("[CLS]", "").replace("[SEP]", "").strip()
    e1 = item['e1']
    e2 = item['e2']
    

    examplars = [all_examplars_new[i][_] for _ in range(4) if _ != zs_preds[i]]
    examplar = "\n\n".join(examplars)
    
    prompt = "Given the context:\n" + context + f"\n\nQ: What's the temporal relation between the event \"{e1}\" and \"{e2}\"? \nChoice A: {e1} happens before {e2}. \nChoice B: {e1} happens after {e2}. \nChoice C: {e1} happens during {e2}. \nChoice D: unknown. \nAnswer only with A, B, C, or D. \n\nA: Choice "
#     print(prompt)
    while True:
        try:
            all_gens_cda_template_1.append(openai.Completion.create(
                        model="text-davinci-003",
                        prompt=examplar + "\n\n" + prompt,
                        max_tokens=20,
                        temperature=0
            ))
            break
        except:
            time.sleep(2)
    time.sleep(2)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [45:46<00:00,  2.75s/it]


In [160]:
convert_dict = {'A':0, 'B':1, 'C':2, 'D':3}
convert_dict_rev = {0:'A', 1:'B', 2:'C', 3:'D'}
def parse_result(ans):
    return convert_dict.get(ans.upper(), 3)


template_1_preds_1 =  [parse_result(item['choices'][0]['text'].strip()) for item in all_gens_cda_template_1]

print(f1_score(labels, template_1_preds_1, average='macro'), 
      f1_score(labels, template_1_preds_1, average="micro"))


0.3061144839549003 0.495
