# Making a attribute Conditioned Diffusion Model

In this notebook we will implement one way to add conditioning information to a diffusion model. Specifically, we'll train a attribute-conditioned diffusion model on CelebA dataset following on from the [huggingface example](https://github.com/huggingface/diffusion-models-class/blob/unit2/unit1/02_diffusion_models_from_scratch.ipynb), where we make several improvements in terms of model architecture and dataset.

This is one of many ways we could add additional conditioning information to a diffusion model, and has been chosen for its relative simplicity.

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Setup and Data Prep

In [4]:
%pip install -q diffusers accelerate

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.8 MB[0m [31m5.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m28.8 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/265.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m265.7/265.7 kB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

import os

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cuda


In [None]:
import torchvision.transforms as T

img_size = 64
transforms = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

# Load the dataset
# Using the "attr" (attribute) labels here for more control in image generation
dataset = torchvision.datasets.CelebA(root="celeba/", split="Train", target_type="attr", download=True, transform=transforms)

# Feed it into a dataloader (batch size 8 here just for demo)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:
# View some examples
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Sample Label for the first datapoint:', y[0])
plt.imshow(torchvision.utils.make_grid(x).numpy().transpose(1,2,0));

## Creating a Class-Conditioned UNet

The way we'll feed in the class conditioning is as follows:
- Create a standard `UNet2DModel` with some additional input channels  
- Map the class label to a learned vector of shape `(class_emb_size)`via an Linear layer (in place of an embedding layer in popular diffusion models)
- Concatenate this information as extra channels for the internal UNet input with `net_input = torch.cat((x, class_cond), 1)`
- Feed this `net_input` (which has (`3 + class_emb_size`) channels in total) into the UNet to get the final prediction

In this case I've set the class_emb_size to 5, (since tehre is 40 binary attributes => 5 bits encoding information) but this is experimental and different encoding sizes can be explored.

In [None]:
class ClassConditionedUnet(nn.Module):
  def __init__(self, class_attr=40, class_emb_size=5):
    super().__init__()

    # The linear layer will map the class attributes to a vector of size class_emb_size
    self.class_emb = nn.Linear(class_attr, class_emb_size)

    # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
    self.model = UNet2DModel(
        sample_size=64,           # the target image resolution
        in_channels=3 + class_emb_size, # Additional input channels for class cond.
        out_channels=3,           # the number of output channels
        layers_per_block=2,       # how many ResNet layers to use per UNet block
        block_out_channels=(32, 32, 32, 64),
        down_block_types=(
            "DownBlock2D",        # a regular ResNet downsampling block
            "AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",          # a regular ResNet upsampling block
          ),
    )

  # Our forward method now takes the class labels as an additional argument
  def forward(self, x, t, class_attr):
    # Shape of x:
    bs, ch, w, h = x.shape
    # print("class_attr shape = ", class_attr.shape, class_attr.type)
    # class conditioning in right shape to add as additional input channels
    class_cond = self.class_emb(class_attr.float().to(device)) # Map to embedding dimension
    # print("class_cond shape =", class_cond.shape)

    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # x is shape (bs, 3, 64, 64) and class_cond is now (bs, 5, 64, 64)

    # Net input is now x and class cond concatenated together along dimension 1
    net_input = torch.cat((x, class_cond), 1) # (bs, 8, 64, 64)

    # Feed this to the UNet alongside the timestep and return the prediction
    return self.model(net_input, t).sample # (bs, 8, 64, 64)

## Training and Sampling

We'll now add the attribute labels as a third argument (`prediction = unet(x, t, y)`) during training, and at inference we can pass whatever attributes we want and the model should generate images that match. `y` in this case is the attributes of the CelebA faces, with length 40 and values [-1,1]

We predict the noise (rather than the denoised image) to match the objective expected by the default DDPMScheduler which we're using to add noise during training and to generate samples at inference time. Training takes a while - speeding this up could be another project.

In [None]:
from accelerate import Accelerator

accelerator = Accelerator()

In [None]:
# Redefining the dataloader to set the batch size higher than the demo of 8
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
train_dataloader = accelerator.prepare(train_dataloader)
# How many runs through the data should we do?
n_epochs = 10

# Our network
net = ClassConditionedUnet().to(device)

trainable_params = sum(
	p.numel() for p in net.parameters() if p.requires_grad
)

print("total no.of parameetrs in unet model = ", trainable_params)


In [None]:
# Create a scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

### Training from a Previous checkpoint

In [None]:
# if you want to strat training from a previous checkpoint then run this cell
# otherwise skip
epoch_no = 10
path = "drive/MyDrive/Colab Notebooks/"+str(epoch_no)+"_ckpt.pt"
net.load_state_dict(torch.load(path, map_location=torch.device('cpu')))

### The Core Tarinng Loop

For every typle `(x,y)`, we randomly choose a timestep `t` and add the cooresponding noise to the original image `noise_scheduler.add_noise(x, noise, timesteps)`

Then the UNet model predicts the amount of noise present in the image. This loss is then backpropagated to update the model. Note that we also pass on the labels/attributes in this case such that it iteratively learns the conditonal information while denoising the image.

In [None]:
# Our loss function
loss_fn = nn.MSELoss()

# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)

# Keeping a record of the losses for later viewing
losses = []

# The training loop
for epoch in range(n_epochs):
    for x, y in tqdm(train_dataloader):

        # Get some data and prepare the corrupted version
        x = x.to(device) * 2 - 1 # Data on the GPU (mapped to (-1, 1))
        y = y.to(device)
        noise = torch.randn_like(x)
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        # Get the model prediction
        pred = net(noisy_x, timesteps, y) # Note that we pass in the labels y

        # Calculate the loss
        loss = loss_fn(pred, noise) # How close is the output to the noise

        # Backprop and update the params:
        opt.zero_grad()
        accelerator.backward(loss)
        opt.step()

        # Store the loss for later
        losses.append(loss.item())

    # Print out the average of the last 100 loss values to get an idea of progress:
    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')
    torch.save(net.state_dict(), os.path.join("./drive/MyDrive/Colab Notebooks/", str(epoch + epoch_no+1)+f"_ckpt.pt"))

# View the loss curve
plt.plot(losses)

### Prediction

This part is pretty straightforward. We start with random noise and iterate over 1000 timesteps to get the real image.

We also pass on the attribute labels to impose additional conditonality to the model.

Note: In case the outputs are not visible, refer to the final project report in the github repo.

In [None]:
# Prepare random x to start from, plus some desired labels y
num_img = 24
x = torch.randn(num_img, 3, 64, 64).to(device)

attr1 = [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
         1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1]

attr2 = [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1]

attr00012 = [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1]

y = torch.tensor([attr00012]*num_img).to(device)

# print(y.shape)
# Sampling loop
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

    # Get model pred
    with torch.no_grad():
        residual = net(x, t, y)  # Again, note that we pass in our labels y

    # Update sample with step
    x = noise_scheduler.step(residual, t, x).prev_sample


In [None]:
print(x.shape)

In [None]:
out = (((x + 1)/2.0)*255).to(torch.uint8)
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(out.detach().cpu()).permute(1,2,0))

### Some more visualization

In [None]:
from PIL import Image
import numpy as np
im = np.array(Image.open("./drive/MyDrive/Colab Notebooks/000012.jpg").resize((64,64)))
im = (im/255.0)
im000012 = torch.Tensor(np.expand_dims(im, axis=-1)).permute(3, 2, 0, 1).to(device)

print(im000012.shape)
x = torch.randn(im000012.shape).to(device)

x = (im000012 + x)

attr000012 = [0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1]

y = torch.tensor([attr000012]).to(device)

for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

    # Get model pred
    with torch.no_grad():
        residual = net(x, t, y)  # Again, note that we pass in our labels y

    # Update sample with step
    x = noise_scheduler.step(residual, t, x).prev_sample


out = (((x + 1)/2.0)*255).to(torch.uint8)
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.imshow(torchvision.utils.make_grid(out.detach().cpu()).permute(1,2,0))

### Testing the model on other tasks

In [None]:
from PIL import Image
import numpy as np
im = np.array(Image.open("./drive/MyDrive/Colab Notebooks/hoi/0368.jpg").resize((64,64)))
im = (im/255.0)
im368 = torch.Tensor(np.expand_dims(im, axis=-1)).permute(3, 2, 0, 1).to(device)

print(im368.shape)
x = torch.randn(im368.shape).to(device)

x = im368

attr000012 = [0]*40

y = torch.tensor([attr000012]).to(device)

image_viewer = []

for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

    # Get model pred
    with torch.no_grad():
        residual = net(x, t, y)  # Again, note that we pass in our labels y

    # Update sample with step
    x = noise_scheduler.step(residual, t, x).prev_sample
    image_viewer.append(x)

out = (((x + 1)/2.0)*255).to(torch.uint8)
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.imshow(torchvision.utils.make_grid(out.detach().cpu()).permute(1,2,0))

In [None]:
out = []
for i in range(len(image_viewer)):
  if i%50 == 0:
    x = image_viewer[i]
    out.append((((x + 1)/2.0)*255).to(torch.uint8))



In [None]:
# out = torch.tensor(np.asarray(out, dtype=np.float32))
# print(np.asarray(out).shape)
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.imshow(torchvision.utils.make_grid(out[5].detach().cpu()).permute(1,2,0))

In [None]:
import os, numpy, PIL
from PIL import Image

# Access all PNG files in directory
allfiles=os.listdir("./drive/MyDrive/Colab Notebooks/hoi")
imlist=[filename for filename in allfiles if  filename[-4:] in [".jpg"]]

# Assuming all images are the same size, get dimensions of first image
w,h=Image.open(os.path.join("./drive/MyDrive/Colab Notebooks/hoi",imlist[0])).resize((256,256)).size
N=len(imlist)

# Create a numpy array of floats to store the average (assume RGB images)
arr=numpy.zeros((h,w,3),numpy.float)

# Build up average pixel intensities, casting each image as an array of floats
for im in imlist:
    imarr=numpy.array(Image.open(os.path.join("./drive/MyDrive/Colab Notebooks/hoi",im)).resize((256,256)),dtype=numpy.float)
    arr=arr+imarr/N

# Round values in array and cast as 8-bit integer
arr=numpy.array(numpy.round(arr),dtype=numpy.uint8)

# Generate, save and preview final image
out=Image.fromarray(arr,mode="RGB")
out.save("Average.png")
out.show()

In [None]:
imarr=Image.open(os.path.join("./drive/MyDrive/Colab Notebooks/hoi/0038.jpg")).resize((256,256))


imarr.save("oneimage.png")
imarr.show()