In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, TextClassificationPipeline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import os
from tqdm import tqdm
from scipy.stats import ttest_ind

In [23]:
path = './soc/data/majority_gab_dataset_25k/'
mode = 'test'
data = pd.read_json(os.path.join(path, f'{mode}.jsonl'), lines=True)
text = data['Text'].to_list()

In [24]:
true_label = ((data['cv']+data['hd'])>0).astype(int).to_list()
for seed in tqdm(range(10), total=10):
    
    # data = pd.read_csv('./data/hatecheck-data/test_suite_cases.csv',index_col=0)

    model_dir = f'./ear_bert/entropybert-gab25k-{seed}-0.01/'
    model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=2)
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, top_k=None, device=0, truncation=True,padding="max_length",max_length=512)
    results = pipe(text)
    prob_non_hate = [x['score'] for y in results for x in y if x['label']=='LABEL_0']
    prob_hate = [x['score'] for y in results for x in y if x['label']=='LABEL_1']
    pred_label = [0 if a > b else 1 for (a, b) in zip(prob_non_hate, prob_hate)]
    # data['true_label'] = [0 if x == 'non-hateful' else 1 for x in data['label_gold']]
    
    out_dict = {}
    out_dict['Text'] = text
    out_dict['pred_label'] = pred_label
    out_dict['true_label'] = true_label
    
    results_df = pd.DataFrame.from_dict(out_dict)
    os.makedirs(os.path.join(model_dir, 'prediction'), exist_ok=True)
    results_df.to_csv(os.path.join(model_dir, f'prediction/gab_{mode}.csv'), index=False)

100%|██████████| 10/10 [15:18<00:00, 91.87s/it]


### get overall metrics

In [25]:
acc_all = []
f1_all = []
precision_all = []
recall_all = []
acc = []
f1 = []
precision = []
recall = []
for seed in range(10):
    
    path = f'./ear_bert/entropybert-gab25k-{seed}-0.01/prediction/gab_test.csv'
    
    # path = f'./soc/runs/majority_gab_es_reg_nb5_h5_is_bal_pos_seed_{seed}/prediction/hatecheck.csv'
    # path = f'./ear_bert/entropybert-gab25k-{seed}-0.01/prediction/hatecheck.csv'
    data = pd.read_csv(path)
    # data = filter_data(data)
    pred = data['pred_label'].to_list()
    label = data['true_label'].to_list()
    acc.append(accuracy_score(y_true=label, y_pred=pred))
    f1.append(f1_score(y_true=label, y_pred=pred))
    precision.append(precision_score(y_true=label, y_pred=pred))
    recall.append(recall_score(y_true=label, y_pred=pred))
acc_all.append(acc)
f1_all.append(f1)
precision_all.append(precision)
recall_all.append(recall)

In [30]:
(np.mean(recall), np.std(recall))

(0.5671814671814672, 0.05312073451781375)