In [2]:
import pdb;
import os
import re
import time
import torch
import PIL.Image as Image

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)

In [None]:
from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor

# 0. Common setups
min_pixels = 256*28*28
max_pixels = 1344*28*28

model_path = "/data/data1/syc/intern/wanshan/models/Qwen2-VL-2B-Instruct"
processor = Qwen2VLProcessor.from_pretrained(
    model_path, 
    min_pixels=min_pixels, max_pixels=max_pixels,
)

In [None]:
from transformers import Qwen2VLConfig
config = Qwen2VLConfig.from_pretrained(model_path)
print(config._attn_implementation) # only eager attention implementation can output attention

In [None]:
lm_qwen_layer = 28 # LLM Decoder layers
device = 'cuda:3'

model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    device_map=device,
    config=config
)

In [None]:
img_url = 'examples/test.jpg'
vis_dir = 'examples'


messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Please describe this image."},
            {
                "type": "image",
                "image": img_url,
                "min_pixels": min_pixels,
                "max_pixels": max_pixels,
            },
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True,
)
print(text)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)

inputs = inputs.to(device)

In [None]:
for key in inputs.keys():
    print(key, inputs[key].shape)

In [None]:
# text_inputs['patch_pos'] = torch.zeros_like(text_inputs['input_ids']) -1
vision_idx = {
    'start': 0,
    'end': 0
}
for i in range(len(inputs['input_ids'][0])):
    # assume here is 1 x L
    if inputs['input_ids'][0, i] == 151652:   # <|vision_start|> in Qwen2VL vocabulary
        vision_idx['start'] = i
        print(f'<vision_state> at {i}')
    if inputs['input_ids'][0, i] == 151653:   # <|vision_end|> in Qwen2VL vocabulary
        print(f'<vision_end> at {i}')
        vision_end = i
        vision_idx['end'] = i

vision_idx

In [None]:
generated_ids = model.generate(
                    **inputs,
                    do_sample=False,
                    max_new_tokens=128,
                    use_cache=True,
                    # stopping_criteria=[stopping_criteria],
                    output_attentions=True,
                    output_scores=True,
                    return_dict_in_generate=True,
                    )
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids['sequences'])
]
output_text = processor.batch_decode(
    generated_ids_trimmed, 
)[0]
print(output_text)
for key in generated_ids.keys():
    print(key, type(generated_ids[key]))

In [None]:
output_seqs = generated_ids['sequences']
output_seqs.shape

In [None]:
attention_outputs = generated_ids['attentions']
print(len(attention_outputs))

In [None]:
scores = generated_ids['scores']
print(len(scores[0]))

In [None]:
for key in generated_ids.keys():
    print(key)

In [19]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as Colormap
from matplotlib.colors import LogNorm

# refer to fastv paper
def visualize_attention(multihead_attention,output_path="atten_map_1.png",title="Layer 5", vision_idx=None):
    print(multihead_attention.shape)
    # multihead_attention (1, num_heads, n_tokens, n_tokens)
    # First, we average the attention scores over the multiple heads
    pool_stride = 20
    averaged_attention = torch.mean(multihead_attention, axis=1)[0].float()# Shape: (n_tokens, n_tokens)
    
    # pooling the attention scores  with stride = pool_stride
    # avg_pool2d requires 4d tensor (batch_size, channels, height, width)
    averaged_attention = torch.nn.functional.avg_pool2d(averaged_attention.unsqueeze(0).unsqueeze(0), pool_stride, stride=pool_stride).squeeze(0).squeeze(0)
    
    cmap = plt.get_cmap("viridis")
    plt.figure(figsize=(5, 5),dpi=400)

    # Log normalization
    log_norm = LogNorm(vmin=0.0007, vmax=0.1)

    # set the x and y ticks to 20x of the original

    ax = sns.heatmap(averaged_attention,
                cmap=cmap,  # custom color map
                norm=log_norm,  # 
                cbar_kws={'label': 'Attention score'},
                )
    
    # replace the x and y ticks with string

    x_ticks = [str(i*pool_stride) for i in range(0,averaged_attention.shape[0]+1)]
    y_ticks = [str(i*pool_stride) for i in range(0,averaged_attention.shape[0])]
    ax.set_xticks([i for i in range(0,averaged_attention.shape[0]+1)])
    ax.set_yticks([i for i in range(0,averaged_attention.shape[0])])
    ax.set_xticklabels(x_ticks)
    ax.set_yticklabels(y_ticks)

    # change the x tinks font size
    plt.xticks(fontsize=3)
    plt.yticks(fontsize=3)
    
    # make y label vertical
    plt.yticks(rotation=0)
    plt.xticks(rotation=90)    

    # remove the x and y ticks
    plt.xticks([])
    plt.yticks([]) 

    # draw horizontal lines for <vision_start> to <visual_end>
    # ax.text(x=vision_idx['start'] // pool_stride + 1, y=vision_idx['start'] // pool_stride, s='<vision_start>', color='b', fontsize=4)
    # ax.text(x=vision_idx['end'] // pool_stride -1 , y=vision_idx['start'] // pool_stride, s='<vision_end>', color='b', fontsize=4)
    ax.axvline(x=vision_idx['start'] // pool_stride + 1, color='r', linestyle='--', linewidth=0.3)
    ax.axhline(y=vision_idx['start'] // pool_stride + 1, color='r', linestyle='--', linewidth=0.3)
    ax.axvline(x=vision_idx['end'] // pool_stride, color='r', linestyle='--', linewidth=0.3)
    ax.axhline(y=vision_idx['end'] // pool_stride, color='r', linestyle='--', linewidth=0.3)
    
    plt.title(title)
    # tight layout
    plt.savefig(output_path, bbox_inches='tight')
    # plt.close() # close the plot, without displaying it in the notebook
    plt.show()

    top_five_attentions = []
    for row in averaged_attention:
        # Use torch.topk to get the top 5 values and their indices
        top_values, top_indices = torch.topk(row, 10)
        # Convert to lists and append to the overall list
        top_five_line = list(zip(top_indices.tolist(), top_values.tolist()))
        top_five_attentions.append(top_five_line)
        
    return top_five_attentions,averaged_attention    

In [None]:
for i, attention in enumerate(attention_outputs[0]):
    # print(i, attention)
    # print(top5_attention)
    top5_attention,average_attentions = visualize_attention(attention.cpu(), output_path="./attn_maps/atten_map_"+str(i+1)+".png",title="Layer "+str(i+1), vision_idx=vision_idx)