In [128]:
import pickle

import numpy as np
import pandas as pd
import torch
from torchmetrics.classification import BinaryAccuracy

from project.pipeline.balance_data import balance_data_indices
from project.pipeline.data_loader import DataLoader

In [129]:
with open('toxic_bert_results.pickle', 'rb') as handle:
    toxic_bert_results = pickle.load(handle)

In [130]:
df_results = pd.DataFrame(toxic_bert_results)
df_results.head()

Unnamed: 0,toxicity,severe_toxicity,obscene,threat,insult,identity_attack
0,0.671682,0.006681,0.603532,0.000814,0.01657,0.000918
1,0.913387,0.026864,0.206393,0.006488,0.428129,0.611231
2,0.998545,0.451745,0.993932,0.006212,0.917578,0.058288
3,0.012694,0.000109,0.000558,0.000169,0.00058,0.000266
4,0.905675,0.03399,0.168355,0.059374,0.110174,0.495062


In [131]:
def sexist_heuristic(row):
    values = list(row.values)
    values = np.array(values)
    sexist = np.any(values >= 0.3)
    return sexist

In [132]:
df_results['sexist'] = df_results.apply(sexist_heuristic, axis=1)
df_results.head()

Unnamed: 0,toxicity,severe_toxicity,obscene,threat,insult,identity_attack,sexist
0,0.671682,0.006681,0.603532,0.000814,0.01657,0.000918,True
1,0.913387,0.026864,0.206393,0.006488,0.428129,0.611231,True
2,0.998545,0.451745,0.993932,0.006212,0.917578,0.058288,True
3,0.012694,0.000109,0.000558,0.000169,0.00058,0.000266,False
4,0.905675,0.03399,0.168355,0.059374,0.110174,0.495062,True


In [133]:
df_data = DataLoader().df
df_data.head()

Unnamed: 0,rewire_id,text,label_sexist,label_category,label_vector
0,sexism2022_english-7358,"Damn, this writing was pretty chaotic",not sexist,none,none
1,sexism2022_english-2367,"Yeah, and apparently a bunch of misogynistic v...",not sexist,none,none
2,sexism2022_english-3073,How the FUCK is this woman still an MP!!!???,not sexist,none,none
3,sexism2022_english-14895,Understand. Know you're right. At same time I ...,not sexist,none,none
4,sexism2022_english-4118,Surprized they didn't stop and rape some women,not sexist,none,none


In [134]:
balanced_df_indices = balance_data_indices(df_data, df_data['label_sexist'])

In [135]:
target = df_data['label_sexist']
target = torch.tensor(np.array([entry != "not sexist" for entry in target]))
predictions = torch.tensor(np.array(df_results['sexist']))

In [138]:
metric = BinaryAccuracy()
acc = metric(predictions, target)
print(f"Accuracy on whole dataset: {acc}")

Accuracy on whole dataset: 0.5740714073181152


In [140]:
acc = metric(torch.zeros(len(target)), target)
print(f"Accuracy on whole dataset with prediction for everything 'non-sexist': {acc}")

Accuracy on whole dataset with prediction for everything 'non-sexist': 0.7572857141494751


In [139]:
acc = metric(predictions[balanced_df_indices], target[balanced_df_indices])
print(f"Accuracy on balanced dataset: {acc}")

Accuracy on balanced dataset: 0.621100664138794


In [141]:
acc = metric(torch.zeros(len(target[balanced_df_indices])), target[balanced_df_indices])
print(f"Accuracy on balanced dataset with prediction for everything 'non-sexist': {acc}")

Accuracy on balanced dataset with prediction for everything 'non-sexist': 0.5
