# Multi-Region Visual Cropping for Qwen2.5-VL (Improved)

This notebook extends the single-region ViCrop method to support multiple regions.

## Key Improvements over Naive Approach
1. **Peak Detection**: Uses `peak_local_max` instead of greedy masking
2. **Gaussian Suppression**: Smooth region separation instead of binary masks
3. **IoU-based Overlap Control**: Properly prevents overlapping regions
4. **Attention Threshold**: Filters out low-quality regions
5. **Quality Scoring**: Combined metric of attention density + sharpness

In [None]:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
import torch
import matplotlib.pyplot as plt
from qwen_vl_utils import process_vision_info
import numpy as np
from PIL import Image
from io import BytesIO
import base64

# Import improved multi-region utilities
from multi_region_utils import (
    bbox_from_att_multi_region,
    crop_multi_regions,
    adaptive_num_regions,
    filter_regions_by_quality,
    visualize_regions,
    is_relational_or_counting_question
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

## Load Model

In [None]:
model_path = 'Qwen/Qwen2.5-VL-3B-Instruct'

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
).eval().to(device)

processor = AutoProcessor.from_pretrained(
    model_path, 
    trust_remote_code=True, 
    padding_side='left', 
    use_fast=True
)

print("Model loaded successfully!")

## Utility Functions

In [None]:
def encode_base64(image):
    """Encodes a PIL image to a base64 string."""
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

def prepare_qwen2_5_input(messages, processor):
    """Prepare the input for Qwen2.5VL."""
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
    return inputs

## Relative Attention Computation

In [None]:
def compute_relative_attention(image, question, model, processor, att_layer=22):
    """
    Compute relative attention for a given image and question.
    
    The relative attention normalizes question-specific attention by 
    general description attention to highlight semantically relevant regions.
    """
    image_str = encode_base64(image)
    
    # Question-specific attention
    messages = [
        {"role": "user", 
         "content": [
             {"type": "image", "image": f'data:image;base64,{image_str}'},
             {"type": "text", "text": f"{question} Answer the question using a single word or phrase."}
         ]}
    ]
    
    # General description for normalization
    general_messages = [
        {"role": "user", 
         "content": [
             {"type": "image", "image": f'data:image;base64,{image_str}'},
             {"type": "text", "text": "Write a general description of the image. Answer the question using a single word or phrase."}
         ]}
    ]
    
    inputs = prepare_qwen2_5_input(messages, processor).to(model.device, torch.bfloat16)
    general_inputs = prepare_qwen2_5_input(general_messages, processor).to(model.device, torch.bfloat16)
    
    att_shape = (inputs['image_grid_thw'][0, 1:] / 2).cpu().numpy().astype(int).tolist()
    
    vision_start_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_start|>')
    vision_end_token_id = processor.tokenizer.convert_tokens_to_ids('<|vision_end|>')
    
    pos = inputs['input_ids'].tolist()[0].index(vision_start_token_id) + 1
    pos_end = inputs['input_ids'].tolist()[0].index(vision_end_token_id)
    
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
        general_outputs = model(**general_inputs, output_attentions=True)
    
    att = outputs['attentions'][att_layer][0, :, -1, pos:pos_end].mean(dim=0).to(torch.float32).detach().cpu().numpy()
    general_att = general_outputs['attentions'][att_layer][0, :, -1, pos:pos_end].mean(dim=0).to(torch.float32).detach().cpu().numpy()
    
    # Relative attention with epsilon to avoid division by zero
    att_map = att / (general_att + 1e-8)
    att_map = att_map.reshape(att_shape)
    
    return att_map

## Multi-Region VQA Inference (Improved)

In [None]:
def multi_region_vqa(
    image, 
    question, 
    model, 
    processor, 
    num_regions=None,  # None = auto-detect based on question
    max_overlap_iou=0.3,
    min_attention_ratio=0.15,
    min_score_ratio=0.4,
    visualize=True
):
    """
    Perform VQA with improved multi-region cropping.
    
    Args:
        image: PIL Image
        question: Question string
        model: Qwen2.5-VL model
        processor: Qwen2.5-VL processor
        num_regions: Number of regions (None = auto-detect)
        max_overlap_iou: Maximum IoU between regions
        min_attention_ratio: Minimum attention to consider a region
        min_score_ratio: Minimum score ratio for region filtering
        visualize: Whether to visualize results
        
    Returns:
        dict: Contains answer, bboxes, crops, scores, attention_map
    """
    # Auto-detect number of regions based on question type
    if num_regions is None:
        num_regions = adaptive_num_regions(question, default=1)
        is_rel, is_count, keywords = is_relational_or_counting_question(question)
        if is_rel or is_count:
            print(f"Detected {'counting' if is_count else 'relational'} question (keywords: {keywords})")
            print(f"Using {num_regions} regions")
    
    # Compute attention map
    att_map = compute_relative_attention(image, question, model, processor)
    
    # Get multiple bounding boxes with scores
    bboxes, scores = bbox_from_att_multi_region(
        att_map, 
        image.size, 
        num_regions=num_regions,
        max_overlap_iou=max_overlap_iou,
        min_attention_ratio=min_attention_ratio
    )
    
    # Filter by quality score
    bboxes, scores = filter_regions_by_quality(
        bboxes, scores, 
        min_score_ratio=min_score_ratio
    )
    
    print(f"Selected {len(bboxes)} regions with scores: {[f'{s:.3f}' for s in scores]}")
    
    # Crop regions
    crops = crop_multi_regions(image, bboxes)
    
    # Visualize if requested
    if visualize:
        fig = visualize_regions(image, bboxes, att_map)
        plt.show()
    
    # Prepare multi-image input: [original, crop1, crop2, ...]
    image_strs = [encode_base64(image)] + [encode_base64(crop) for crop in crops]
    
    content = []
    for img_str in image_strs:
        content.append({"type": "image", "image": f'data:image;base64,{img_str}'})
    content.append({"type": "text", "text": f"{question} Answer the question using a single word or phrase."})
    
    messages = [{"role": "user", "content": content}]
    
    inputs = prepare_qwen2_5_input(messages, processor).to(model.device, torch.bfloat16)
    
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=128)
    
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    
    answer = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    return {
        'answer': answer,
        'bboxes': bboxes,
        'crops': crops,
        'scores': scores,
        'attention_map': att_map,
        'num_regions_used': len(bboxes)
    }

