<span style="color: red;">Requirement when running in Goolge Colab</span>

In [None]:
!pip install diffusers

#  Chapter 4 - Attention Layer Breakdown

Attention layers, a key component of the transformer architecture, were introduced to Stable Diffusion to enhance its ability to capture long-range dependencies and context in image generation. These layers, inspired by the seminal "Attention Is All You Need" (https://arxiv.org/abs/1706.03762) paper by Vaswani et al. (2017), allow the model to focus on relevant parts of the input when generating each part of the output. In Stable Diffusion, attention layers play a crucial role in connecting text prompts to visual elements. Understanding and breaking down these attention layers is vital for advanced users and researchers who aim to modify or optimise the model's behavior. By overriding and analysing the attention mechanism, one can gain insights into how the model interprets prompts and constructs images, paving the way for targeted improvements, custom behaviors, or even novel applications of the technology.

The original architecture is only displaying the Cross Attention which the Key is the prompt condition and the Query is the latents, but in reality for each Cross-Attention in Stable Diffusion there is a prior Self-Attention which the Key and Query are the image latents.

![image.png](https://github.com/OutofAi/StableFace/blob/main/assets/image5.png?raw=1)

This part, is technically the same as previous chapter, but for simplification we removed the breakdown and shorten it, all individual elements of the Stable Diffusion architecture can be accessed through the pipeline variable itself, for example the u-net can be accessed by pipe.unet

In [None]:
import warnings
warnings.filterwarnings("ignore")
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch
import matplotlib.pyplot as plt
from typing import Optional
from tqdm import tqdm
from diffusers.models.attention_processor import Attention, AttnProcessor2_0

model_id = "stabilityai/stable-diffusion-2-1-base"

pipe = StableDiffusionPipeline.from_pretrained(model_id)
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = pipe.to("cuda")

prompt = "A photo of a woman, straight hair, light blonde and pink hair, smiling expression, grey background"

org_prompt_embeds = pipe.encode_prompt(prompt=prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
prompt_embeds_combined = torch.cat([uncond_prompt_embeds, org_prompt_embeds])

latents = torch.randn((1, pipe.unet.in_channels, pipe.unet.config.sample_size, pipe.unet.config.sample_size), generator=torch.Generator().manual_seed(220)).to("cuda")

First we expand a version of a attention layer calculation. We could then replace it with the main caclulation in the model layers.

If you are not familiar with math behind self-attention or cross-attention layers, I highly recommend this 3Blue1Brown YouTube video https://youtu.be/eMlx5fFNoYc explaining in details of how it works.

The default calcuation function in Stable Diffusion is an optimised c version which has python binding, so the pure python version will introduce a slight overhead to the current performance of sampling.

In [None]:
class AttnBreakdownProcessor(AttnProcessor2_0):

    def __init__(self):
        super().__init__()

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:

        residual = hidden_states

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, _, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))

        attention_probs = attention_scores.softmax(dim=-1)
        del attention_scores

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        del attention_probs

        hidden_states = attn.to_out[0](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

We now replace the default attention processer with our custom python version.

In [None]:
def replace_attention_processor(unet):

  replace_processor = AttnBreakdownProcessor()

  for name, module in unet.named_modules():
    if 'attn1' in name and 'to' not in name:
      module.processor = replace_processor


Now before running inference steps for sampling we overwrite our current attention layer implementation in the model

In [None]:
num_inference_steps = 50
guidance_scale = 7.5
scheduler.set_timesteps(num_inference_steps, device="cuda")
timesteps = scheduler.timesteps

with torch.no_grad():

    # prior to running our model we replace the unet forward function for all attention layer to our custom version
    replace_attention_processor(pipe.unet)

    for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):

        latent_model_input = torch.cat([latents] * 2)

        noise_pred = pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds_combined,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]


        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]


The results should theoratically look exactly the same as the previous chapter as we only overriden the attention layers forward function for our breakdown

In [None]:
with torch.no_grad():
    image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
    image_np = image.squeeze(0).float().permute(1,2,0).detach().cpu()
    image_np = (image_np / 2 + 0.5).clamp(0, 1)

In [None]:
plt.imshow(image_np)