### Model

In [1]:
from diffusion.ddpm import Unet, GaussianDiffusion, Trainer

model = Unet(
    dim = 64,
    dim_mults = (1,2,4),
    flash_attn = False
)

diffusion = GaussianDiffusion(
    model,
    image_size = 28,
    timesteps = 1000,          
    sampling_timesteps = 250   
)

In [None]:
trainer = Trainer(
    diffusion,
    './data',
    train_batch_size = 128,
    train_lr = 5e-5,
    train_num_steps = 50000,        
    gradient_accumulate_every = 1,    
    ema_decay = 0.995,             
    amp = True,                  
    calculate_fid = False,
    save_and_sample_every = 10000,
    num_fid_samples = 1000, 
)

trainer.train()

### Training

### Sampling

In [6]:
import torch
from ema_pytorch import EMA
from torchvision.utils import save_image

ckpt_path = "./ckpts/model-5.pt"  

device = 'cuda'

ckpt = torch.load(ckpt_path, map_location=device)
diffusion.load_state_dict(ckpt["model"])

ema = EMA(diffusion, beta=0.995, update_every=10).to(device)

ema.load_state_dict(ckpt["ema"])

ema_model = ema.ema_model

In [7]:
ema_model.eval()

with torch.no_grad():
    samples = ema_model.sample(batch_size=16)  
    save_image(samples, "results/samples_baseline.png", nrow=4)