# Lab 2.4.4: Fine-tuning with the Trainer API

**Module:** 2.4 - Hugging Face Ecosystem  
**Time:** 2 hours  
**Difficulty:** ⭐⭐⭐ (Intermediate)

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Configure TrainingArguments for optimal training
- [ ] Use the Trainer API for fine-tuning
- [ ] Implement custom metrics with compute_metrics
- [ ] Use callbacks for custom behavior
- [ ] Leverage DGX Spark's capabilities (bf16, large batches)
- [ ] Evaluate and compare models

---

## Prerequisites

- Completed: Lab 2.4.3 (Dataset Processing)
- Knowledge of: Tokenization, basic neural network training concepts

---

## Real-World Context

You've been hired by a movie streaming company to build a review sentiment classifier. They have thousands of movie reviews and want to automatically categorize them as positive or negative.

You could train a model from scratch, but that would take weeks and millions of examples. Instead, you'll **fine-tune** a pre-trained model in just minutes!

**Fine-tuning success stories:**
- Twitter trained a toxicity classifier in hours (vs. months from scratch)
- Healthcare companies fine-tune for medical text understanding
- Banks fine-tune for fraud detection in transaction descriptions
- E-commerce uses fine-tuned models for product categorization

---

## ELI5: What is Fine-tuning?

> **Imagine you're learning to play piano.** You could:
> - Option A: Start from scratch, learn music theory, practice scales for years
> - Option B: Find someone who already plays guitar (similar skill!), teach them piano differences
>
> **Fine-tuning is Option B for AI.** You take a model that already knows language (like BERT) and teach it YOUR specific task.
>
> **What the pre-trained model already knows:**
> - Grammar and sentence structure
> - Word meanings and relationships
> - Common sense about the world
>
> **What fine-tuning teaches it:**
> - "This specific task: positive vs negative reviews"
> - "Your specific data patterns"
>
> **The Trainer API** is like a personal coach that handles all the training logistics:
> - Schedules practice sessions (epochs)
> - Tracks progress (metrics)
> - Adjusts difficulty (learning rate)
> - Saves best performances (checkpoints)

---

## Part 1: Setup and Data Preparation

In [None]:
# Install required packages
# Note: These packages are pre-installed in the NGC PyTorch container.
# Running pip install ensures you have compatible versions.
# If NOT using NGC container, ensure you have ARM64-compatible packages for DGX Spark.

!pip install -q "transformers>=4.35.0" "datasets>=2.14.0" "evaluate>=0.4.0" "accelerate>=0.24.0" scikit-learn

import torch
import numpy as np
from typing import List, Dict, Any, Union
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    PreTrainedModel,
    PreTrainedTokenizer
)
import evaluate
import warnings
warnings.filterwarnings('ignore')

# Check hardware
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Load and prepare the IMDB dataset
print("Loading IMDB dataset...")
raw_dataset = load_dataset("imdb")

# Create train/validation split
train_val = raw_dataset['train'].train_test_split(
    test_size=0.1,
    seed=42,
    stratify_by_column='label'
)

dataset = DatasetDict({
    'train': train_val['train'],
    'validation': train_val['test'],
    'test': raw_dataset['test']
})

print(f"\nDataset splits:")
print(f"  Train: {len(dataset['train']):,}")
print(f"  Validation: {len(dataset['validation']):,}")
print(f"  Test: {len(dataset['test']):,}")

In [None]:
# Load tokenizer and model
model_name = "distilbert-base-uncased"

print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load model for sequence classification
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,  # Binary classification
    id2label={0: "NEGATIVE", 1: "POSITIVE"},
    label2id={"NEGATIVE": 0, "POSITIVE": 1}
)

print(f"Model parameters: {model.num_parameters():,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        padding='max_length',
        max_length=256
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    batch_size=1000,
    num_proc=4,
    remove_columns=['text'],
    desc="Tokenizing"
)

