In [1]:
from pathlib import Path

while Path.cwd().name != 'ambient':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/ambient/notebooks
/mmfs1/gscratch/xlab/alisaliu/ambient


In [2]:
import pandas as pd
from mturk.back_translation import back_translate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import random
import numpy as np
from tqdm import tqdm

In [3]:
s_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang='eng_Latn')
t_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M", src_lang='yor_Latn')
mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")

In [4]:
sample_size = 50
test_df = pd.read_json('annotation/AmbiEnt/test.jsonl', lines=True)
test_df = test_df[test_df['premise_ambiguous'] ^ test_df['hypothesis_ambiguous']]
sample_ids = test_df.sample(sample_size).id.tolist()

In [10]:
def create_example_csv(df):
    examples = []
    for i, row in tqdm(df.iterrows(), total=sample_size):
        ambiguous_sentence_key = 'premise' if row['premise_ambiguous'] else 'hypothesis'
        other_sentence_key = 'hypothesis' if row['premise_ambiguous'] else 'premise'
        ambiguous_sentence = row[ambiguous_sentence_key]

        disambiguations, labels = list(row['predicted_rewrites'].values()), list(row['predicted_rewrites'].keys())
        
        distractor_sentence = back_translate([ambiguous_sentence], mt_model, s_tokenizer, t_tokenizer)
        distractor_idxs = random.sample(range(3), 3-len(disambiguations))
        candidate_disambiguations, candidate_labels = [None]*3, [None]*3
        
        for j in range(len(candidate_disambiguations)):
            if j in distractor_idxs:
                candidate_disambiguations[j] = distractor_sentence
            else:
                candidate_disambiguations[j] = disambiguations[0]
                disambiguations = disambiguations[1:]
                candidate_labels[j] = labels[0]
                labels = labels[1:]

        ex = {
            'id': row['id'],
            'premise': row['premise'],
            'hypothesis': row['hypothesis'],
            'ambiguous_sent_html': f'<span class="{ambiguous_sentence_key}">{ambiguous_sentence_key}</span>',
            'ambiguous_sent': ambiguous_sentence,
            'distractor_idxs': distractor_idxs,
            'labels': candidate_labels,
        }

        for i in range(3):
            ex[f'{ambiguous_sentence_key}{i+1}'] = candidate_disambiguations[i]
            ex[f'{other_sentence_key}{i+1}'] = row[other_sentence_key]
            ex[f'interpretation{i+1}'] = candidate_disambiguations[i]

        examples.append(ex)
    
    pd.DataFrame(examples).to_csv(f'annotation/human_eval/examples_by_source/{model}_{sample_size}.csv', index=False)
    return examples

In [11]:
for model in ['gpt-4', 'llama-65b', 'text-davinci-003', 'davinci', 'flan-t5-xxl', 'gpt-3.5-turbo']:
    df = pd.read_json(f'results/generative_evaluation/{model}-n4.jsonl', lines=True)
    df = df.loc[df['id'].isin(sample_ids)]
    examples = create_example_csv(df)

100%|██████████| 50/50 [04:49<00:00,  5.78s/it]
100%|██████████| 50/50 [04:47<00:00,  5.76s/it]
100%|██████████| 50/50 [04:48<00:00,  5.77s/it]
100%|██████████| 50/50 [04:48<00:00,  5.78s/it]
100%|██████████| 50/50 [04:48<00:00,  5.78s/it]
100%|██████████| 50/50 [04:47<00:00,  5.76s/it]


## combine from different sources

In [12]:
models = ['gpt-3.5-turbo', 'gpt-4', 'llama-65b', 'text-davinci-003', 'davinci', 'flan-t5-xxl']

In [13]:
model_dfs = []
for model in models:
    model_df = pd.read_csv(f'annotation/human_eval/examples_by_source/{model}_50.csv')
    model_df['source'] = model
    model_dfs.append(model_df)

example_df = pd.concat(model_dfs).sample(frac=1)

In [14]:
example_df.to_csv('annotation/human_eval/examples.csv', index=False)

In [15]:
batch_size = 100
num_examples = len(example_df.index)
for j, i in enumerate(np.arange(0, num_examples, batch_size)):
    example_df.iloc[i:np.min([i+batch_size, num_examples])].to_csv(f'annotation/human_eval/next_batches/batch_{j}.csv', index=False)