# SGLD Sampling Tutorial
Collect posterior samples with stochastic-gradient Langevin dynamics and aggregate predictions.

## Goal
Perform stochastic-gradient Langevin dynamics (SGLD) to draw approximate posterior samples and aggregate predictions.

In [None]:
# Configure Python path so the notebook sees the local deepuq package
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path(os.getcwd())
if not (PROJECT_ROOT / 'src').exists():
    PROJECT_ROOT = PROJECT_ROOT.parent

SRC_PATH = str(PROJECT_ROOT / 'src')
if SRC_PATH not in sys.path:
    sys.path.insert(0, SRC_PATH)

In [None]:
import torch
from torchvision import datasets, transforms

from deepuq.models import MLP
from deepuq.methods import collect_posterior_samples, predict_with_samples
from deepuq.utils import set_seed

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Running on {DEVICE}')

## Dataset
We again rely on MNIST; adjust the loader to your own dataset if needed.

In [None]:
set_seed(99)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=256, shuffle=False)

## Model Initialisation
SGLD starts from random weights, then injects Gaussian noise during each update to explore the posterior.

In [None]:
model = MLP(28*28, [128], 10, p_drop=0.0)
model.to(DEVICE)

## Collect Posterior Samples
`collect_posterior_samples` performs a short SGLD run. Increase `n_steps` for more thorough exploration; remember to adjust `burn_in` proportion accordingly.

In [None]:
samples = collect_posterior_samples(
    model=model,
    data_loader=train_loader,
    n_steps=50,        # total optimiser steps (including burn-in)
    lr=1e-4,           # small step size keeps noise-controlled moves
    weight_decay=1e-4, # weight decay acts as a Gaussian prior on weights
    burn_in=0.4,       # discard first 40% of steps to reach steady state
    device=DEVICE,
)

print(f'Collected {len(samples)} posterior samples')

## Predict with Sampled Weights
Load each sampled parameter dictionary into the model, perform a forward pass, and aggregate the results to estimate mean and variance.

In [None]:
with torch.inference_mode():
    batch, _ = next(iter(test_loader))
    batch = batch.to(DEVICE)
    mean_probs, var_probs = predict_with_samples(model, samples, batch, apply_softmax=True, device=DEVICE)

print('Predictive mean shape:', mean_probs.shape)
print('Predictive variance shape:', var_probs.shape)
print('First example mean probs:', mean_probs[0])
print('First example variance:', var_probs[0])

## Practical Advice
- SGLD requires longer runs for well-mixed samples; monitor training loss and sample autocorrelation.
- Store samples to disk if you need to pause and resume experiments.
- Combine with thinning (keep every k-th sample) to reduce correlation between stored states.