## Fine-Tuning a Pre-Trained Transformer for Diverse Classification

This notebook takes the pre-trained encoder weights (`pretrained_transformer_encoder.pth`) and fine-tunes them on a **combined dataset** to create a versatile classifier.

**Our Goal:** Build a single model that can perform both:
1.  **Topic Classification** (using AG News)
2.  **Sentiment Analysis** (using IMDB movie reviews)

In [None]:
# Cell 1: Imports and Device Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import re
from collections import Counter
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Cell 2: Model Architecture 

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
    def forward(self, x, mask=None):
        B, L, _ = x.size()
        Q = self.W_Q(x).view(B, L, self.num_heads, self.d_k).transpose(1,2)
        K = self.W_K(x).view(B, L, self.num_heads, self.d_k).transpose(1,2)
        V = self.W_V(x).view(B, L, self.num_heads, self.d_k).transpose(1,2)
        scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k)
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        out = out.transpose(1,2).contiguous().view(B, L,  self.d_model)
        return self.W_O(out), attn

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(), nn.Linear(d_ff, d_model))
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        
        x_attn, _ = self.mha(x, mask)
        x = self.norm1(x + x_attn)

        x_ff = self.ff(x)
        x = self.norm2(x + x_ff)
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([TransformerEncoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)])
    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x

class EncoderClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, num_classes):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocab_size, d_model)
        self.pos_enc = SinusoidalPositionalEncoding(d_model)
        self.encoder = TransformerEncoder(d_model, num_heads, d_ff, num_layers)
        self.classifier = nn.Linear(d_model, num_classes)
    def forward(self, input_ids, mask=None):
        embeds = self.embedding_layer(input_ids)
        embeds_pos = self.pos_enc(embeds)
        enc_out = self.encoder(embeds_pos, mask)
        cls_token_out = enc_out[:, 0, :] 
        return self.classifier(cls_token_out)

In [None]:
# Cell 3: Load, Combine, and Unify Datasets
print("Loading datasets...")
ag_news_ds = load_dataset("ag_news")
imdb_ds = load_dataset("imdb")


unified_class_map = {
    0: "World News",
    1: "Sports News",
    2: "Business News",
    3: "Sci/Tech News",
    4: "Negative Review",
    5: "Positive Review"
}
NUM_CLASSES = len(unified_class_map)


combined_data = []

for item in ag_news_ds['train']:
    combined_data.append({'text': item['text'], 'label': item['label']})

for item in imdb_ds['train']:
    new_label = item['label'] + 4
    combined_data.append({'text': item['text'], 'label': new_label})

import random
random.shuffle(combined_data)

train_texts, val_texts, train_labels, val_labels = train_test_split(
    [item['text'] for item in combined_data],
    [item['label'] for item in combined_data],
    test_size=0.1, 
    random_state=42
)

train_split = [{'text': t, 'label': l} for t, l in zip(train_texts, train_labels)]
val_split = [{'text': t, 'label': l} for t, l in zip(val_texts, val_labels)]

print(f"Created a diverse dataset with {len(train_split)} training and {len(val_split)} validation examples.")
print(f"Total number of classes: {NUM_CLASSES}")

In [None]:
# Cell 4: Vocabulary and Dataset Class
print("Building vocabulary from combined data...")

def simple_tokenizer(text):
    return re.findall(r'\b\w+\b', text.lower())

word_counts = Counter()
for example in tqdm(train_split):
    word_counts.update(simple_tokenizer(example['text']))

min_freq = 5
vocab = {"<pad>": 0, "<unk>": 1, "[CLS]": 2}
offset = len(vocab)
for word, count in word_counts.items():
    if count >= min_freq:
        vocab[word] = offset
        offset += 1

VOCAB_SIZE = len(vocab)
print(f"Vocabulary size: {VOCAB_SIZE}")

class ClassificationDataset(Dataset):
    def __init__(self, data, vocab, max_seq_len):
        self.data = data
        self.vocab = vocab
        self.max_seq_len = max_seq_len
        self.tokenizer = simple_tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text, label = item['text'], item['label']
        tokens = ['[CLS]'] + self.tokenizer(text)
        token_ids = [self.vocab.get(token, self.vocab["<unk>"]) for token in tokens]
        
        # Pad or truncate
        if len(token_ids) < self.max_seq_len:
            token_ids += [self.vocab["<pad>"]] * (self.max_seq_len - len(token_ids))
        else:
            token_ids = token_ids[:self.max_seq_len]
        
        return {
            "input_ids": torch.tensor(token_ids, dtype=torch.long),
            "label": torch.tensor(label, dtype=torch.long)
        }

In [None]:
# Cell 5: Hyperparameters & Instantiation



D_MODEL = 256
NUM_HEADS = 8
D_FF = 1024
NUM_ENCODER_LAYERS = 4

