In [None]:
import src
from src.utils import create_gif, load_model
from torchvision.transforms import Compose, Lambda, ToPILImage
import numpy as np
import torch

In [None]:
beta_start = 0.0001
beta_end = 0.02
timesteps = 200
image_size = 28
num_channels = 1
batch_size = 128
device = "cuda"
epoch = 4

In [None]:
scheduler = src.LinearScheduler(beta_start=beta_start, beta_end=beta_end, timesteps=timesteps)

In [None]:
reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

forward_diffusion = src.ForwardDiffusion(sqrt_alphas_cumprod=scheduler.sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=scheduler.sqrt_one_minus_alphas_cumprod, reverse_transform=reverse_transform)


In [None]:
device = torch.device(device)

In [None]:
model = src.DDPM(n_features=image_size, in_channels=num_channels, channel_scale_factors=(1, 2, 4,))
model = load_model(model, f"results/model_{epoch}.pt")
model.to(device)

sampler = src.Sampler(betas=scheduler.betas, sqrt_one_minus_alphas_cumprod=scheduler.sqrt_one_minus_alphas_cumprod, sqrt_one_by_alphas=scheduler.sqrt_one_by_alphas, posterior_variance=scheduler.posterior_variance, timesteps=timesteps)


In [None]:
# Get samples
samples = sampler.sample(model=model, image_size=image_size, batch_size=batch_size, channels=num_channels)


In [None]:
create_gif(samples, image_size, num_channels, timesteps)