In [2]:
from pathlib import Path

while Path.cwd().name != 'ambient':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/ambignli/notebooks
/mmfs1/gscratch/xlab/alisaliu/ambignli


In [3]:
import pandas as pd
from tqdm import tqdm
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from modeling.multitask_model import RobertaForMultitaskSequenceClassification
from utils.utils import predict_nli
from utils.mturk_utils import read_batch
from torch import sigmoid
from collections import Counter
from utils.utils import ensure_dir
import numpy as np
import os

## find possibly ambiguous examples

In [4]:
multilabel_model = RobertaForSequenceClassification.from_pretrained('models/roberta-large-wanli-multilabel').to('cuda')
tokenizer = RobertaTokenizer.from_pretrained('models/roberta-large-wanli')

In [5]:
def compute_example_ambiguity(df):
    df['ambiguity_score'] = None
    df['predicted_labels'] = None
    for i, row in tqdm(df.iterrows(), total=len(df.index)):
        premise, hypothesis = row['premise'], row['hypothesis']
        probs = predict_nli(premise, hypothesis, multilabel_model, tokenizer)
        preds = set([l for l, p in probs.items() if p > 0.04])
        # ambiguity score is the probability assigned to the second-highest label
        sorted_probs = sorted([p for p in probs.values()], reverse=True)
        s = sorted_probs[1]
        df.at[i, 'ambiguity_score'] = s
        df.at[i, 'predicted_labels'] = ', '.join(sorted(preds))
    return df

In [7]:
gen_dir = Path('generated_data/wanli_disagreement_p0.9_davinci-002')
df_wanli_disagreement_instruct = pd.read_json(gen_dir / 'filtered_examples.jsonl', lines=True)

In [8]:
df = compute_example_ambiguity(df_wanli_disagreement_instruct)

100%|██████████| 77564/77564 [18:52<00:00, 68.46it/s]


In [9]:
thres = 0.05
sub_df = df.loc[df['ambiguity_score'] > thres]
sub_df

Unnamed: 0,id,premise,hypothesis,nearest_neighbors,ambiguity_score,predicted_labels
0,0,The proposal was met with some skepticism from...,The proposal was met with some optimism from t...,"[82936, 245722, 331487, 19994]",0.925016,"contradiction, neutral"
1,2,The company's decision to downsize was met wit...,The company's decision to downsize was met wit...,"[82936, 245722, 331487, 19994]",0.886042,"contradiction, neutral"
2,3,The amount of money that was spent on the proj...,The amount of money that was saved on the proj...,"[82936, 245722, 331487, 19994]",0.630185,"contradiction, neutral"
3,4,We cannot be sure that the meeting will be pro...,We cannot be sure that the meeting will not be...,"[82936, 245722, 331487, 19994]",0.126387,"contradiction, neutral"
4,5,The company will only offer the position to so...,The company will only offer the position to so...,"[214335, 8249, 65040, 102411]",0.791046,"entailment, neutral"
...,...,...,...,...,...,...
77557,104063,Most people in the United States speak English.,English is the official language of the United...,"[22805, 173022, 66665, 188215]",0.747514,"entailment, neutral"
77558,104064,The novel is a fiction.,The movie is based on a true story.,"[22805, 173022, 66665, 188215]",0.662432,"contradiction, neutral"
77559,104065,"The poet T.S. Eliot wrote, ""We shall not cease...",We never really know a place until we leave it.,"[22805, 173022, 66665, 188215]",0.066416,"contradiction, neutral"
77561,104067,The researchers say that this is the first stu...,This is the first study to look at the long-te...,"[133594, 371112, 155042, 348420]",0.744364,"entailment, neutral"


In [11]:
# include all examples with contradiction label
con_mask = sub_df['predicted_labels'].str.contains('contradiction')
balanced_df = sub_df[con_mask]
# get label distribution
counter = [ls.split(', ') for ls in balanced_df.predicted_labels.tolist()]
counter = Counter([l for ls in counter for l in ls])
# patch up with entailment examples
num_entailment_needed = counter['contradiction'] - counter['entailment']
ent_mask = sub_df['predicted_labels'].str.contains('entailment')
balanced_df = pd.concat([balanced_df, sub_df[~con_mask][ent_mask].sample(num_entailment_needed)])

  past_df = sub_df.loc[sub_df['id'].isin(old_ids)][~con_mask][ent_mask]
  balanced_df = pd.concat([balanced_df, sub_df.loc[~sub_df['id'].isin(old_ids)][~con_mask][ent_mask].sample(num_entailment_needed-len(past_df))])


In [12]:
balanced_df.predicted_labels.value_counts()

contradiction, neutral                6850
entailment, neutral                   6850
contradiction, entailment, neutral    2531
contradiction, entailment              595
Name: predicted_labels, dtype: int64

In [15]:
balanced_df.sample(frac=1).to_csv('annotation/AmbiEnt/balanced_examples.csv', index=False)
balanced_df.sample(frac=1).to_json(gen_dir / 'balanced_examples.jsonl', lines=True, orient='records')