# Monte Carlo Dropout Tutorial
Hands-on walk-through of uncertainty estimation using stochastic dropout at inference time.

## Goal
Learn how to wrap a standard neural network with Monte Carlo Dropout to obtain predictive uncertainty measures. We'll train a small multilayer perceptron on MNIST for only a handful of mini-batches to keep runtime short.

## Prerequisites
Make sure `torch`, `torchvision`, and `deepuq` are installed. Execute `pip install -e .` from the repository root to install the package in editable mode.
> Tip: Run cells sequentially; inline comments describe the reasoning behind each operation.

In [None]:
# Configure Python path so the notebook sees the local deepuq package
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd())
if not (PROJECT_ROOT / 'src').exists():
    PROJECT_ROOT = PROJECT_ROOT.parent

SRC_PATH = str(PROJECT_ROOT / 'src')
if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

In [None]:
# Core scientific stack
import torch
from torch import nn, optim
from torchvision import datasets, transforms

# deepuq provides reusable architectures and uncertainty wrappers
from deepuq.models import MLP
from deepuq.methods import MCDropoutWrapper

# Utility to keep experiments reproducible across runs
from deepuq.utils import set_seed

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Running on {DEVICE}')

## Data Pipeline
We'll use MNIST for familiarity. The transformation flattens each 28×28 image into a 784-dimensional vector so it fits the dense MLP.

In [None]:
set_seed(42)  # lock in deterministic initialisation and data order
transform = transforms.Compose([
    transforms.ToTensor(),  # converts PIL images to [0,1] tensors
    transforms.Lambda(lambda x: x.view(-1))  # flatten so the MLP can ingest the vector
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# For notebook speed, keep loaders light. Adjust batch sizes as your hardware allows.
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=False)

## Build the Deterministic Network
The `MLP` from `deepuq.models` stacks linear layers, ReLU activations, and dropout. Dropout probability is the lever that controls predictive spread when we enable MC sampling.

In [None]:
model = MLP(input_dim=28*28, hidden_dims=[512, 256], output_dim=10, p_drop=0.2)
model.to(DEVICE)

# Adam usually converges quickly on MNIST; feel free to experiment with SGD or different learning rates.
optimizer = optim.Adam(model.parameters(), lr=1e-3)

## Quick Training Loop
We only iterate over a handful of mini-batches to illustrate the workflow. For serious experiments, increase the number of epochs and monitor validation metrics.

In [None]:
model.train()
loss_fn = nn.CrossEntropyLoss()
num_batches = 5  # keep training snappy inside the notebook
for step, (features, labels) in enumerate(train_loader, start=1):
    features, labels = features.to(DEVICE), labels.to(DEVICE)
    optimizer.zero_grad()
    logits = model(features)
    loss = loss_fn(logits, labels)
    loss.backward()
    optimizer.step()
    print(f'Batch {step}/{num_batches} - training loss: {loss.item():.3f}')
    if step >= num_batches:
        break

## Wrap with Monte Carlo Dropout
`MCDropoutWrapper` keeps dropout layers active during evaluation and runs repeated stochastic forward passes to build a predictive distribution.

In [None]:
model.eval()  # base model enters eval mode before wrapping
uq_model = MCDropoutWrapper(model=model, n_mc=50, apply_softmax=True)
uq_model.to(DEVICE)

# Grab a batch of test images to interrogate uncertainty.
sample_batch, sample_labels = next(iter(test_loader))
sample_batch = sample_batch.to(DEVICE)

with torch.inference_mode():
    mean_probs, var_probs = uq_model.predict(sample_batch)

print('Predictive mean shape:', mean_probs.shape)
print('Predictive variance shape:', var_probs.shape)
print('Example mean probs for first sample:', mean_probs[0])
print('Example predictive variance for first sample:', var_probs[0])

## Interpreting the Numbers
- `mean_probs`: average class probabilities over 50 stochastic forward passes.
- `var_probs`: per-class variance capturing epistemic uncertainty from dropout randomness.
You can visualise variance scores or aggregate them (e.g., maximum variance) to flag uncertain predictions.

## Next Steps
1. Increase `n_mc` to tighten Monte Carlo estimates at the cost of runtime.
2. Calibrate probabilities with temperature scaling using a held-out validation set.
3. Swap in your own architecture—`MCDropoutWrapper` works with any `nn.Module` containing dropout layers.