In [1]:
"""
create AmbiEnt from the validation batches
"""

from pathlib import Path

while Path.cwd().name != 'ambient':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/ambient/notebooks
/mmfs1/gscratch/xlab/alisaliu/ambient


In [2]:
import os
from mturk.annotation_utils import read_batch, clean_validation_batch, statistics_for_worker
import numpy as np
import pandas as pd
import seaborn as sns
from collections import Counter

In [3]:
def get_num_rewrites(row, key: str):
    """
    key: one of premise, hypothesis
    return number of distinct rewrites
    """
    return len(set([d[key] for d in flatten_list_of_lists(row['disambiguations'].values()) if d[key] != row[key]]))

def flatten_list_of_lists(list_of_lists):
    return [x for sublist in list_of_lists for x in sublist]

In [10]:
batches_dir = Path('annotation/validation/batches')
dirs = [d for d in os.listdir(batches_dir) if (os.path.isdir(batches_dir / d) and d.startswith('batch_'))]
hits_per_annotator = Counter()

batch_dfs = []
for batch_dir in dirs:
    batch_id = int(batch_dir.split('_')[1])
    batch_df = read_batch(batch_id, batch_dir=batches_dir)
    batch_dfs.append(batch_df)
    hits_per_annotator += batch_df.worker_id.value_counts()

validated_df = pd.concat(batch_dfs)
print(f'Number of examples annotated: {len(validated_df.index)}')
validated_df = clean_validation_batch(validated_df)
print(f'Number of examples in validated dataset: {len(validated_df)}')
validated_df.reset_index(drop=True, inplace=True)
hits_per_annotator

Number of examples annotated: 2167
Number of examples in validated dataset: 1503


Counter({'A3AA2VKV87R6PG': 345,
         'A1KBELVHWNE4D5': 1473,
         'A2AX828Q4WXK3Z': 248,
         'A14KPHOYAQCFWH': 101})

In [11]:
validated_df['premise_ambiguous'] = False
validated_df['hypothesis_ambiguous'] = False
validated_df['reformatted_disambiguations'] = None

for i, row in validated_df.iterrows():
    assert get_num_rewrites(row, 'premise') != 1 and get_num_rewrites(row, 'hypothesis') != 1
    validated_df.at[i, 'gold'] = ', '.join(row['gold'].split('|'))
    validated_df.at[i, 'reformatted_disambiguations'] = [d | {'label': l} for l, ds in row['disambiguations'].items() for d in ds]
    
    for key in ['premise', 'hypothesis']:
        if get_num_rewrites(row, key) >= 2:
            validated_df.at[i, f'{key}_ambiguous'] = True

In [12]:
validated_df.drop('validator_id', axis=1, inplace=True)
column_order = validated_df.columns.tolist()
column_order.remove('disambiguations')
validated_df = validated_df[column_order]
validated_df.rename(columns={'reformatted_disambiguations': 'disambiguations', 'gold': 'labels'}, inplace=True)

In [13]:
validated_df.head(3)
print(len(validated_df.index))

1503


In [14]:
len(validated_df.loc[validated_df['premise_ambiguous'] | validated_df['hypothesis_ambiguous']])/len(validated_df.index)

0.2907518296739854

In [15]:
validated_df.sample(frac=1).to_json('annotation/AmbiEnt/validated_examples.jsonl', orient='records', lines=True)