In [1]:
#%pip install datasets torch transformers sentencepiece ipywidgets protobuf

In [2]:
from datasets import load_dataset

dataset = load_dataset("tommasobonomo/sem_augmented_fever_nli")
print(dataset['validation'])
print(dataset['test'])

Dataset({
    features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
    num_rows: 2270
})
Dataset({
    features: ['id', 'premise', 'hypothesis', 'label', 'wsd', 'srl'],
    num_rows: 2281
})


In [3]:
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"> using {device}")
MODELS = {}
TOKENIZERS = {}

model_name_base = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
model_name_large = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"
model_name_large_2 = "Joelzhang/deberta-v3-large-snli_mnli_fever_anli_R1_R2_R3-nli"

for model_name in [model_name_base, model_name_large]:
    print(f"> loading {model_name}")
    TOKENIZERS[model_name] = AutoTokenizer.from_pretrained(model_name)
    MODELS[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

> using cuda
> loading MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli
> loading MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli


In [4]:
def inference(model_name:str, premise:str, hypothesis:str):
    tokenizer = TOKENIZERS[model_name]
    model = MODELS[model_name]
    
    tokenizer_out = tokenizer(premise, hypothesis, truncation=False, return_tensors='pt')
    model_out = model(tokenizer_out['input_ids'].to(device))
    prediction = torch.softmax(model_out["logits"][0], -1).tolist()
    label_names = ["ENTAILMENT", "NEUTRAL", "CONTRADICTION"]
    prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
    return max(prediction, key=prediction.get)

In [5]:
import polars as pl
first_round = {
    'cid':[], 
    'premise':[], 
    'hypothesis': [], 
    'label': []
}

for split in ['test', 'validation']:
    for sample in tqdm(dataset[split], desc=f"> processing {split} data"):
        premise = sample['premise']
        hypothesis = sample['hypothesis']
        label = sample['label']
        
        if premise == '' or hypothesis == '':
            continue

        valid_sample_flag = True
        for model_name in MODELS.keys():
            predicted_label = inference(model_name, premise, hypothesis)
            valid_sample_flag = (predicted_label == label)
            #print(f"> sample {sample['id']} predicted {predicted_label} --- gold {label} >>> {valid_sample_flag}")
            if not valid_sample_flag:  
                break
        
        if valid_sample_flag:
            first_round['cid'].append(sample['id'])
            first_round['premise'].append(premise)
            first_round['hypothesis'].append(hypothesis)
            first_round['label'].append(label)

first_round = pl.from_dict(first_round)
first_round.height

> processing test data: 100%|██████████| 2281/2281 [04:12<00:00,  9.04it/s]
> processing validation data: 100%|██████████| 2270/2270 [03:02<00:00, 12.46it/s]


3144

In [6]:
first_round.write_csv('fever_validation_filtered_first_round.csv', separator=',')

Do the same but with the third model (all three models do not fit in my GPU)

In [7]:
MODELS = {}
TOKENIZERS = {}

for model_name in [model_name_large_2]:
    print(f"> loading {model_name}")
    TOKENIZERS[model_name] = AutoTokenizer.from_pretrained(model_name)
    MODELS[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)

> loading Joelzhang/deberta-v3-large-snli_mnli_fever_anli_R1_R2_R3-nli




In [8]:
second_round = {
    'id':[], 
    'cid':[], 
    'premise':[], 
    'hypothesis': [], 
    'label': []
}


i = 0
for sample in tqdm(first_round.iter_rows(named=True), desc="> processing first round", total=first_round.height):
    premise = sample['premise']
    hypothesis = sample['hypothesis']
    label = sample['label']
    
    valid_sample_flag = True
    for model_name in MODELS.keys():
        predicted_label = inference(model_name, premise, hypothesis)
        valid_sample_flag = (predicted_label == label)
        #print(f"> sample {sample['cid']} predicted {predicted_label} --- gold {label} >>> {valid_sample_flag}")
        if not valid_sample_flag:  
            break
    
    if valid_sample_flag:
        second_round['id'].append(i)
        second_round['cid'].append(sample['cid'])
        second_round['premise'].append(premise)
        second_round['hypothesis'].append(hypothesis)
        second_round['label'].append(label)
        i += 1

> processing first round: 100%|██████████| 3144/3144 [03:09<00:00, 16.59it/s]


In [9]:
second_round = pl.from_dict(second_round)
print(f"{second_round.height} elements (-{first_round.height - second_round.height})")
second_round.write_csv('fever_validation_filtered_second_round.csv', separator=',')

3076 elements (-68)


In [10]:
sampled_e = second_round.filter(pl.col('label') == 'ENTAILMENT').sample(n=50, seed=42)
sampled_n = second_round.filter(pl.col('label') == 'NEUTRAL').sample(n=50, seed=42)
sampled_c = second_round.filter(pl.col('label') == 'CONTRADICTION').sample(n=50, seed=42)
print(sampled_e.head())
print(sampled_n.head())
print(sampled_c.head())
concat = pl.concat([sampled_e, sampled_n, sampled_c]).select(pl.all().shuffle(seed=42))
print(concat)

shape: (5, 5)
┌──────┬────────┬─────────────────────────────────┬─────────────────────────────────┬────────────┐
│ id   ┆ cid    ┆ premise                         ┆ hypothesis                      ┆ label      │
│ ---  ┆ ---    ┆ ---                             ┆ ---                             ┆ ---        │
│ i64  ┆ str    ┆ str                             ┆ str                             ┆ str        │
╞══════╪════════╪═════════════════════════════════╪═════════════════════════════════╪════════════╡
│ 520  ┆ 201374 ┆ Varsity Blues (film) . The fil… ┆ Varsity Blues (film) was filme… ┆ ENTAILMENT │
│ 957  ┆ 109431 ┆ Visigoths . In or around 589 ,… ┆ The culture of their Hispano-R… ┆ ENTAILMENT │
│ 922  ┆ 70935  ┆ China . , it is the world 's s… ┆ The world's second largest eco… ┆ ENTAILMENT │
│ 1667 ┆ 179029 ┆ Steve Ditko . Ditko studied un… ┆ Steve Ditko studied at school.  ┆ ENTAILMENT │
│ 1004 ┆ 135477 ┆ Google Search . These include … ┆ Scores for sports games can be… ┆ ENTAILMEN

In [11]:
to_csv = {
    'id' : [],
    'cid' : [], 
    'premise': [],
    'hypothesis' : [],
    'alternative hypothesis' : [], 
    'label' : [],	
    'new hypothesis': [], 	
    'new label': [], 
    'change type':[]
}
for i, sample in enumerate(concat.iter_rows(named=True)):
    to_csv['id'].append(i)
    to_csv['cid'].append(sample['cid'])
    to_csv['premise'].append(sample['premise'])
    to_csv['hypothesis'].append(sample['hypothesis'])
    to_csv['alternative hypothesis'].append('')
    to_csv['label'].append(sample['label'])
    to_csv['new hypothesis'].append('')
    to_csv['new label'].append('')
    to_csv['change type'].append('')
to_csv = pl.from_dict(to_csv)
to_csv.write_csv("sampled_fever_validation_filtered.csv", separator=',')
print("csv written.")

q = (
    to_csv.lazy()
    .group_by("label")
    .len()
)
df = q.collect()
print(df)

csv written.
shape: (3, 2)
┌───────────────┬─────┐
│ label         ┆ len │
│ ---           ┆ --- │
│ str           ┆ u32 │
╞═══════════════╪═════╡
│ NEUTRAL       ┆ 50  │
│ ENTAILMENT    ┆ 50  │
│ CONTRADICTION ┆ 50  │
└───────────────┴─────┘
