# Plot attention maps

In [None]:
from sdhelper import SD
from PIL import Image
import re
import numpy as np
import matplotlib.pyplot as plt


In [None]:
sd = SD()

if 'flux' in sd.model_name.lower():
    sd.quantize(model_cpu_offload=True)

In [None]:
model = sd.pipeline.unet if hasattr(sd.pipeline, 'unet') else sd.pipeline.transformer

# attention related modules in the unet
for x in model.named_modules():
    if 'attn' in x[0]:
        print(x[0])

In [None]:
img_path = '../random_images_flux/0.jpg'
Image.open(img_path)

In [None]:
# get attention blocks
attentions = {}
for block_name in sd.available_extract_positions:
    block = eval(f'model.{block_name}', {'model': model}, None)
    if not hasattr(block, 'attentions'):
        continue
    for j, attention in enumerate(block.attentions):
        for k, transformer_block in enumerate(attention.transformer_blocks):
            for l, module in enumerate(transformer_block.named_modules()):
                if not re.search(r'^attn\d$', module[0]):
                    continue
                name = f'{block_name}.attentions[{j}].transformer_blocks[{k}].{module[0]}'
                attentions[name] = module[1]

# get attention q and k
extract_positions = [f'{a}.{l}' for a in attentions.keys() for l in ['to_q', 'to_k']]
reprs = sd.img2repr(img_path, extract_positions, 50)

for name, attn in attentions.items():
    # from: diffusers.models.attention_processor.AttnProcessor.__call__
    query = attn.head_to_batch_dim(reprs[name + '.to_q'])
    key = attn.head_to_batch_dim(reprs[name + '.to_k'])
    attention_probs = attn.get_attention_scores(query, key)
    assert attention_probs.ndim == 3

    n = int(attention_probs.shape[1]**.5)
    print(name, tuple(attention_probs.shape), n)

    i0s = [0, 0, 1, 3, 7]
    if attention_probs.shape[-1] == 77:
        # text attention
        i2s = [0, 10, 0, 42, 76]
        attn_probs = [(f'[{i0}, :, {i2}]', attention_probs[i0, :, i2]) for i0, i2 in zip(i0s, i2s)]
    else:
        # image attention
        i1s = [0, n-1, 0, n//2*n+n//2, n//3*n+n//3]
        i2s = [0, n-1, 0, n//2*n+n//2, n//3*n+n//3]
        attn_probs = [(f'[{i0}, {i1}, :]', attention_probs[i0, i1, :]) for i0, i1 in zip(i0s, i1s)] + [(f'[{i0}, :, {i2}]', attention_probs[i0, :, i2]) for i0, i2 in zip(i0s, i2s)]

    fig, axes = plt.subplots(len(attn_probs)//5, 5, figsize=(5*3, 3*len(attn_probs)//5))
    axes = axes.flatten()
    for i, (title, attn_prob) in enumerate(attn_probs):
        att_map = attn_prob.reshape(n, n).detach().cpu().numpy()  # works well
        im = axes[i].imshow(att_map, cmap='viridis')
        axes[i].set_title(title)
        axes[i].axis('off')
        # plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)

    plt.suptitle(name)
    plt.tight_layout()
    plt.show()