In [1]:
from pathlib import Path

while Path.cwd().name != 'ambient':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/ambignli/notebooks
/mmfs1/gscratch/xlab/alisaliu/ambignli


In [5]:
import numpy as np
import pandas as pd
import os
from collections import defaultdict, Counter
from sklearn.metrics import cohen_kappa_score
from itertools import combinations
from mturk.annotation_utils import read_batch, statistics_for_worker
from utils.constants import id2label, NLI_LABELS
from utils.transformation_rules import get_rule
import seaborn as sns
import matplotlib.pyplot as plt

In [13]:
annotated_df = pd.read_json('annotation/ambignli/cleaned_examples.jsonl', lines=True)
print(len(annotated_df))
annotated_df.head(3)

2020


Unnamed: 0,id,worker_ids,premise,hypothesis,annotations,disambiguations
0,4,"[A3FVGZKEEKXBON, A15WACUALQNT90]",We cannot be sure that the meeting will be pro...,We cannot be sure that the meeting will not be...,"[entailment|neutral, entailment]",{'neutral': [{'premise': 'We cannot be sure th...
1,21,"[A1HBYLYXKD7VWX, A15WACUALQNT90]",The first step is to contact your state's depa...,The next step is to contact your state's depar...,"[entailment|contradiction, contradiction]",{'entailment': [{'premise': 'The first step is...
2,39,"[A362GR9VFKI1V4, A15WACUALQNT90]",The person who told me the story is an unrelia...,I can't believe the story because the person w...,"[entailment, neutral]",{}


In [7]:
def format_label_colors(annotations):
    for label in NLI_LABELS:
        for x in [label.lower(), label.upper()]:
            annotations = annotations.replace(x, f'<span class="{x.lower()}">{x}</span>')
    return annotations

In [8]:
def get_spans_changed(original, revised):
    """
    return the spans that were changed between the original text and the revised version
    """
    rule, key = get_rule(original, revised)
    before, after = rule.split(' -> ')
    idx_tracker = 0
    spans_changed = []

    for segment in after:
        if segment not in before:
            spans_changed.append((idx_tracker, idx_tracker+len(key[segment])))
        idx_tracker += len(key[segment]) + 1 # white space comes after every segment
    
    return spans_changed

def format_revision(original, revised):
    """
    bold the spans that were changed
    """
    spans_changed = get_spans_changed(original, revised)
    
    for i, span in enumerate(spans_changed):
        a, b = span[0] + i*7, span[1] + i*7     # since the string is getting longer, we need to keep adjusting span indices
        revised = revised[:a] + '<b>' + revised[a:b] + '</b>' + revised[b:]
    return revised

In [14]:
error_ct = 0

validation_data = []
for i, row in annotated_df.iterrows():
    if len(row['worker_ids']) < 2:
        continue
    
    premise, hypothesis = row['premise'], row['hypothesis']
    
    disambiguation_text = ''
    for label, disambiguations in row['disambiguations'].items():
        disambiguation_text += f'{label.upper()}<br>'
        for disambiguation in disambiguations:
            revised_premise = disambiguation['premise']
            revised_hypothesis = disambiguation['hypothesis']
            premise_ambiguous = True if revised_premise != premise else False
            hypothesis_ambiguous = True if revised_hypothesis != hypothesis else False
            try:
                if premise_ambiguous:
                    revised_premise = format_revision(premise, revised_premise)
                if hypothesis_ambiguous:
                    revised_hypothesis = format_revision(hypothesis, revised_hypothesis)
            except KeyError:
                error_ct += 1
                continue
            disambiguation_text += f"P': {revised_premise}<br>" if premise_ambiguous else f"P: {revised_premise}<br>"
            disambiguation_text += f"H': {revised_hypothesis}<br>" if hypothesis_ambiguous else f"H: {revised_hypothesis}<br>"
    
    disambiguation_text = format_label_colors(disambiguation_text)
    
    validation_data.append({
        'id': row['id'],
        'worker_ids': '<em>' + ', '.join(row['worker_ids']) + '</em>',
        'premise': premise,
        'hypothesis': hypothesis,
        'annotations': format_label_colors(', '.join(row['annotations'])),
        'disambiguations': disambiguation_text,
    })

In [15]:
len(validation_data)

1973

In [16]:
batches_dir = Path('annotation/validation/batches')
dirs = [d for d in os.listdir(batches_dir) if (os.path.isdir(batches_dir / d) and d.startswith('batch_'))]

validated_ids = []

for batch_dir in dirs:
    print(batch_dir)
    batch_id = int(batch_dir.split('_')[-1])
    batch_df = read_batch(batch_id, batch_dir=batches_dir)
    validated_ids += batch_df.id.tolist()

batch_368416
batch_368417
batch_369355
batch_369686
batch_367632
batch_371185
batch_367636
batch_372388
batch_367494
batch_368414
batch_369357
batch_372199
batch_369687
batch_369356


In [18]:
validation_df.sample(frac=1).to_csv('annotation/validation/remaining_examples.csv', index=False)