# Context for Controlling diffusion models

#### Adding context to the denoising process of the diffusion model to control the generation of the model

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader 
from torchvision import models, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, Tuple 
from tqdm import tqdm
from matplotlib.animation import FuncAnimation, PillowWriter

### Sampling with context

In [None]:
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t])) / (1 - ab_t[t].sqrt())) / a_t[t].sqrt()
    return mean + noise

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device("cpu"))
timesteps = 500

In [None]:
@torch.no_grad()
def sample_ddqm_context(n_sample, context, save_rate=20):
    samples = torch.randn(n_sample, 3, height, height).to(device)
    intermidiate = []
    for i in range(timestamps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')
        t = torch.tensor([i / timesteps])[:None, None, None].to(device)
        
        z  = torch.randn_like(samples) if i > 1 else 0
        eps = nn_model(sampels, t, c=context)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate == 0 or i == timesteps or i < 8:
            intermidiate.append(samples.detach().cpu().numpy())
    
    intermidiate = np.stack(intermidiate)
    return samples, intermidiate

In [None]:
plt.clf()
ctx = F.one_hot(torch.randint(0, 5, (32,)), 5).to(device=device).float()
samples, intermediate = sample_ddqm_context(32, ctx)
animation_ddqm_context = plot_sample(intermediate, 32, 4, save_dir, "ani_run", None, save=False)
HTML(animation_ddqm_context.to_jshtml())