In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from diffusion.models.ddpm import DDPM

import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
batch_size = 128
epochs = 5
T = 1000

In [None]:
transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(), # [0,1]
        transforms.Lambda(lambda x: x * 2.0 - 1.0) #[-1,1]
    ])

mnist = datasets.MNIST(
    root="../data", train=True, download=True, transform=transform
)
dl = DataLoader(mnist, batch_size=batch_size, shuffle=True, num_workers=4)


In [None]:
ddpm = DDPM(img_channels=1, base_c=32, time_dim=128, T=T).to(device)

In [None]:
lr = 1e-5
optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)

In [None]:
ddpm.train()
for epoch in range(epochs):
    for i, (x, _) in enumerate(dl):
        x = x.to(device)  # (B,1,32,32) in [-1,1]

        optimizer.zero_grad()
        loss = ddpm(x)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"[Epoch {epoch} | Step {i}] Loss: {loss.item():.4f}")

In [None]:
ddpm.eval()
with torch.no_grad():
    samples = ddpm.sample(n_samples=16, img_size=(1, 32, 32))  # [-1,1]


samples_vis = (samples + 1) / 2
samples_vis = samples_vis.clamp(0, 1).cpu()

fig, axes = plt.subplots(4, 4, figsize=(6, 6))

for ax, img in zip(axes.flatten(), samples_vis):
    ax.imshow(img[0], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()