MAX_SEQ_LEN = 256
BATCH_SIZE = 32
LEARNING_RATE = 2e-5 
EPOCHS = 3


model = EncoderClassifier(
    vocab_size=VOCAB_SIZE,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    d_ff=D_FF,
    num_layers=NUM_ENCODER_LAYERS,
    num_classes=NUM_CLASSES
).to(device)

ENCODER_WEIGHTS_PATH = 'pretrained_transformer_encoder.pth'
print(f"Loading pre-trained encoder weights from {ENCODER_WEIGHTS_PATH}")
model.encoder.load_state_dict(torch.load(ENCODER_WEIGHTS_PATH, map_location=device))
print("✅ Weights loaded successfully!")

train_dataset = ClassificationDataset(train_split, vocab, MAX_SEQ_LEN)
val_dataset = ClassificationDataset(val_split, vocab, MAX_SEQ_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [None]:
# Cell 6: Training and Evaluation Loops

def train_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    total_correct = 0
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch["input_ids"].to(device)
        labels = batch["label"].to(device)
        padding_mask = (input_ids != vocab["<pad>"]).to(device)
        logits = model(input_ids, mask=padding_mask)
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        predictions = torch.argmax(logits, dim=1)
        total_correct += (predictions == labels).sum().item()
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            labels = batch["label"].to(device)
            padding_mask = (input_ids != vocab["<pad>"]).to(device)
            logits = model(input_ids, mask=padding_mask)
            loss = loss_fn(logits, labels)
            total_loss += loss.item()
            predictions = torch.argmax(logits, dim=1)
            total_correct += (predictions == labels).sum().item()
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)
    return avg_loss, accuracy

In [None]:
# Cell 7 (REVISED): Main Fine-Tuning Execution with Freezing

print("--- Stage 1: Training the Classifier Head (Encoder Frozen) ---")

for param in model.encoder.parameters():
    param.requires_grad = False

print("\n--- Epoch 1/4 ---")
train_loss, train_acc = train_epoch(model, train_loader, loss_fn, optimizer, device)
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
val_loss, val_acc = evaluate(model, val_loader, loss_fn, device)
print(f"Validation Loss: {val_loss:.4f} | Validation Acc: {val_acc:.4f}")


print("\n--- Stage 2: Fine-Tuning the Full Model (Encoder Unfrozen) ---")

for param in model.encoder.parameters():
    param.requires_grad = True
    
for epoch in range(1, EPOCHS + 1): 
    print(f"\n--- Epoch {epoch + 1}/4 ---")
    train_loss, train_acc = train_epoch(model, train_loader, loss_fn, optimizer, device)
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    val_loss, val_acc = evaluate(model, val_loader, loss_fn, device)
    print(f"Validation Loss: {val_loss:.4f} | Validation Acc: {val_acc:.4f}")

In [None]:
# Cell 8: Inference with the Diverse Classifier

def predict_category(text, model, vocab, class_map, max_len, device):
    model.eval()
    tokens = ['[CLS]'] + simple_tokenizer(text)
    token_ids = [vocab.get(t, vocab["<unk>"]) for t in tokens]
    if len(token_ids) < max_len:
        token_ids += [vocab["<pad>"]] * (max_len - len(token_ids))
    else:
        token_ids = token_ids[:max_len]
        
    input_tensor = torch.tensor([token_ids], dtype=torch.long).to(device)
    mask = (input_tensor != vocab["<pad>"]).to(device)
    
    with torch.no_grad():
        logits = model(input_tensor, mask=mask)
        
    pred_index = torch.argmax(logits, dim=1).item()
    return class_map[pred_index]

news_headline = "The space agency announced a new mission to Mars to search for signs of ancient life."
movie_review_1 = "The movie was absolutely fantastic! The acting was superb and the plot was thrilling."
movie_review_2 = "I was so bored throughout the entire film. It was a complete waste of time."

print(f"'{news_headline}' -> Prediction: {predict_category(news_headline, model, vocab, unified_class_map, MAX_SEQ_LEN, device)}")
print(f"'{movie_review_1}' -> Prediction: {predict_category(movie_review_1, model, vocab, unified_class_map, MAX_SEQ_LEN, device)}")
print(f"'{movie_review_2}' -> Prediction: {predict_category(movie_review_2, model, vocab, unified_class_map, MAX_SEQ_LEN, device)}")

In [None]:
# Add this code to the end of your fine-tuning notebook

FINETUNED_MODEL_PATH = 'finetuned_diverse_classifier.pth'
torch.save(model.state_dict(), FINETUNED_MODEL_PATH)
print(f"Fine-tuned model saved to {FINETUNED_MODEL_PATH}")

FINETUNED_VOCAB_PATH = 'finetuned_vocab.json'
with open(FINETUNED_VOCAB_PATH, 'w') as f:
    json.dump(vocab, f, indent=4)
print(f"Fine-tuned vocabulary saved to {FINETUNED_VOCAB_PATH}")