In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, TextClassificationPipeline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import os
from tqdm import tqdm
from scipy.stats import ttest_ind

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = pd.read_csv('./data/hatecheck-data/test_suite_cases.csv',index_col=0)
data.head()

In [None]:
# ear regularisation
for seed in tqdm(range(10), total=10):
    for round in (range(1, 5)):
        data = pd.read_csv('./data/hatecheck-data/test_suite_cases.csv',index_col=0)

        model_dir = f'./ear_bert/entropybert-gab25k-{seed}-0.01/R{round}'
        model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=2)
        tokenizer = AutoTokenizer.from_pretrained(model_dir, do_lower_case=True)
        pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, top_k=None, device=0)
        results = pipe(data['test_case'].to_list())
        data['prob_non_hate'] = [x['score'] for y in results for x in y if x['label']=='LABEL_0']
        data['prob_hate'] = [x['score'] for y in results for x in y if x['label']=='LABEL_1']
        data['pred_label'] = [0 if a > b else 1 for (a, b) in zip(data['prob_non_hate'].to_list(), data['prob_hate'].to_list())]
        data['true_label'] = [0 if x == 'non-hateful' else 1 for x in data['label_gold']]
        os.makedirs(os.path.join(model_dir, 'prediction'), exist_ok=True)
        data.to_csv(os.path.join(model_dir, 'prediction/hatecheck.csv'), index=False)

## get overall performance metrics

In [None]:
acc_all = []
f1_all = []
precision_all = []
recall_all = []

for seed in range(10):
    acc = []
    f1 = []
    precision = []
    recall = []
    for round in range(1,5):
    
        path = f'./ear_bert/entropybert-gab25k-{seed}-0.01/R{round}/prediction/hatecheck.csv'

        data = pd.read_csv(path)
        pred = data['pred_label'].to_list()
        label = data['true_label'].to_list()
        acc.append(accuracy_score(y_true=label, y_pred=pred))
        f1.append(f1_score(y_true=label, y_pred=pred))
        precision.append(precision_score(y_true=label, y_pred=pred))
        recall.append(recall_score(y_true=label, y_pred=pred))
    acc_all.append(acc)
    f1_all.append(f1)
    precision_all.append(precision)
    recall_all.append(recall)