In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

# Load the IMDB dataset
sst5 = load_dataset("SetFit/sst5")

Downloading readme:   0%|          | 0.00/421 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/171k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/343k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8544 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1101 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2210 [00:00<?, ? examples/s]

In [5]:
# Load the pre-trained transformer model and tokenizer
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformer_model = AutoModel.from_pretrained(model_name)

In [6]:
# Freeze the pre-trained model parameters
for param in transformer_model.parameters():
    param.requires_grad = False

In [33]:
# Set up the data collator and dataloaders
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

max_length = 512
batch_size = 32

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True, max_length=max_length)


small_train_dataset = (sst5["train"].shuffle(seed=42).select([i for i in list(range(3000))]))
small_eval_dataset = sst5["test"].shuffle(seed=42).select([i for i in list(range(300))])
small_test_dataset = sst5["test"].select([i for i in list(range(300, 600))])

tokenized_train = small_train_dataset.map(preprocess_function, batched=True, remove_columns=["text", "label_text"])
tokenized_eval = small_eval_dataset.map(preprocess_function, batched=True, remove_columns=["text", "label_text"])
tokenized_test = small_test_dataset.map(preprocess_function, batched=True, remove_columns=["text", "label_text"])

train_dataloader = DataLoader(tokenized_train, shuffle=True, batch_size=batch_size, collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_eval, batch_size=batch_size, collate_fn=data_collator)
test_dataloader = DataLoader(tokenized_test, batch_size=batch_size, collate_fn=data_collator)

In [9]:
# Move the model to the GPU (if available)
device = "mps" if torch.backends.mps.is_available() else "cpu"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [12]:
# Define the custom classification head
class ClassificationHead(nn.Module):
    def __init__(self, transformer_model, num_classes):
        super().__init__()
        self.transformer_model = transformer_model.to(device)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(transformer_model.config.hidden_size, num_classes).to(device)

    def forward(self, input_ids, attention_mask):
        output = self.transformer_model(input_ids=input_ids, attention_mask=attention_mask)[0]
        output = self.dropout(output[:, 0])  # Take the CLS token representation
        output = self.classifier(output)
        return output

In [13]:
# Set hyperparameters
num_classes = 5
learning_rate = 2e-5
num_epochs = 3

In [14]:
# Create the classification model
model = ClassificationHead(transformer_model, num_classes)

# Set up the optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [34]:
# Train
for epoch in range(num_epochs):
    model.train()
    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    eval_loss = 0
    correct_predictions = 0
    total_samples = 0
    for batch in eval_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        output = model(input_ids, attention_mask)
        loss = loss_fn(output, labels)
        eval_loss += loss.item()
        correct_predictions += (output.argmax(dim=1) == labels).sum().item()
        total_samples += labels.size(0)

    # Show metrics
    accuracy = correct_predictions / total_samples
    print(
        f"Epoch {epoch+1}/{num_epochs}, Eval Loss: {eval_loss/len(eval_dataloader)}, Accuracy: {accuracy:.4f}"
    )

Epoch 1/3, Eval Loss: 1.5707220792770387, Accuracy: 0.3033
Epoch 2/3, Eval Loss: 1.5620791792869568, Accuracy: 0.3133
Epoch 3/3, Eval Loss: 1.5561365485191345, Accuracy: 0.3033
