# 🏋️ Model Training and Evaluation

## 📚 Overview

This is where the magic happens! You'll implement the **complete training loop** from scratch and learn:
- How neural networks actually learn (forward pass, loss, backward pass, optimization)
- Training vs. validation
- Monitoring training progress
- Evaluating model performance
- Preventing overfitting

## 🎯 Learning Objectives

By completing this notebook, you will:
1. **Implement a training loop** - the core of deep learning
2. **Understand backpropagation** - how gradients update weights
3. **Monitor training metrics** - loss, accuracy, validation performance
4. **Implement early stopping** - prevent overfitting
5. **Evaluate on validation set** - assess generalization
6. **Calculate classification metrics** - accuracy, precision, recall, F1
7. **Visualize training progress** - loss curves, accuracy plots

## 📋 Prerequisites

Before starting, ensure you've completed:
- ✅ `00_exploration.ipynb` - Data exploration
- ✅ `01_preprocessing.ipynb` - Text preprocessing
- ✅ `02_vocab_and_dataloader.ipynb` - Vocabulary and DataLoader
- ✅ `03_model_baseline.ipynb` - Model architecture

You should have:
- Model class defined
- DataLoaders ready (train + validation)
- Loss function and optimizer configured

---

## 🔄 The Training Loop Explained

### The Core Cycle:

```
For each epoch:
    For each batch in training data:
        1. Forward Pass: predictions = model(inputs)
        2. Calculate Loss: loss = loss_fn(predictions, targets)
        3. Backward Pass: loss.backward()  # Compute gradients
        4. Optimizer Step: optimizer.step()  # Update weights
        5. Zero Gradients: optimizer.zero_grad()  # Reset for next batch
        
    Validate on validation set
    Save best model
    Check for early stopping
```

### Key Concepts:
- **Epoch**: One complete pass through the training data
- **Batch**: A subset of data (e.g., 32 samples)
- **Forward Pass**: Input → Model → Output
- **Loss**: How wrong are the predictions?
- **Backward Pass**: Calculate gradients (how to improve)
- **Optimizer Step**: Update model parameters using gradients

---

## 🚀 Let's Train!


## TODO 1: Setup - Import Everything from Previous Notebooks 📦

**Goal**: Load all necessary components from your previous work.

**What you need**:
1. **Libraries**: torch, pandas, numpy, matplotlib, sklearn metrics
2. **Model**: Your DisasterTweetClassifier class
3. **Data**: Vocabulary, DataLoaders (train and validation)
4. **Config**: Hyperparameters, device, random seed

**Creating Train/Val Split**:
```python
# Load cleaned data
df = pd.read_csv('../data/interim/train_cleaned.csv')

# Split data (e.g., 80/20 or 90/10)
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['target'])
```

**Important**: 
- Use `stratify=df['target']` to maintain class balance
- Create separate DataLoaders for train and validation
- Set `shuffle=True` for training, `shuffle=False` for validation

**Expected outcome**:
- Model, loss_fn, optimizer all initialized
- train_loader and val_loader ready
- Device set (CPU or GPU)


In [None]:
# TODO 1: Your code here
# Setup and imports

# Standard libraries


# PyTorch


# Metrics and visualization


# Copy/import your model class from notebook 03


# Load vocabulary (from notebook 02)


# Create train/val split


# Create DataLoaders


# Initialize model, loss, optimizer


# Print summary


## TODO 2: Implement Training Function for One Epoch 🔄

**Goal**: Write a function that trains the model for one complete epoch.

**Function Structure**:
```python
def train_one_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()  # Set to training mode
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in dataloader:
        # Get data
        # Move to device
        # Forward pass
        # Calculate loss
        # Backward pass
        # Optimizer step
        # Track metrics
    
    return avg_loss, accuracy
```

**Key Steps in Loop**:
1. `model.train()` - Enable dropout, batch norm training mode
2. Get texts and labels from batch
3. Move tensors to device
4. Zero gradients: `optimizer.zero_grad()`
5. Forward pass: `outputs = model(texts)`
6. Calculate loss: `loss = loss_fn(outputs, labels)`
7. Backward pass: `loss.backward()`
8. Update weights: `optimizer.step()`
9. Track loss and accuracy

