In [1]:
import torch
import torchvision
import os
import torch.nn as nn
from ddpm.diffusion_model import DiffusionModel

import ddpm.config as _config
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from ddpm.data import CIFAR10_CLASSES

In [2]:
_config.DEBUG = False

In [3]:
cpt_path = 'epoch_24.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
cpt = torch.load(cpt_path, map_location=device, weights_only=False)
model = DiffusionModel(cpt['config'])
model.load_state_dict(cpt['model'])
model = nn.DataParallel(model).to(device)

model.eval()
print("model loaded")

# sample

In [5]:
image_num = 2

In [None]:
save_dir = 'cifar10_samples'
for label in range(10):
    print(CIFAR10_CLASSES[label])
    labels = torch.ones(image_num, dtype=torch.long, device=device) * label
    samples = model.module.sample(shape=(image_num, 3, 32, 32), device=device, y=labels)

    for sample in samples:
        processed_sample = ((sample + 1) / 2).clip(0, 1)
        # Save each image with a filename indicating class and image number
        filename = f"{save_dir}/{CIFAR10_CLASSES[label]}_{i}.png"
        
        # Convert to CPU and save
        torchvision.utils.save_image(processed_sample, filename)
        print(f"Saved {filename}")

# sample intermediate

In [None]:
labels = torch.ones(16, dtype=torch.long, device=device) * 1 # Automobile
samples = model.module.sample_intermediate(shape=(16, 3, 32, 32), device=device, y=labels, save_every=200)

for sample in samples:
    grid = torchvision.utils.make_grid(sample, nrow=4)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()