# Visualize the Attention on Images with Heat Map
In this tutorial, we will show step by step how to use `ICLTestbed` to visualize attention of an auto-regressive architecture model.

Let's take LLaVa 1.5 as an example. LLaVa is an auto-regressive architecture multimodal model that takes as input arbitrary sequences of texts and images, and generates text responses. The more details about LLaVa can be found in following ways:

[paper](https://arxiv.org/abs/2304.08485) [hf model](https://huggingface.co/llava-hf/llava-1.5-7b-hf) [official-code](https://github.com/haotian-liu/LLaVA)

## Step 1. Prepare Model and Inputs
The model in ICLTestbed can be roughly regarded as a simple combination of a processor and a specific model. You can access underlying processor or model by `model.processor` or `model.model`.

The model input should be a [conversation-like object](https://huggingface.co/docs/transformers/main/en/conversations), which is a format that is easy to understand. The corresponding chat is:

<ul>
  <li>
    User: <span style="vertical-align: top;">What is in this image?</span>
    <img src="./images/idefix.jpg" alt="Idefix image" style="vertical-align: top;">
  </li>
  <li>
    Assistant: This picture depicts Idefix, the dog of Obelix in Asterix and Obelix. Idefix is running on the ground.
  </li>
  <li>
    User: 
    <img src="./images/caesar.png" alt="Caesar image" style="vertical-align: top;">
    <span style="vertical-align: top;">And who is that?</span>
  </li>
  <li>
    Assistant:
  </li>
</ul>



In [1]:
import torch
import os
import PIL
import PIL.Image
from testbed.models import LLaVa

model_path = "/data/share/model_weight/llava/llava-1.5-7b-hf"
device = torch.device("cuda")

model = LLaVa(model_path, torch_dtype=torch.float16).to(device)

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What is in this image?"},
            {"type": "image"},
        ],
    },
    {
        "role": "assistant",
        "content": [
            {
                "type": "text",
                "text": "This picture depicts Idefix, the dog of Obelix in Asterix and Obelix. Idefix is running on the ground.",
            }
        ],
    },
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {
                "type": "text",
                "text": "And who is that?",
            },
        ],
    },
    {"role": "assistant"},
]

images = [
    PIL.Image.open(os.path.abspath(os.path.join("images", "idefix.jpg"))),
    PIL.Image.open(os.path.abspath(os.path.join("images", "caesar.png"))),
]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Step 2. Set a Tracker
To visualize attention, we need to get the attention scores over all layers.

Here, `testbed` offers various trackers to extract intermediate variables during forward. The tracker is an observer that records module inputs and outputs. 

For example, you can attach a `testbed.utils.tracker.ForwardTracker`
to the model to collect attention inputs and outputs during a forward. This is useful when a HF model is not support `output_attentions`.

In [2]:
from testbed.utils.tracker import ForwardTracker
import copy

with ForwardTracker() as k_proj, ForwardTracker() as q_proj:
    model.add_tracker(r"language_model.model.layers.\d+.self_attn.k_proj$", k_proj)
    model.add_tracker(r"language_model.model.layers.\d+.self_attn.q_proj$", q_proj)
    results = model.generate(
        images,
        conversation,
        max_new_tokens=20,
        return_generated_ids=True,
        return_inputs=True,
    )
    inputs, generated_text, generated_ids = (
        results["inputs"],
        results["outputs"][0],
        results["generated_ids"][0],
    )

    assert len(generated_ids) == len(k_proj.outputs)
    key_states, query_states = copy.deepcopy(k_proj.outputs), copy.deepcopy(
        q_proj.outputs
    )

    # manually split heads
    num_head = (
        model.config.text_config.num_attention_heads
    )  # this may various among different models
    for token_idx, (token_key, token_query) in enumerate(zip(key_states, query_states)):
        for layer_idx, (layer_key, layer_query) in enumerate(zip(token_key, token_query)):
            batch_size, kv_len, d_model = layer_key.shape
            key_states[token_idx][layer_idx] = layer_key.view(
                batch_size, kv_len, num_head, d_model // num_head
            ).transpose(1, 2)
            query_states[token_idx][layer_idx] = layer_query.view(
                batch_size, kv_len, num_head, d_model // num_head
            ).transpose(1, 2)

# btw, the correct answer is Caesar
generated_text

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


'That is Idefix, the dog of Obelix in Asterix and Obelix'

In some models, such as Idefics, an extra rotary embedding is added to key and query after `k_proj` and `q_proj`. In this case, `ForwardTracker` will not be able to obtain the final value of key and query. Therefore, `testbed` introduces `testbed.utils.tracker.LocalsTracker` to track the local variables of a module method.

If you decide to use `LocalsTracker`, you should check the source code of that module to figure out the names of the variables you want to extract.

In [3]:
from testbed.utils.tracker import LocalsTracker
import copy

with LocalsTracker("forward", ["key_states", "query_states"]) as kq_tracker:
    model.add_tracker(r"language_model.model.layers.\d+.self_attn$", kq_tracker)

    results = model.generate(
        images,
        conversation,
        max_new_tokens=20,
        return_generated_ids=True,
        return_inputs=True,
    )
    inputs, generated_text, generated_ids = (
        results["inputs"],
        results["outputs"][0],
        results["generated_ids"][0],
    )
    
    # in forward, key states are always concatenated with previous cache
    # thus, the last step states is states of the whole sequence
    key_states, query_states = copy.deepcopy(
        kq_tracker.get("key_states")[-1]
    ), copy.deepcopy(kq_tracker.get("query_states"))

    # make batch of key states 
    key_states = [key_states]


generated_text

'That is Idefix, the dog of Obelix in Asterix and Obelix'

In [4]:
key_states[0][0].shape, query_states[0][0].shape

(torch.Size([1, 32, 1239, 128]), torch.Size([1, 32, 1220, 128]))

## Step 3. Calculate Attention and Visualize
Now, let's calculate the attention scores, i.e., $\textrm{softmax}(\frac{KQ^\top}{\sqrt{d_k}})$.

In [5]:
max_new_tokens = len(query_states)
num_layer = len(query_states[0])  # the 0th forward
batch_size, num_head, kv_len, head_dim = query_states[0][
    0
].shape  # the 0th forward, the 0th layer
assert batch_size == 1, "batch_size should be 1"


def build_qk(outputs):
    # returns: [num_layer, num_head, seq_len, head_dim]
    return torch.cat(
        [torch.cat(token_states) for token_states in outputs],
        dim=2,
    ).to(dtype=torch.float32)


q = build_qk(query_states)
k = build_qk(key_states)
attn_weights = torch.softmax(
    torch.matmul(q, k.transpose(-2, -1)) / head_dim**0.5, dim=-1
)
# note that the last token of generated_ids is not used in attention calculation
seq_ids = torch.cat([inputs.input_ids.squeeze(), generated_ids[:-1]])
# we need to do some compute-intensive operations later
seq_ids, attn_weights = seq_ids.to(device), attn_weights.to(device)
seq_ids.shape, attn_weights.shape

(torch.Size([1239]), torch.Size([32, 32, 1239, 1239]))

In [13]:
from IPython.display import display
from utils import AttnVisualizer

display(AttnVisualizer(attn_weights, seq_ids, model.processor.tokenizer, images, "<image>"))

AttnVisualizer(children=(DisplayPanel(children=(VBox(children=(HBox(children=(IntRangeSlider(value=(0, 88), co…