# Testing GPT-2 fine-tuned on SST-2

In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification
from datasets import load_dataset
import numpy as np
import evaluate

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("michelecafagna26/gpt2-medium-finetuned-sst2-sentiment")
tokenizer.pad_token = tokenizer.eos_token

In [3]:
dataset_test = load_dataset("csv", data_files="books_test.csv")

In [4]:
device = torch.device('cuda')

In [5]:
model = GPT2ForSequenceClassification.from_pretrained("michelecafagna26/gpt2-medium-finetuned-sst2-sentiment").to(device)
model.config.pad_token_id = model.config.eos_token_id


In [6]:
logits_list = []
data = dataset_test['train']['body']
batch_size = 2

for start in range(0, len(data), batch_size):
    print("Processing " + str(start))
    batch = data[start:start + batch_size]
    # tokenize all first?
    inputs = tokenizer(batch, truncation=True, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        logits_list.append(model(**inputs).logits)
logits = torch.cat(logits_list)


Processing 0
Processing 2
Processing 4
Processing 6
Processing 8
Processing 10
Processing 12
Processing 14
Processing 16
Processing 18
Processing 20
Processing 22
Processing 24
Processing 26
Processing 28
Processing 30
Processing 32
Processing 34
Processing 36
Processing 38
Processing 40
Processing 42
Processing 44
Processing 46
Processing 48
Processing 50
Processing 52
Processing 54
Processing 56
Processing 58
Processing 60
Processing 62
Processing 64
Processing 66
Processing 68
Processing 70
Processing 72
Processing 74
Processing 76
Processing 78
Processing 80
Processing 82
Processing 84
Processing 86
Processing 88
Processing 90
Processing 92
Processing 94
Processing 96
Processing 98


In [7]:
load_accuracy = evaluate.load("accuracy")
load_f1 = evaluate.load("f1")
labels = dataset_test['train']['label']
predictions = np.argmax(logits.data.cpu().numpy(), axis=-1)
accuracy = load_accuracy.compute(predictions=predictions, references=labels)
f1 = load_f1.compute(predictions=predictions, references=labels)
metrics = {"accuracy": accuracy, "f1": f1}

In [9]:
print(metrics)

{'accuracy': {'accuracy': 0.78}, 'f1': {'f1': 0.8333333333333334}}


## The accuracy is 0.78, while the F1-score is 0.8333333333333334