In [None]:
!pip install transformer_lens
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys

if IN_COLAB:
    # Install packages
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python


In [None]:
!pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [2]:
from transformers import GPT2Tokenizer
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from datasets import load_dataset


data_path = "/content/IMDB Dataset.csv"
# Load IMDB dataset (3000 train, 1000 test)
dataset = load_dataset('csv', data_files= data_path)
train_subset = dataset["train"].shuffle(seed=42).select(range(7000))
test_subset = dataset["train"].shuffle(seed=42).select(range(2000))
dataset = {"train": train_subset, "test": test_subset}

Generating train split: 0 examples [00:00, ? examples/s]

In [36]:
test_subset[:5]
# print how many positive and negative reviews are in the dataset
print("Number of positive reviews: ", sum(1 for sentiment in test_subset['sentiment'] if sentiment == 'positive'))
print("Number of negative reviews: ", sum(1 for sentiment in test_subset['sentiment'] if sentiment == 'negative'))

Number of positive reviews:  984
Number of negative reviews:  1016


In [28]:
from torch.utils.data import DataLoader, Dataset
label_map = {"negative": 0, "positive": 1}
class IMDBDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.texts = dataset["review"]
        self.labels = [label_map[label] for label in dataset["sentiment"]]
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encodings = self.tokenizer(
            self.texts[idx],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        input_ids = encodings["input_ids"].squeeze()
        attention_mask = encodings["attention_mask"].squeeze()
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return input_ids, attention_mask, label, idx

In [5]:
class SentimentClassifier(nn.Module):
    def __init__(self, transformer, hidden_dim=768, num_classes=2):
        super(SentimentClassifier, self).__init__()
        self.transformer = transformer
        self.classifier = nn.Linear(50257, num_classes)  # Maps hidden state → sentiment classes

    def forward(self, input_ids, attention_mask):
        outputs = self.transformer(input_ids)  # (batch, seq_len, hidden_dim)
        pooled_output = outputs.mean(dim=1)  # Mean pool across sequence length
        logits = self.classifier(pooled_output)  # Shape: (batch_size, 2)
        return logits

In [None]:
import torch
from transformer_lens import HookedTransformer
# load the base model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Load HookedTransformer GPT-2 model
base_model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Create the sentiment classifier model
base_classifier = SentimentClassifier(base_model).to(device)

In [29]:
test_dataset = IMDBDataset(dataset["test"], tokenizer)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True)

In [31]:
base_classifier.eval()
correct = 0
total = 0
# store the data points  where it predicts wrong
# make an empty dataset to store the wrong predictions
wrong_predictions = []
with torch.no_grad():
    for input_ids, attention_mask, labels, idx in test_loader:
        input_ids, attention_mask, labels, idx = input_ids.to(device), attention_mask.to(device), labels.to(device), idx.to(device)
        logits = base_classifier(input_ids, attention_mask)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        wrong_predictions.extend(idx[predicted != labels].tolist())

print(f"Accuracy: {correct / total:.2f}")
print(f"Number of wrong predictions: {total - correct}")
print("Correct predictions: ", correct)

Accuracy: 0.52
Number of wrong predictions: 964
Correct predictions:  1036


In [33]:
# convert the wrong predictions to a dataset
wrong_dataset = [test_dataset[i] for i in wrong_predictions]
wrong_loader = DataLoader(wrong_dataset, batch_size=4, shuffle=True)

In [18]:
for batch in wrong_loader:
    print(labels)
    break

tensor([0, 0, 1, 1], device='cuda:0')


In [10]:
loaded_model = SentimentClassifier(base_model).to(device)
loaded_model.load_state_dict(torch.load("/content/gpt2-small-imdb-finetuned.pt"))

<All keys matched successfully>

In [11]:
loaded_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for input_ids, attention_mask, labels in test_loader:
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        logits = loaded_model(input_ids, attention_mask)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {correct / total:.2f}")
print("Correct: ", correct)
print("Total: ", total)

Accuracy: 0.74
Correct:  1484
Total:  2000


In [35]:
# do evaluation on the wrong dataset
correct = 0
total = 0
with torch.no_grad():
    for input_ids, attention_mask, labels, idx in wrong_loader:
        input_ids, attention_mask, labels, idx = input_ids.to(device), attention_mask.to(device), labels.to(device), idx.to(device)
        logits = loaded_model(input_ids, attention_mask)
        _, predicted = torch.max(logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy: {correct / total:.2f}")
print("Correct: ", correct)
print("Total: ", total)

Accuracy: 0.51
Correct:  491
Total:  964
