# Lab 2.5.4: Fine-Tuning with the Trainer API

**Module:** 2.5 - Hugging Face Ecosystem  
**Time:** 2 hours  
**Difficulty:** ⭐⭐⭐ (Intermediate-Advanced)

---

## Learning Objectives

By the end of this lab, you will:
- [ ] Configure TrainingArguments for optimal DGX Spark training
- [ ] Implement custom metrics for evaluation
- [ ] Use callbacks for monitoring training
- [ ] Fine-tune a model for text classification
- [ ] Evaluate and save your trained model

---

## Prerequisites

- Completed: Labs 2.5.1 through 2.5.3
- Knowledge of: Training loops, loss functions, optimizers

---

## Real-World Context

**From General to Specialist**: Imagine you're a hospital. BERT is like a general doctor - knows a bit about everything. But you need a specialist who understands medical terminology, diagnoses, and patient sentiment.

**Fine-tuning** teaches this general doctor to become a specialist - using YOUR data for YOUR specific task. The Hugging Face **Trainer** makes this process as simple as:

```python
trainer = Trainer(model, args, train_data, eval_data)
trainer.train()  # That's it!
```

---

## ELI5: What is Fine-Tuning?

> **Imagine you're learning to play tennis...**
>
> Pre-training (what BERT learned): General coordination, how to hold a racket, basic movement
>
> Fine-tuning (what we'll do): Practice YOUR favorite shots, learn YOUR opponent's weaknesses
>
> The key insight: You don't start from scratch! You keep all the general knowledge and just add specialized skills.
>
> **In AI terms:**
> - Pre-trained BERT knows language structure, grammar, context
> - We fine-tune it on sentiment data so it learns "words like 'amazing' = positive"
> - Training is FAST because we only adjust weights slightly, not learn everything from scratch

---

## Part 1: Setup

In [None]:
import torch
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    TrainerCallback,
    EarlyStoppingCallback
)
import evaluate
import time
import gc

# Check environment
print("Environment Check")
print("=" * 50)
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### 1.1 Load and Prepare Dataset

In [None]:
# Load IMDB dataset
print("Loading IMDB dataset...")
dataset = load_dataset("imdb")

# Create validation split from training data
split = dataset['train'].train_test_split(test_size=0.1, seed=42)

train_dataset = split['train']
eval_dataset = split['test']
test_dataset = dataset['test']

print(f"\nDataset splits:")
print(f"  Train: {len(train_dataset):,}")
print(f"  Eval:  {len(eval_dataset):,}")
print(f"  Test:  {len(test_dataset):,}")

# Check label balance
from collections import Counter
print(f"\nLabel distribution (train): {Counter(train_dataset['label'])}")

In [None]:
# Load tokenizer and model
model_name = "distilbert-base-uncased"

print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    torch_dtype=torch.bfloat16  # Use BF16 for DGX Spark
)

# Check model size
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")
print(f"Estimated size: {num_params * 2 / 1e9:.2f} GB (BF16)")

In [None]:
# Tokenize datasets
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=256
    )

print("Tokenizing datasets...")
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=['text']
)

tokenized_eval = eval_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=['text']
)

tokenized_test = test_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=['text']
)

# Rename label to labels (Trainer expects this)
tokenized_train = tokenized_train.rename_column("label", "labels")
tokenized_eval = tokenized_eval.rename_column("label", "labels")
tokenized_test = tokenized_test.rename_column("label", "labels")

print("\nTokenized dataset columns:", tokenized_train.column_names)

---

## Part 2: Understanding TrainingArguments

This is where you configure everything about your training run.

In [None]:
# Let's explore TrainingArguments step by step
print("TrainingArguments - Key Parameters")
print("=" * 60)

