<span style="color: red;">Requirement when running in Goolge Colab</span>

In [None]:
!pip install diffusers

#  Chapter 5 - Replace Attention Layers

So far we learnt how to generate images and how to break the model down all the way to its attention layer. As mentioned in Stable Diffusion in its U-Net architecture has self-attention and cross-attention and a resent on top of that to decrease or increase dimensionality based of whether its in encoder or decoder stage, so we are dealing with different dimensions of attention maps as well.

The idea in this chapter is to generate a synthetic image with a prompt and an editted version of that prompt and during generation to replace some dimension of the cross-attention maps and see the relevant results and see if we can achieve similiar images with the required editting

This part, as before, it's been copied from the previous chapter and you can and run and move to the next cell

In [None]:
import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch
import matplotlib.pyplot as plt
from typing import Optional
from tqdm import tqdm
from diffusers.models.attention_processor import Attention, AttnProcessor2_0

model_id = "stabilityai/stable-diffusion-2-1-base"

pipe = StableDiffusionPipeline.from_pretrained(model_id)
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = pipe.to("cuda")

prompt = "A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background"

org_prompt_embeds = pipe.encode_prompt(prompt=prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]

initial_latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(220)).to("cuda")

In [None]:
new_prompt = "A photo of a woman, curly hair, light blonde and pink hair, smiling expression, grey background"

new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
prompt_embeds_combined = torch.cat([uncond_prompt_embeds, uncond_prompt_embeds, org_prompt_embeds, new_prompt_embeds])

A Modification of the AttnBreakdownProcessor from the previous section with the addition of replacing the destination self-attention maps with the ones from the source.
We used a specific configuration that empirically granted the best results.

In [None]:
class AttnReplaceProcessor(AttnProcessor2_0):

    def __init__(self, replace_all):
        super().__init__()
        self.replace_all = replace_all

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:

        residual = hidden_states

        is_cross = not encoder_hidden_states is None

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, _, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))

        #############################################################
        ### The replacing process of attention maps happens here  ###
        #############################################################

        dimension_squared = hidden_states.shape[1]

        # our experiement showed that this is the combination granted the best results when it comes
        # to facial related changes, hence this specific configuration for our replacement
        skip_dimension = pipe.unet.config.sample_size

        if not is_cross and (self.replace_all and not dimension_squared == skip_dimension * skip_dimension):

            ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
            attn_scores_dst.copy_(attn_scores_src)
            ucond_attn_scores_dst.copy_(ucond_attn_scores_src)
        #############################################################

        attention_probs = attention_scores.softmax(dim=-1)
        del attention_scores

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        del attention_probs

        hidden_states = attn.to_out[0](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

To differentiate between encoding state (down) and decoding state (up) of the U-Net attention checks needs to be added by using the class name:

In [None]:
def replace_attention_processor(unet, use_default = False):

  for name, module in unet.named_modules():
    if 'attn1' in name and 'to' not in name:
        layer_type = name.split('.')[0].split('_')[0]
        if layer_type == 'down':
            module.processor = AttnReplaceProcessor(True) if not use_default else AttnProcessor2_0()
        elif layer_type == 'mid':
            module.processor = AttnReplaceProcessor(False) if not use_default else AttnProcessor2_0()
        elif layer_type == 'up':
            module.processor = AttnReplaceProcessor(True) if not use_default else AttnProcessor2_0()

A wrapper for displaying the latents, considering we are dealing with 2 images this time

In [None]:
def display_latents(latents):
    with torch.no_grad():
        image_0 = pipe.vae.decode(latents[0].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_0 = image_0.squeeze(0).float().permute(1, 2, 0).detach().cpu()
        image_np_0 = (image_np_0 / 2 + 0.5).clamp(0, 1)

        image_1 = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
        image_np_1 = image_1.squeeze(0).float().permute(1, 2, 0).detach().cpu()
        image_np_1 = (image_np_1 / 2 + 0.5).clamp(0, 1)

        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        axes[0].imshow(image_np_0)
        axes[0].axis('off')
        axes[0].set_title('Latent 0')

        axes[1].imshow(image_np_1)
        axes[1].axis('off')
        axes[1].set_title('Latent 1')

        plt.show()

# Displaying the sampling results WITHOUT replacing the attention layer

In [None]:
num_inference_steps = 50
guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

latents = torch.cat([initial_latents] * 2)

with torch.no_grad():

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_combined,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    display_latents(latents)


# Displaying the sampling results WITH replacing the attention layer

In [None]:
num_inference_steps = 50
guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

latents = torch.cat([initial_latents] * 2)

with torch.no_grad():
    # Apply attention replacer to the unet, you only need to apply this once, since it redirects the processer to your custom version
    replace_attention_processor(pipe.unet)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_combined,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

    display_latents(latents)

    # default back to the default attnetion process when we are done
    replace_attention_processor(pipe.unet)

