<span style="color: red;">Requirement when running in Goolge Colab</span>

In [None]:
# !pip install diffusers

#  Chapter 5 - Replace Attention Layers

So far we learnt how to generate images and how to break the model down all the way to its attention layer. As mentioned in Stable Diffusion in its U-Net architecture has self-attention and cross-attention and a resent on top of that to decrease or increase dimensionality based of whether its in encoder or decoder stage, so we are dealing with different dimensions of attention maps as well.

The idea in this chapter is to generate a synthetic image with a prompt and an editted version of that prompt and during generation to replace some dimension of the cross-attention maps and see the relevant results and see if we can achieve similiar images with the required editting

This part, as before, it's been copied from the previous chapter and you can and run and move to the next cell

In [None]:
import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline
import torch
import matplotlib.pyplot as plt
from typing import Optional
from tqdm import tqdm


model_id = "stabilityai/stable-diffusion-2-1-base"

pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to("cuda")

prompt = "A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background"

prompt_embeds = pipe.encode_prompt(prompt=prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt="")

initial_latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(220)).to("cuda")

In [None]:
new_prompt = "A photo of a woman, curly hair, light blonde and pink hair, smiling expression, grey background"

new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt="")
prompt_embeds_combined = torch.cat([prompt_embeds[1], prompt_embeds[0], new_prompt_embeds[1], new_prompt_embeds[0]])

A Modification of the contextual_forward from the previous section with the addition of replacing the target self-attention maps with the ones of the source and we used a specific configuration that granted the best results for our task at hand

In [None]:

def contextual_forward(self, source_batch_index, target_batch_index, skip_dimension, should_replace = False):

    def forward_modified(
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
    ) -> torch.FloatTensor:

            residual = hidden_states

            is_cross = not encoder_hidden_states is None

            input_ndim = hidden_states.ndim

            if input_ndim == 4:
                batch_size, channel, height, width = hidden_states.shape
                hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

            batch_size, _, _ = (
                hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
            )

            if self.group_norm is not None:
                hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

            query = self.to_q(hidden_states)

            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            elif self.norm_cross:
                encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)

            key = self.to_k(encoder_hidden_states)
            value = self.to_v(encoder_hidden_states)

            query = self.head_to_batch_dim(query)
            key = self.head_to_batch_dim(key)
            value = self.head_to_batch_dim(value)

            attention_scores = self.scale * torch.bmm(query, key.transpose(-1, -2))

            #############################################################
            ### The replacing process of attention maps happens here    ###
            #############################################################

            # each attention_scores comes with a batch shape of heads * batch size and considering
            # the unconditional prompt that we require for CFG technically if couple them up as one
            # batch then technically it comes in the format of heads * batch size * 2 and we use that
            # information to find the relevant indicies to each batch
            num_heads = self.heads
            source_starting_index = num_heads * source_batch_index * 2
            source_ending_index = num_heads * (source_batch_index + 1) * 2
            target_starting_index = num_heads * target_batch_index * 2
            target_ending_index = num_heads * (target_batch_index + 1) * 2

            dimension_squared = hidden_states.shape[1]

            # our experiement showed that this is the combination granted the best results when it comes
            # to facial related changes, hence this specific configuration for our replacement
            if not is_cross and (should_replace or not dimension_squared == skip_dimension * skip_dimension):
                attention_scores[target_starting_index:target_ending_index,:,:] = attention_scores[source_starting_index:source_ending_index,:,:]
            #############################################################
            attention_probs = attention_scores.softmax(dim=-1)
            del attention_scores

            hidden_states = torch.bmm(attention_probs, value)
            hidden_states = self.batch_to_head_dim(hidden_states)
            del attention_probs

            hidden_states = self.to_out[0](hidden_states)

            if input_ndim == 4:
                hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

            if self.residual_connection:
                hidden_states = hidden_states + residual

            hidden_states = hidden_states / self.rescale_output_factor

            return hidden_states

    return forward_modified

To differniate between encoding state (Down) and decoding state (Up) of the U-Net attention checks needs to be added by using the class name:

In [None]:
# source_batch_index and target_batch_index added to copy the relevant attention maps
def apply_forward_function(unet, child = None, source_batch_index=-1, target_batch_index=-1, should_replace= False):
    if child == None:
        children = unet.named_children()
        for child in children:
            apply_forward_function(unet, child[1], source_batch_index, target_batch_index, should_replace)
    else:
        if child.__class__.__name__ == 'Attention':
            child.forward = contextual_forward(child, source_batch_index, target_batch_index, pipe.unet.config.sample_size, should_replace)
        elif hasattr(child, 'children'):
            for sub_child in child.children():
                block_name = child.__class__.__name__

                if "Down" in block_name:
                    should_replace = True
                elif "Up" in block_name:
                    should_replace = False
                elif "Mid" in block_name:
                    should_replace = True
                apply_forward_function(unet, sub_child, source_batch_index, target_batch_index, should_replace)

A wrapper for displaying the latents, considering we are dealing with 2 images this time

In [None]:
def display_latents(latents):
    with torch.no_grad():
        image_0 = pipe.vae.decode(latents[0].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_0 = image_0.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()
        image_np_0 = (image_np_0 - image_np_0.min()) / (image_np_0.max() - image_np_0.min())

        image_1 = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_1 = image_1.squeeze(0).float().permute(1, 2, 0).detach().cpu().numpy()
        image_np_1 = (image_np_1 - image_np_1.min()) / (image_np_1.max() - image_np_1.min())

        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        axes[0].imshow(image_np_0)
        axes[0].axis('off')
        axes[0].set_title('Latent 0')

        axes[1].imshow(image_np_1)
        axes[1].axis('off')
        axes[1].set_title('Latent 1')

        plt.show()

# Displaying the sampling results WITHOUT replacing the attention layer

In [None]:
num_inference_steps = 50
guidance_scale = 7.5
pipe.scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = pipe.scheduler.timesteps

latents = torch.cat([initial_latents] * 2)

with torch.no_grad():

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_combined,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)

        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)

        latents = pipe.scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]

    display_latents(latents)


# Displaying the sampling results WITH replacing the attention layer

In [None]:
num_inference_steps = 50
guidance_scale = 7.5
pipe.scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = pipe.scheduler.timesteps

latents = torch.cat([initial_latents] * 2)

with torch.no_grad():

    # we want to replace the self-attention maps of the batch with the new prompt with the ones from the old prompt
    # therefore copying 1 to 0
    apply_forward_function(pipe.unet, source_batch_index=0, target_batch_index=1)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        latent_model_input = torch.cat([latents[0].unsqueeze(0),latents[0].unsqueeze(0),latents[1].unsqueeze(0),latents[1].unsqueeze(0)])

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_combined,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text, new_noise_pred_uncond, new_noise_pred_text = noise_pred.chunk(4)

        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        new_noise_pred = new_noise_pred_uncond + guidance_scale * (new_noise_pred_text - new_noise_pred_uncond)

        latents = pipe.scheduler.step(torch.cat([noise_pred, new_noise_pred]), t, latents, return_dict=False)[0]

    display_latents(latents)

