# Diffusion on the Moons Toy-Dataset - Basic MLP PoC

In [None]:
import torch
import torch.nn.functional as F
from torchinfo import summary
import matplotlib.pyplot as plt

from src.diffusion_playground.visualization.plot import show_denoising_steps_2d
from src.diffusion_playground.data_loader.toy_datasets import load_toy_dataset
from src.diffusion_playground.models.mlp_denoiser import MLPDenoiser
from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.training.denoiser_trainer import train_denoiser

## Train the Model

In [None]:
# Load the data
data = load_toy_dataset("moons", n_samples=1_000)
data = torch.tensor(data)

# Create the model
model = MLPDenoiser()

# Create the noising schedule
schedule = LinearNoiseSchedule(time_steps=100)

# Train the model
# Note: This should not take longer than approx. 1 minute on a normal laptop
train_denoiser(model, data, schedule, epochs=100_000)

## Evaluate the Model

In [None]:
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 256
time_steps = 100
visualize_ever = 10

model.eval()

# Get and show the rae data for comparison
idx = torch.randint(0, data.shape[0], (batch_size,))
x0_eval = data[idx]

plt.scatter(x0_eval[:, 0], x0_eval[:, 1], c="red")
plt.show()

### De-Noising Process

In [None]:
# Create pure noise as a starting point
xt = torch.randn_like(x0_eval).to(device)
x_steps = []

# De-noising loop
for t in reversed(range(1, time_steps + 1)):
    # The current time step for every datapoint (same for all datapoints)
    t_tensor = torch.full((batch_size, 1), t, device=device, dtype=torch.float32)

    # Predict the noise that was "added" at this time step - Reverse diffusion process
    with torch.no_grad():
        pred_noise = model(xt, t_tensor)

    # Re-calculate the previous datapoints (note that the weights alpha_hat and beta must be included!)
    beta_t = schedule.betas[t - 1]
    alpha_t = schedule.alphas[t - 1]
    alpha_bar_t = schedule.alpha_bars[t - 1]

    # Step back
    x_prev = (xt - beta_t / torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)

    # Update xt
    xt = x_prev

    # Track steps for visualization
    if t % visualize_ever == 0 or t == 1:
        x_steps.append((xt.clone(), pred_noise.clone()))

### Visualize the interim de-noising Results

In [None]:
show_denoising_steps_2d(x0_eval, *x_steps[0], title="Step 0 / 100")

In [None]:
show_denoising_steps_2d(x0_eval, *x_steps[5], title="Step 50 / 100")

In [None]:
show_denoising_steps_2d(x0_eval, *x_steps[9], title="Step 100 / 100")

### Print the Model Summary and final Loss Value

In [None]:
mse_final = F.mse_loss(xt, x0_eval)
print(f"Final MSE: {mse_final.item():.4f}")

summary(model)