## Loading dataset

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from datasets import DatasetDict, load_dataset, concatenate_datasets
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

# Load the IMDB dataset
imdb_hf = load_dataset(path="imdb")

In [None]:
dataset = concatenate_datasets([imdb_hf["train"], imdb_hf["test"]])
train_test = dataset.train_test_split(test_size=0.3)
eval_test = train_test["test"].train_test_split(test_size=0.5)

imdb = DatasetDict(
    {
        "train": train_test["train"],
        "eval": eval_test["train"],
        "test": eval_test["test"],
    }
)

## Loading pretrained model

In [None]:
# Load the pre-trained transformer model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformer_model = AutoModel.from_pretrained(model_name)

In [None]:
# Freeze the pre-trained model parameters
for param in transformer_model.parameters():
    param.requires_grad = False

## Preparing data

In [None]:
# Set up the data collator and dataloaders
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

max_length = 512
batch_size = 64

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True, max_length=max_length)

# train_dataset = (imdb["train"].shuffle(seed=42).select([i for i in list(range(3000))]))
# eval_dataset = imdb["test"].shuffle(seed=42).select([i for i in list(range(300))])
# test_dataset = imdb["test"].select([i for i in list(range(300, 600))])

train_dataset = imdb["train"].shuffle(seed=42)
eval_dataset = imdb["test"].shuffle(seed=42)
test_dataset = imdb["test"]

tokenized_train = train_dataset.map(preprocess_function, batched=True, remove_columns=["text"])
tokenized_eval = eval_dataset.map(preprocess_function, batched=True, remove_columns=["text"])
tokenized_test = test_dataset.map(preprocess_function, batched=True, remove_columns=["text"])

train_dataloader = DataLoader(tokenized_train, shuffle=True, batch_size=batch_size, collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_eval, batch_size=batch_size, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized_test, batch_size=batch_size, collate_fn=data_collator)

In [None]:
# Move the model to the GPU (if available)
# device = "mps" if torch.backends.mps.is_available() else "cpu"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Designing classification head

In [None]:
# Define the custom classification head
class ClassificationHead(nn.Module):
    def __init__(self, transformer_model, num_classes):
        super().__init__()
        self.transformer_model = transformer_model.to(device)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(
            transformer_model.config.hidden_size, num_classes
        ).to(device)

    def forward(self, input_ids, attention_mask):
        output = self.transformer_model(
            input_ids=input_ids, attention_mask=attention_mask
        )[0]
        output = self.dropout(output[:, 0])  # Take the CLS token representation
        output = self.classifier(output)
        return output

## Tuning model

In [None]:
# Set hyperparameters
num_classes = 2  # Binary classification (positive/negative)
learning_rate = 2e-5
num_epochs = 5

In [None]:
# Create the classification model
model = ClassificationHead(transformer_model, num_classes)

# Set up the optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [None]:
from pycm import ConfusionMatrix


def train_epoch(model, train_dataloader, optimizer, loss_fn, device):
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
    return train_loss / len(train_dataloader)


def eval_epoch(model, eval_dataloader, loss_fn, device):
    model.eval()
    eval_loss = 0
    y_preds = []
    for batch in eval_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)
        eval_loss += loss.item()
        y_preds.extend(output.argmax(dim=1).detach().tolist())
    return eval_loss / len(eval_dataloader), y_preds


def test_model(model, test_dataloader, loss_fn, device):
    model.eval()
    test_loss = 0
    y_preds = []
    for batch in test_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)
        test_loss += loss.item()
        y_preds.extend(output.argmax(dim=1).detach().tolist())
    test_loss /= len(test_dataloader)
    test_cm = ConfusionMatrix(test_dataloader.dataset["label"], y_preds, digit=5)
    return test_loss, test_cm


def train_model(
    model,
    train_dataloader,
    eval_dataloader,
    test_dataloader,
    optimizer,
    loss_fn,
    device,
    num_epochs,
):
    stats = {}
    for epoch in range(num_epochs):
        tloss = train_epoch(model, train_dataloader, optimizer, loss_fn, device)
        eloss, y_preds = eval_epoch(model, eval_dataloader, loss_fn, device)
        evaluation_cm = ConfusionMatrix(
            eval_dataloader.dataset["label"], y_preds, digit=5
        )

        stats[f"epoch_{epoch}"] = {
            "training_loss": tloss,
            "validation_loss": eloss,
            "validation_metrics": evaluation_cm,
        }

        print(
            f"Epoch = {epoch+1}/{num_epochs}\t Training Loss = {tloss:.2f}\t Validation Loss = {eloss:.2f}\t Validation Accuracy = {evaluation_cm.Overall_ACC:.2f}")

    test_loss, test_cm = test_model(model, test_dataloader, loss_fn, device)
    stats["test_loss"] = test_loss
    stats["test_metrics"] = test_cm

    print(f"\nTest Accuracy = {test_cm.Overall_ACC:.2f}")

    return stats

In [None]:
# Train the model and evaluate on test set
stats = train_model(
    model,
    train_dataloader,
    eval_dataloader,
    test_dataloader,
    optimizer,
    loss_fn,
    device,
    num_epochs,
)

## Saving the results

In [None]:
import pickle

with open(f"results/imdb_bert.pickle", "wb") as file:
    pickle.dump(stats, file)

# with open("imdb_bert.pickle", "rb") as file:
#     stats = pickle.load(file)