# Rename label to labels (expected by Trainer)
tokenized_dataset = tokenized_dataset.rename_column('label', 'labels')

print(f"\nTokenized columns: {tokenized_dataset['train'].column_names}")

---

## Part 2: Understanding TrainingArguments

TrainingArguments controls EVERYTHING about training. Let's understand the key parameters.

In [None]:
# Create TrainingArguments with detailed comments
training_args = TrainingArguments(
    # === Output Settings ===
    output_dir="./results/imdb-sentiment",  # Where to save checkpoints
    overwrite_output_dir=True,              # Overwrite existing dir
    
    # === Training Duration ===
    num_train_epochs=3,                     # Total passes through data
    # max_steps=-1,                         # Alternative: stop after N steps
    
    # === Batch Size (DGX Spark can handle larger!) ===
    per_device_train_batch_size=16,         # Batch size per GPU
    per_device_eval_batch_size=32,          # Larger for eval (no gradients)
    # gradient_accumulation_steps=2,        # Effective batch = 16*2 = 32
    
    # === Learning Rate ===
    learning_rate=2e-5,                     # Standard for fine-tuning
    weight_decay=0.01,                      # L2 regularization
    warmup_ratio=0.1,                       # 10% of training for warmup
    # warmup_steps=500,                     # Alternative: fixed warmup steps
    lr_scheduler_type="linear",             # Linear decay after warmup
    
    # === Evaluation ===
    eval_strategy="epoch",                  # Evaluate after each epoch
    # eval_steps=500,                       # Alternative: every N steps
    
    # === Checkpointing ===
    save_strategy="epoch",                  # Save after each epoch
    save_total_limit=2,                     # Keep only 2 best checkpoints
    load_best_model_at_end=True,            # Load best model after training
    metric_for_best_model="accuracy",       # Which metric to track
    greater_is_better=True,                 # Higher accuracy = better
    
    # === Precision (DGX Spark optimization!) ===
    bf16=True,                              # Use bfloat16 (Blackwell native!)
    # fp16=False,                           # Don't use fp16 on Blackwell
    
    # === Logging ===
    logging_dir="./logs",                   # TensorBoard logs
    logging_strategy="steps",
    logging_steps=100,                      # Log every 100 steps
    report_to="none",                       # Disable W&B/etc reporting
    
    # === Performance ===
    dataloader_num_workers=4,               # Parallel data loading
    dataloader_pin_memory=True,             # Pin memory for faster GPU transfer
    
    # === Reproducibility ===
    seed=42,
)

print("TrainingArguments configured!")
print(f"\nKey settings:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  BF16: {training_args.bf16}")

### Understanding Key Parameters

| Parameter | What it does | DGX Spark recommendation |
|-----------|--------------|---------------------------|
| `per_device_train_batch_size` | Samples per forward pass | 16-64 (you have memory!) |
| `learning_rate` | Step size for updates | 1e-5 to 5e-5 for fine-tuning |
| `warmup_ratio` | Gradual LR increase period | 0.1 (10% of training) |
| `bf16` | Use bfloat16 precision | **True** (native Blackwell support) |
| `gradient_accumulation_steps` | Simulate larger batches | Use if batch doesn't fit |

---

## Part 3: Custom Metrics with compute_metrics

In [None]:
# Load evaluation metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")

def compute_metrics(eval_pred):
    """
    Compute multiple metrics for evaluation.
    
    Args:
        eval_pred: EvalPrediction with predictions and label_ids
        
    Returns:
        Dictionary of metric names and values
    """
    predictions, labels = eval_pred
    
    # Convert logits to predictions
    predictions = np.argmax(predictions, axis=1)
    
    # Calculate metrics
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels, average='weighted')
    precision = precision_metric.compute(predictions=predictions, references=labels, average='weighted')
    recall = recall_metric.compute(predictions=predictions, references=labels, average='weighted')
    
    return {
        'accuracy': accuracy['accuracy'],
        'f1': f1['f1'],
        'precision': precision['precision'],
        'recall': recall['recall']
    }

