In [1]:
import pandas as pd
from sklearn.metrics import matthews_corrcoef
from tqdm import tqdm
import torch
from transformers import pipeline, AutoTokenizer
from sklearn.metrics import balanced_accuracy_score, precision_recall_fscore_support, accuracy_score, f1_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test = pd.read_csv('polnli_test_results.csv')

In [3]:
def metrics(df, preds, group_by=None):
    true_col = 'entailment'
    
    def get_metrics(y_true, y_pred):
        return {
            'MCC': matthews_corrcoef(y_true, y_pred),
            'Accuracy': accuracy_score(y_true, y_pred),
            'F1': f1_score(y_true, y_pred, average='weighted')
        }
    
    results = []
    
    if group_by not in ['dataset', 'task']:
        for col in preds:
            metrics = get_metrics(df[true_col], df[col])
            metrics['Column'] = col
            results.append(metrics)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                metrics = get_metrics(group[true_col], group[col])
                metrics['Column'] = col
                metrics[group_by.capitalize()] = group_name
                results.append(metrics)
    
    results_df = pd.DataFrame(results)
    
    if group_by in ['dataset', 'task']:
        return results_df.set_index(['Column', group_by.capitalize()])
    else:
        return results_df.set_index('Column')

# Llama 3 8B

In [4]:
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

pipe = pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="cuda",
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:42<00:00, 10.69s/it]


In [8]:
test['augmented_hypothesis'] = test['augmented_hypothesis'].str.lower()

In [9]:
user_message = """You are a classifier that can only respond with 0 or 1. I'm going to show you a short text sample and I want you to determine if {hypothesis}. Here is the text:
{doc}

If it is true that {hypothesis}, return 0. If it is not true that {hypothesis}, return 1.
Do not explain your answer, and only return 0 or 1.
"""

In [10]:
%%time
data = test
res = []

for i in data.index:
    doc = data.loc[i, 'premise']
    hypothesis = data.loc[i, 'augmented_hypothesis']
    messages = [
        {"role": "user", "content": user_message.format(doc = doc, hypothesis = hypothesis)},
    ]
    prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = pipe(prompt, max_new_tokens=2, do_sample=False, return_full_text = False, pad_token_id=pipe.tokenizer.eos_token_id, temperature = 0)
    res.extend(outputs)

res = [text['generated_text'] for text in res]
# return a list of unique responses from the model
print(set(res))

  attn_output = torch.nn.functional.scaled_dot_product_attention(
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


{'0', '1'}
CPU times: total: 24min 12s
Wall time: 25min 51s


In [None]:
res = [num[0] for num in res]
test['llama'] = [1 if '1' in text else 0 for text in res]
test.to_csv('polnli_test_results.csv', index = False)

In [14]:
metrics(test, preds = ['base_polnli', 'base_nli', 'large_nli', 'llama'], group_by = None)

Unnamed: 0_level_0,MCC,Accuracy,F1
Column,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
base_polnli,0.894269,0.948978,0.948852
base_nli,0.658454,0.834635,0.830205
large_nli,0.731021,0.869062,0.866256
llama,0.730997,0.862358,0.863467


In [15]:
metrics(test, preds = ['base_polnli', 'base_nli', 'large_nli', 'llama'], group_by = 'task')

Unnamed: 0_level_0,Unnamed: 1_level_0,MCC,Accuracy,F1
Column,Task,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
base_polnli,event extraction,0.813742,0.906774,0.907042
base_polnli,hatespeech and toxicity,0.84141,0.946036,0.945248
base_polnli,stance detection,0.9206,0.961546,0.961488
base_polnli,topic classification,0.926241,0.963834,0.963826
base_nli,event extraction,0.528918,0.746858,0.74571
base_nli,hatespeech and toxicity,0.494319,0.844437,0.822392
base_nli,stance detection,0.553824,0.786101,0.781703
base_nli,topic classification,0.875762,0.937653,0.937047
large_nli,event extraction,0.718805,0.850209,0.850508
large_nli,hatespeech and toxicity,0.571824,0.861426,0.854152