**Accuracy Calculation**:
```python
# Convert logits to predictions (0 or 1)
predictions = (torch.sigmoid(outputs) > 0.5).float()
correct += (predictions == labels).sum().item()
```

**Hints**:
- Accumulate loss: `total_loss += loss.item()`
- Track number of samples: `total += labels.size(0)`
- Return average loss: `total_loss / len(dataloader)`
- Return accuracy: `correct / total`


In [None]:
# TODO 2: Your code here
def train_one_epoch(model, dataloader, loss_fn, optimizer, device):
    """Train model for one epoch."""
    # TODO: Set model to training mode
    pass
    
    # TODO: Initialize tracking variables
    
    # TODO: Loop through batches
    # for texts, labels in dataloader:
        # TODO: Move to device
        # TODO: Zero gradients
        # TODO: Forward pass
        # TODO: Calculate loss
        # TODO: Backward pass
        # TODO: Optimizer step
        # TODO: Track metrics
        
    # TODO: Calculate and return average metrics
    # return avg_loss, accuracy


## TODO 3: Implement Validation Function 📊

**Goal**: Evaluate model on validation set (no gradient updates).

**Key Difference from Training**:
- `model.eval()` - Disable dropout, use batch norm in eval mode
- `torch.no_grad()` - Don't track gradients (saves memory, faster)
- No `backward()` or `optimizer.step()`

**Function Structure**:
```python
def validate(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():  # Don't track gradients
        for texts, labels in dataloader:
            # Move to device
            # Forward pass only
            # Calculate loss and accuracy
            
    return avg_loss, accuracy
```

**Why validation?**:
- Check if model generalizes to unseen data
- Detect overfitting (train acc high, val acc low)
- Decide when to stop training

---

## TODO 4: Implement Complete Training Loop 🔁

**Goal**: Train for multiple epochs with validation and progress tracking.

**Structure**:
```python
def train_model(model, train_loader, val_loader, loss_fn, optimizer, 
                num_epochs, device):
    history = {'train_loss': [], 'train_acc': [], 
               'val_loss': [], 'val_acc': []}
    
    for epoch in range(num_epochs):
        # Train for one epoch
        train_loss, train_acc = train_one_epoch(...)
        
        # Validate
        val_loss, val_acc = validate(...)
        
        # Store metrics
        history['train_loss'].append(train_loss)
        # ... store others
        
        # Print progress
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}')
        print(f'  Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}')
        
        # Optional: Save best model, early stopping
        
    return history
```

**Number of Epochs**: Start with 10-20, adjust based on convergence

---

## TODO 5: Train Your Model 🚀

Execute the training! Monitor for:
- **Loss decreasing** - Model is learning
- **Accuracy increasing** - Predictions improving  
- **Val metrics similar to train** - Good generalization
- **Val metrics much worse than train** - Overfitting! Stop or regularize

---

## TODO 6: Evaluate and Calculate Metrics 📈

Calculate detailed classification metrics:
- **Accuracy**: Overall correctness
- **Precision**: Of predicted disasters, how many are real?
- **Recall**: Of real disasters, how many did we catch?
- **F1-Score**: Harmonic mean of precision and recall (competition metric!)
- **Confusion Matrix**: Visualize true/false positives/negatives

```python
from sklearn.metrics import classification_report, confusion_matrix
# Get predictions on validation set
# Calculate metrics
```

---

## TODO 7: Visualize Training Progress 📊

Create plots to understand training:
1. **Loss curves**: Train vs. Val loss over epochs
2. **Accuracy curves**: Train vs. Val accuracy over epochs
3. **Confusion matrix**: Heatmap of predictions

---

## TODO 8: Save Your Trained Model 💾

Save the model for inference and submission:
```python
torch.save(model.state_dict(), '../models/disaster_classifier.pth')
torch.save(history, '../models/training_history.pkl')
```