print("Metrics function ready!")
print("Will compute: accuracy, f1, precision, recall")

### What Each Metric Means

| Metric | Meaning | When it matters |
|--------|---------|------------------|
| **Accuracy** | % correct predictions | Balanced datasets |
| **Precision** | Of predicted positives, % correct | Cost of false positives high |
| **Recall** | Of actual positives, % found | Cost of missing positives high |
| **F1** | Harmonic mean of precision & recall | Imbalanced datasets |

---

## Part 4: Creating and Running the Trainer

In [None]:
# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]  # Stop if no improvement
)

print("Trainer created!")
print(f"\nTraining samples: {len(trainer.train_dataset):,}")
print(f"Validation samples: {len(trainer.eval_dataset):,}")

In [None]:
# Evaluate before training (baseline)
print("Evaluating baseline (before training)...")
baseline_results = trainer.evaluate()

print("\nBaseline Results:")
for key, value in baseline_results.items():
    if 'loss' in key or 'accuracy' in key or 'f1' in key:
        print(f"  {key}: {value:.4f}")

In [None]:
# Train the model!
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

# This is where the magic happens!
train_result = trainer.train()

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

In [None]:
# Training statistics
print("\nTraining Statistics:")
print(f"  Total steps: {train_result.global_step}")
print(f"  Training loss: {train_result.training_loss:.4f}")
print(f"  Training time: {train_result.metrics['train_runtime']:.1f}s")
print(f"  Samples/second: {train_result.metrics['train_samples_per_second']:.1f}")

---

## Part 5: Evaluation on Test Set

In [None]:
# Evaluate on test set
print("Evaluating on test set...")
test_results = trainer.evaluate(tokenized_dataset['test'])

print("\n" + "="*60)
print("TEST SET RESULTS")
print("="*60)

for key, value in test_results.items():
    if not key.startswith('eval_runtime'):
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")

In [None]:
# Get predictions for detailed analysis
predictions = trainer.predict(tokenized_dataset['test'])

# Convert to class predictions
pred_labels = np.argmax(predictions.predictions, axis=1)
true_labels = predictions.label_ids

# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report

cm = confusion_matrix(true_labels, pred_labels)
print("\nConfusion Matrix:")
print(f"               Predicted")
print(f"              NEG    POS")
print(f"Actual NEG  {cm[0,0]:5d}  {cm[0,1]:5d}")
print(f"       POS  {cm[1,0]:5d}  {cm[1,1]:5d}")

print("\n" + classification_report(
    true_labels, 
    pred_labels, 
    target_names=['NEGATIVE', 'POSITIVE']
))

---

## Part 6: Making Predictions

In [None]:
# Function to predict sentiment for new texts
def predict_sentiment(
    texts: Union[str, List[str]], 
    model: PreTrainedModel, 
    tokenizer: PreTrainedTokenizer
) -> List[Dict[str, Any]]:
    """
    Predict sentiment for a list of texts.
    
    Args:
        texts: List of strings or single string to classify
        model: Trained sequence classification model
        tokenizer: Tokenizer corresponding to the model
        
    Returns:
        List of predictions with 'label' and 'confidence' keys
    """
    if isinstance(texts, str):
        texts = [texts]
    
    # Tokenize
    inputs = tokenizer(
        texts,
        truncation=True,
        padding=True,
        max_length=256,
        return_tensors="pt"
    )
    
    # Move to same device as model
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Predict
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Process outputs
    probs = torch.softmax(outputs.logits, dim=1)
    predictions = torch.argmax(probs, dim=1)
    confidences = probs.max(dim=1).values
    
    results: List[Dict[str, Any]] = []
    for pred, conf in zip(predictions, confidences):
        label = model.config.id2label[pred.item()]
        results.append({
            'label': label,
            'confidence': conf.item()
        })
    
    return results

