# MNIST Diffusion - Let's get to real images!

In [None]:
import torch
import matplotlib.pyplot as plt

from src.diffusion_playground.data_loader.mnist import load_mnist
from src.diffusion_playground.diffusion.noise_schedule import LinearNoiseSchedule
from src.diffusion_playground.diffusion.training_utils import sample_xt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

## Load the MNIST-Dataset

In [None]:
batch_size = 128
data_loader, input_shape = load_mnist(batch_size=batch_size)
print("Input shape:", input_shape)

## Explore the Dataset

In [None]:
x0, labels = next(iter(data_loader))

print("Batch shape:", x0.shape)
print("Labels shape:", labels.shape)

In [None]:
plt.imshow(x0[5][0], cmap="gray")
plt.show()

## Forward Diffusion on MNIST

In [None]:
# Create NoiseSchedule
schedule = LinearNoiseSchedule(time_steps=1_000)

time_steps = [0, 50, 200, 500, 999]

plt.figure(figsize=(15, 3))
for i, t in enumerate(time_steps):
    t_tensor = torch.tensor([t])
    xt, _, _ = sample_xt(x0, schedule, t=t_tensor)

    plt.subplot(1, len(time_steps), i + 1)
    plt.imshow(xt[0, 0].cpu(), cmap="gray")
    plt.title(f"t={t}")
    plt.axis("off")

plt.show()