# Conditional Diffusion models

We’ll see how we can steer our model outputs towards specific types or classes of images. We can use a method called conditioning, where the idea is to ask the model to generate not just any image but an image belonging to a pre-defined class.

First, rather than using the butterflies dataset, we’ll switch to a dataset that has classes. We’ll use Fashion MNIST, a dataset with thousands of images of clothes associated with a label from 10 different classes. Then, crucially, we’ll run two inputs through the model. Instead of just showing it what real images look like, we’ll also tell it the class every image belongs to. We expect the model to learn to associate images and labels to understand the distinctive features of sweaters, boots, and the like.

Note that we are not interested in solving a classification problem – we don’t want the model to tell us which class the image belongs to.

## Preparing the data

a compact size, black-and-white images, and ten classes. The main difference is that classes correspond to different types of clothing instead of being digits, and the images contain more detail than simple handwritten digits.

In [None]:
!pip install datasets diffusers

In [None]:
from datasets import load_dataset

from genaibook.core import show_images

fashion_mnist = load_dataset("fashion_mnist")
clothes = fashion_mnist["train"]["image"][:8]
classes = fashion_mnist["train"]["label"][:8]
show_images(clothes, titles=classes, figsize=(4, 2.5))

Instead of resizing, we’ll pad our image inputs (28 × 28 pixels) to 32 × 32. This will preserve the original image quality, which will help the UNet make higher quality predictions.2

In [None]:
import torch
from torchvision import transforms

preprocess = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Pad(2),  # Add 2 pixels on all sides
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)


def transform(examples):
    images = [preprocess(image) for image in examples["image"]]
    return {"images": images, "labels": examples["label"]}


train_dataset = fashion_mnist["train"].with_transform(transform)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=256, shuffle=True
)

## Creatin a class conditioned model

The UNet from the diffusers library allows providing custom conditioning information.we add a num_class_embeds argument to the UNet constructor. This argument tells the model we’d like to use class labels as additional conditioning. We’ll use ten as that’s the number of classes in Fashion MNIST.

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    in_channels=1,  # 1 channel for grayscale images
    out_channels=1,
    sample_size=32,
    block_out_channels=(32, 64, 128, 256),
    num_class_embeds=10,  # Enable class conditioning
)

To make predictions with this model, we must pass in the class labels as additional inputs to the forward() method:

In [None]:
x = torch.randn((1, 1, 32, 32))
with torch.no_grad():
    out = model(x, timestep=7, class_labels=torch.tensor([2])).sample
out.shape

We also pass something else to the model as conditioning: the timestep! That’s right, even the model from Diffusion chapter can be considered a conditional diffusion model. We condition it on the timestep, expecting that knowing how far we are in the diffusion process will help it generate more realistic images.

Internally, the timestep and the class label are turned into embeddings that the model uses during its forward pass. At multiple stages throughout the UNet, these embeddings are projected onto a dimension that matches the number of channels in a given layer. The embeddings are then added to the outputs of that layer. This means the conditioning information is fed to every block of the UNet,



## Training the model

Adding noise works just as well on greyscale images as on the butterflies from Chapter 4. Let’s look at the impact of noise as we do more noising timesteps.

In [None]:
from diffusers import DDPMScheduler

scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02
)
timesteps = torch.linspace(0, 999, 8).long()
batch = next(iter(train_dataloader))
x = batch["images"][0].expand([8, 1, 32, 32])
noise = torch.rand_like(x)
noised_x = scheduler.add_noise(x, noise, timesteps)
show_images((noised_x * 0.5 + 0.5).clip(0, 1))

Our training loop is also almost the same as in Chapter 4, except that we now pass the class labels for conditioning. Note that this is just additional information for the model, but it doesn’t affect how we define our loss function in any way.

We’ll also display some progress during training using the Python package tqdm. tqdm means “progress” in Arabic (taqadum, تقدّم) and is an abbreviation for “I love you so much” in Spanish (te quiero demasiado).

