<a href="https://colab.research.google.com/github/Evil-Tux/Diffusion-Models/blob/main/Article_4_Guidance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install diffusers==0.16.1 accelerate open_clip_torch transformers

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision.transforms import Compose, Resize, ToTensor, ToPILImage

from diffusers import DDPMScheduler, DDIMScheduler, DDPMPipeline, DDIMPipeline

from matplotlib import pyplot as plt
from PIL import Image
from tqdm import tqdm
import numpy as np

def plot_images(images, n=8, axs=None):
    if axs is None:
        fig, axs = plt.subplots(1, n, figsize=(10, 3))
    assert len(axs) == len(images)
    for i, img in enumerate(images):
        axs[i].axis('off')
        if isinstance(img, torch.Tensor):
            img = ToPILImage()((img/2+0.5).clamp(0, 1))
        axs[i].imshow(img.resize((64, 64), resample=Image.NEAREST), cmap='gray_r', vmin=0, vmax=255)

## Guidance

In the previous post in the Diffusion Models 101 series, we fine-tuned a pretrained model to generate MNIST digit images. When we generated images using the manual loop as opposed to the training loop, the generated digits resembled characters from “The Matrix”.

What if we wanted our digits to be green as the falling characters from "The Matrix"? They kind of look like that already, so why not?

The question is, how can we tell the model we like our digits green?

We need to guide our model into the right path or, as Morpheus said to Neo, "there's a difference between knowing the path, and walking the path."

Deep, right?

With guidance, we can manipulate the quality of the output images by directing the generative process towards a specified outcome. In this post, we will demonstrate the flexibility of diffusion models and target color modification of the generated image output using guidance.

In [None]:
#########
## NEW ##
#########
from diffusers import DDPMPipeline, DDIMScheduler
image_pipe = DDPMPipeline.from_pretrained('dvgodoy/ddpm-cifar10-32-mnist')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

image_pipe.to(device)
noise_scheduler = image_pipe.scheduler
noise_scheduler.set_timesteps(40)
model = image_pipe.unet

torch.manual_seed(33)
sample = torch.randn(8, 3, 32, 32).to(device)

for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    # Ensures schedulers are interchangeable
    model_input = noise_scheduler.scale_model_input(sample, t)

    with torch.no_grad():
        epsilon = model(sample, t).sample

    sample = noise_scheduler.step(epsilon, t, sample).prev_sample

In [None]:
#########
## NEW ##
#########
plot_images(sample)

### Color

Let’s start by defining the color of "The Matrix", and generating a single pixel colored like that, expanded to the expected four dimensions of a PyTorch mini-batch of images: the number of data samples, image channels, image height, and image width (NCHW).

In [None]:
# The Matrix
color = (28/255, 161/255, 82/255)

colored_pixel = torch.tensor(color)[None, :, None, None]
colored_pixel, colored_pixel.shape # NCHW
ToPILImage()(colored_pixel[0]).resize((64, 64))

That's green for sure! Now, let's center its values at zero:

In [None]:
colored_pixel = colored_pixel * 2 - 1
colored_pixel

#### Color Loss

Once our guiding pixel is on the same footing as our noisy samples, we can take the mean absolute difference between them. We only have ONE green pixel to compare the other image to, but broadcasting has our backs, so it will effectively compare that lone pixel to every other pixel.

In [None]:
# Mean pixel difference between an image and our (broadcast) colored pixel
torch.abs(sample[0] - colored_pixel.to(device)).mean()

Let's organize this code into a function that takes a mini-batch of images and returns the corresponding loss:

In [None]:
def color_loss(images, color):
    colored_pixel = torch.tensor(color).to(images.device)[None, :, None, None]
    colored_pixel = colored_pixel * 2 - 1
    errors = torch.abs(images - colored_pixel)
    loss = errors.mean()
    return loss

In [None]:
color_loss(sample, color)

That's a loss value, but aren't you **missing** something?

We usually call the `backward()` method on losses, but it only makes sense to do so if the loss itself is gradient-requiring tensor, right? Let's fix that by **making our samples require gradient**.

In [None]:
sample_with_grad = sample.detach().requires_grad_()
loss = color_loss(sample_with_grad, color)
loss

Just like we update parameters using their gradients (based on the MSE loss, for example) and a learning rate, we'll update samples using their gradients based on the color loss we used as guidance and a learning rate-equivalent called `guidance_loss_scale`.

In [None]:
guidance_loss_scale = 40

grad = torch.autograd.grad(loss, sample_with_grad)[0]

sample = sample_with_grad.detach()
sample = sample - guidance_loss_scale * grad  # analogous to w = w - lr * grad

