
# PyTorch CNN + MLflow (MNIST) — End‑to‑End Demo

This notebook trains a simple CNN on MNIST, logs metrics and the model to **MLflow** (with an **input/output signature**), and then loads the logged model back via `mlflow.pyfunc` for inference.

> **Tracking URI:** `http://127.0.0.1:8080` (adjust if your MLflow server runs elsewhere).  
> If your MLflow **Model Registry** is not configured, the notebook will **gracefully fall back** to logging without registering the model.


In [None]:

# %% [markdown]
# ## 0) Install (if needed)
# If you're running in a fresh environment, uncomment and run:
# !pip install torch torchvision mlflow numpy


In [None]:

# %%
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import mlflow
import mlflow.pytorch
from mlflow.models import infer_signature
import numpy as np

# Reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Point to your MLflow Tracking Server
mlflow.set_tracking_uri("http://127.0.0.1:8080")
print("MLflow Tracking URI ->", mlflow.get_tracking_uri())

# (Optional) Choose/declare an experiment name for clarity
mlflow.set_experiment("MNIST_CNN_MLflow_Demo")


## 1) Define a simple CNN model

In [None]:

# %%
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # (N, 32, 14, 14)
        x = self.pool(self.relu(self.conv2(x)))  # (N, 64, 7, 7)
        x = x.view(-1, 64 * 7 * 7)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


## 2) Data preparation (MNIST)

In [None]:

# %%
def prepare_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )
    
    test_dataset = datasets.MNIST(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader


## 3) Training & Evaluation helpers

In [None]:

# %%
def train_model(model, train_loader, criterion, optimizer, device, epochs=2, log_every=100):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if (i + 1) % log_every == 0:
                avg_loss = running_loss / log_every
                print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}], Loss: {avg_loss:.4f}')
                # Log iterative loss to MLflow
                mlflow.log_metric("train_loss", avg_loss, step=epoch * len(train_loader) + i + 1)
                running_loss = 0.0

def evaluate_model(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100.0 * correct / total
    return accuracy


## 4) Train, Log to MLflow (with Signature), and Reload

In [None]:

# %%
def main():
    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Data
    print("Preparing data...")
    train_loader, test_loader = prepare_data(batch_size=64)
    
    with mlflow.start_run(run_name="pytorch_image_classification"):
        # Hyperparameters
        lr = 1e-3
        epochs = 2
        num_classes = 10
        
        # Log parameters
        mlflow.log_param("learning_rate", lr)
        mlflow.log_param("epochs", epochs)
        mlflow.log_param("model_type", "SimpleCNN")
        mlflow.log_param("optimizer", "Adam")
        
        # Model / loss / optimizer
        model = SimpleCNN(num_classes=num_classes).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        
        # Train
        print("Training model...")
        train_model(model, train_loader, criterion, optimizer, device, epochs=epochs, log_every=100)
        
        # Evaluate
        print("Evaluating model...")
        accuracy = evaluate_model(model, test_loader, device)
        print(f"Test Accuracy: {accuracy:.2f}%")
        mlflow.log_metric("test_accuracy", accuracy)
        
        # Signature: create a single-sample input and corresponding output
        sample_batch = next(iter(test_loader))
        sample_input = sample_batch[0][:1].numpy()  # shape (1, 1, 28, 28)
        
        model.eval()
        with torch.no_grad():
            sample_output = model(torch.tensor(sample_input).to(device))
            sample_output = sample_output.cpu().numpy()  # shape (1, 10)
        
        signature = infer_signature(sample_input, sample_output)
        
        # Try to register; if registry not configured, fall back to plain log
        print("Logging model to MLflow...")
        model_info = None
        try:
            model_info = mlflow.pytorch.log_model(
                pytorch_model=model,
                artifact_path="model",
                signature=signature,
                input_example=sample_input,
                registered_model_name="pytorch_mnist_classifier"
            )
            print("✔ Model logged and registered as 'pytorch_mnist_classifier'.")
        except Exception as e:
            print("⚠ Registry not available or other issue:", e)
            print("→ Falling back to logging without registration...")
            model_info = mlflow.pytorch.log_model(
                pytorch_model=model,
                artifact_path="model",
                signature=signature,
                input_example=sample_input
            )
            print("✔ Model logged (unregistered).")
        
        run_id = mlflow.active_run().info.run_id
        print(f"Run ID: {run_id}")
        print(f"Model URI: {model_info.model_uri}")
        
        # Test loading via PyFunc
        print("\n" + "="*50)
        print("Testing model loading...")
        print("="*50)
        
        loaded_model = mlflow.pyfunc.load_model(model_info.model_uri)
        
        test_input = sample_input
        prediction = loaded_model.predict(test_input)  # (1, 10)
        
        print(f"Test input shape: {test_input.shape}")
        print(f"Prediction shape: {prediction.shape}")
        print(f"Predicted class: {int(np.argmax(prediction[0]))}")
        print(f"Prediction logits (first 5): {prediction[0][:5]}")
        
        print("\n✓ Model successfully logged and loaded!")
        print(f"✓ MLflow UI: {mlflow.get_tracking_uri()}")

if __name__ == "__main__":
    main()



## 5) Tips & Troubleshooting

- **Model Registry not configured?**  
  If your MLflow server does not have a backend store/registry set up, the registration step will fail; this notebook automatically falls back to logging the model without registering it.

- **Can't reach MLflow at `http://127.0.0.1:8080`?**  
  Make sure your MLflow server is running, e.g.:
  ```bash
  pkill -f "mlflow server" || true
  mlflow server --host 127.0.0.1 --port 8080 --backend-store-uri sqlite:///mlflow.db --default-artifact-root ./mlruns
  ```

- **GPU not available?**  
  The code will automatically run on CPU. Training MNIST still finishes quickly.

- **Dataset download blocked?**  
  Ensure your environment has internet access to download MNIST on first run. Subsequent runs will use cached data in `./data`.
