In [1]:
import pandas as pd
from anthropic import Anthropic
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
test = pd.read_csv('../data/polnli_test_results.csv')

In [6]:
def metrics(df, preds, group_by=None):
    true_col = 'entailment'
    
    def get_metrics(y_true, y_pred):
        return {
            'MCC': matthews_corrcoef(y_true, y_pred),
            'Accuracy': accuracy_score(y_true, y_pred),
            'F1': f1_score(y_true, y_pred, average='weighted')
        }
    
    results = []
    
    if group_by not in ['dataset', 'task']:
        for col in preds:
            metrics = get_metrics(df[true_col], df[col])
            metrics['Column'] = col
            results.append(metrics)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                metrics = get_metrics(group[true_col], group[col])
                metrics['Column'] = col
                metrics[group_by.capitalize()] = group_name
                results.append(metrics)
    
    results_df = pd.DataFrame(results)
    
    if group_by in ['dataset', 'task']:
        return results_df.set_index(['Column', group_by.capitalize()])
    else:
        return results_df.set_index('Column')

In [4]:
claude = Anthropic(api_key="##################")

In [5]:
user_message = """You are a classifier that can only respond with 0 or 1. I'm going to show you a short text sample and I want you to determine if {hypothesis}. Here is the text:
{doc}

If it is true that {hypothesis}, return 0. If it is not true that {hypothesis}, return 1.
Do not explain your answer, and only return 0 or 1.
"""

In [6]:
%%time
user_message = user_message
model = "claude-3-5-sonnet-20240620"
data = test
labels = []

for i in data.index:
    doc = data.loc[i, 'premise']
    hypothesis = data.loc[i, 'augmented_hypothesis']
    res = claude.messages.create(
        max_tokens=2,
        messages=[
            {
                "role": "user",
                "content": user_message.format(doc = doc, hypothesis = hypothesis),
            }
        ],
        model=model,
        temperature = 0
    )
    labels.append(res.content[0].text)

CPU times: total: 35.2 s
Wall time: 4h 8min 42s


In [7]:
test['sonnet'] = labels
test['sonnet'] = test['sonnet'].replace({'1':1, '0':0, '2':-1})
test.to_csv('polnli_test_results.csv', index = False)

  test['sonnet'] = test['sonnet'].replace({'1':1, '0':0, '2':-1})


In [7]:
metrics(test, preds = ['base_polnli', 'large_polnli', 'llama', 'sonnet'], group_by = None)

Unnamed: 0_level_0,MCC,Accuracy,F1
Column,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
base_polnli,0.894269,0.948978,0.948852
large_polnli,0.915911,0.959326,0.95918
llama,0.730997,0.862358,0.863467
sonnet,0.815902,0.910517,0.909423


In [8]:
metrics(test, preds = ['base_polnli', 'large_polnli', 'llama', 'sonnet'], group_by = 'task')

Unnamed: 0_level_0,Unnamed: 1_level_0,MCC,Accuracy,F1
Column,Task,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
base_polnli,event extraction,0.813742,0.906774,0.907042
base_polnli,hatespeech and toxicity,0.84141,0.946036,0.945248
base_polnli,stance detection,0.9206,0.961546,0.961488
base_polnli,topic classification,0.926241,0.963834,0.963826
large_polnli,event extraction,0.819699,0.909567,0.909838
large_polnli,hatespeech and toxicity,0.881535,0.95936,0.959026
large_polnli,stance detection,0.969009,0.984979,0.984972
large_polnli,topic classification,0.924496,0.962503,0.962322
llama,event extraction,0.808244,0.905726,0.90548
llama,hatespeech and toxicity,0.55906,0.782145,0.799067
