### Model

In [1]:
from diffusion.ddpm import Unet, GaussianDiffusion, Trainer

model = Unet(
    dim = 64,
    dim_mults = (1,2,4),
    flash_attn = False
)

diffusion = GaussianDiffusion(
    model,
    image_size = 28,
    timesteps = 1000,          
    sampling_timesteps = 250   
)

In [None]:
trainer = Trainer(
    diffusion,
    './data',
    train_batch_size = 128,
    train_lr = 5e-5,
    train_num_steps = 50000,        
    gradient_accumulate_every = 1,    
    ema_decay = 0.995,             
    amp = True,                  
    calculate_fid = False,
    save_and_sample_every = 10000,
    num_fid_samples = 1000, 
)

trainer.train()

### Training

### Sampling

In [6]:
import torch
from ema_pytorch import EMA
from torchvision.utils import save_image

ckpt_path = "./ckpts/model-5.pt"  

device = 'cuda'

ckpt = torch.load(ckpt_path, map_location=device)
diffusion.load_state_dict(ckpt["model"])

ema = EMA(diffusion, beta=0.995, update_every=10).to(device)

ema.load_state_dict(ckpt["ema"])

ema_model = ema.ema_model

In [7]:
ema_model.eval()

with torch.no_grad():
    samples = ema_model.sample(batch_size=16)  
    save_image(samples, "results/samples_baseline.png", nrow=4)

Sample 1k imgs for FID

### FID Calculation

In [12]:
import os
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image

from torchmetrics.image.fid import FrechetInceptionDistance


class ImageFolderDataset(Dataset):
    def __init__(self, root, image_size=299):
        self.root = Path(root)
        exts = [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
        self.paths = sorted(
            p for ext in exts for p in self.root.rglob(f"*{ext}")
        )

        if len(self.paths) == 0:
            raise RuntimeError(f"No images found in {self.root}")

        self.transform = T.Compose([
            T.ConvertImageDtype(torch.float32),
        ])
        self.resize = T.Resize((image_size, image_size))

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        img = Image.open(path).convert("RGB")  
        img = self.resize(img)
        img = T.functional.to_tensor(img)
        return img


@torch.no_grad()
def compute_fid(
    real_dir: str,
    fake_dir: str,
    batch_size: int = 64,
    device: str = None,
    num_workers: int = 4,
) -> float:

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    real_ds = ImageFolderDataset(real_dir)
    fake_ds = ImageFolderDataset(fake_dir)

    real_loader = DataLoader(
        real_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    fake_loader = DataLoader(
        fake_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    fid = FrechetInceptionDistance(feature=2048).to(device)
    fid.eval()

    for imgs in real_loader:
        imgs = imgs.to(device)
        fid.update(imgs, real=True)

    for imgs in fake_loader:
        imgs = imgs.to(device)
        fid.update(imgs, real=False)

    value = fid.compute().item()
    return value

In [None]:
real_path = "./data/val"         
fake_path = "./samples/baseline"  

fid_value = compute_fid(real_path, fake_path, batch_size=64)
print(f"FID(real={real_path}, fake={fake_path}) = {fid_value:.4f}")