# LLM Response Classifier - PyTorch Lightning Training

This notebook trains a transformer-based model to classify LLM response comparisons using PyTorch Lightning.

## Features
- PyTorch Lightning for simplified training
- Multiple transformer model support (BERT, DistilBERT, RoBERTa, ELECTRA, DeBERTa, ALBERT)
- Automatic GPU/CPU detection
- Multi-GPU training support
- Checkpoint saving and resuming

## 1. Imports and Configuration

In [None]:
import os, time, sys
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from transformers import (BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer,
                          RobertaModel, RobertaTokenizer, ElectraModel, ElectraTokenizerFast,
                          DebertaModel, DebertaTokenizer, AlbertModel, AlbertTokenizer, AutoTokenizer)
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

# Add parent directory to path to import local modules
sys.path.append('..')
from core import ResponseDataset, ResponseScorer

# Set matmul precision for better performance on modern GPUs
torch.set_float32_matmul_precision('medium')

## 2. Model Registry

In [None]:
MODEL_REGISTRY = {
    "bert-base-uncased": (BertModel, BertTokenizer),
    "distilbert-base-uncased": (DistilBertModel, DistilBertTokenizer),
    "roberta-base": (RobertaModel, RobertaTokenizer),
    "google/electra-base-discriminator": (ElectraModel, ElectraTokenizerFast),
    "microsoft/deberta-base": (DebertaModel, DebertaTokenizer),
    "albert-base-v2": (AlbertModel, AlbertTokenizer),
}

## 3. PyTorch Lightning Module

In [None]:
class LightningResponseScorer(pl.LightningModule):
    def __init__(self, model_class, base_model_name, weights_dir, lr=1e-5, smoke_test=False):
        super().__init__()
        self.save_hyperparameters(ignore=['model_class'])
        self.lr = lr
        self.smoke_test = smoke_test

        # Create a simple logging function for model initialization
        def log_fn(msg, rank_specific=False):
            print(msg)

        # Initialize the ResponseScorer model
        self.model = ResponseScorer.from_pretrained(
            model_class=model_class,
            base_model_name=base_model_name,
            weights_dir=weights_dir,
            smoke_test=smoke_test,
            rank=self.global_rank if hasattr(self, 'global_rank') else 0,
            log_fn=log_fn
        )

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, ids_a, mask_a, ids_b, mask_b):
        return self.model(ids_a, mask_a, ids_b, mask_b)

    def training_step(self, batch, batch_idx):
        ids_a = batch["input_ids_a"]
        mask_a = batch["attention_mask_a"]
        ids_b = batch["input_ids_b"]
        mask_b = batch["attention_mask_b"]
        labels = batch["label"]

        logits = self(ids_a, mask_a, ids_b, mask_b)
        loss = self.loss_fn(logits, labels)

        preds = logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        ids_a = batch["input_ids_a"]
        mask_a = batch["attention_mask_a"]
        ids_b = batch["input_ids_b"]
        mask_b = batch["attention_mask_b"]
        labels = batch["label"]

        logits = self(ids_a, mask_a, ids_b, mask_b)
        loss = self.loss_fn(logits, labels)

        preds = logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()

        self.log('val_loss', loss, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True, sync_dist=True)

        return loss

    def test_step(self, batch, batch_idx):
        ids_a = batch["input_ids_a"]
        mask_a = batch["attention_mask_a"]
        ids_b = batch["input_ids_b"]
        mask_b = batch["attention_mask_b"]
        labels = batch["label"]

        logits = self(ids_a, mask_a, ids_b, mask_b)
        loss = self.loss_fn(logits, labels)

        preds = logits.argmax(dim=-1)
        acc = (preds == labels).float().mean()

        self.log('test_loss', loss, on_epoch=True, sync_dist=True)
        self.log('test_acc', acc, on_epoch=True, sync_dist=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

## 4. Configuration Parameters

In [None]:
# Training configuration
SMOKE_TEST = False  # Set to True for quick testing
BASE_MODEL = 'distilbert-base-uncased'  # Choose from MODEL_REGISTRY keys
TAG = ''  # Optional tag for organizing runs
NUM_EPOCHS = 3
BATCH_SIZE = 32
LR = 1e-5
MAX_LEN = 512
NUM_WORKERS = 4

# Paths
DATA_PATH = '../data/train.csv'
WEIGHTS_DIR = os.path.join('../weights', BASE_MODEL)
ROOT_CKPT = '../checkpoints'

# For resuming training
RESUME_TIMESTAMP = None  # Set to timestamp string to resume (e.g., '20231215_120000')

# Override for smoke test
if SMOKE_TEST:
    NUM_EPOCHS = 1
    BATCH_SIZE = 1

print(f"Configuration:")
print(f"  Base Model: {BASE_MODEL}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Learning Rate: {LR}")
print(f"  Smoke Test: {SMOKE_TEST}")

## 5. Setup Checkpoint Directory

In [None]:
if BASE_MODEL not in MODEL_REGISTRY:
    raise ValueError(f"Unsupported base_model '{BASE_MODEL}'. Available: {list(MODEL_REGISTRY.keys())}")

# Setup checkpoint directory
if RESUME_TIMESTAMP is None:
    TIMESTAMP = time.strftime("%Y%m%d_%H%M%S")
    CKPT_DIR = os.path.join(ROOT_CKPT, TIMESTAMP, TAG)
else:
    TIMESTAMP = RESUME_TIMESTAMP
    CKPT_DIR = os.path.join(ROOT_CKPT, RESUME_TIMESTAMP, TAG)
    if not os.path.isdir(os.path.join(ROOT_CKPT, RESUME_TIMESTAMP)):
        raise ValueError(f"Checkpoint directory {os.path.join(ROOT_CKPT, RESUME_TIMESTAMP)} does not exist.")

os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(WEIGHTS_DIR, exist_ok=True)

print(f"Checkpoint directory: {CKPT_DIR}")

## 6. Load and Prepare Dataset

In [None]:
print("Loading dataset...")

# Load tokenizer
model_class, _ = MODEL_REGISTRY[BASE_MODEL]
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

# Load and split dataset
df = pd.read_csv(DATA_PATH)
print(f"Total samples: {len(df)}")

df_train, df_temp = train_test_split(df, test_size=0.2, random_state=42)
df_val, df_test = train_test_split(df_temp, test_size=0.5, random_state=42)

print(f"Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test)}")

# Create datasets
train_dataset = ResponseDataset(df_train, tokenizer, MAX_LEN)
val_dataset = ResponseDataset(df_val, tokenizer, MAX_LEN)
test_dataset = ResponseDataset(df_test, tokenizer, MAX_LEN)

print("Dataset ready.")

## 7. Create DataLoaders

In [None]:
# Determine number of workers - use 0 for CPU to avoid multiprocessing issues
num_workers = 0 if not torch.cuda.is_available() else NUM_WORKERS

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available()
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available()
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available()
)