## Baseline VQA (for comparison)

In [None]:
def baseline_vqa(image, question, model, processor):
    """
    Perform VQA without any cropping (baseline).
    """
    image_str = encode_base64(image)
    
    messages = [
        {"role": "user", 
         "content": [
             {"type": "image", "image": f'data:image;base64,{image_str}'},
             {"type": "text", "text": f"{question} Answer the question using a single word or phrase."}
         ]}
    ]
    
    inputs = prepare_qwen2_5_input(messages, processor).to(model.device, torch.bfloat16)
    
    with torch.no_grad():
        generated_ids = model.generate(**inputs, max_new_tokens=128)
    
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    
    answer = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    
    return answer

## Test Examples

In [None]:
# Load test image
image_path = './images/demo1.png'  # Replace with your image path
image = Image.open(image_path).convert('RGB')

# Display original image
plt.figure(figsize=(8, 8))
plt.imshow(image)
plt.title("Test Image")
plt.axis('off')
plt.show()

In [None]:
# Test 1: Relational question
question = "What is to the left of the person?"

print("=" * 80)
print(f"Question: {question}")
print("=" * 80)

# Baseline
print("\n[Baseline - No Cropping]")
baseline_answer = baseline_vqa(image, question, model, processor)
print(f"Answer: {baseline_answer}")

# Multi-region (auto-detect)
print("\n[Multi-Region Cropping - Auto]")
result = multi_region_vqa(image, question, model, processor, visualize=True)
print(f"Answer: {result['answer']}")

In [None]:
# Test 2: Counting question
question = "How many people are in the image?"

print("=" * 80)
print(f"Question: {question}")
print("=" * 80)

# Baseline
print("\n[Baseline - No Cropping]")
baseline_answer = baseline_vqa(image, question, model, processor)
print(f"Answer: {baseline_answer}")

# Multi-region (auto-detect)
print("\n[Multi-Region Cropping - Auto]")
result = multi_region_vqa(image, question, model, processor, visualize=True)
print(f"Answer: {result['answer']}")

In [None]:
# Test 3: Detail question (should use single region)
question = "What text is visible on the sign?"

print("=" * 80)
print(f"Question: {question}")
print("=" * 80)

# Baseline
print("\n[Baseline - No Cropping]")
baseline_answer = baseline_vqa(image, question, model, processor)
print(f"Answer: {baseline_answer}")

# Multi-region (auto-detect - should use 1 region)
print("\n[Multi-Region Cropping - Auto]")
result = multi_region_vqa(image, question, model, processor, visualize=True)
print(f"Answer: {result['answer']}")

## Parameter Tuning

In [None]:
# Experiment with different parameters
question = "What is between the two people?"

print("Testing different overlap thresholds...")
print("=" * 80)

for max_iou in [0.1, 0.3, 0.5]:
    print(f"\nmax_overlap_iou = {max_iou}")
    result = multi_region_vqa(
        image, question, model, processor,
        num_regions=3,
        max_overlap_iou=max_iou,
        visualize=False
    )
    print(f"  Regions: {len(result['bboxes'])}, Answer: {result['answer']}")

## Comparison Function

In [None]:
def compare_methods(image, questions, visualize=False):
    """
    Compare baseline vs multi-region cropping on multiple questions.
    """
    results = []
    
    for question in questions:
        print(f"\nQuestion: {question}")
        print("-" * 60)
        
        baseline = baseline_vqa(image, question, model, processor)
        multi = multi_region_vqa(image, question, model, processor, visualize=visualize)
        
        print(f"Baseline:     {baseline}")
        print(f"Multi-Region: {multi['answer']} ({multi['num_regions_used']} regions)")
        
        results.append({
            'question': question,
            'baseline': baseline,
            'multi_region': multi['answer'],
            'num_regions': multi['num_regions_used']
        })
    
    return results

In [None]:
# Run comparison on multiple question types
test_questions = [
    # Relational
    "What is to the left of the person?",
    "What is to the right of the person?",
    "What is between the objects?",
    # Counting
    "How many people are in the image?",
    "How many objects are on the table?",
    # Detail (single region)
    "What color is the shirt?",
    "What text is shown?",
]

results = compare_methods(image, test_questions, visualize=False)

## Summary

The improved multi-region implementation addresses the following issues:

1. **Peak Detection**: Uses `skimage.feature.peak_local_max` to find true local maxima in the attention map, rather than greedily masking regions.

2. **Gaussian Suppression**: After selecting a region, applies smooth Gaussian suppression instead of binary masking, allowing nearby regions to still contribute.

3. **IoU Control**: Actually implements overlap checking using Intersection over Union (the original code defined `min_overlap` but never used it).

4. **Quality Thresholding**: Filters out low-quality regions based on both attention threshold and score ratio.

5. **Adaptive Regions**: Automatically detects question type (counting/relational/detail) and adjusts the number of regions accordingly.