# Assignment 1: Loss Curve Visualization

In this assignment, you'll learn to visualize and interpret loss curves during neural network training.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Import all functions from loss_visualization module
from loss_visualization import (
    SimpleNet,
    generate_synthetic_data,
    create_data_loaders,
    train_with_loss_tracking,
    plot_loss_curves,
    compare_learning_rates,
    identify_training_issues
)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Step 1: Generate Synthetic Data

Create a simple regression dataset.

In [None]:
# Generate synthetic data using the function from loss_visualization.py
n_samples = 1000
n_features = 10

X, y = generate_synthetic_data(n_samples=n_samples, n_features=n_features, noise=0.1)

print(f"Data shape: X={X.shape}, y={y.shape}")

## Step 2: Create Train/Val Split and DataLoaders

In [None]:
# Create train/val data loaders using the function from loss_visualization.py
train_loader, val_loader = create_data_loaders(X, y, train_ratio=0.8, batch_size=32, shuffle=True)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

## Step 3: Define a Simple Model

In [None]:
# Create model using SimpleNet class from loss_visualization.py
model = SimpleNet(input_dim=n_features, hidden_dim=64, output_dim=1)
print(model)

## Step 4: Training Loop with Loss Tracking

In [None]:
# Train model using the training function from loss_visualization.py
import torch.nn as nn
import torch.optim as optim

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

num_epochs = 100
history = train_with_loss_tracking(model, train_loader, val_loader, 
                                  criterion, optimizer, num_epochs=num_epochs)

## Step 5: Visualize Loss Curves

In [None]:
# Plot loss curves using the function from loss_visualization.py
plot_loss_curves(history, title="Training and Validation Loss")

## Step 6: Experiment with Different Learning Rates

In [None]:
# Compare different learning rates using the function from loss_visualization.py
learning_rates = [0.001, 0.01, 0.1, 1.0]

compare_learning_rates(
    learning_rates=learning_rates,
    model_fn=lambda: SimpleNet(n_features, 64, 1),
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=100
)

# Also analyze the training behavior
diagnosis = identify_training_issues(history)
print(f"\nTraining Diagnosis: {diagnosis}")

## Questions to Answer

---

### 1. What happens when the learning rate is too high?

**Symptoms:**
- **Loss diverges or explodes** – Loss values increase dramatically or become NaN (Not a Number)
- **Unstable training** – Loss oscillates wildly without converging
- **Overshooting** – The optimizer repeatedly jumps over the optimal solution
- **Failure to converge** – Training does not settle toward a minimum

**Why this happens:**
When the learning rate is too high, weight updates are excessively large. Imagine trying to find the bottom of a valley while taking giant steps—you keep jumping from one side to the other and never settle at the bottom.

**Visual pattern on loss curve:**

```
Loss
 ^
 |  /\    /\    /\
 | /  \  /  \  /  \
 |/    \/    \/    \
 +-------------------> Epochs
   Chaotic, oscillating
```


> Note: This behavior is especially pronounced with plain SGD; adaptive optimizers like Adam or RMSProp can tolerate slightly higher learning rates but will still diverge if the rate is too large.

---

### 2. What happens when the learning rate is too low?

**Symptoms:**
- **Very slow convergence** – Loss decreases extremely slowly
- **Training takes a very long time** – May require 10× or 100× more epochs
- **Gets stuck in plateaus or saddle points** – Updates are too small to escape flat regions of the loss surface
- **Apparent premature plateau** – Loss appears to stop improving

**Why this happens:**
With a learning rate that’s too low, weight updates are tiny. It’s like trying to cross a field by taking baby steps—you’ll eventually move forward, but progress is painfully slow.

**Visual pattern on loss curve:**

```
Loss
 ^
 |___
 |    \_____
 |          \________
 |                  \___
 +-------------------> Epochs
   Slow, gradual descent
```

---

### 3. How can you identify overfitting from the loss curves?

**Key indicators:**

**A. The Gap Pattern** (most common):
- Training loss continues to **decrease**
- Validation loss **stops decreasing** and starts **increasing** or plateauing
- Growing gap between training and validation loss

**Visual pattern:**
```
Loss
 ^
 |        Validation Loss
 |       /‾‾‾‾‾‾‾‾‾
 |      /
 |-----/-------------- <- Overfitting starts here
 |    /
 |   /  Training Loss
 |  /    \_____
 | /           \_____
 |/                  \_____
 +-------------------------> Epochs
```

**B. Other signs:**
- **Large generalization gap**: `val_loss >> train_loss` (e.g., train=0.1, val=0.8)
- **Validation loss stagnates or increases**: While training loss is still going down
- **Perfect training performance** with poor validation performance

**What's happening:**
The model is memorizing the training data instead of learning generalizable patterns. It's like a student who memorizes answers to practice problems but can't solve new ones.

---

### 4. What techniques could you use to prevent overfitting?

Here are the main techniques, ordered by how commonly they're used:

#### **A. Regularization Techniques**

**1. Dropout** (most popular)
```python
self.dropout = nn.Dropout(p=0.5)  # Drop 50% of neurons randomly
```
- Randomly "turns off" neurons during training
- Forces the network to learn robust features
- Applied only during training, not inference

**2. Weight Decay (L2 Regularization)**
```python
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
```
- Penalizes large weights
- Encourages simpler models

**3. L1 Regularization**
- Encourages sparse weights (many weights become exactly zero)
- More common in linear models; less used in deep networks due to optimization difficulty

#### **B. Data-Related Techniques**

**4. Get More Training Data**
- More data = better generalization
- Use data augmentation if you can't get more real data

**5. Data Augmentation**
```python
# For images: random flips, rotations, crops, color jittering
transforms.RandomHorizontalFlip()
transforms.RandomCrop(32, padding=4)
```
- Creates realistics variations of existing data
- Often the most effective regularization method in computer vision

#### **C. Model Architecture Techniques**

**6. Reduce Model Complexity**
- Use fewer layers or neurons
- Simpler models are less likely to overfit

**7. Early Stopping**
```python
# Stop training when validation loss stops improving
if val_loss > best_val_loss:
    patience_counter += 1
if patience_counter > max_patience:
    break  # Stop training
```
- Monitor validation loss and stop before overfitting gets worse

**8. Batch Normalization**
```python
self.bn1 = nn.BatchNorm1d(hidden_dim)
```
- Stabilizes and accelerates training
- Enables higher learning rates
- Only has a **mild** regularization effect

#### **D. Training Techniques**

**9. Cross-Validation**
- Use k-fold cross-validation to ensure model generalizes
- Helps detect if model only works on one specific validation set
- Common in classical ML and small datasets; Rarely used in large-scale deep learning due to computational cost

**10. Ensemble Methods**
- Train multiple models and average their predictions
- Reduces overfitting through diversity; Reduce variance and improve generalization

---

### **Quick Reference Summary**

| Problem | Symptoms | Solution |
|---------|----------|----------|
| **LR too high** | Loss explodes, wild oscillations | Decrease learning rate (try 10x smaller) |
| **LR too low** |  Slow training, plateau | Increase learning rate (try 10x larger) |
| **Overfitting** | Train loss ↓, Val loss ↑, large gap | Dropout, regularization, more data, augmentation, early stopping |
| **Underfitting** | Both losses high and not improving | Bigger model, better features, train longer, decrease regularization, adjust LR |