print(f"DataLoaders created with {num_workers} workers")

## 8. Initialize Model

In [None]:
print("Initializing model...")

model = LightningResponseScorer(
    model_class=model_class,
    base_model_name=BASE_MODEL,
    weights_dir=WEIGHTS_DIR,
    lr=LR,
    smoke_test=SMOKE_TEST
)

total_params = sum(p.numel() for p in model.parameters())
print(f"ResponseScorer has {total_params:,} parameters")

## 9. Setup Callbacks and Logger

In [None]:
# Setup callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath=CKPT_DIR,
    filename='epoch{epoch:02d}',
    save_top_k=-1,  # Save all checkpoints
    every_n_epochs=1,
    verbose=True
)

# Setup logger
csv_logger = CSVLogger(save_dir=CKPT_DIR, name='lightning_logs')

print("Callbacks and logger configured")

## 10. Determine Checkpoint for Resuming

In [None]:
# Determine checkpoint path for resuming
ckpt_path = None
if RESUME_TIMESTAMP is not None:
    ckpt_files = [f for f in os.listdir(CKPT_DIR) if f.startswith("epoch") and f.endswith(".ckpt")]
    if ckpt_files:
        ckpt_files.sort(key=lambda x: int(x.replace("epoch", "").replace(".ckpt", "")))
        latest_ckpt = ckpt_files[-1]
        ckpt_path = os.path.join(CKPT_DIR, latest_ckpt)
        print(f"Resuming from {latest_ckpt}")
    else:
        print("Resume timestamp given, but no checkpoints found. Starting from scratch.")
else:
    print("Starting fresh training run")

## 11. Setup Trainer

In [None]:
# Determine accelerator and devices
if torch.cuda.is_available():
    accelerator = "gpu"
    devices = torch.cuda.device_count()  # Use all available GPUs
    print(f"Using GPU acceleration with {devices} device(s)")
else:
    accelerator = "cpu"
    devices = 1  # CPU always uses 1 device in Lightning
    print("Using CPU")

# Adjust strategy for single device
strategy = "auto" if devices == 1 else "ddp"  # Use DDP for multi-GPU

# Setup trainer
trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    devices=devices,
    accelerator=accelerator,
    strategy=strategy,
    callbacks=[checkpoint_callback],
    logger=csv_logger,
    enable_progress_bar=True,
    enable_model_summary=True,
    deterministic=False,
    fast_dev_run=SMOKE_TEST
)

print("Trainer configured")

## 12. Train the Model

In [None]:
print("Starting training...")
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
print("Training complete!")

## 13. Test the Model

In [None]:
print("Testing...")
test_results = trainer.test(model, test_loader)
print("Testing complete!")
print(f"\nTest Results: {test_results}")