1.Loads a batch of images and their corresponding labels.

2.Adds noise to the images based on their timestep.

3.Feeds the noisy images into the model, alongside the class labels for conditioning.

4.Calculates the loss.
5.Backpropagates the loss and updates the model weights with the optimizer.

In [None]:
from torch.nn import functional as F
from tqdm import tqdm

scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_start=0.0001, beta_end=0.02
)

num_epochs = 25
lr = 3e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, eps=1e-5)
losses = []  # Somewhere to store the loss values for later plotting

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Train the model (this takes a while!)
for epoch in (progress := tqdm(range(num_epochs))):
    for step, batch in (
        inner := tqdm(
            enumerate(train_dataloader),
            position=0,
            leave=True,
            total=len(train_dataloader),
        )
    ):
        # Load the input images and classes
        clean_images = batch["images"].to(device)
        class_labels = batch["labels"].to(device)

        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(device)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0,
            scheduler.config.num_train_timesteps,
            (clean_images.shape[0],),
            device=device,
        ).long()

        # Add noise to the clean images according to the timestep
        noisy_images = scheduler.add_noise(clean_images, noise, timesteps)

        # Get the model prediction for the noise - note the use of class_labels
        noise_pred = model(
            noisy_images,
            timesteps,
            class_labels=class_labels,
            return_dict=False,
        )[0]

        # Compare the prediction with the actual noise:
        loss = F.mse_loss(noise_pred, noise)

        # Display loss
        inner.set_postfix(loss=f"{loss.cpu().item():.3f}")

        # Store the loss for later plotting
        losses.append(loss.item())

        # Update the model parameters with the optimizer based on this loss
        loss.backward(loss)
        optimizer.step()
        optimizer.zero_grad()


In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)

## Sampling

In [None]:
We now have a model that expects two inputs when making predictions: the image and the class label. We can create samples by beginning with random noise and then iteratively denoising, passing in whatever class label we’d like to generate:

In [None]:
def generate_from_class(class_to_generate, n_samples=8):
    sample = torch.randn(n_samples, 1, 32, 32).to(device)
    class_labels = [class_to_generate] * n_samples
    class_labels = torch.tensor(class_labels).to(device)

    for _, t in tqdm(enumerate(scheduler.timesteps)):
        # Get model pred
        with torch.no_grad():
            noise_pred = model(sample, t, class_labels=class_labels).sample

        # Update sample with step
        sample = scheduler.step(noise_pred, t, sample).prev_sample

    return sample.clip(-1, 1) * 0.5 + 0.5

In [None]:
# Generate t-shirts (class 0)
images = generate_from_class(0)
show_images(images, nrows=2)

In [None]:
# Now generate some sneakers (class 7)
images = generate_from_class(7)
show_images(images, nrows=2)

In [None]:
# ...or boots (class 9)
images = generate_from_class(9)
show_images(images, nrows=2)


# Improving efficiency with Latent Diffusion Models

As image size grows, so does the computational power required to work with those images. This is especially pronounced in self-attention, where the amount of operations grows quadratically with the number of inputs. A 128px square image has four times as many pixels as a 64px square image, requiring 16 times the memory and computing in a self-attention layer. This is a problem for anyone who’d like to generate high-resolution images.

Latent diffusion tries to mitigate this issue using a separate Variational Auto-Encoder. As we saw in Chapter 2, VAEs can compress images to a smaller spatial dimension.

The VAE used in Stable Diffusion takes in 3-channel images and produces a 4-channel latent representation with a reduction factor of 8 for each spatial dimension. A 512px input square image (3x512x512=786,432 values) will be compressed down to a 4x64x64 latent (16,384 values).

In [None]:
from diffusers import AutoencoderKL, StableDiffusionPipeline

vae = AutoencoderKL.from_pretrained(
    "stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16
).to(device)
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
).to(device)

In [None]:
pipe("Watercolor illustration of a rose").images[0]