# PyTorch Diffusion Lab
Execute this notebook end-to-end to train the Fashion-MNIST DDPM baseline, inspect the logged metrics, and sample new images without leaving the lab environment.

In [None]:
import sys
from pathlib import Path

NOTEBOOK_DIR = Path().resolve()
SRC_DIR = NOTEBOOK_DIR.parent / 'src'
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from config import CONFIG
from train import train
from inference import generate_samples

In [None]:
metrics = train(CONFIG)
metrics

## Review training and validation trends
Convert the metrics dictionary into a quick report so you can compare experiments or log summaries downstream.

In [None]:
try:
    import pandas as pd
except ImportError:
    pd = None
try:
    import matplotlib.pyplot as plt
except ImportError:
    plt = None

if pd is not None:
    df = pd.DataFrame(metrics)
    display(df.tail())
    if plt is not None and not df.empty:
        ax = df.plot(figsize=(6, 3), title="Diffusion losses")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        plt.tight_layout()
        plt.show()
else:
    from pprint import pprint
    pprint(metrics)

### Interpret the metrics
- `train_loss` and `val_loss` track the MSE between predicted and true noise.
- Plateaus indicate it is time to tweak learning rate, UNet depth, or training epochs.
- Use `pandas.DataFrame(metrics)` to plot curves and compare experiments.
- Store the metrics JSON with the checkpoint for reproducibility.

In [None]:
samples = generate_samples(CONFIG, num_images=16, output_path=CONFIG.artifact_dir / 'notebook_samples.png')
samples.shape

## Visualise the diffusion samples
View the generated batch inline to compare noise schedules, checkpoints, or sampler tweaks across runs.

In [None]:
import math
try:
    import matplotlib.pyplot as plt
    from torchvision.utils import make_grid
except ImportError as exc:
    print(f"Visualization skipped: {exc}")
else:
    grid = make_grid(samples, nrow=int(math.sqrt(len(samples))), normalize=True, value_range=(-1, 1))
    plt.figure(figsize=(6, 6))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis("off")
    plt.title("Diffusion samples")
    plt.show()

### Next experiments
- Visualise the saved grid (`ddpm_samples.png`) alongside earlier runs to spot improvements.
- Reduce `CONFIG.sample_steps` for faster inference and gauge quality loss.
- Sweep different beta schedules (linear vs cosine) inside `engine.py` and compare metrics.
- Log generated samples to TensorBoard or wandb for automated experiment tracking.