In [None]:
%load_ext autoreload
%autoreload 2

# mount drive
from google.colab import drive
drive.mount('/content/drive')

# cd into project directory
%cd /content/drive/My\ Drive/Georgia_Tech/Spring_2021/sbic_stereotypes/src/classification

In [None]:
!pip install transformers
!pip install datasets

import torch
import pandas as pd
import numpy as np

############## MODIFY THESE PARAMS ##############
CLASSIFY_COL = 'whoTarget'
MODEL_NAME = 'model/' + CLASSIFY_COL + '/checkpoint-1280/'
#################################################

DATA_DIR = '../../data/'
BASE_MODEL = 'bert-base-uncased'
CLEAN_DATA_FILE = 'data/train_' + CLASSIFY_COL + '.csv'

In [None]:
from classifier_utils import *

# Classify column
df = pd.read_csv(DATA_DIR + 'SBIC.v2.dev.csv')
df = prep_df_for_classification(df, CLEAN_DATA_FILE, CLASSIFY_COL)

In [None]:
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import Trainer, TrainingArguments
from tqdm import tqdm
import math

tokenizer = BertTokenizer.from_pretrained(BASE_MODEL)
model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()

num_examples = df.shape[0]
outputs = np.empty((num_examples, 2))

for i in tqdm(range(num_examples)):  
  inputs = tokenizer(df["post"][i], return_tensors='pt')
  output = model(**inputs)
  outputs[i] = output['logits'].detach().cpu().numpy()

np.savetxt('pred/' + CLASSIFY_COL + '_predictions.csv', outputs)
np.savetxt('pred/' + CLASSIFY_COL + '_labels.csv', df[CLASSIFY_COL].to_numpy())

In [None]:
y_hat = np.loadtxt('pred/' + CLASSIFY_COL + "_predictions.csv")
y = np.loadtxt('pred/' + CLASSIFY_COL + '_labels.csv')

y_hat = np.exp(y_hat) / np.sum(np.exp(y_hat), axis=1)[:,None]
y_hat = np.argmax(y_hat, axis=1)

tp = np.sum((y_hat == 1) & (y == 1))
fp = np.sum((y_hat == 1) & (y == 0))
tn = np.sum((y_hat == 0) & (y == 0))
fn = np.sum((y_hat == 0) & (y == 1))

print("tp: ", tp)
print("fp: ", fp)
print("tn: ", tn)
print("fn: ", fn)

precision = float(tp) / (tp + fp)
recall = float(tp) / (tp + fn)
f1 = 2 * ((precision * recall) / (precision + recall))

print("precision: ", precision)
print("recall: ", recall)
print("f1: ", f1)