print("Prediction function ready!")

In [None]:
# Test on new reviews
test_reviews = [
    "This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.",
    "Waste of time. Terrible acting, boring plot, would not recommend to anyone.",
    "It was okay. Not great, not terrible. Just average.",
    "One of the best films I've seen this year! Definitely Oscar-worthy.",
    "I walked out after 30 minutes. Couldn't stand it.",
    "The special effects were amazing but the story was lacking."
]

print("Testing on new reviews:\n")
print(f"{'Review':<70} {'Prediction':<12} {'Confidence'}")
print("=" * 100)

predictions = predict_sentiment(test_reviews, trainer.model, tokenizer)

for review, pred in zip(test_reviews, predictions):
    display_review = review[:67] + "..." if len(review) > 70 else review
    print(f"{display_review:<70} {pred['label']:<12} {pred['confidence']:.2%}")

---

## Part 7: Saving and Loading the Model

In [None]:
# Save the fine-tuned model
save_path = "./saved_models/imdb-sentiment"

print(f"Saving model to {save_path}...")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

# Check what was saved
import os
print("\nSaved files:")
for f in os.listdir(save_path):
    size = os.path.getsize(os.path.join(save_path, f)) / 1e6
    print(f"  {f}: {size:.1f} MB")

In [None]:
# Load the saved model
print("\nLoading saved model...")

loaded_model = AutoModelForSequenceClassification.from_pretrained(save_path)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_path)

# Move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model = loaded_model.to(device)

# Test loaded model
test_text = "This movie was incredible!"
result = predict_sentiment(test_text, loaded_model, loaded_tokenizer)
print(f"\nLoaded model prediction: '{test_text}' → {result[0]['label']} ({result[0]['confidence']:.2%})")

---

## Part 8: Comparing Training Configurations

Let's compare different training setups to see their impact.

In [None]:
# Function to train with different configs and compare
def quick_train(config_name, learning_rate, batch_size, epochs=1):
    """Train with specific config and return results."""
    print(f"\nTraining: {config_name}")
    print(f"  LR: {learning_rate}, Batch: {batch_size}, Epochs: {epochs}")
    
    # Fresh model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=2
    )
    
    # Use smaller subset for speed
    small_train = tokenized_dataset['train'].select(range(2000))
    small_val = tokenized_dataset['validation'].select(range(500))
    
    args = TrainingArguments(
        output_dir=f"./results/{config_name}",
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2,
        learning_rate=learning_rate,
        eval_strategy="epoch",
        save_strategy="no",
        bf16=True,
        logging_strategy="no",
        report_to="none",
        seed=42
    )
    
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=small_train,
        eval_dataset=small_val,
        compute_metrics=compute_metrics
    )
    
    import time
    start = time.time()
    trainer.train()
    train_time = time.time() - start
    
    results = trainer.evaluate()
    
    return {
        'config': config_name,
        'accuracy': results['eval_accuracy'],
        'f1': results['eval_f1'],
        'time': train_time
    }

print("Comparison function ready!")

In [None]:
# Compare different configurations
configs = [
    ("small_lr", 1e-5, 16),
    ("medium_lr", 2e-5, 16),
    ("large_lr", 5e-5, 16),
    ("large_batch", 2e-5, 32),
]

comparison_results = []
for name, lr, bs in configs:
    result = quick_train(name, lr, bs)
    comparison_results.append(result)
    
# Display comparison
print("\n" + "="*70)
print("CONFIGURATION COMPARISON")
print("="*70)
print(f"{'Config':<15} {'Accuracy':<12} {'F1':<12} {'Time (s)':<10}")
print("-"*50)
for r in comparison_results:
    print(f"{r['config']:<15} {r['accuracy']:.4f}       {r['f1']:.4f}       {r['time']:.1f}")