config_explanation = """
OUTPUT:
  output_dir         : Where to save checkpoints and results
  overwrite_output_dir: Overwrite existing output directory

TRAINING DURATION:
  num_train_epochs   : Number of training epochs (full passes through data)
  max_steps          : Override epochs with exact step count (-1 to disable)

BATCH SIZE:
  per_device_train_batch_size: Batch size per GPU for training
  per_device_eval_batch_size : Batch size per GPU for evaluation
  gradient_accumulation_steps: Accumulate gradients over N steps (simulates larger batch)

OPTIMIZATION:
  learning_rate      : Initial learning rate
  weight_decay       : L2 regularization weight
  warmup_steps       : Steps for learning rate warmup (or warmup_ratio)
  lr_scheduler_type  : "linear", "cosine", "polynomial", etc.

EVALUATION & SAVING:
  eval_strategy      : "no", "steps", "epoch" - when to evaluate
  save_strategy      : "no", "steps", "epoch" - when to save
  load_best_model_at_end: Load best checkpoint at end of training
  metric_for_best_model : Which metric determines "best"

DGX SPARK OPTIMIZATION:
  bf16               : Use bfloat16 mixed precision (True for Blackwell GPU!)
  dataloader_num_workers: Parallel data loading
  dataloader_pin_memory : Pin memory for faster GPU transfer
"""
print(config_explanation)

In [None]:
# Create optimized training arguments for DGX Spark
training_args = TrainingArguments(
    # Output
    output_dir="./results/imdb_classifier",
    overwrite_output_dir=True,
    
    # Training duration
    num_train_epochs=3,
    
    # Batch size (with 128GB, we can use larger batches!)
    per_device_train_batch_size=32,  # Adjust based on GPU memory
    per_device_eval_batch_size=64,
    
    # Optimization
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,  # 10% of training for warmup
    lr_scheduler_type="linear",
    
    # Evaluation & Saving
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,  # Keep only 2 best checkpoints
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    
    # DGX Spark Optimizations
    bf16=True,  # Blackwell native BF16
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    
    # Logging
    logging_strategy="steps",
    logging_steps=100,
    report_to="none",  # Disable wandb/tensorboard for now
    
    # Reproducibility
    seed=42,
)

print("Training Arguments Created!")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  BF16: {training_args.bf16}")

---

## Part 3: Custom Metrics

In [None]:
# Load evaluation metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")

def compute_metrics(eval_pred):
    """
    Compute multiple metrics for evaluation.
    
    Args:
        eval_pred: EvalPrediction containing predictions and labels
        
    Returns:
        Dictionary of metric names and values
    """
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    # Compute all metrics
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='weighted')
    precision = precision_metric.compute(predictions=predictions, references=labels, average='weighted')
    recall = recall_metric.compute(predictions=predictions, references=labels, average='weighted')
    
    return {
        "accuracy": accuracy["accuracy"],
        "f1": f1["f1"],
        "precision": precision["precision"],
        "recall": recall["recall"]
    }

print("Custom metrics function created!")
print("Will compute: accuracy, f1, precision, recall")

---

## Part 4: Custom Callbacks

In [None]:
class MemoryCallback(TrainerCallback):
    """
    Callback to monitor GPU memory usage during training.
    """
    def __init__(self, log_every=100):
        self.log_every = log_every
        self.memory_log = []
        
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % self.log_every == 0:
            if torch.cuda.is_available():
                allocated = torch.cuda.memory_allocated() / 1e9
                reserved = torch.cuda.memory_reserved() / 1e9
                self.memory_log.append({
                    "step": state.global_step,
                    "allocated_gb": allocated,
                    "reserved_gb": reserved
                })
                
    def on_train_end(self, args, state, control, **kwargs):
        if self.memory_log:
            max_allocated = max(m["allocated_gb"] for m in self.memory_log)
            print(f"\n[Memory] Peak GPU usage: {max_allocated:.2f} GB")


class TimingCallback(TrainerCallback):
    """
    Callback to track training timing.
    """
    def __init__(self):
        self.start_time = None
        self.epoch_times = []
        
    def on_train_begin(self, args, state, control, **kwargs):
        self.start_time = time.time()
        print(f"\n[Timing] Training started...")
        
    def on_epoch_begin(self, args, state, control, **kwargs):
        self.epoch_start = time.time()
        
    def on_epoch_end(self, args, state, control, **kwargs):
        epoch_time = time.time() - self.epoch_start
        self.epoch_times.append(epoch_time)
        print(f"[Timing] Epoch completed in {epoch_time:.1f}s")
        
    def on_train_end(self, args, state, control, **kwargs):
        total_time = time.time() - self.start_time
        print(f"\n[Timing] Total training time: {total_time/60:.1f} minutes")
        print(f"[Timing] Average epoch time: {np.mean(self.epoch_times):.1f}s")


