# Day 3: Training Pipeline & Loss Functions

**Goals:**
- Test and verify your loss function implementations
- Audit your training loop for best practices
- Implement training diagnostics

**Time:** 6 hours

**Approach:** Instructions only. Write all code yourself.

---

## Setup

Import PyTorch, torch.nn, torch.nn.functional (as F), numpy, and matplotlib. Also import your loss modules and trainer from `src/`.

In [None]:
# Your imports


---

# Part 1: Theory Questions

---

## Q3.1: SSIM Implementation Details

The SSIM formula includes stability constants C1 and C2:

```
SSIM = [(2μ_xμ_y + C1)(2σ_xy + C2)] / [(μ_x² + μ_y² + C1)(σ_x² + σ_y² + C2)]
```

**a)** What numerical problem do C1 and C2 prevent? What would happen if you set them to zero?

**b)** SSIM uses a Gaussian-weighted window for computing local statistics. Why Gaussian instead of a uniform (box) filter?

**c)** What would happen if you set window_size=1? What would SSIM measure then?

### Your Answer:

**a)**

**b)**

**c)**


## Q3.2: Loss Balancing

Your combined loss is: L = α × MSE + β × SSIM_loss

Suppose during training you observe:
- MSE ≈ 0.01
- SSIM_loss (which is 1 - SSIM) ≈ 0.15

You're using α = 1.0 and β = 0.1.

**a)** What is the total loss?

**b)** What fraction of the total loss comes from MSE vs SSIM? Which term dominates gradient updates?

**c)** If you wanted SSIM to have equal influence as MSE, how would you adjust the weights?

### Your Answer:

**a)**

**b)**

**c)**


## Q3.3: Optimizer Choice

**a)** Why is Adam generally preferred over vanilla SGD for training autoencoders?

**b)** Adam has parameters (β1, β2). The default β1 is 0.9. What does this parameter control? (Hint: it relates to momentum of the first moment estimate.)

**c)** When would you use AdamW instead of Adam? What's the difference?

### Your Answer:

**a)**

**b)**

**c)**


## Q3.4: Training Dynamics

You're training your autoencoder and observe:
- Training loss decreases steadily over 50 epochs
- Validation loss decreases until epoch 30, then starts increasing

**a)** What phenomenon is occurring after epoch 30?

**b)** What strategies could you employ to address this? List at least 3.

### Your Answer:

**a)**

**b)**
1. 
2. 
3. 


## Q3.5: Gradient Clipping

**a)** Why might you clip gradients to max_norm=1.0 during training?

**b)** What training symptoms might indicate you need gradient clipping? (Think about loss curves and what happens when gradients explode.)

**c)** What's the potential downside of clipping gradients too aggressively (e.g., max_norm=0.01)?

### Your Answer:

**a)**

**b)**

**c)**


---

# Part 2: Loss Function Testing (1.5 hours)

---

## Exercise 3.1: Create Test Data

**Your task:**

Create three test tensors:
1. `img1`: Random tensor of shape (4, 1, 64, 64) with values in [0, 1]
2. `img2`: A noisy version of img1 (add Gaussian noise with std=0.1, then clamp to [0, 1])
3. `img1_copy`: An exact clone of img1

These will be used to test all loss functions.

In [None]:
# Create test tensors


## Exercise 3.2: Test MSE Loss

**Your task:**

1. Compute MSE between img1 and img1_copy using PyTorch's F.mse_loss. This should be essentially 0.

2. Compute MSE between img1 and img2. This should be positive (around 0.01 for std=0.1 noise).

3. Now test YOUR MSE loss implementation from `src/losses/mse.py`. Instantiate your loss class and compute the same two comparisons.

4. Verify that your implementation matches PyTorch's F.mse_loss.

In [None]:
# Test MSE loss


## Exercise 3.3: Implement Reference SSIM

Before testing your SSIM, implement a reference version to understand the algorithm.

**Your task:**

Write a function `ssim_reference(img1, img2, window_size=11)` that:

1. Creates a 1D Gaussian kernel with the given window_size and σ=1.5:
   ```
   coords = torch.arange(window_size) - window_size // 2
   g = torch.exp(-coords² / (2 × 1.5²))
   g = g / g.sum()
   ```

2. Creates a 2D window by outer product: `window = g.unsqueeze(0) @ g.unsqueeze(1)`

