In [1]:
# This project requires Python 3.10 or above:
import sys
assert sys.version_info >= (3, 10)

# We also need PyTorch â‰¥ 2.6.0:
from packaging.version import Version
import torch
assert Version(torch.__version__) >= Version("2.6.0")

from torch.utils.data import DataLoader
import torch.nn as nn
import torchmetrics
# from collections import namedtuple

from ldm_ludo import diff_model as dm
from ldm_ludo import data
from ldm_ludo import plots
from ldm_ludo import utils
from ldm_ludo import training

In [2]:
# Prefer and hw accelerator
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [None]:
# Set seed for reproducibility
torch.manual_seed(42)

# Variance schedule: get alphas and betas
T = int(10)
embed_dim = 64 # TODO change this for time embedding
alphas, betas, alpha_bars = dm.variance_schedule(T)

In [None]:
# Load dataset, split it, and load it into DataLoaders
train_data, valid_data, test_data = data.loadDataset("mnist")
train_set = data.DiffusionDataset(train_data, T, alpha_bars)  # wrap dataset
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_set = data.DiffusionDataset(valid_data, T, alpha_bars)
valid_loader = DataLoader(valid_set, batch_size=32)

In [None]:
# let's train the model
torch.manual_seed(42)
diffusion_model = dm.DiffusionModel(T).to(device)
huber = nn.HuberLoss()
optimizer = torch.optim.NAdam(diffusion_model.parameters(), lr=3e-3)
rmse = torchmetrics.MeanSquaredError(squared=False).to(device)
history = training.train(diffusion_model, optimizer, huber, rmse, train_loader,
                valid_loader, device=device, n_epochs=1)

# save model's trained weights
utils.save_model(diffusion_model)

In [None]:
# Generate images
X_gen = dm.generate_ddpm(diffusion_model)  # generated images
utils.plot_multiple_images(X_gen, 8)
plt.show()

In [None]:
# use DDIM sampling
X_gen_ddim = dm.generate_ddim(diffusion_model, num_steps=500)
utils.plot_multiple_images(X_gen_ddim, 8)
plt.show()