print("Custom callbacks created: MemoryCallback, TimingCallback")

---

## Part 5: Training!

In [None]:
# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[
        MemoryCallback(log_every=100),
        TimingCallback(),
        EarlyStoppingCallback(early_stopping_patience=3)
    ]
)

print("Trainer created!")
print(f"\nTraining dataset: {len(tokenized_train):,} examples")
print(f"Evaluation dataset: {len(tokenized_eval):,} examples")

In [None]:
# Estimate training time
total_steps = len(tokenized_train) // training_args.per_device_train_batch_size * training_args.num_train_epochs
print(f"Total training steps: {total_steps:,}")
print(f"Steps per epoch: {len(tokenized_train) // training_args.per_device_train_batch_size:,}")

# Clear memory before training
gc.collect()
torch.cuda.empty_cache()
print(f"\nGPU memory before training: {torch.cuda.memory_allocated()/1e9:.2f} GB")

In [None]:
# Train!
print("="*60)
print("STARTING TRAINING")
print("="*60)

train_result = trainer.train()

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)

In [None]:
# Training summary
print("\nTraining Results:")
print("-" * 40)
print(f"Final training loss: {train_result.training_loss:.4f}")
print(f"Training runtime: {train_result.metrics['train_runtime']:.1f}s")
print(f"Samples/second: {train_result.metrics['train_samples_per_second']:.1f}")
print(f"Steps/second: {train_result.metrics['train_steps_per_second']:.2f}")

---

## Part 6: Evaluation

In [None]:
# Evaluate on validation set
print("Evaluating on validation set...")
eval_results = trainer.evaluate()

print("\nValidation Results:")
print("-" * 40)
for key, value in eval_results.items():
    if "eval_" in key:
        metric_name = key.replace("eval_", "")
        if isinstance(value, float):
            print(f"{metric_name:15}: {value:.4f}")
        else:
            print(f"{metric_name:15}: {value}")

In [None]:
# Evaluate on test set
print("\nEvaluating on test set...")
test_results = trainer.evaluate(tokenized_test)

print("\nTest Results:")
print("-" * 40)
for key, value in test_results.items():
    if "eval_" in key:
        metric_name = key.replace("eval_", "")
        if isinstance(value, float):
            print(f"{metric_name:15}: {value:.4f}")

In [None]:
# Get predictions for confusion matrix
from sklearn.metrics import confusion_matrix, classification_report

predictions = trainer.predict(tokenized_test)
preds = np.argmax(predictions.predictions, axis=1)
labels = predictions.label_ids

# Confusion matrix
print("\nConfusion Matrix:")
print("-" * 40)
cm = confusion_matrix(labels, preds)
print(f"              Predicted")
print(f"              Neg    Pos")
print(f"Actual Neg   {cm[0][0]:5}  {cm[0][1]:5}")
print(f"       Pos   {cm[1][0]:5}  {cm[1][1]:5}")

# Classification report
print("\nClassification Report:")
print("-" * 40)
print(classification_report(labels, preds, target_names=['Negative', 'Positive']))

---

## Part 7: Saving and Loading the Model

In [None]:
# Save the model
save_path = "./results/imdb_classifier_final"
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

print(f"Model saved to {save_path}")

# Check what was saved
import os
print("\nSaved files:")
for f in os.listdir(save_path):
    size = os.path.getsize(os.path.join(save_path, f)) / 1e6
    print(f"  {f}: {size:.2f} MB")

In [None]:
# Load the saved model for inference
from transformers import pipeline

# Create pipeline from saved model
classifier = pipeline(
    "text-classification",
    model=save_path,
    tokenizer=save_path,
    device=0 if torch.cuda.is_available() else -1
)

