# Lesson 7: W&B with PyTorch

**Module 2: Reproducibility & Versioning**  
**Estimated Time**: 2-3 hours  
**Difficulty**: Intermediate

---

## ðŸŽ¯ Learning Objectives

By the end of this lesson, you will:

âœ… Track PyTorch gradients and weights with `wandb.watch`  
âœ… Log custom media (Images) to debug predictions  
âœ… Configure Hyperparameters properly with `wandb.config`  
âœ… Answer interview questions on debugging Neural Networks  

---

## ðŸ“š Table of Contents

1. [Why Watch Gradients?](#1-gradients)
2. [The `wandb.watch` Magic](#2-wandb-watch)
3. [Logging Predictions as Images](#3-logging-images)
4. [Hands-On: PyTorch Setup](#4-hands-on)
5. [Interview Preparation](#5-interview-questions)

---

## 1. Why Watch Gradients?

In Deep Learning, **Vanishing Gradients** or **Exploding Gradients** are common silent failures. Your loss curve might look "okay" (flat), but your model isn't learning because gradients are zero.

Visualizing the distribution of gradients across layers helps you catch this early.

## 2. The `wandb.watch` Magic

```python
wandb.watch(model, log="all", log_freq=10)
```

This one line hooks into your model and logs histograms of:
- Weights (are they too large?)
- Gradients (are they zero?)

You call this **before** the training loop.

## 3. Logging Predictions as Images

Numbers (Accuracy) don't tell the whole story. You want to see **where** the model failed.

```python
wandb.log({"examples": [wandb.Image(x, caption=f"Pred: {y_pred}")]})
```

This creates a gallery in the dashboard.

## 4. Hands-On: PyTorch Setup

Simulating a CNN training loop.

In [None]:
import wandb
import torch
import torch.nn as nn
import torch.optim as optim

print("Ensure 'wandb login' is done.")

# 1. Define Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.fc = nn.Linear(5760, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 2. Training Code
def train():
    # Standard Config approach
    config = {
        "batch_size": 64,
        "learning_rate": 0.01,
        "epochs": 5
    }
    
    with wandb.init(project="pytorch-demo", config=config):
        config = wandb.config
        
        model = SimpleCNN()
        optimizer = optim.SGD(model.parameters(), lr=config.learning_rate)
        criterion = nn.CrossEntropyLoss()
        
        # MAGIC LINE: Watch gradients
        wandb.watch(model, log="all", log_freq=1)
        
        # Fake Data (28x28 images)
        inputs = torch.randn(64, 1, 28, 28)
        labels = torch.randint(0, 10, (64,))
        
        print("Starting training...")
        for epoch in range(config.epochs):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            # Log metrics
            wandb.log({"loss": loss.item(), "epoch": epoch})
            
            # Log sample image every epoch
            if epoch % 2 == 0:
                # Log the first image of the batch as an example
                img_tensor = inputs[0]
                wandb.log({"example_img": wandb.Image(img_tensor, caption=f"Epoch {epoch}")})

    print("Finished. Check dashboard for Gradients and Images!")

train()

## 5. Interview Preparation

### Common Questions

#### Q1: "How do you debug a neural network that isn't learning?"
**Answer**: "I look at the gradient distribution. I use `wandb.watch` to visualize histograms of gradients for each layer. If the histograms are concentrated at zero (vanishing) or extremely wide (exploding), I know I need to adjust initialization, activation functions, or learning rate."

#### Q2: "How to you monitor data quality during training?"
**Answer**: "I log sample predictions using `wandb.log({'image': ...})`. By visually inspecting what the model thinks is a 'cat' vs 'dog' during training, I might discover that my preprocessing pipeline is corrupting images (e.g., wrong normalization colors)."