In [None]:
import torch
from unet import UNet
from vdm import VDM
from train import get_cifar10_dataset
from ema_pytorch import EMA
from torch.utils.data import DataLoader
from torch import nn
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


MODEL_PATH = "./outputs/2025-12-02T21:53:47.100815/model.pt"


def show_image(img):
    img = img.detach().cpu()
    img = img / 2 + 0.5  # unnormalize
    with sns.axes_style("white"):
        plt.figure(figsize=(8, 8))
        plt.imshow(img.permute((1, 2, 0)).numpy())
        plt.axis("off")
        plt.show()


def eval_model(model, loader, device):
    model.eval()
    total_loss = 0.0
    n_batches = 0

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            out = model(x)

            # Adapt this to your VDM's API:
            if isinstance(out, dict) and "loss" in out:
                loss = out["loss"]
            elif isinstance(out, (tuple, list)):
                loss = out[0]
            else:
                loss = out  # scalar tensor

            total_loss += loss.item()
            n_batches += 1

    return total_loss / n_batches


# load model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = torch.load(MODEL_PATH, map_location=device)

train_set = get_cifar10_dataset(train=True, download=False)
test_set = get_cifar10_dataset(train=False, download=False)

unet = UNet()
vdm = VDM(unet, image_shape=train_set[0][0].shape, device=device)
vdm = vdm.eval().to(device)

ema = EMA(vdm).to(device)
if not isinstance(ema.ema_model, VDM):
    raise ValueError("EMA model is not a VDM")
ema.ema_model.eval()


vdm.load_state_dict(data["model"])
ema.load_state_dict(data["ema"])

<All keys matched successfully>

In [None]:
# Eval model
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
test_loss = eval_model(vdm, test_loader, device)
print(f"Test loss (VDM): {test_loss:.4f}")

Test loss (VDM): 0.2800


In [None]:
# Sample images
import math


def sample_batched(model: VDM, num_samples: int, batch_size: int, n_sample_steps: int, clip_samples: bool):
    samples = []
    for i in range(0, num_samples, batch_size):
        corrected_batch_size = min(batch_size, num_samples - i)
        samples.append(model.sample(corrected_batch_size, n_sample_steps, clip_samples))
    return torch.cat(samples, dim=0)


NUM_SAMPLES = 64
BATCH_SIZE = 64
N_SAMPLE_STEPS = 1000
CLIP_SAMPLES = True


def sample_images(model: VDM, is_ema: bool):
    samples = sample_batched(
        model,
        NUM_SAMPLES,
        BATCH_SIZE,
        N_SAMPLE_STEPS,
        CLIP_SAMPLES,
    )

    path = Path("./samples") / f"sample{'-ema' if is_ema else ''}-{N_SAMPLE_STEPS}.png"
    save_image(samples, str(path), nrow=int(math.sqrt(NUM_SAMPLES)))


sample_images(vdm, is_ema=False)
sample_images(ema.ema_model, is_ema=True)

sampling: 100%|██████████| 1000/1000 [01:53<00:00,  8.82it/s]
sampling: 100%|██████████| 1000/1000 [01:52<00:00,  8.86it/s]
