# Lesson 4: MLflow with PyTorch

**Module 2: Reproducibility & Versioning**  
**Estimated Time**: 2-3 hours  
**Difficulty**: Intermediate

---

## ðŸŽ¯ Learning Objectives

By the end of this lesson, you will:

âœ… Learn how to track Deep Learning training loops  
âœ… Implement `mlflow.pytorch.autolog()` vs manual logging  
âœ… Track loss curves over epochs  
âœ… Save and load PyTorch models with MLflow  
âœ… Answer interview questions on DL experiment tracking  

---

## ðŸ“š Table of Contents

1. [Deep Learning Tracking Challenges](#1-dl-challenges)
2. [The Easy Way: Autologging](#2-autolog)
3. [The Custom Way: Manual Training Loop](#3-manual-loop)
4. [Hands-On: PyTorch MNIST Example](#4-hands-on)
5. [Interview Preparation](#5-interview-questions)

---

## 1. Deep Learning Tracking Challenges

Comparing sklearn (Lesson 3) vs PyTorch:

**Sklearn**:
- 1 `fit()` call.
- Metrics are usually calculated once at the end (Accuracy, RMSE).

**PyTorch / Deep Learning**:
- Iterative `for` loops (Epochs, Batches).
- Metrics change constantly (Loss decreases, Accuracy increases).
- Need to visualize **curves**, not just final numbers.
 
**Goal**: We want to see a chart of Training Loss vs Validation Loss in MLflow.

## 2. The Easy Way: Autologging

MLflow supports automatic logging for PyTorch Lightning/Ignite (and raw PyTorch to some extent).

```python
import mlflow.pytorch

mlflow.pytorch.autolog()

# ... Training code ...
```

**Pros**: Zero code changes. captures common metrics.
**Cons**: Less control over exactly what/when to log. Not perfect for raw PyTorch loops.

## 3. The Custom Way: Manual Training Loop

This is the standard for raw PyTorch. You inject `mlflow.log_metric()` inside your loop.

```python
with mlflow.start_run():
    for epoch in range(epochs):
        loss = train_one_epoch()
        # The key: pass the 'step' argument!
        mlflow.log_metric("train_loss", loss, step=epoch)
```

Using `step=epoch` allows MLflow UI to draw the line chart.

## 4. Hands-On: PyTorch MNIST Example

Let's simulate a simplified training loop.

In [None]:
import mlflow
import torch
import torch.nn as nn
import torch.optim as optim

# 1. Define a Simple Model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.fc(x)

# 2. Training Function with MLflow
def train_model(lr=0.01, epochs=5):
    mlflow.set_experiment("pytorch_demo")
    
    with mlflow.start_run():
        # Log Params
        mlflow.log_param("lr", lr)
        mlflow.log_param("epochs", epochs)
        
        model = SimpleNet()
        optimizer = optim.SGD(model.parameters(), lr=lr)
        criterion = nn.MSELoss()
        
        # Dummy Data
        inputs = torch.randn(5, 10)
        targets = torch.randn(5, 1)
        
        print(f"Starting training with lr={lr}...")
        for epoch in range(epochs):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Log Metric PER EPOCH
            current_loss = loss.item()
            mlflow.log_metric("loss", current_loss, step=epoch)
            print(f"Epoch {epoch}: Loss={current_loss:.4f}")
        
        # Log Final Model
        # MLflow handles saving the state_dict and wrapping it
        mlflow.pytorch.log_model(model, "model")
        print("Run Complete. Model saved to MLflow.")

# 3. Run It
train_model(lr=0.1, epochs=10)
train_model(lr=0.01, epochs=10)

## 5. Interview Preparation

### Common Questions

#### Q1: "How do you visualize overfitting in MLflow?"
**Answer**: I log both `train_loss` and `val_loss` metrics at each epoch. In the MLflow UI, I can plot both lines on the same chart. If `train_loss` keeps going down while `val_loss` starts going up, that divergence indicates overfitting.

#### Q2: "When saving a PyTorch model in MLflow, what actually gets saved?"
**Answer**: `mlflow.pytorch.log_model` typically saves the model using `torch.save` (pickle format) or `state_dict`. It also saves a `MLmodel` configuration file and a `conda.yaml` defining the environment (python version, torch version) needed to run it. This ensures reproducibility.

#### Q3: "Can I log images (like generated GAN outputs) to MLflow?"
**Answer**: Yes, using `mlflow.log_artifact()`. For example, at the end of every 10 epochs, I can save a generated image to a local file `output_epoch_10.png` and then log it. It will appear in the artifact viewer in the UI.