# Sampling from a diffusion model

In this notebook we will:
* Sample from the previously trained conditional diffusion model (see [1_wandb_training.ipynb](/notebooks/1_wandb_training.ipynb)), by using the DDPM and DDIM sampling schemes.
* Compare the samples from the DDPM and DDIM samplers.

In [1]:
from pathlib import Path
from types import SimpleNamespace
import torch
import torch.nn.functional as F
import numpy as np
from utils.wandb_utils import *

import wandb

  from .autonotebook import tqdm as notebook_tqdm


# Setup

In [2]:
# Wandb Params
# - We saved the best performing model (trained in the previous notebook) as an 
#   artifact in the W&B model registry. This is the path to it
MODEL_ARTIFACT = "doc93/model-registry/Diffusion-Model-Sprite:latest"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

config = SimpleNamespace(
    # hyperparameters
    num_samples = 30,
    
    # ddpm sampler hyperparameters
    timesteps = 500,
    beta1 = 1e-4,
    beta2 = 0.02,
    
    # ddim sampler hp
    ddim_n = 25,
    
    # network hyperparameters
    height = 16,
)

In [3]:
def load_model(model_artifact_name):

    # Pull the model artifact from the wandb registry using the wandb API
    api = wandb.Api()
    artifact = api.artifact(model_artifact_name, type="model")

    # Alternatively ???
    # run = wandb.init()
    # artifact = run.use_artifact(model_artifact_name, type='model')

    model_path = Path(artifact.download())

    # Retrieve info (from the registry) about the run that produced the model
    producer_run = artifact.logged_by()

    # Load the weights dictionary from the model artifact
    model_weights = torch.load(model_path/"context_model.pth", 
                               map_location="cpu")

    # Create the model using same parameters as the original run
    model = ContextUnet(in_channels=3, 
                        n_feat=producer_run.config["n_feat"], 
                        n_cfeat=producer_run.config["n_cfeat"], 
                        height=producer_run.config["height"])
    
    # Load the weights into the model
    model.load_state_dict(model_weights)

    # Set the model to eval mode
    model.eval()
    return model.to(DEVICE)

In [4]:
nn_model = load_model(MODEL_ARTIFACT)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


# Sampling

In [5]:
# Setup the diffusion model sampler (DDPM) that we used during training (in 
# previous notebook)
_, sample_ddpm_context = setup_ddpm(config.beta1, 
                                    config.beta2, 
                                    config.timesteps, 
                                    DEVICE)

In [7]:
# Setup a faster sampler (DDIM). This sampler was not used during training
# - this sampler is faster but compromises on output quality
sample_ddim_context = setup_ddim(config.beta1, 
                                 config.beta2, 
                                 config.timesteps, 
                                 DEVICE)

In [8]:
# Define a set of fixed noises and a context vector (like during training)

# Noise vector
# x_T ~ N(0, 1), sample initial noise
noises = torch.randn(config.num_samples, 3, 
                     config.height, config.height).to(DEVICE)  

# A fixed context vector to sample from
ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0,   # hero
                                     1,1,1,1,1,1,   # non-hero
                                     2,2,2,2,2,2,   # food
                                     3,3,3,3,3,3,   # spell
                                     4,4,4,4,4,4]), # side-facing 
                       5).to(DEVICE).float()

Goal is to compare output from both samplers (DDPM and DDIM)

In [9]:
# Compute DDPM samples
ddpm_samples, _ = sample_ddpm_context(nn_model, noises, ctx_vector)

                                                                         

In [10]:
# Compute DDIM samples
# - for DDIM we can control the step size by the n param:
ddim_samples, _ = sample_ddim_context(nn_model, 
                                      noises, 
                                      ctx_vector, 
                                      n=config.ddim_n)

                                                             

# Compare results

Create a `wandb.Table` to store diffusion model outputs. This table behaves like a `dataframe` that can be rendered in the `W&B` workspace of the project.

In [17]:
table = wandb.Table(columns=["input_noise", "ddpm", "ddim", "class"])

In [18]:
# Construct the table row by row 
# - we cast images to wandb.Image so we can render them correctly in the UI
for noise, ddpm_s, ddim_s, c in zip(noises, 
                                    ddpm_samples, 
                                    ddim_samples, 
                                    to_classes(ctx_vector)):
    
    # add data row by row to the Table
    table.add_data(wandb.Image(noise),
                   wandb.Image(ddpm_s), 
                   wandb.Image(ddim_s),
                   c)

In [19]:
# Log the table to W&B project, 
# - use wandb.init as a context manager. This way we ensure that the run is 
#   finished when exiting the manager.
# - use same project name as during training 
# - change job_type name (to make it easier to find this run)
with wandb.init(project="diff_model_sprite", 
                job_type="samplers_battle", 
                config=config):
    
    wandb.log({"samplers_table":table})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33md-oliver-cort[0m ([33mdoc93[0m). Use [1m`wandb login --relogin`[0m to force relogin