3. Reshapes window to (1, 1, window_size, window_size) for use with F.conv2d.

4. Computes local means μ1, μ2 by convolving with the window.

5. Computes local variances σ1², σ2² and covariance σ12 using:
   - σ1² = conv(img1²) - μ1²
   - σ12 = conv(img1 × img2) - μ1 × μ2

6. Applies SSIM formula with C1=(0.01)² and C2=(0.03)².

7. Returns the mean SSIM over all pixels.

In [None]:
# Implement reference SSIM


## Exercise 3.4: Test SSIM Loss

**Your task:**

1. Use your reference SSIM to compute:
   - SSIM(img1, img1_copy) - should be ~1.0
   - SSIM(img1, img2) - should be less than 1.0

2. Test YOUR SSIM loss implementation from `src/losses/ssim.py`.

3. Note: Your loss probably returns (1 - SSIM) since we minimize loss. Verify this convention.

4. If possible, compare against scikit-image's structural_similarity function to validate.

In [None]:
# Test SSIM loss


## Exercise 3.5: Test Combined Loss

**Your task:**

1. Import or instantiate your combined loss from `src/losses/combined.py`.

2. Compute the combined loss between img1 and img2.

3. Manually compute what the loss should be by combining MSE and SSIM with the weights your implementation uses.

4. Verify they match.

In [None]:
# Test combined loss


## Exercise 3.6: Test Gradient Flow Through Loss

**Your task:**

1. Create img1 with `requires_grad=True`.