# Test on new examples
test_texts = [
    "This movie was absolutely fantastic! I loved every minute of it.",
    "Worst movie I've ever seen. Complete waste of time.",
    "It was okay, nothing special but watchable."
]

print("\nInference on new examples:")
print("-" * 60)
for text in test_texts:
    result = classifier(text)[0]
    sentiment = "POSITIVE" if result['label'] == 'LABEL_1' else "NEGATIVE"
    print(f"{sentiment} ({result['score']:.1%}): {text[:50]}...")

---

## Part 8: Hyperparameter Tips for DGX Spark

### Recommended Settings

In [None]:
print("DGX SPARK TRAINING RECOMMENDATIONS")
print("=" * 60)

recommendations = """
PRECISION:
  - Use bf16=True (native Blackwell support)
  - Avoid fp16 on Blackwell - BF16 is more stable and equally fast

BATCH SIZE:
  - With 128GB unified memory, you can use larger batches
  - DistilBERT: batch_size=64-128
  - BERT-base: batch_size=32-64
  - Large models (7B+): Use gradient accumulation

LEARNING RATE:
  - Start with 2e-5 for most transformer fine-tuning
  - Larger batches can use slightly higher LR (3e-5)
  - Use warmup (10-20% of training)

DATA LOADING:
  - dataloader_num_workers=4-8
  - dataloader_pin_memory=True
  - Use batched tokenization with num_proc=4+

GRADIENT ACCUMULATION:
  - Use when batch_size is limited by memory
  - Effective batch = batch_size * gradient_accumulation_steps
  - Example: batch=8, accum=4 -> effective batch=32

CHECKPOINTING:
  - save_total_limit=2-3 (save disk space)
  - load_best_model_at_end=True
  - Use early stopping for efficiency
"""
print(recommendations)

---

## Try It Yourself: Train on AG News

In [None]:
# YOUR CODE HERE
# Fine-tune a model on the AG News dataset (4-class classification)
# 
# Steps:
# 1. Load AG News dataset
# 2. Prepare train/eval splits
# 3. Tokenize with a transformer tokenizer
# 4. Create TrainingArguments
# 5. Create Trainer
# 6. Train and evaluate
#
# Hint: AG News has 4 classes, so use num_labels=4

# Your code:
# ag_news = load_dataset("ag_news")
# ...

---

## Common Mistakes

### Mistake 1: Forgetting to Rename Label Column

```python
# Wrong: Trainer expects 'labels'
tokenized = dataset.map(tokenize_fn, batched=True)
# Error: KeyError 'labels'

# Right: Rename the column
tokenized = tokenized.rename_column("label", "labels")
```

### Mistake 2: Not Setting load_best_model_at_end

```python
# Wrong: Final model might be overfit
args = TrainingArguments(
    eval_strategy="epoch",
    save_strategy="epoch"
)

# Right: Load best checkpoint at end
args = TrainingArguments(
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)
```

### Mistake 3: Mismatched Strategies

```python
# Wrong: Can't load best model if save strategy doesn't match eval
args = TrainingArguments(
    eval_strategy="epoch",
    save_strategy="steps",  # Mismatch!
    load_best_model_at_end=True
)

# Right: Match strategies
args = TrainingArguments(
    eval_strategy="epoch",
    save_strategy="epoch",  # Match!
    load_best_model_at_end=True
)
```

---

## Checkpoint

You've learned:
- How to configure TrainingArguments for DGX Spark
- How to implement custom metrics
- How to use callbacks for monitoring
- How to train and evaluate a text classifier
- How to save and load fine-tuned models

---

## Further Reading

- [Trainer Documentation](https://huggingface.co/docs/transformers/main_classes/trainer)
- [TrainingArguments Reference](https://huggingface.co/docs/transformers/main_classes/trainer#trainingarguments)
- [Custom Training Loops](https://huggingface.co/docs/transformers/training)

---

## Cleanup

In [None]:
# Clean up
del model, trainer, classifier
gc.collect()
torch.cuda.empty_cache()

print(f"GPU memory after cleanup: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print("\nLab 2.5.4 complete!")