In [1]:
import torch
import torch.nn.functional as F
import requests
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
# from datasets import load_dataset

import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import logging
logger = logging.getLogger(__name__)

func_to_enable_grad = '_sample'
setattr(LlavaForConditionalGeneration, func_to_enable_grad, torch.enable_grad(getattr(LlavaForConditionalGeneration, func_to_enable_grad)))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "llava-hf/llava-1.5-7b-hf"
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    offload_state_dict=True
)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
    quantization_config=quant_config,
    attn_implementation = "eager"
).to(0)

#--------------------------------------------------
model.vision_tower.config.output_attentions = True

# set hooks to get attention weights
model.enc_attn_weights = []
#outputs: attn_output, attn_weights, past_key_value
def forward_hook(module, inputs, output): 
    if output[1] is None:
        logger.error(
            ("Attention weights were not returned for the encoder. "
            "To enable, set output_attentions=True in the forward pass of the model. ")
        )
        return output
    
    output[1].requires_grad_(True)
    output[1].retain_grad()
    model.enc_attn_weights.append(output[1].detach().cpu())
    return output

hooks_pre_encoder, hooks_encoder = [], []
for layer in model.language_model.layers:
    hook_encoder_layer = layer.self_attn.register_forward_hook(forward_hook)
    hooks_pre_encoder.append(hook_encoder_layer)

model.enc_attn_weights_vit = []

def forward_hook_image_processor(module, inputs, output): 
    if output[1] is None:
        logger.warning(
            ("Attention weights were not returned for the vision model. "
             "Relevancy maps will not be calculated for the vision model. " 
             "To enable, set output_attentions=True in the forward pass of vision_tower. ")
        )
        return output

    output[1].requires_grad_(True)
    output[1].retain_grad()
    model.enc_attn_weights_vit.append(output[1].detach().cpu())
    return output

hooks_pre_encoder_vit = []
for layer in model.vision_tower.vision_model.encoder.layers:
    hook_encoder_layer_vit = layer.self_attn.register_forward_hook(forward_hook_image_processor)
    hooks_pre_encoder_vit.append(hook_encoder_layer_vit)
#--------------------------------------------------

processor = AutoProcessor.from_pretrained(model_id)

if model.language_model.config.model_type == "gemma":
    eos_token_id = processor.tokenizer('<end_of_turn>', add_special_tokens=False).input_ids[0]
else:
    eos_token_id = processor.tokenizer.eos_token_id

# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image") 
conversation = [
    {

      "role": "user",
      "content": [
          {"type": "text", "text": "What is this animal?"},
          {"type": "image"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

# image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
# raw_image = Image.open(requests.get(image_file, stream=True).raw)
image_file = r"C:\Users\PRYth\OneDrive\Pictures\Screenshots\Screenshot 2025-09-24 103058.png"
raw_image = Image.open(image_file).convert("RGB")
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)

output = model.generate(
    **inputs, 
    max_new_tokens=8, 
    do_sample=False,
    use_cache=True,
    output_attentions=True,
    output_hidden_states=True,
    return_dict_in_generate=True,
    output_scores=True,
    eos_token_id=eos_token_id
)

for h in hooks_pre_encoder:
    h.remove()
for h in hooks_pre_encoder_vit:
    h.remove()

print(processor.decode(output.sequences[0], skip_special_tokens=True))

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 3/3 [00:27<00:00,  9.24s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


USER:  
What is this animal? ASSISTANT: This animal is a dog, specifically a


In [None]:
# output.sequences[0] is a tensor of shape [604]
tokens = output.sequences[0].tolist()  # convert to Python list
torch.save(output.sequences[0], "full_generation.pt")

# Decode each token individually
decoded_tokens = [processor.decode([t]) for t in tokens]

# Optionally, print
for i, tok in enumerate(decoded_tokens):
    print(f"{i}: {repr(tok)}")

In [None]:
conv_step = conversation + [
    {"role": "assistant", "content": [{"type": "text", "text": processor.decode(tokens[596:601])}]}
]
conv_step

In [None]:
import gc
del model
del processor
del hooks_pre_encoder
del hooks_encoder
del prompt
del inputs
del output
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()

In [None]:
print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

In [None]:
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
    attn_implementation = "eager"
).to(0)

#--------------------------------------------------
model.vision_tower.config.output_attentions = True

# set hooks to get attention weights
model.enc_attn_weights = []
#outputs: attn_output, attn_weights, past_key_value
def forward_hook(module, inputs, output): 
    if output[1] is None:
        logger.error(
            ("Attention weights were not returned for the encoder. "
            "To enable, set output_attentions=True in the forward pass of the model. ")
        )
        return output
    
    output[1].requires_grad_(True)
    output[1].retain_grad()
    model.enc_attn_weights.append(output[1].detach().cpu())
    return output

hooks_pre_encoder, hooks_encoder = [], []
for layer in model.language_model.layers:
    hook_encoder_layer = layer.self_attn.register_forward_hook(forward_hook)
    hooks_pre_encoder.append(hook_encoder_layer)

model.enc_attn_weights_vit = []

def forward_hook_image_processor(module, inputs, output): 
    if output[1] is None:
        logger.warning(
            ("Attention weights were not returned for the vision model. "
             "Relevancy maps will not be calculated for the vision model. " 
             "To enable, set output_attentions=True in the forward pass of vision_tower. ")
        )
        return output

    output[1].requires_grad_(True)
    output[1].retain_grad()
    model.enc_attn_weights_vit.append(output[1].detach().cpu())
    return output

hooks_pre_encoder_vit = []
for layer in model.vision_tower.vision_model.encoder.layers:
    hook_encoder_layer_vit = layer.self_attn.register_forward_hook(forward_hook_image_processor)
    hooks_pre_encoder_vit.append(hook_encoder_layer_vit)
#--------------------------------------------------

processor = AutoProcessor.from_pretrained(model_id)

prompt_appended = processor.apply_chat_template(conv_step, add_generation_prompt=False, return_tensors="pt")
inputs_appended = processor(images=raw_image, text=prompt_appended, return_tensors='pt').to(0, torch.float16)

print(len(model.enc_attn_weights))
output_appended = model(**inputs_appended,
                        use_cache=True,
                        output_attentions=True,
                        output_hidden_states=True,
                        return_dict_in_generate=True,
                        output_scores=True,
                        eos_token_id=eos_token_id)
print(len(model.enc_attn_weights))

In [None]:
print(output_appended.logits.shape)
# Decode them
topk = torch.topk(output_appended.logits[:,-2], k=1, dim=-1)
for ids in topk.indices:
    print(processor.tokenizer.batch_decode(ids))

# Visualise ViT attention

In [None]:
model.enc_attn_weights_vit[4].shape

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

attn_avg = model.enc_attn_weights_vit[3][0].mean(dim=0)
print(attn_avg.shape)
cls_attn = attn_avg[0, 1:]  # exclude CLS itself
H, W = 24, 24  # 336 / 14
heatmap = cls_attn.reshape(H, W).detach().cpu().numpy()

# Assuming heatmap shape [H, W]
heatmap_tensor = torch.tensor(heatmap[None, None], dtype=torch.float32)
heatmap_full = F.interpolate(heatmap_tensor, size=(raw_image.size[1], raw_image.size[0]), mode='bilinear')[0,0].numpy()

image = Image.open(image_file).convert("RGB")
plt.imshow(image)
plt.imshow(heatmap_full, cmap='jet', alpha=0.5)
plt.axis('off')
plt.colorbar()
plt.show()

# Visualize LLM attention 
## During first forward pass when attn weights are input_token_len by input_token_len

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
print(f"attn_weights has length: {len(model.enc_attn_weights)}")
print(f"no. of forward passes = {len(model.enc_attn_weights)}/32 = {int(len(model.enc_attn_weights)/32)}")

# First forward pass, first layer
print(model.enc_attn_weights[0].shape) # [batch, head, no. of tokens, no. of tokens]

print(model.enc_attn_weights[51].shape)

In [None]:
# Model starts generating tokens after "...ASSISTANT:"
attn_avg = model.enc_attn_weights[11][0].mean(dim=0)
print(attn_avg.shape)
cls_attn = attn_avg[592, 5:581]  # exclude CLS itself
H, W = 24, 24  # 336 / 14
heatmap = cls_attn.reshape(H, W).detach().cpu().numpy()

# Assuming heatmap shape [H, W]
heatmap_tensor = torch.tensor(heatmap[None, None], dtype=torch.float32)
heatmap_full = F.interpolate(heatmap_tensor, size=(raw_image.size[1], raw_image.size[0]), mode='bilinear')[0,0].numpy()

image = Image.open(image_file).convert("RGB")
plt.imshow(image)
plt.imshow(heatmap_full, cmap='jet', alpha=0.75)
plt.axis('off')
plt.colorbar()
plt.show()

# Attention Rollout
## During first forward pass when attn weights are input_token_len by input_token_len

In [None]:
def attention_rollout_function(attn_maps):
    attn_rollout = []
    device = attn_maps[0].device
    batch_size, _, seq_len, _ = attn_maps[0].shape
    
    # Identity matrix for self-attention
    I = torch.eye(seq_len, device=device).unsqueeze(0).expand(batch_size, seq_len, seq_len)

    prod = I.clone()
    
    for i, attn_map in enumerate(attn_maps):
        # Average over heads → [batch, seq_len, seq_len]
        attn_map = attn_map.mean(dim=1)
        
        # Add identity and multiply
        prod = prod @ (attn_map + I)
        
        # Normalize across sequence dimension
        prod = prod / prod.sum(dim=-1, keepdim=True)
        
        attn_rollout.append(prod)

    return attn_rollout

In [None]:
model.enc_attn_weights[32].shape

In [None]:
rollout = attention_rollout_function(model.enc_attn_weights[0:32])

In [None]:
rollout[31][0] # Rollout up to i-th index layer, batch size

In [None]:
topk_vals, topk_inds = torch.topk(rollout[31][0, 595], 10)

In [None]:
topk_inds

In [None]:
for i in topk_inds:
    print(repr(decoded_tokens[i.item()]))

# Visualise attention for subsequent forward passes

In [None]:
model.enc_attn_weights[32].shape

In [None]:
attn_weight = model.enc_attn_weights[243][0].mean(dim=0)
attn_weight.shape

In [None]:
topk_vals, topk_inds = torch.topk(attn_weight, 25)

In [None]:
topk_inds

In [None]:
topk_vals

In [None]:
heatmap = [0 for i in range(576)]

for i in topk_inds[0]:
    if repr(decoded_tokens[i.item()]) == "'<image>'":
        heatmap[i.item()-5] = 1

H, W = 24, 24  # 336 / 14
heatmap = torch.tensor(heatmap)
heatmap = heatmap.reshape(H, W).detach().cpu().numpy()

# Assuming heatmap shape [H, W]
heatmap_tensor = torch.tensor(heatmap[None, None], dtype=torch.float32)
heatmap_full = F.interpolate(heatmap_tensor, size=(raw_image.size[1], raw_image.size[0]), mode='nearest-exact')[0,0].numpy()

image = Image.open(image_file).convert("RGB")
plt.imshow(image)
plt.imshow(heatmap_full, cmap='jet', alpha=0.75)
plt.axis('off')
plt.colorbar()
plt.show()

In [None]:
# Model starts generating tokens after "...ASSISTANT:"
attn_avg = attn_weight
print(attn_avg.shape)
cls_attn = attn_avg[0, 5:581]  # exclude CLS itself
H, W = 24, 24  # 336 / 14
heatmap = cls_attn.reshape(H, W).detach().cpu().numpy()

# Assuming heatmap shape [H, W]
heatmap_tensor = torch.tensor(heatmap[None, None], dtype=torch.float32)
heatmap_full = F.interpolate(heatmap_tensor, size=(raw_image.size[1], raw_image.size[0]), mode='bilinear')[0,0].numpy()

image = Image.open(image_file).convert("RGB")
plt.imshow(image)
plt.imshow(heatmap_full, cmap='jet', alpha=0.75)
plt.axis('off')
plt.colorbar()
plt.show()

# Attention relevancy testing

In [None]:
device = "cuda"
input_ids = inputs.input_ids
output_ids = output.sequences.reshape(-1)[input_ids.shape[-1]:].tolist() 
torch.cuda.empty_cache()

for target_index in tqdm(range(len(output.scores)), desc="Building relevancy maps"):
    token_logits = output.scores[target_index]
    token_id = torch.tensor(output_ids[target_index]).to(device)

    token_id_one_hot = torch.nn.functional.one_hot(token_id, num_classes=token_logits.size(-1)).float()
    token_id_one_hot = token_id_one_hot.view(1, -1)
    token_id_one_hot.requires_grad_(True)

    # Compute loss and backpropagate to get gradients on attention weights
    model.zero_grad()
    token_logits.backward(gradient=token_id_one_hot)

    for i, blk in enumerate(model.enc_attn_weights):
        grad = blk.grad.float().detach()
        print(grad)
        break
    
    break

# Hidden states/concepts

In [None]:
print(f"No. of forward passes: {len(output.hidden_states)}")
print(f"No. of hidden states in each forward pass: {len(output.hidden_states[0])}")
print(f"Shape of 1 hidden state in first forward pass: {output.hidden_states[0][24].shape}")

logits = model.lm_head(output.hidden_states[3][17])
# Decode them
topk = torch.topk(logits, k=10, dim=-1)
for ids in topk.indices:
    print(processor.tokenizer.batch_decode(ids))

# Finetuning