### Setup

In [1]:
from matplotlib import pyplot as plt

from src.models.diffusion import DiffusionModel
from src.models.rin import RINModel
from src.utils.noise import gamma_cosine
from src.utils.sample import sample_ddpm


def show_images(images, title):
    fig, axes = plt.subplots(1, len(images), figsize=(15, 4))
    for i, ax in enumerate(axes):
        img = images[i].permute(1, 2, 0)
        img = (img + 1) / 2.0
        ax.imshow(img.clamp(0, 1))
        ax.axis("off")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

In [None]:
net = RINModel(
    image_size=64,
    patch_size=8,
    latent_dim=256,
    interface_dim=128,
    num_latents=64,
    num_blocks=2,
    block_depth=1,
    num_heads=4,
)

model = DiffusionModel.load_from_checkpoint(
    "logs/checkpoints/epoch=9-step=25440.ckpt",
    net=net,
    image_size=64,
    map_location="cpu",
)

model = model.to("mps")
_ = model.eval()

### Sample

In [None]:
NUM_SAMPLES = 8
NUM_STEPS = 100

samples = sample_ddpm(
    model=model,
    shape=(NUM_SAMPLES, 3, 64, 64),
    gamma_fn=gamma_cosine,
    num_steps=NUM_STEPS,
    device="mps",
)

show_images(samples.cpu(), f"DDPM samples with {NUM_STEPS} steps")