2. Compute your combined loss between img1 and img2 (which doesn't need gradients).

3. Call `loss.backward()`.

4. Verify that `img1.grad` is not None.

5. Check that the gradient has the same shape as img1.

6. Check that the gradient contains no NaN or Inf values.

In [None]:
# Test gradient flow


---

# Part 3: Trainer Audit (1.5 hours)

---

## Exercise 3.7: Trainer Checklist

Open your `src/training/trainer.py` file and carefully read through it.

**Your task:**

Check each item and note whether it's present in your implementation. If something is missing, you'll implement it later.

### Training Loop Essentials
- [ ] `model.train()` called before training epoch
- [ ] `optimizer.zero_grad()` called before each forward pass
- [ ] `loss.backward()` called after computing loss
- [ ] `optimizer.step()` called after backward
- [ ] Gradient clipping with `torch.nn.utils.clip_grad_norm_`

### Validation
- [ ] `model.eval()` called before validation
- [ ] `torch.no_grad()` context used during validation
- [ ] No `optimizer.step()` during validation

### Logging
- [ ] Training loss recorded per epoch
- [ ] Validation loss recorded per epoch
- [ ] Learning rate tracked
- [ ] Progress indication (print statements or progress bar)

### Checkpointing
- [ ] Best model saved when validation loss improves
- [ ] Checkpoint includes model state dict
- [ ] Checkpoint includes optimizer state dict
- [ ] Checkpoint includes epoch number
- [ ] Checkpoint includes best loss value

### Scheduling & Stopping
- [ ] Learning rate scheduler implemented
- [ ] Scheduler.step() called at appropriate time
- [ ] Early stopping implemented

### Your Audit Results:

**Present:**

**Missing:**


## Exercise 3.8: Implement Training Diagnostics

Before running a full training, it's wise to run diagnostics to catch problems early.

**Your task:**

Write a function `diagnose_training(model, dataloader, loss_fn, device)` that:

1. **Input Statistics**: Get one batch, print shape, min, max, mean.

2. **Output Statistics**: Pass batch through model, print output range and mean.

3. **Gradient Statistics**: 
   - Compute loss on one batch
   - Call backward()
   - For each parameter, compute gradient norm
   - Print total gradient norm (sqrt of sum of squared norms)
   - Flag any layers with zero gradient

4. **Overfit Test**:
   - Try to overfit a single batch for 100 iterations
   - Use a fresh optimizer with lr=1e-3
   - Record loss at iteration 0 and 100
   - If loss doesn't decrease significantly (>50%), something is wrong

Return a dict with all the diagnostic info.

In [None]:
# Implement diagnose_training function


## Exercise 3.9: Test Diagnostics

**Your task:**

1. Create a simple DataLoader with synthetic data (or use your actual data if available):
   - Create a TensorDataset with random tensors of shape (N, 1, 256, 256)
   - Wrap in DataLoader with batch_size=4

2. Instantiate your model and loss function.

3. Run your diagnostic function and interpret the results.

4. Did the overfit test pass? If not, investigate why.

In [None]:
# Test diagnostics on your model


---

# Part 4: Implement Missing Training Features (1.5 hours)

Based on your audit, implement any missing features.

---

## Exercise 3.10: Implement Gradient Clipping

If your trainer doesn't have gradient clipping:

**Your task:**

Show where in the training loop you would add:
```python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
```

Write a small test demonstrating that it works:
1. Create a tensor with requires_grad=True
2. Create a "loss" that would produce large gradients (e.g., `loss = 1000 * tensor.sum()`)
3. Call backward
4. Print gradient norm before clipping
5. Apply clipping
6. Print gradient norm after clipping

In [None]:
# Demonstrate gradient clipping


## Exercise 3.11: Implement Learning Rate Scheduler

If your trainer doesn't have LR scheduling:

**Your task:**

1. Choose an appropriate scheduler. For autoencoders, `ReduceLROnPlateau` works well:
   - Reduces LR when validation loss plateaus
   - Patience of 5-10 epochs is typical

2. Show how to create the scheduler:
   ```python
   scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
       optimizer, mode='min', factor=0.5, patience=10
   )
   ```

3. Show where in the training loop to call `scheduler.step(val_loss)`.

4. Write code to track and print the current learning rate.

In [None]:
# Demonstrate LR scheduler


## Exercise 3.12: Implement Early Stopping

If your trainer doesn't have early stopping:

**Your task:**

Implement an EarlyStopping class that:

1. Tracks the best validation loss seen so far.

2. Counts epochs since last improvement.

3. Has a `patience` parameter (e.g., 20 epochs).

4. Has a `__call__(val_loss)` method that:
   - Updates best loss if this is an improvement
   - Resets counter on improvement
   - Increments counter if no improvement
   - Returns True if training should stop (counter >= patience)

5. Optionally: saves the best model when improvement happens.

In [None]:
# Implement EarlyStopping class


## Exercise 3.13: Implement Checkpointing

If your trainer doesn't save checkpoints properly:

**Your task:**

Write functions for saving and loading checkpoints:

```python
def save_checkpoint(path, model, optimizer, epoch, val_loss, config=None):
    # Save dict with all necessary info
    pass

def load_checkpoint(path, model, optimizer=None):
    # Load and restore state
    # Return epoch and val_loss so training can resume
    pass
```

Test by saving a checkpoint, creating fresh model/optimizer instances, and loading.

In [None]:
# Implement checkpointing functions


In [None]:
# Test save and load


---

# Part 5: Short Training Test

---

## Exercise 3.14: Run a Short Training

**Your task:**

Run a short training (5-10 epochs) to verify everything works together:

1. Load or create your training and validation data.

2. Instantiate model, optimizer, loss function, scheduler.

3. Run training for 5-10 epochs.

4. Plot training and validation loss curves.

5. Verify:
   - Losses decrease over time
   - No NaN or Inf losses
   - Validation loss tracks training loss reasonably

If you encounter issues, debug them before Day 4.

In [None]:
# Short training test


In [None]:
# Plot loss curves


---

# Day 3 Checklist

- [ ] Answered all theory questions (Q3.1 - Q3.5)
- [ ] MSE loss tested and verified
- [ ] SSIM reference implementation written
- [ ] SSIM loss tested and verified
- [ ] Combined loss tested
- [ ] Gradient flow through loss verified
- [ ] Trainer audited against checklist
- [ ] Training diagnostics implemented and run
- [ ] Gradient clipping implemented/verified
- [ ] Learning rate scheduler implemented/verified
- [ ] Early stopping implemented/verified
- [ ] Checkpointing implemented/verified
- [ ] Short training test completed successfully

---

## Training Configuration Summary

*Fill in your current configuration:*

**Loss function:** MSE + SSIM (weights: α=__, β=__)

**Optimizer:** Adam (lr=__, weight_decay=__)

**Scheduler:** ReduceLROnPlateau (patience=__, factor=__)

**Gradient clipping:** max_norm=__

**Early stopping:** patience=__

---

## Notes and Issues

1. 

2. 

3. 