# TensorFlow Diffusion Lab
Mirror the PyTorch DDPM experiment with Keras/GradientTape. Run through setup, training, metrics review, and sampling in one place for quick iteration.

In [None]:
import sys
from pathlib import Path

NOTEBOOK_DIR = Path().resolve()
SRC_DIR = NOTEBOOK_DIR.parent / 'src'
if str(SRC_DIR) not in sys.path:
    sys.path.append(str(SRC_DIR))

from config import CONFIG
from train import train
from inference import generate_samples

In [None]:
metrics = train(CONFIG)
metrics

## Review training and validation trends
Display the tracked losses inline and optionally plot them for quick sanity checks.

In [None]:
try:
    import pandas as pd
except ImportError:
    pd = None
try:
    import matplotlib.pyplot as plt
except ImportError:
    plt = None

if pd is not None:
    df = pd.DataFrame(metrics)
    display(df.tail())
    if plt is not None and not df.empty:
        ax = df.plot(figsize=(6, 3), title="Diffusion losses")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        plt.tight_layout()
        plt.show()
else:
    from pprint import pprint
    pprint(metrics)

### Interpret the metrics
- `train_loss` captures the MSE between predicted and true noise; expect a gradual decline.
- `val_loss` helps detect overfitting or divergence when experimenting with architectures.
- Convert to a pandas DataFrame for quick plotting and experiment comparison.
- Pair the metrics JSON with checkpoints for reproducibility.

In [None]:
samples = generate_samples(CONFIG, num_images=16, output_path=CONFIG.artifact_dir / 'notebook_samples.png')
samples.shape

## Visualise the diffusion samples
Plot the generated batch to compare checkpoints, schedules, or strategy tweaks at a glance.

In [None]:
import math
import numpy as np
try:
    import matplotlib.pyplot as plt
except ImportError as exc:
    print(f"Visualization skipped: {exc}")
else:
    images = samples.numpy()
    images = (images + 1.0) / 2.0
    n_cols = int(math.sqrt(images.shape[0])) or 1
    n_rows = math.ceil(images.shape[0] / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))
    axes = np.array(axes).reshape(n_rows, n_cols)
    for idx in range(n_rows * n_cols):
        ax = axes[idx // n_cols, idx % n_cols]
        ax.axis("off")
        if idx < images.shape[0]:
            ax.imshow(images[idx].squeeze(), cmap="gray")
    plt.suptitle("Diffusion samples", y=0.95)
    plt.tight_layout()
    plt.show()

### Next experiments
- Enable mixed precision and benchmark speed gains.
- Reduce `CONFIG.sample_steps` to see how few steps still produce recognisable digits.
- Use TensorBoard callbacks to stream metrics during training.
- Compare generated grids with the PyTorch notebook outputs.