In [3]:
import pickle

import numpy as np
import torch
from einops import rearrange
from torchmetrics.classification import BinaryAccuracy

from project.pipeline.balance_data import balance_data_indices_reduction
from project.pipeline.data_loader import DataLoader
from project.task_a.train_mlp import build_model

In [4]:
with open('toxic_bert_results.pickle', 'rb') as handle:
    features = pickle.load(handle)
    features = torch.tensor(list(features.values()))
    features = rearrange(features, "f n -> n f")

In [8]:
model_path = "../trained_agents/feed_forward.pt"
mlp = build_model(model_path)

In [14]:
is_sexist = mlp(features)
is_sexist = rearrange(is_sexist, "n 1 -> n")
print(is_sexist.shape)

torch.Size([14000])


In [15]:
df_data = DataLoader().df
len(df_data)

14000

In [16]:
balanced_df_indices = balance_data_indices_reduction(df_data['label_sexist'])

In [17]:
target = df_data['label_sexist']
target = torch.tensor(np.array([entry != "not sexist" for entry in target]))
predictions = is_sexist

In [18]:
metric = BinaryAccuracy()
acc = metric(predictions, target)
print(f"Accuracy on whole dataset: {acc}")

Accuracy on whole dataset: 0.6895714402198792


In [19]:
acc = metric(torch.zeros(len(target)), target)
print(f"Accuracy on whole dataset with prediction for everything 'non-sexist': {acc}")

Accuracy on whole dataset with prediction for everything 'non-sexist': 0.7572857141494751


In [20]:
acc = metric(predictions[balanced_df_indices], target[balanced_df_indices])
print(f"Accuracy on balanced dataset: {acc}")

Accuracy on balanced dataset: 0.667304277420044


In [21]:
acc = metric(torch.zeros(len(target[balanced_df_indices])), target[balanced_df_indices])
print(f"Accuracy on balanced dataset with prediction for everything 'non-sexist': {acc}")

Accuracy on balanced dataset with prediction for everything 'non-sexist': 0.5
