# Testing GPT-2 fine-tuned on SST-2

In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification
from datasets import load_dataset
import numpy as np
import evaluate

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("michelecafagna26/gpt2-medium-finetuned-sst2-sentiment")
# tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# default to left padding
# tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

In [3]:
dataset_test = load_dataset("csv", data_files="books_test.csv")

In [4]:
device = torch.device('cpu')
torch.set_num_threads(16)

In [5]:

model = GPT2ForSequenceClassification.from_pretrained("michelecafagna26/gpt2-medium-finetuned-sst2-sentiment").to(device)
# model = GPT2ForSequenceClassification.from_pretrained("gpt2").to(device)

model.config.pad_token_id = model.config.eos_token_id
# resize model embedding to match new tokenizer
# model.resize_token_embeddings(len(tokenizer))


In [6]:
logits_list = []
data = dataset_test['train']['body']
batch_size = 10

for start in range(0, len(data), batch_size):
    print("Processing " + str(start))
    batch = data[start:start + batch_size]
    # tokenize all first?
    inputs = tokenizer(batch, truncation=True, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        logits_list.append(model(**inputs).logits)
logits = torch.cat(logits_list)


Processing 0
Processing 10
Processing 20
Processing 30
Processing 40
Processing 50
Processing 60
Processing 70
Processing 80
Processing 90


In [7]:
load_accuracy = evaluate.load("accuracy")
load_f1 = evaluate.load("f1")
labels = dataset_test['train']['label']

predictions = np.argmax(logits.data.cpu().numpy(), axis=-1)
accuracy = load_accuracy.compute(predictions=predictions, references=labels)
f1 = load_f1.compute(predictions=predictions, references=labels)
metrics = {"accuracy": accuracy, "f1": f1}

In [9]:
metrics

{'accuracy': {'accuracy': 0.78}, 'f1': {'f1': 0.8333333333333334}}

In [8]:
metrics

{'accuracy': {'accuracy': 0.78}, 'f1': {'f1': 0.8333333333333334}}

In [8]:
metrics

{'accuracy': {'accuracy': 0.78}, 'f1': {'f1': 0.8333333333333334}}