# Diffusion on the Moons üåó Toy-Dataset - Basic MLP PoC

This notebook demonstrates a **proof-of-concept diffusion model** trained on the classic scikit-learn "moons" toy dataset. Diffusion models work by learning to reverse a gradual noising process: we start with clean data, progressively add Gaussian noise over multiple time steps, and then train a neural network to predict and remove that noise step-by-step.

In this experiment, we use a simple **Multi-Layer Perceptron (MLP)** as our denoising model to learn the reverse diffusion process on 2D point clouds. This toy example helps visualize how diffusion models generate new samples by starting from pure noise and iteratively denoising to recover the data distribution.

In [None]:
import torch
import torch.nn.functional as F
from torchinfo import summary
import matplotlib.pyplot as plt

from src.diffusion_playground.visualization.plot import show_denoising_steps_2d
from src.diffusion_playground.data_loader.toy_datasets import load_toy_dataset
from src.diffusion_playground.models.mlp_denoiser import MLPDenoiser
from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.training.denoiser_trainer import train_denoiser

## Train the Model üèãÔ∏è

Here we set up and train our MLP denoiser on the moons dataset. The training process teaches the model to predict the noise that was added at each time step during the forward diffusion process.

**Key components:**
- **Dataset**: 1,000 samples from the moons toy dataset (two interleaving half-circles)
- **Model**: A simple MLP that takes noisy data points and the current time step as input
- **Noise Schedule**: Linear schedule over 100 time steps controlling how noise is added
- **Training**: 100,000 epochs to learn the noise prediction task (takes ~1 minute on a typical laptop's CPU)

In [None]:
# Load the data
data = load_toy_dataset("moons", n_samples=1_000)
data = torch.tensor(data)

# Create the model
model = MLPDenoiser()

# Create the noising schedule
schedule = LinearNoiseSchedule(time_steps=100)

# Train the model
# Note: This should not take longer than approx. 1 minute on a normal laptop
train_denoiser(model, data, schedule, epochs=100_000)

## Evaluate the Model üìà

Now that our model is trained, let's test its ability to generate new samples! We'll start from pure Gaussian noise and iteratively denoise it using our trained MLP. If successful, the final denoised samples should resemble the characteristic moon shapes from our training data.

We'll visualize both the original data (in red) and track the denoising process at various time steps to see how the model gradually recovers the data distribution from random noise.

In [None]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 256
time_steps = 100
visualize_ever = 10

model.eval()

# Get and show the original moons data for comparison with our generated samples
idx = torch.randint(0, data.shape[0], (batch_size,))
x0_eval = data[idx]

plt.figure(figsize=(8, 6))
plt.scatter(x0_eval[:, 0], x0_eval[:, 1], c="red", alpha=0.6)
plt.title("Sample from the original Moons Dataset (Ground Truth)", fontsize=14, fontweight='bold')
plt.xlabel("x‚ÇÄ")
plt.ylabel("x‚ÇÅ")
plt.grid(True, alpha=0.3)
plt.show()

### De-Noising Process ‚öôÔ∏è

This is the **reverse diffusion process** - the heart of generation! We start with pure random noise and step backwards through our 100 time steps, using our trained model to predict and remove noise at each step.

At each iteration:
1. The model predicts what noise was added at time step `t`
2. We use the noise schedule parameters (Œ±, Œ≤) to calculate the cleaner version of the data
3. We step backwards in time, gradually revealing the underlying moon pattern

We save snapshots every 10 steps to visualize how the structure emerges from chaos.

In [None]:
# Create pure noise as a starting point
xt = torch.randn_like(x0_eval).to(device)
x_steps = []

# De-noising loop
for t in reversed(range(1, time_steps + 1)):
    # The current time step for every datapoint (same for all datapoints)
    t_tensor = torch.full((batch_size, 1), t, device=device, dtype=torch.float32)

    # Predict the noise that was "added" at this time step - Reverse diffusion process
    with torch.no_grad():
        pred_noise = model(xt, t_tensor)

    # Re-calculate the previous datapoints (note that the weights alpha_hat and beta must be included!)
    beta_t = schedule.betas[t - 1]
    alpha_t = schedule.alphas[t - 1]
    alpha_bar_t = schedule.alpha_bars[t - 1]

    # Step back
    x_prev = (xt - beta_t / torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)

    # Update xt
    xt = x_prev

    # Track steps for visualization
    if t % visualize_ever == 0 or t == 1:
        x_steps.append((xt.clone(), pred_noise.clone()))

### Visualize the Interim De-Noising Results üèûÔ∏è

Watch the magic happen! These visualizations show how our generated samples evolve during the reverse diffusion process:

- **Step 0/100** (t=100): Starting point - pure random noise with no discernible structure
- **Step 50/100** (t=50): Midway through - some structure begins to emerge
- **Step 100/100** (t=1): Final result - clean samples that should closely match the moon distribution

In [None]:
show_denoising_steps_2d(x0_eval, *x_steps[0], title="Step 0 / 100")

In [None]:
show_denoising_steps_2d(x0_eval, *x_steps[5], title="Step 50 / 100")

In [None]:
show_denoising_steps_2d(x0_eval, *x_steps[9], title="Step 100 / 100")

### Print the Model Summary and Final Loss Value üíØ

Let's evaluate the quality of our generated samples by computing the Mean Squared Error (MSE) between the final denoised samples and the original data. A low MSE indicates that our model successfully learned to generate samples similar to the training distribution.

We also display the model architecture summary to see the simplicity of our MLP denoiser - proof that even simple models can learn diffusion dynamics on toy datasets!

In [None]:
mse_final = F.mse_loss(xt, x0_eval)
print(f"Final MSE: {mse_final.item():.4f}")

summary(model)