This is notebook 1 on may way to implement https://github.com/huggingface/diffusers/issues/2121.<br/>
As the nb runs into OOMs, I have started a 2nd nb.

___

In [1]:
!pip install -Uqq fastcore

In [2]:
import numpy as np
import torch
from torch import nn, tensor
from fastcore.all import *

import diffusers

In [3]:
torch.__version__

'2.0.1'

Goal is to implement prompt-to-prompt pipeline (p2p).

What is p2p?
1. Takes a text input
2. Creates an image
3. Takes another text input, which should be a slight variation of the first text input
4. Create an adapted image that as similar to the firstly created image, but still depicts the adapted text

How?<br/>
My understanding:
- In every denoising step, we run the unet, which consists of many attention blocks
- When generating the first image, we save the attention maps (for every attention block and denosing step)
- When generating the second image, we substitute the saved attention maps
    - we need to save all attention maps
    - we need to match attention maps to the correct attention blocks
    - we need to a way to substitute attention maps

Steps:<br/>
1. In a minimal unet, save the attention maps
2. In a minimal unet, substitute the attention (eg use 2*attention)
3. Find out how to housekeep
4. Find out how to write above using AttentionProcessor class
5. Scale up & try out

Also, check out `Pix2Pix Zero`, a concept related to P2P, but which is already implemented in diffusers.

In [4]:
from diffusers.models import UNet2DConditionModel
from PIL import Image
from torchvision.transforms import ToTensor
from torch import tensor

In [5]:
im_shape = (20,20)
nc = 3  # number of channels

To keep things as simple as possible, let's first work with a tiny, unconditional unet

Edit: The unconditional UNet2DModel doesn't give access to it's attention blocks. This is probably because you only modify the attention when you want to get a specific kind of input, meaning you would always use a conditioned unet. => Let's use UNet2DConditionModel

In [6]:
from transformers import CLIPTextModel, CLIPTokenizer

Normally, we'd use `torch_dtype=torch.float16`, which is a half-precision float.
To keep things simple, we'll let the model live on the CPU (instead of GPU). As not all ops are implemented for half precsion floats on cpu, we instead use full precision floats (`torch_dtype=torch.float32`)

In [7]:
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float32)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float32)

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.13.self_attn.k_proj.weight', 'vision_model.encoder.layers.2.layer_norm2.weight', 'vision_model.encoder.layers.16.self_attn.q_proj.weight', 'vision_model.encoder.layers.1.mlp.fc1.weight', 'vision_model.encoder.layers.9.mlp.fc1.weight', 'vision_model.encoder.layers.2.mlp.fc1.bias', 'vision_model.encoder.layers.4.self_attn.v_proj.bias', 'text_projection.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.weight', 'vision_model.encoder.layers.7.mlp.fc2.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.bias', 'vision_model.encoder.layers.5.self_attn.k_proj.weight', 'vision_model.encoder.layers.0.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.layer_norm1.weight', 'vision_model.encoder.layers.0.mlp.fc1.weight', 'vision_model.encoder.layers.22.layer_norm2.bias', 'vision_model.encoder.layers.5.self_attn.out_proj.bias', 'vision_

In [8]:
some_text = 'bla'
txt_tok = tokenizer(some_text, return_tensors="pt", padding='max_length')
txt_emb = text_encoder(txt_tok.input_ids)[0]

In [9]:
txt_emb.shape

torch.Size([1, 77, 768])

Okay, let's createa a tiny (conditional) unet

In [10]:
baby_unet = UNet2DConditionModel(
    sample_size=im_shape,  # the target image resolution
    in_channels=nc,  # the number of input channels, 3 for RGB images
    out_channels=nc,  # the number of output channels
    layers_per_block=1,  # how many ResNet layers to use per UNet block
    block_out_channels=(32, 64, 128),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
    ),
)

In [11]:
baby_unet

