In [None]:
import os

import torch
from PIL import Image

from accelerate import Accelerator
from diffusers import UNet2DModel, DDPMPipeline, DDIMPipeline, DDPMScheduler, DDIMScheduler

In [None]:
from dataclasses import dataclass

@dataclass
class SamplingConfig:
    model_path = "models/CIFAR10_Noise" # Path to model folder (must contain `unet` and `scheduler` subfolders)

    num_inference_steps = 50            # Number of denoising steps
    eta = 0.0                           # 0.0 (Default DDIM), 1.0 (Equiv. to DDPM)
    
    seed = 10
    batch_size = 16                     # Number of images to sample in a batch
    rows = 4
    cols = 4  

    mixed_precision = "fp16"
    num_train_timesteps = 1000


config = SamplingConfig()

In [None]:
# Load model from path
model = UNet2DModel.from_pretrained(os.path.join(config.model_path, "unet"))

In [None]:
# Load the model to device
accelerator = Accelerator(mixed_precision = config.mixed_precision)
model = accelerator.prepare(model)

# Load scheduler from path
scheduler = DDIMScheduler.from_pretrained(os.path.join(config.model_path, "scheduler"))

# Create pipleine
pipeline = DDIMPipeline(accelerator.unwrap_model(model), scheduler)

In [None]:
def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i % cols * w, i // cols * h))
    return grid


def sample(config, pipeline):
    images = pipeline(
        batch_size=config.batch_size,
        generator=torch.manual_seed(config.seed),
        num_inference_steps=config.num_inference_steps,
        eta=config.eta
    ).images

    image_grid = make_grid(images, rows=config.rows, cols=config.cols)
    return image_grid

In [None]:
sample(config, pipeline)