## 14. View Training Metrics

In [None]:
# Load and display training metrics
import glob

metrics_file = glob.glob(os.path.join(CKPT_DIR, 'lightning_logs', '**', 'metrics.csv'), recursive=True)
if metrics_file:
    metrics_df = pd.read_csv(metrics_file[0])
    print("\nTraining Metrics:")
    display(metrics_df)
else:
    print("No metrics file found")

## 15. Plot Training History (Optional)

In [None]:
import matplotlib.pyplot as plt

if metrics_file:
    # Plot loss
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    axes[0].plot(metrics_df['epoch'].dropna(), metrics_df['train_loss_epoch'].dropna(), label='Train Loss', marker='o')
    axes[0].plot(metrics_df['epoch'].dropna(), metrics_df['val_loss'].dropna(), label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy plot
    axes[1].plot(metrics_df['epoch'].dropna(), metrics_df['train_acc_epoch'].dropna(), label='Train Acc', marker='o')
    axes[1].plot(metrics_df['epoch'].dropna(), metrics_df['val_acc'].dropna(), label='Val Acc', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(CKPT_DIR, 'training_history.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print(f"Plot saved to {os.path.join(CKPT_DIR, 'training_history.png')}")

## Summary

This notebook provides an interactive way to train the LLM response classifier using PyTorch Lightning. You can:

- Adjust hyperparameters in cell 4
- Run cells individually to inspect intermediate results
- Visualize training progress
- Resume from checkpoints
- Use multiple GPUs automatically

All checkpoints and logs are saved to the checkpoint directory for later analysis.

## 16. Generate Submission File from Test Data

This cell loads the test.csv file, runs inference with the trained model, and creates a submission.csv file with probability predictions.

In [None]:
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm

# Create a simple dataset for test data (no labels)
class TestDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        t_a = self.tokenizer(row["prompt"] + row["response_a"], 
                             max_length=self.max_len,
                             padding="max_length", 
                             truncation=True, 
                             return_tensors="pt")
        t_b = self.tokenizer(row["prompt"] + row["response_b"], 
                             max_length=self.max_len,
                             padding="max_length", 
                             truncation=True, 
                             return_tensors="pt")
        return {
            "id": row["id"],
            "input_ids_a": t_a["input_ids"].squeeze(0),
            "attention_mask_a": t_a["attention_mask"].squeeze(0),
            "input_ids_b": t_b["input_ids"].squeeze(0),
            "attention_mask_b": t_b["attention_mask"].squeeze(0)
        }

print("Loading test data...")
test_df = pd.read_csv('../data/test.csv')
print(f"Test samples: {len(test_df)}")

# Create test dataset and dataloader
test_dataset_submission = TestDataset(test_df, tokenizer, MAX_LEN)
test_loader_submission = DataLoader(
    test_dataset_submission,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=torch.cuda.is_available()
)

# Set model to evaluation mode
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Run inference
predictions = []
ids = []

print("Running inference on test data...")
with torch.no_grad():
    for batch in tqdm(test_loader_submission, desc="Generating predictions"):
        batch_ids = batch["id"]
        ids_a = batch["input_ids_a"].to(device)
        mask_a = batch["attention_mask_a"].to(device)
        ids_b = batch["input_ids_b"].to(device)
        mask_b = batch["attention_mask_b"].to(device)
        
        # Get logits from model
        logits = model(ids_a, mask_a, ids_b, mask_b)
        
        # Convert logits to probabilities using softmax
        probs = F.softmax(logits, dim=-1)
        
        # Store predictions
        for i, batch_id in enumerate(batch_ids):
            ids.append(batch_id)
            # probs shape: (batch_size, 3) where classes are [model_b_wins, tie, model_a_wins]
            # Index 0 = model_b wins, Index 1 = tie, Index 2 = model_a wins
            predictions.append({
                'id': batch_id,
                'winner_model_a': probs[i, 2].cpu().item(),  # Probability that model_a wins
                'winner_model_b': probs[i, 0].cpu().item(),  # Probability that model_b wins
                'winner_tie': probs[i, 1].cpu().item()       # Probability of tie
            })

# Create submission dataframe
submission_df = pd.DataFrame(predictions)

# Save to CSV
submission_path = '../submission.csv'
submission_df.to_csv(submission_path, index=False)

print(f"\nSubmission file saved to: {submission_path}")
print(f"Shape: {submission_df.shape}")
print(f"\nFirst few rows:")
print(submission_df.head())

# Verify probabilities sum to 1
print(f"\nVerifying probabilities sum to 1.0:")
prob_sums = submission_df[['winner_model_a', 'winner_model_b', 'winner_tie']].sum(axis=1)
print(f"Min sum: {prob_sums.min():.6f}")
print(f"Max sum: {prob_sums.max():.6f}")
print(f"Mean sum: {prob_sums.mean():.6f}")