In [1]:
from pathlib import Path

while Path.cwd().name != 'ambignli':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/ambignli


In [2]:
import pandas as pd
from tqdm import tqdm
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from modeling.multitask_model import RobertaForMultitaskSequenceClassification
from utils.utils import predict_nli
from torch import sigmoid
from collections import Counter
from utils.utils import ensure_dir

In [3]:
def predict_nli(premise, hypothesis, model, tokenizer):
    x = tokenizer(premise, hypothesis, return_tensors='pt', max_length=128, truncation=True).to('cuda')
    logits = model(**x).logits
    # multi-task model
    if hasattr(model, 'output_heads'):
        probs = logits.softmax(dim=-1).squeeze(0)
        return {model.config.id2label[i]: probs[i,1].item() for i in range(len(probs))}
    # multi-label model
    elif model.config.problem_type == 'multi_label_classification':
        logits = logits.squeeze(0)
        probs = sigmoid(logits)
        return {model.config.id2label[i]: probs[i].item() for i in range(len(probs))}
    # classification model
    else:
        probs = logits.softmax(dim=1).squeeze(0)
        return {model.config.id2label[i]: probs[i].item() for i in range(len(probs))}

In [4]:
multilabel_model = RobertaForSequenceClassification.from_pretrained('models/roberta-large-wanli-multilabel').to('cuda')
tokenizer = RobertaTokenizer.from_pretrained('models/roberta-large-wanli')

In [5]:
def compute_example_ambiguity(df):
    df['ambiguity_score'] = None
    df['predicted_labels'] = None
    for i, row in tqdm(df.iterrows(), total=len(df.index)):
        premise, hypothesis = row['premise'], row['hypothesis']
        probs = predict_nli(premise, hypothesis, multilabel_model, tokenizer)
        preds = set([l for l, p in probs.items() if p > 0.04])
        # ambiguity score is the probability assigned to the second-highest label
        sorted_probs = sorted([p for p in probs.values()], reverse=True)
        s = sorted_probs[1]
        df.at[i, 'ambiguity_score'] = s
        df.at[i, 'predicted_labels'] = ', '.join(sorted(preds))
    return df

In [6]:
gen_dir = Path('generated_data/wanli_disagreement_p0.9_davinci-002')
df_wanli_disagreement_instruct = pd.read_json(gen_dir / 'filtered_examples.jsonl', lines=True)

In [None]:
df = compute_example_ambiguity(df_wanli_disagreement_instruct)

 40%|███▉      | 30938/77870 [07:40<11:23, 68.67it/s]

In [12]:
thres = 0.05
sub_df = df.loc[df['ambiguity_score'] > thres]
sub_df

Unnamed: 0,id,premise,hypothesis,nearest_neighbors,ambiguity_score,predicted_labels
0,0,The proposal was met with some skepticism from...,The proposal was met with some optimism from t...,"[82936, 245722, 331487, 19994]",0.925016,"contradiction, neutral"
1,2,The company's decision to downsize was met wit...,The company's decision to downsize was met wit...,"[82936, 245722, 331487, 19994]",0.886042,"contradiction, neutral"
2,3,The amount of money that was spent on the proj...,The amount of money that was saved on the proj...,"[82936, 245722, 331487, 19994]",0.630185,"contradiction, neutral"
3,4,We cannot be sure that the meeting will be pro...,We cannot be sure that the meeting will not be...,"[82936, 245722, 331487, 19994]",0.126387,"contradiction, neutral"
4,5,The company will only offer the position to so...,The company will only offer the position to so...,"[214335, 8249, 65040, 102411]",0.791046,"entailment, neutral"
...,...,...,...,...,...,...
77863,104063,Most people in the United States speak English.,English is the official language of the United...,"[22805, 173022, 66665, 188215]",0.747514,"entailment, neutral"
77864,104064,The novel is a fiction.,The movie is based on a true story.,"[22805, 173022, 66665, 188215]",0.662432,"contradiction, neutral"
77865,104065,"The poet T.S. Eliot wrote, ""We shall not cease...",We never really know a place until we leave it.,"[22805, 173022, 66665, 188215]",0.066416,"contradiction, neutral"
77867,104067,The researchers say that this is the first stu...,This is the first study to look at the long-te...,"[133594, 371112, 155042, 348420]",0.744364,"entailment, neutral"


In [13]:
old_balanced_df = pd.read_json(gen_dir / 'balanced_examples_old.jsonl', lines=True)
old_ids = old_balanced_df.id.tolist()

In [14]:
# include all examples with contradiction label
con_mask = sub_df['predicted_labels'].str.contains('contradiction')
balanced_df = sub_df[con_mask]
# get label distribution
counter = [ls.split(', ') for ls in balanced_df.predicted_labels.tolist()]
counter = Counter([l for ls in counter for l in ls])
# patch up with entailment examples
num_entailment_needed = counter['contradiction'] - counter['entailment']
ent_mask = sub_df['predicted_labels'].str.contains('entailment')
# balanced_df = pd.concat([balanced_df, sub_df[~con_mask][ent_mask].sample(num_entailment_needed)])
# use examples from previous data
past_df = sub_df.loc[sub_df['id'].isin(old_ids)][~con_mask][ent_mask]
balanced_df = pd.concat([balanced_df, past_df])
balanced_df = pd.concat([balanced_df, sub_df.loc[~sub_df['id'].isin(old_ids)][~con_mask][ent_mask].sample(num_entailment_needed-len(past_df))])

  past_df = sub_df.loc[sub_df['id'].isin(old_ids)][~con_mask][ent_mask]
  balanced_df = pd.concat([balanced_df, sub_df.loc[~sub_df['id'].isin(old_ids)][~con_mask][ent_mask].sample(num_entailment_needed-len(past_df))])


In [17]:
balanced_df.predicted_labels.value_counts()

contradiction, neutral                6864
entailment, neutral                   6864
contradiction, entailment, neutral    2621
contradiction, entailment              602
Name: predicted_labels, dtype: int64

In [24]:
print(len(set(balanced_df.id.tolist()).intersection(set(old_ids))))
print(len(set(old_ids)))

16860
16974


In [18]:
balanced_df.sample(frac=1).to_csv('annotation/balanced_examples.csv', index=False)
balanced_df.sample(frac=1).to_json(gen_dir / 'balanced_examples.jsonl', lines=True, orient='records')

## create new batch

In [9]:
balanced_df = pd.read_csv('annotation/ambignli/balanced_examples.csv')
annotated_ids = pd.read_json('annotation/ambignli/annotated_examples.jsonl', lines=True)['id'].tolist()

In [10]:
# TODO: this may re-sample examples that were discarded
batch_size = 100
remaining_pool_df = balanced_df[~balanced_df.id.isin(annotated_ids)]
print(len(remaining_pool_df))
ensure_dir('annotation/batches/nextbatch')
remaining_pool_df.sample(batch_size).to_csv('annotation/batches/nextbatch/examples.csv', index=False)

15495