---

## Try It Yourself: Train on AG News

Fine-tune a model on the AG News dataset (news category classification):
1. Load the `ag_news` dataset
2. Create appropriate splits
3. Tokenize with a model of your choice
4. Configure TrainingArguments
5. Train and evaluate

<details>
<summary>Hint</summary>

```python
# AG News has 4 classes: World, Sports, Business, Sci/Tech
ag_news = load_dataset("ag_news")

# Configure model for 4 classes
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=4,
    id2label={0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"},
    label2id={"World": 0, "Sports": 1, "Business": 2, "Sci/Tech": 3}
)
```
</details>

In [None]:
# YOUR CODE HERE
# Train a news classifier on AG News




---

## Common Mistakes

### Mistake 1: Wrong Label Column Name

In [None]:
# WRONG: Trainer expects 'labels', not 'label'
# dataset.column_names = ['input_ids', 'attention_mask', 'label']
# trainer.train()  # Error or poor results!

# CORRECT: Rename label to labels
# dataset = dataset.rename_column('label', 'labels')

print("Trainer expects 'labels' (plural) as the target column!")

### Mistake 2: Learning Rate Too High

In [None]:
# WRONG: LR too high for fine-tuning
# learning_rate=1e-3  # Will destroy pre-trained knowledge!

# CORRECT: Use small LR for fine-tuning
# learning_rate=2e-5  # Gentle updates to pre-trained weights

print("Fine-tuning LR guide:")
print("  1e-5 to 5e-5: Safe range for most tasks")
print("  1e-4: Can work with warmup")
print("  1e-3+: Too high, will forget pre-training!")

### Mistake 3: Not Using eval_strategy

In [None]:
# WRONG: No evaluation during training
# TrainingArguments(eval_strategy="no")  # Can't track progress!

# CORRECT: Evaluate regularly
# TrainingArguments(
#     eval_strategy="epoch",  # or "steps" with eval_steps=500
#     load_best_model_at_end=True,
#     metric_for_best_model="accuracy"
# )

print("Always enable evaluation to track training progress!")

---

## Checkpoint

You've learned:
- ✅ How to configure TrainingArguments for optimal training
- ✅ How to use the Trainer API for fine-tuning
- ✅ How to implement custom metrics
- ✅ How to use callbacks for early stopping
- ✅ How to evaluate and compare models
- ✅ How to save and load fine-tuned models

---

## Challenge: Multi-class Emotion Detection

Train an emotion classifier using the `emotion` dataset:
1. Load and explore the dataset (6 emotions)
2. Fine-tune with custom metrics including per-class F1
3. Implement a confusion matrix callback
4. Achieve >90% accuracy

In [None]:
# YOUR CHALLENGE CODE HERE
# emotion dataset has: sadness, joy, love, anger, fear, surprise




---

## Further Reading

- [Trainer Documentation](https://huggingface.co/docs/transformers/main_classes/trainer)
- [TrainingArguments Reference](https://huggingface.co/docs/transformers/main_classes/trainer#trainingarguments)
- [Fine-tuning Guide](https://huggingface.co/docs/transformers/training)
- [Evaluate Library](https://huggingface.co/docs/evaluate)

---

## Cleanup

In [None]:
# Clean up
import shutil
import gc

# Remove saved models and results
for path in ["./results", "./logs", "./saved_models"]:
    if os.path.exists(path):
        shutil.rmtree(path)
        print(f"Removed {path}")

# Clear memory
del model, trainer
gc.collect()
torch.cuda.empty_cache()

print("\nCleanup complete!")

---

## Next Steps

In the next notebook, **05-lora-introduction.ipynb**, we'll learn about Parameter-Efficient Fine-Tuning (PEFT) with LoRA - how to fine-tune large models using just a fraction of the parameters!

Great job completing Lab 2.4.4! You now know how to fine-tune transformer models like a pro!