Notice that we're detaching the sample (i.e., removing the gradient requirement), because it will be an input to the scheduler's `step()` method.

Moreover, we will be computing the loss by comparing the **predicted denoised image**, the image we're actually interested in, and the one that we'd like to be green, to the guiding green pixel.

### Generating Images

Let's update our generation loop to include all the steps from the guidance code:
1. Make samples require gradient
2. Compute the color loss between the predicted denoised image and the guiding color
3. Compute gradients of the color loss with reference to the image's pixels
4. Detach the samples and update them using the gradients and the guidance loss scale (the "learning rate")

Other than that, the following code also includes a call to the scheduler's `scale_model_input()` method that ensures the schedulers are interchangeable (this was copied from the pipeline's generation method, but it has no impact in this example).

In [None]:
noise_scheduler = image_pipe.scheduler
model = image_pipe.unet

torch.manual_seed(33)
sample = torch.randn(8, 3, 32, 32).to(device)
guidance_loss_scale = 40

for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    # Ensures schedulers are interchangeable
    model_input = noise_scheduler.scale_model_input(sample, t)

    with torch.no_grad():
        epsilon = model(sample, t).sample

    ## GUIDANCE ##
    # Step 1
    sample_with_grad = sample.detach().requires_grad_()

    # Step 2
    # What does the denoised image look like at this point?
    pred_x0 = noise_scheduler.step(epsilon, t, sample_with_grad).pred_original_sample
    # Does it have the right color?
    loss = color_loss(pred_x0, color)
    if i % 10 == 0:
        print(i, "loss:", loss.item())

    # Step 3
    # Compute gradient
    grad = torch.autograd.grad(loss, sample_with_grad)[0]

    # Step 4
    # Detach the sample so it is a regular tensor again
    sample = sample_with_grad.detach()
    # Update the sample
    sample = sample - guidance_loss_scale * grad
    ##############

    # Uses the updated sample in the next step
    sample = noise_scheduler.step(epsilon, t, sample).prev_sample

What do the resulting images look like?

In [None]:
plot_images(sample)

WOW! They are definitely green like "The Matrix", so we guided them well into being green. But they apparently got lost in the digits department. The generation process became biased towards complying with the guidance, at the expense of the underlying task: generating digits.

You can try adjusting the `guidance_loss_scale` variable to see if you can make them more digit-ish and less green.

But there's also a different way of incorporating the guidance into the generation loop: gradients all the way!

We're moving Step 1 to the very top, making samples gradient-requiring from the get go, and ditching the `no_grad()` context manager altogether.

Notice that the loss is computed using `pred_x0`, which is computed using the sample that is requiring gradients now. This means that the denoising process itself is part of the dynamic computation graph now, thus affecting the gradients used to update the sample in the guidance process.

In [None]:
noise_scheduler = image_pipe.scheduler
model = image_pipe.unet

torch.manual_seed(33)
sample = torch.randn(8, 3, 32, 32).to(device)
guidance_loss_scale = 150

for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    # Step 1
    sample_with_grad = sample.detach().requires_grad_()

    # Ensures schedulers are interchangeable
    model_input = noise_scheduler.scale_model_input(sample_with_grad, t)

    # with torch.no_grad():
    epsilon = model(sample_with_grad, t).sample

    ## GUIDANCE ##
    # Step 2
    # What does the denoised image look like at this point?
    pred_x0 = noise_scheduler.step(epsilon, t, sample_with_grad).pred_original_sample

    # Does it have the right color?
    loss = color_loss(pred_x0, color)
    if i % 10 == 0:
        print(i, "loss:", loss.item())

    # Step 3
    # Compute gradient
    grad = torch.autograd.grad(loss, sample_with_grad)[0]

    # Step 4
    # Detach the sample so it is a regular tensor again
    sample = sample_with_grad.detach()
    # Update the sample
    sample = sample - guidance_loss_scale * grad

    # Uses the updated sample in the next step
    sample = noise_scheduler.step(epsilon, t, sample).prev_sample

Let's take a look at the images:

In [None]:
plot_images(sample)

They are more digit-ish now, that's for sure, and some of them are green-ish. Again, you can try tweaking the `guidance_loss_scale` to try generating more-balanced images.

But let's be honest, turning images green isn't that impressive, and it's also quite some work since we had to define a custom loss function just for that.

What if we could use words to guide our model instead?

In our next post in this series, we’ll take guidance to the next level and use Contrastive Language-Image Pre-training (CLIP) to direct our model to generate digits we specify through a text prompt. Stay tuned!