In [1]:
import torch
from unet import UNet
from vdm import VDM
from train import get_cifar10_dataset
from ema_pytorch import EMA
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path


def show_image(img):
    img = img.detach().cpu()
    img = img / 2 + 0.5  # unnormalize
    with sns.axes_style("white"):
        plt.figure(figsize=(8, 8))
        plt.imshow(img.permute((1, 2, 0)).numpy())
        plt.axis("off")
        plt.show()


def eval_model(model, loader, device):
    model.eval()
    total_loss = 0.0
    n_batches = 0

    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            out = model(x)

            # Adapt this to your VDM's API:
            if isinstance(out, dict) and "loss" in out:
                loss = out["loss"]
            elif isinstance(out, (tuple, list)):
                loss = out[0]
            else:
                loss = out  # scalar tensor

            total_loss += loss.item()
            n_batches += 1

    return total_loss / n_batches

# load model

device = torch.device("cpu")
data = torch.load("./model.pt", map_location=device)

train_set = get_cifar10_dataset(train=True, download=False)
test_set = get_cifar10_dataset(train=False, download=False)

unet = UNet()
vdm = VDM(unet, image_shape=train_set[0][0].shape, device=device)
vdm = vdm.eval().to(device)

ema = EMA(vdm).to(device)
ema.ema_model.eval()


vdm.load_state_dict(data["model"])
ema.load_state_dict(data["ema"])


<All keys matched successfully>

In [None]:
# Eval model
test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
test_loss = eval_model(vdm, test_loader, device)
print(f"Test loss (VDM): {test_loss:.4f}")

In [None]:
# Sample images
import math
def sample_batched(model, num_samples, batch_size, n_sample_steps, clip_samples):
    samples = []
    for i in range(0, num_samples, batch_size):
        corrected_batch_size = min(batch_size, num_samples - i)
        samples.append(model.sample(corrected_batch_size, n_sample_steps, clip_samples))
    return torch.cat(samples, dim=0)

def sample_images(model, *, is_ema):
    samples = sample_batched(
        model,
        1,       # num_samples,
        1,      # eval_batch_size,
        250,      # n_sample_steps,
        True,    # clip_samples,
    )
    path = Path("./samples") / f"sample{'-ema' if is_ema else ''}.png"
    save_image(samples, str(path), nrow=int(math.sqrt(1)))

sample_images(ema.ema_model, is_ema=True)

sampling: 100%|██████████| 250/250 [00:26<00:00,  9.41it/s]
