In [6]:
import os
import math
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from typing import List

import kagglehub
import shutil

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, IterableDataset
from torchvision import transforms

from diffusion_model import DeviceManager, LearningRateScheduler, \
    CatDataset, NoiseScheduler, Block, SinusidalPositionEmbeddings, SimpleUnet

In [7]:
def load_model(filename):
    if not os.path.exists(filename):
        folder_path = os.path.dirname(filename)
        os.makedirs(folder_path, exist_ok=True)

        path = kagglehub.model_download("danildolgov/cat-diffusion/pyTorch/64x64")
        cache_file = os.listdir(path)[0]
        
        shutil.move(os.path.join(path, cache_file), filename)
    
    return filename

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device_manager = DeviceManager(0, 1, True, device)

model_path = r"models\cat-diffusion-64px\model.pt"

In [9]:
checkpoint = torch.load(load_model(model_path))

model = SimpleUnet()
model.to(device_manager.device)
model.load_state_dict(checkpoint["model"])

  checkpoint = torch.load(load_model(model_path))


<All keys matched successfully>

### Inference code

In [10]:
T = 300
image_size = 64
count_x = 5
count_y = 5

noise_scheduler = NoiseScheduler(T, device_manager)
noises = torch.randn((count_x*count_y, 3, image_size, image_size), device=device_manager.device)

images = model.generate_images(noises, T, noise_scheduler, device_manager).cpu()

fig, axes = plt.subplots(count_y, count_x, figsize=(8, 8))
axes = axes.flatten()

reverse_transform = transforms.Compose([
    transforms.Lambda(lambda x: (x + 1) / 2), 
    transforms.ToPILImage()
])

for ax, image in zip(axes, images):
    pil_image = reverse_transform(image)
    ax.imshow(pil_image)
    ax.axis('off')

plt.tight_layout()
plt.show()