UNet2DConditionModel(
  (conv_in): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=32, out_features=128, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=128, out_features=128, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0): ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(32, 32, 

In [12]:
img = torch.rand(1, nc,*im_shape)
img.shape

torch.Size([1, 3, 20, 20])

Running
`with torch.no_grad(): outp = baby_unet(img, 2, encoder_hidden_states=txt_emb)`
gives
![image.png](attachment:db53aada-ed3b-4d9f-bbc0-46f745666989.png)

The problem is that the text embedding's shape doesn't fit to our unet's shape.<br/>
So we have 2 options:
- Figure out what shape the unet expects, and create a fitting (random) txt embedding vector
- Use a working model (eg stable diffusion), which is not minimal, but works out of the box

=> I'll use a larger model

In [13]:
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14', torch_dtype=torch.float16).to('cuda')
unet = UNet2DConditionModel.from_pretrained('CompVis/stable-diffusion-v1-4', subfolder='unet', torch_dtype=torch.float16).to('cuda')

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.13.self_attn.k_proj.weight', 'vision_model.encoder.layers.2.layer_norm2.weight', 'vision_model.encoder.layers.16.self_attn.q_proj.weight', 'vision_model.encoder.layers.1.mlp.fc1.weight', 'vision_model.encoder.layers.9.mlp.fc1.weight', 'vision_model.encoder.layers.2.mlp.fc1.bias', 'vision_model.encoder.layers.4.self_attn.v_proj.bias', 'text_projection.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.weight', 'vision_model.encoder.layers.7.mlp.fc2.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.bias', 'vision_model.encoder.layers.5.self_attn.k_proj.weight', 'vision_model.encoder.layers.0.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.layer_norm1.weight', 'vision_model.encoder.layers.0.mlp.fc1.weight', 'vision_model.encoder.layers.22.layer_norm2.bias', 'vision_model.encoder.layers.5.self_attn.out_proj.bias', 'vision_

In [14]:
len(unet.down_blocks), len(unet.up_blocks)

(4, 4)

In [15]:
L(unet.down_blocks[0].named_children()).itemgot(0)

(#3) ['attentions','resnets','downsamplers']

In [16]:
unet.down_blocks[0].attentions[0].transformer_blocks[0].attn1

Attention(
  (to_q): Linear(in_features=320, out_features=320, bias=False)
  (to_k): Linear(in_features=320, out_features=320, bias=False)
  (to_v): Linear(in_features=320, out_features=320, bias=False)
  (to_out): ModuleList(
    (0): Linear(in_features=320, out_features=320, bias=True)
    (1): Dropout(p=0.0, inplace=False)
  )
)

In [17]:
unet.down_blocks[0].attentions[0].transformer_blocks[0].attn2

Attention(
  (to_q): Linear(in_features=320, out_features=320, bias=False)
  (to_k): Linear(in_features=768, out_features=320, bias=False)
  (to_v): Linear(in_features=768, out_features=320, bias=False)
  (to_out): ModuleList(
    (0): Linear(in_features=320, out_features=320, bias=True)
    (1): Dropout(p=0.0, inplace=False)
  )
)

Out text embedding has size `768`. It seems it's only used in `attn2` for the `k` and `v` projection

In [18]:
rand_lat = torch.rand((1,4,64,64)).to('cuda').half()
txt_emb = txt_emb.to('cuda').half()

In [19]:
with torch.no_grad(): outp = unet(rand_lat, 2, encoder_hidden_states=txt_emb)[0]

In [20]:
outp.shape

torch.Size([1, 4, 64, 64])

Okay, we got a workingunet.<br/>
Now, let's figure out how we can change it from the outside.

The following is adapted from `pipeline_stable_diffusion_pix2pix_zero.prepare_unet`:

In [21]:
def methods(obj): return [o for o in dir(obj) if not o.startswith('_')]

In [22]:
methods(unet)

['T_destination',
 'add_module',
 'apply',
 'attn_processors',
 'bfloat16',
 'buffers',
 'call_super_init',
 'children',
 'class_embedding',
 'config',
 'config_name',
 'conv_act',
 'conv_in',
 'conv_norm_out',
 'conv_out',
 'cpu',
 'cuda',
 'device',
 'disable_gradient_checkpointing',
 'disable_xformers_memory_efficient_attention',
 'double',
 'down_blocks',
 'dtype',
 'dump_patches',
 'enable_gradient_checkpointing',
 'enable_xformers_memory_efficient_attention',
 'encoder_hid_proj',
 'eval',
 'extra_repr',
 'extract_init_dict',
 'float',
 'forward',
 'from_config',
 'from_pretrained',
 'get_buffer',
 'get_config_dict',
 'get_extra_state',
 'get_parameter',
 'get_submodule',
 'half',
 'has_compatibles',
 'ignore_for_config',
 'ipu',
 'is_gradient_checkpointing',
 'load_attn_procs',
 'load_config',
 'load_state_dict',
 'mid_block',
 'modules',
 'named_buffers',
 'named_children',
 'named_modules',
 'named_parameters',
 'num_parameters',
 'num_upsamplers',
 'parameters',
 'register_bac

In [23]:
unet.attn_processors

{'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 at 0x7f9575b99f60>,
 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 at 0x7f9575b9bc70>,
 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 at 0x7f9575badbd0>,
 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 at 0x7f9575baf790>,
 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 at 0x7f9575c31900>,
 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': <diffusers.models.attention_processor.AttnProcessor2_0 at 0x7f9575c32830>,
 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': <diffusers.models.attention_processor.AttnProcessor2_0 at 0x7f9575c326b0>,

In [24]:
unet.set_attn_processor??

[0;31mSignature:[0m
[0munet[0m[0;34m.[0m[0mset_attn_processor[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mprocessor[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mdiffusers[0m[0;34m.[0m[0mmodels[0m[0;34m.[0m[0mattention_processor[0m[0;34m.[0m[0mAttnProcessor[0m[0;34m,[0m [0mdiffusers[0m[0;34m.[0m[0mmodels[0m[0;34m.[0m[0mattention_processor[0m[0;34m.[0m[0mAttnProcessor2_0[0m[0;34m,[0m [0mdiffusers[0m[0;34m.[0m[0mmodels[0m[0;34m.[0m[0mattention_processor[0m[0;34m.[0m[0mXFormersAttnProcessor[0m[0;34m,[0m [0mdiffusers[0m[0;34m.[0m[0mmodels[0m[0;34m.[0m[0mattention_processor[0m[0;34m.[0m[0mSlicedAttnProcessor[0m[0;34m,[0m [0mdiffusers[0m[0;34m.[0m[0mmodels[0m[0;34m.[0m[0mattention_processor[0m[0;34m.[0m[0mAttnAddedKVProcessor[0m[0;34m,[0m [0mdiffusers[0m[0;34m.[0m[0mmodels[0m[0;34m.[0m[0mattention_processor[0m[0;34m.[0m[0mSlicedAttnAddedKVProcessor[0m[0;34m,[0m [0mdiffusers[0m[0;34m.[0m

In [25]:
from diffusers.models.attention_processor import AttnProcessor2_0

In [26]:
AttnProcessor2_0??

[0;31mInit signature:[0m [0mAttnProcessor2_0[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      <no docstring>
[0;31mSource:[0m        
[0;32mclass[0m [0mAttnProcessor2_0[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0;32mif[0m [0;32mnot[0m [0mhasattr[0m[0;34m([0m[0mF[0m[0;34m,[0m [0;34m"scaled_dot_product_attention"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m            [0;32mraise[0m [0mImportError[0m[0;34m([0m[0;34m"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__call__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mattn[0m[0;34m:[0m [0mAttention[0m[0;34m,[0m [0mhidden_states[0m[0;34m,[0m [0mencoder_hidden_states[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mattention_mask[0m[0;34m=[

In [27]:
class ExcitedAttnProcessor2_0(AttnProcessor2_0):
    def __call__(self, *args, **kwargs):
        print('Oh boi! I\'m being called!')
        return super().__call__(*args, **kwargs)

Let's first replace **all** attention processors by our custome one

In [28]:
unet.set_attn_processor(ExcitedAttnProcessor2_0())

In [29]:
with torch.no_grad(): outp = unet(rand_lat, 2, encoder_hidden_states=txt_emb)[0]
outp.shape

Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!
Oh boi! I'm being called!


torch.Size([1, 4, 64, 64])

Works!

Next, let's replace each attention individually

In [30]:
class KnowledgeableAttnProcessor2_0(AttnProcessor2_0):
    def __init__(self, name):
        super().__init__()
        self.name = name

    def __call__(self, *args, **kwargs):
        print(f'My name? It is {self.name}. Imma do some attending now.')
        return super().__call__(*args, **kwargs)

In [31]:
attns = {
    k: KnowledgeableAttnProcessor2_0(name=k)
    for k,v in unet.attn_processors.items()
}

In [32]:
attns

{'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor': <__main__.KnowledgeableAttnProcessor2_0 at 0x7f9575b98460>,
 'down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor': <__main__.KnowledgeableAttnProcessor2_0 at 0x7f9648f58520>,
 'down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor': <__main__.KnowledgeableAttnProcessor2_0 at 0x7f95752417b0>,
 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor': <__main__.KnowledgeableAttnProcessor2_0 at 0x7f9575242500>,
 'down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor': <__main__.KnowledgeableAttnProcessor2_0 at 0x7f95752437f0>,
 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': <__main__.KnowledgeableAttnProcessor2_0 at 0x7f9575241c60>,
 'down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor': <__main__.KnowledgeableAttnProcessor2_0 at 0x7f95752425f0>,
 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': <__main__.KnowledgeableAttnProcessor

In [33]:
unet.set_attn_processor(attns)

In [34]:
with torch.no_grad(): outp = unet(rand_lat, 2, encoder_hidden_states=txt_emb)

My name? It is down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor. Imma do some attending now.
My name? It is down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor. Imma do some attending now.
My name? It is down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor. Imma do some attending now.
My name? It is down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor. Imma do some attending now.
My name? It is down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor. Imma do some attending now.
My name? It is down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor. Imma do some attending now.
My name? It is down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor. Imma do some attending now.
My name? It is down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor. Imma do some attending now.
My name? It is down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor. Imma do some attending now.
My name? It is down_blocks.2

Cool! I can now change attention processors for each attention block individually!

___

Let's now implement P2P.<br/>
Luckily, Weifeng Chen already implemented it [here](https://github.com/Weifeng-Chen/prompt2prompt/). Let's understand his code.

In [35]:
import abc

class AttentionControl(abc.ABC):
    
    def step_callback(self, x_t):
        return x_t
    
    def between_steps(self):
        return
    
    @property
    def num_uncond_att_layers(self):
        return 0
    
    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            h = attn.shape[0]   
            attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn
    
    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

In [36]:
class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 32 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention


    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

In [37]:
from typing import Callable, List, Optional, Union
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline

# todo:
# need to add negative prompting?
# and prompt_embeds (=pass already encoded prompt text, not text itself)?
# 
class Prompt2PromptPipeline(StableDiffusionPipeline):
    """todo: docs"""
    _optional_components = ["safety_checker", "feature_extractor"]

    @torch.no_grad()
    def __call__(
        self,
        prompt: Union[str, List[str]],
        height: Optional[int] = None,
        width: Optional[int] = None,
        controller: AttentionStore = None,  # todo: don't pass in controller, but use cross_attention_kwargs
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
        callback_steps: Optional[int] = 1,
    ):
        """todo: docs"""

        self.register_attention_control(controller) # add attention controller

        # 0. Default height and width to unet
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(prompt, height, width, callback_steps)

        # 2. Define call parameters
        batch_size = 1 if isinstance(prompt, str) else len(prompt)
        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        do_classifier_free_guidance = guidance_scale > 1.0

        # 3. Encode input prompt
        # following lines are missing:
        # text_encoder_lora_scale = (
        #   cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
        # )
        text_embeddings = self._encode_prompt(
            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
        )

        # 4. Prepare timesteps
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = self.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = self.unet.in_channels
        latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, device, generator, latents,)

        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # following lines are missing:
                # if do_classifier_free_guidance and guidance_rescale > 0.0:
                #     # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                #     noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                # step callback
                latents = controller.step_callback(latents)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        callback(i, t, latents)

        # 8. Post-processing
        image = self.decode_latents(latents)

        # 9. Run safety checker
        image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)

        # 10. Convert to PIL
        if output_type == "pil":
            image = self.numpy_to_pil(image)

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

    def register_attention_control(self, controller):
        attn_procs = {}
        cross_att_count = 0
        for name in self.unet.attn_processors.keys():
            # comment Umer: we seem to only using the 2nd attn in each attn block
            cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim

            if name.startswith("mid_block"):
                hidden_size = self.unet.config.block_out_channels[-1]
                place_in_unet = "mid"
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
                place_in_unet = "up"
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = self.unet.config.block_out_channels[block_id]
                place_in_unet = "down"
            else:
                continue
            cross_att_count += 1
            attn_procs[name] = P2PCrossAttnProcessor(
                controller=controller, place_in_unet=place_in_unet
            )

        self.unet.set_attn_processor(attn_procs)
        controller.num_att_layers = cross_att_count

In [52]:
from diffusers.models.cross_attention import CrossAttention

class P2PCrossAttnProcessor:
    def __init__(self, controller, place_in_unet):
        super().__init__()
        self.controller = controller
        self.place_in_unet = place_in_unet

    def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        query = attn.to_q(hidden_states)

        is_cross = encoder_hidden_states is not None
        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        
        self.controller(attention_probs, is_cross, self.place_in_unet) # one line change

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)
        hidden_states = attn.to_out[0](hidden_states) # linear proj
        hidden_states = attn.to_out[1](hidden_states) # dropout
        return hidden_states

In [53]:
class AttentionControlEdit(AttentionStore, abc.ABC):
    def step_callback(self, x_t):
        if self.local_blend is not None: x_t = self.local_blend(x_t, self.attention_store)
        return x_t
        
    def replace_self_attention(self, attn_base, att_replace):
        if att_replace.shape[2] <= 16 ** 2: return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
        else: return att_replace
    
    @abc.abstractmethod
    def replace_cross_attention(self, attn_base, att_replace): raise NotImplementedError
    
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
        # FIXME not replace correctly
        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
            h = attn.shape[0] // (self.batch_size)
            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
            attn_base, attn_repalce = attn[0], attn[1:]
            if is_cross:
                alpha_words = self.cross_replace_alpha[self.cur_step]
                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
                attn[1:] = attn_repalce_new
            else: attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
        return attn
    
    def __init__(self, prompts, num_steps: int,
                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
                 self_replace_steps: Union[float, Tuple[float, float]],
                 local_blend,
                 tokenizer,
                 device):
        super(AttentionControlEdit, self).__init__()
        # add tokenizer and device here

        self.tokenizer = tokenizer
        self.device = device

        self.batch_size = len(prompts)
        self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, self.tokenizer).to(self.device)
        if type(self_replace_steps) is float: self_replace_steps = 0, self_replace_steps
        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
        self.local_blend = local_blend  # 在外面定义后传进来

In [54]:
def get_time_words_attention_alpha(prompts, num_steps,
                                   cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
                                   tokenizer, max_num_words=77):
    if type(cross_replace_steps) is not dict: cross_replace_steps = {"default_": cross_replace_steps}
    if "default_" not in cross_replace_steps: cross_replace_steps["default_"] = (0., 1.)
    alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
    for i in range(len(prompts) - 1):
        alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
                                                  i)
    for key, item in cross_replace_steps.items():
        if key != "default_":
             inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
             for i, ind in enumerate(inds):
                 if len(ind) > 0: alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
    alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
    return alpha_time_words

In [55]:
def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
                           word_inds: Optional[torch.Tensor]=None):
    if type(bounds) is float: bounds = 0, bounds
    start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
    if word_inds is None: word_inds = torch.arange(alpha.shape[2])
    alpha[: start, prompt_ind, word_inds] = 0
    alpha[start: end, prompt_ind, word_inds] = 1
    alpha[end:, prompt_ind, word_inds] = 0
    return alpha

In [56]:
def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
    words_x = x.split(' ')
    words_y = y.split(' ')
    if len(words_x) != len(words_y):
        raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
                         f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
    inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
    inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
    inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
    mapper = np.zeros((max_len, max_len))
    i = j = 0
    cur_inds = 0
    while i < max_len and j < max_len:
        if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
            inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
            if len(inds_source_) == len(inds_target_):
                mapper[inds_source_, inds_target_] = 1
            else:
                ratio = 1 / len(inds_target_)
                for i_t in inds_target_: mapper[inds_source_, i_t] = ratio
            cur_inds += 1
            i += len(inds_source_)
            j += len(inds_target_)
        elif cur_inds < len(inds_source):
            mapper[i, j] = 1
            i += 1
            j += 1
        else:
            mapper[j, j] = 1
            i += 1
            j += 1
    return torch.from_numpy(mapper).float()

def get_replacement_mapper(prompts, tokenizer, max_len=77):
    x_seq = prompts[0]
    mappers = []
    for i in range(1, len(prompts)):
        mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
        mappers.append(mapper)
    return torch.stack(mappers)

In [57]:
def get_word_inds(text: str, word_place: int, tokenizer):
    split_text = text.split(" ")
    if type(word_place) is str: word_place = [i for i, word in enumerate(split_text) if word_place == word]
    elif type(word_place) is int: word_place = [word_place]
    out = []
    if len(word_place) > 0:
        words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
        cur_len, ptr = 0, 0

        for i in range(len(words_encode)):
            cur_len += len(words_encode[i])
            if ptr in word_place: out.append(i + 1)
            if cur_len >= len(split_text[ptr]):
                ptr += 1
                cur_len = 0
    return np.array(out)

In [58]:
class AttentionReplace(AttentionControlEdit):
    def replace_cross_attention(self, attn_base, att_replace):
        return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
      
    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
                 local_blend = None, tokenizer=None, device=None):
        super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device)
        self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)

In [59]:
device = 'cuda'
g_cpu = torch.Generator().manual_seed(2333)
prompts = ['A painting of a squirrel eating a burger',
           'A painting of a cat eating a burger']

NUM_DIFFUSION_STEPS = 20

In [60]:
pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4" )

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [61]:
controller = AttentionReplace(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=0.4, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=pipe.device)

In [62]:
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=NUM_DIFFUSION_STEPS,
                controller=controller, generator=g_cpu,)

  num_channels_latents = self.unet.in_channels


  0%|          | 0/20 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB (GPU 0; 15.75 GiB total capacity; 14.34 GiB already allocated; 326.44 MiB free; 14.47 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
view_images( [np.array(img) for img in outputs.images] )

In [None]:
pipe.show_cross_attention(prompts, controller, res=16, from_where=("up", "down"), select=0)
pipe.show_cross_attention(prompts, controller, res=16, from_where=("up", "down"), select=1)