In [4]:
import pandas as pd
import torch
from transformers import AutoTokenizer, BertForSequenceClassification, TextClassificationPipeline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import os
from tqdm import tqdm
from scipy.stats import ttest_ind

In [5]:
path = './soc/data/majority_gab_dataset_25k/'
mode = 'test'
data = pd.read_json(os.path.join(path, f'{mode}.jsonl'), lines=True)
text = data['Text'].to_list()

In [8]:
true_label = ((data['cv']+data['hd'])>0).astype(int).to_list()
for seed in tqdm(range(10), total=10):
    
    # model_dir = f'./soc/runs/majority_gab_es_vanilla_bal_seed_{seed}/'
    model_dir = f'./soc/runs/majority_gab_es_reg_nb5_h5_is_bal_pos_seed_{seed}'
    # model_dir = f'./ear_bert/entropybert-gab25k-{seed}-0.01/'
    model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=2)
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, top_k=None, device=0, truncation=True,padding="max_length",max_length=512)
    results = pipe(text)
    prob_non_hate = [x['score'] for y in results for x in y if x['label']=='LABEL_0']
    prob_hate = [x['score'] for y in results for x in y if x['label']=='LABEL_1']
    pred_label = [0 if a > b else 1 for (a, b) in zip(prob_non_hate, prob_hate)]
    # data['true_label'] = [0 if x == 'non-hateful' else 1 for x in data['label_gold']]
    
    out_dict = {}
    out_dict['Text'] = text
    out_dict['pred_label'] = pred_label
    out_dict['true_label'] = true_label
    
    results_df = pd.DataFrame.from_dict(out_dict)
    os.makedirs(os.path.join(model_dir, 'prediction'), exist_ok=True)
    results_df.to_csv(os.path.join(model_dir, f'prediction/gab_{mode}.csv'), index=False)

100%|██████████| 10/10 [06:14<00:00, 37.49s/it]


### get overall metrics

In [9]:
acc_all = []
f1_all = []
precision_all = []
recall_all = []

In [32]:

acc = []
f1 = []
precision = []
recall = []
for seed in range(10):
    
    # path = f'./soc/runs/majority_gab_es_vanilla_bal_seed_{seed}/prediction/gab_test.csv'
    
    # path = f'./soc/runs/majority_gab_es_reg_nb5_h5_is_bal_pos_seed_{seed}/prediction/gab_test.csv'
    path = f'./ear_bert/entropybert-gab25k-{seed}-0.01/prediction/gab_test.csv'
    data = pd.read_csv(path)
    # data = filter_data(data)
    pred = data['pred_label'].to_list()
    label = data['true_label'].to_list()
    acc.append(accuracy_score(y_true=label, y_pred=pred))
    f1.append(f1_score(y_true=label, y_pred=pred))
    precision.append(precision_score(y_true=label, y_pred=pred))
    recall.append(recall_score(y_true=label, y_pred=pred))
acc_all.append(acc)
f1_all.append(f1)
precision_all.append(precision)
recall_all.append(recall)

In [30]:
(np.mean(recall), np.std(recall))

(0.5671814671814672, 0.05312073451781375)

In [37]:
metric = [100*x for x in recall_all[3]]
print(np.mean(metric))
print(np.std(metric))

56.71814671814671
5.312073451781375


In [53]:
ttest_ind(recall_all[1], recall_all[0])

Ttest_indResult(statistic=0.39080471157197444, pvalue=0.7005284798854077)

In [54]:
recall_all

[[0.6138996138996139,
  0.6447876447876448,
  0.5521235521235521,
  0.7258687258687259,
  0.6563706563706564,
  0.5868725868725869,
  0.61003861003861,
  0.583011583011583,
  0.6177606177606177,
  0.5907335907335908],
 [0.5907335907335908,
  0.5212355212355212,
  0.6640926640926641,
  0.6640926640926641,
  0.6293436293436293,
  0.7065637065637066,
  0.6640926640926641,
  0.5984555984555985,
  0.5907335907335908,
  0.640926640926641],
 [0.7112758486149044,
  0.6500195083886071,
  0.6246586031993757,
  0.7019118220834959,
  0.759266484588373,
  0.7740928599297698,
  0.5673039406944986,
  0.6984003121342177,
  0.6757705813499805,
  0.6410456496293406],
 [0.5019305019305019,
  0.5444015444015444,
  0.583011583011583,
  0.5096525096525096,
  0.6756756756756757,
  0.5675675675675675,
  0.5173745173745173,
  0.6254826254826255,
  0.6061776061776062,
  0.5405405405405406]]