In [None]:
#ADD UTILS.PY 
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import cv2
from PIL import Image

import torch
import torch.nn.functional as F
from utils import (
    load_image,
    aggregate_llm_attention, aggregate_vit_attention,
    heterogenous_stack,
    show_mask_on_image
)

In [None]:
import os
import warnings
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
import torch


def load_pretrained_model(model_path, model_base=None, model_name="chartgemma", load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
    """
    Load ChartGemma model (based on PaliGemma architecture)

    Args:
        model_path: Path or HuggingFace model ID (e.g., "ahmed-masry/chartgemma")
        model_base: Base model path (for LoRA, not typically used with ChartGemma)
        model_name: Model name identifier
        load_8bit: Load model in 8-bit quantization
        load_4bit: Load model in 4-bit quantization
        device_map: Device mapping strategy
        device: Target device
        use_flash_attn: Use flash attention (if supported)

    Returns:
        processor, model, context_len
    """
    kwargs = {"device_map": device_map, **kwargs}

    if device != "cuda":
        kwargs['device_map'] = {"": device}

    # Configure quantization
    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.float16

    # Flash attention support for PaliGemma
    if use_flash_attn:
        kwargs['attn_implementation'] = 'flash_attention_2'
    kwargs['attn_implementation'] = 'eager'

    print(f'Loading ChartGemma model from {model_path}...')

    # Load processor (combines tokenizer and image processor for PaliGemma)
    processor = AutoProcessor.from_pretrained(model_path)

    # Load model
    if model_base is not None:
        # Handle LoRA or fine-tuned models
        from peft import PeftModel
        print(f'Loading base model from {model_base}...')
        model = PaliGemmaForConditionalGeneration.from_pretrained(
            model_base,
            low_cpu_mem_usage=True,
            **kwargs
        )
        print(f"Loading adapter weights from {model_path}")
        model = PeftModel.from_pretrained(model, model_path)
        print(f"Merging weights...")
        model = model.merge_and_unload()
        print('Model loaded and merged successfully')
    else:
        # Load full model directly
        model = PaliGemmaForConditionalGeneration.from_pretrained(
            model_path,
            low_cpu_mem_usage=True,
            **kwargs
        )
        print('Model loaded successfully')

    # Get context length
    if hasattr(model.config, "max_position_embeddings"):
        context_len = model.config.max_position_embeddings
    elif hasattr(model.config, "text_config") and hasattr(model.config.text_config, "max_position_embeddings"):
        context_len = model.config.text_config.max_position_embeddings
    else:
        context_len = 2048  # Default fallback

    return processor, model, context_len




In [None]:
import torch
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from questions import mini_VLAT_questions
import os

model_path = "ahmed-masry/chartgemma"

# Load the model
load_8bit = False
load_4bit = False
device = "cuda" if torch.cuda.is_available() else "cpu"

processor, model, context_len = load_pretrained_model(
    model_path,
    None,  # model_base
    "chartgemma",
    load_8bit,
    load_4bit,
    device=device
)

# Get the number of image tokens (patches)
vision_model = model.vision_tower if hasattr(model, 'vision_tower') else model.vision_model
if hasattr(vision_model.config, 'num_image_tokens'):
    num_image_tokens = vision_model.config.num_image_tokens
else:
    image_size = vision_model.config.image_size
    patch_size = vision_model.config.patch_size
    num_image_tokens = (image_size // patch_size) ** 2

# Calculate grid size for attention visualization
grid_size = int(np.sqrt(num_image_tokens))

# Create output directory for visualizations
os.makedirs("attention_outputs", exist_ok=True)

# Process each question
for idx, question_data in enumerate(mini_VLAT_questions):
    chart_type = question_data[0]
    prompt_text = question_data[1]
    image_path = question_data[2]
    
    print(f"\n{'='*80}")
    print(f"Processing {idx + 1}/{len(mini_VLAT_questions)}: {chart_type}")
    print(f"Question: {prompt_text}")
    print(f"Image: {image_path}")
    print('='*80)
    
    # Load image
    try:
        image = load_image(image_path)
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        continue
    
    # Prepare inputs
    inputs = processor(text=prompt_text, images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate response with attention
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=512,
            use_cache=True,
            return_dict_in_generate=True,
            output_attentions=True,
        )
    
    # Decode the response
    response = processor.decode(outputs["sequences"][0], skip_special_tokens=True).strip()
    print(f"\nResponse: {response}\n")
    
    # Construct attention matrix
    aggregated_prompt_attention = []
    for i, layer in enumerate(outputs["attentions"][0]):
        layer_attns = layer.squeeze(0)
        attns_per_head = layer_attns.mean(dim=0)
        cur = attns_per_head[:-1].cpu().clone()
        cur[1:, 0] = 0.
        cur[1:] = cur[1:] / cur[1:].sum(-1, keepdim=True)
        aggregated_prompt_attention.append(cur)
    aggregated_prompt_attention = torch.stack(aggregated_prompt_attention).mean(dim=0)
    
    # Build LLM attention matrix
    llm_attn_matrix = heterogenous_stack(
        [torch.tensor([1])]
        + list(aggregated_prompt_attention)
        + list(map(aggregate_llm_attention, outputs["attentions"]))
    )
    
    # Calculate token positions
    total_sequence_len = outputs["sequences"].shape[1]
    input_prompt_len = inputs["input_ids"].shape[1]
    
    vision_token_start = 0
    vision_token_end = num_image_tokens
    
    output_token_start = input_prompt_len
    output_token_len = total_sequence_len - input_prompt_len
    output_token_end = total_sequence_len
    
    # Validate
    if output_token_len <= 0:
        print(f"Warning: No new tokens generated for question {idx + 1}")
        continue
    
    # Create visualization
    num_image_per_row = 8
    image_ratio = image.size[0] / image.size[1]
    num_rows = output_token_len // num_image_per_row + (1 if output_token_len % num_image_per_row != 0 else 0)
    
    if num_rows == 0:
        num_rows = 1
    
    fig, axes = plt.subplots(
        num_rows, num_image_per_row,
        figsize=(10, (10 / num_image_per_row) * image_ratio * num_rows),
        dpi=150
    )
    
    if num_rows == 1:
        axes = axes.reshape(1, -1)
    
    plt.subplots_adjust(wspace=0.05, hspace=0.2)
    
    vis_overlayed_with_attn = True
    output_token_inds = list(range(output_token_start, output_token_end))
    vision_model = model.vision_tower if hasattr(model, 'vision_tower') else model.vision_model

# For PaliGemma, we need to get the vision attention from the last forward pass
# The attention is stored differently than in LLaVA
# You may need to run a forward pass with output_attentions=True first

# Get vision attention - PaliGemma structure
    if hasattr(vision_model, 'image_attentions'):
        vision_attentions = vision_model.image_attentions
    else:
        # If not stored, you'll need to do a forward pass on just the vision model
        # with the image to get the attentions
        with torch.no_grad():
            vision_outputs = vision_model(
                pixel_values=inputs["pixel_values"],
                output_attentions=True
            )
            vision_attentions = vision_outputs.attentions

    # Aggregate vision attention
    # PaliGemma typically has 27 vision layers
    select_layer = -1  # Use last layer, or specify which layer you want
    all_prev_layers = True

    if all_prev_layers and select_layer < 0:
        # Average all layers
        stacked_attentions = torch.stack([attn.mean(dim=1).squeeze(0) for attn in vision_attentions])
        vis_attn_matrix = stacked_attentions.mean(dim=0)
    else:
        # Use specific layer
        layer_idx = select_layer if select_layer >= 0 else len(vision_attentions) + select_layer
        vis_attn_matrix = vision_attentions[layer_idx].mean(dim=1).squeeze(0)

    # Remove CLS token attention if present (first token)
    if vis_attn_matrix.shape[0] > num_image_tokens:
        vis_attn_matrix = vis_attn_matrix[1:, 1:]  # Remove CLS token
    else:
        vis_attn_matrix = vis_attn_matrix
    patch_size = vision_model.config.patch_size
    image_size_model = vision_model.config.image_size
    grid_size = image_size_model // patch_size


    for i, ax in enumerate(axes.flatten()):
        if i >= output_token_len:
            ax.axis("off")
            continue
        
        target_token_ind = output_token_inds[i]
        attn_weights_over_vis_tokens = llm_attn_matrix[target_token_ind][vision_token_start:vision_token_end]
        
        if len(attn_weights_over_vis_tokens) == 0 or attn_weights_over_vis_tokens.sum() == 0:
            ax.axis("off")
            continue
        
        attn_weights_over_vis_tokens = attn_weights_over_vis_tokens / attn_weights_over_vis_tokens.sum()
        
        attn_over_image = []
        for weight, vis_attn in zip(attn_weights_over_vis_tokens, vis_attn_matrix):
            vis_attn = vis_attn.reshape(grid_size, grid_size)
            attn_over_image.append(vis_attn * weight)
        attn_over_image = torch.stack(attn_over_image).sum(dim=0)
        attn_over_image = attn_over_image / attn_over_image.max()
        
        attn_over_image = F.interpolate(
            attn_over_image.unsqueeze(0).unsqueeze(0),
            size=(image.size[1], image.size[0]),
            mode='nearest',
        ).squeeze()
        
        np_img = np.array(image)[:, :, ::-1]
        img_with_attn, heatmap = show_mask_on_image(np_img, attn_over_image.cpu().numpy())
        ax.imshow(heatmap if not vis_overlayed_with_attn else img_with_attn)
        
        token_id = outputs["sequences"][0][target_token_ind]
        ax.set_title(
            processor.decode([token_id], skip_special_tokens=False).strip(),
            fontsize=7,
            pad=1
        )
        ax.axis("off")
    
    # Save visualization
    output_filename = f"attention_outputs/{chart_type}_{idx+1}.png"
    plt.tight_layout()
    plt.savefig(output_filename, bbox_inches='tight')
    plt.close()
    
    print(f"Saved visualization to: {output_filename}")

print(f"\n{'='*80}")
print(f"Completed processing {len(mini_VLAT_questions)} questions")
print(f"Visualizations saved in: attention_outputs/")
print('='*80)