In [4]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T

from transformers import AutoProcessor, AutoModelForSeq2SeqLM, AutoTokenizer
from captum.attr import IntegratedGradients

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model name and load components
model_name = "OpenGVLab/InternVL2-1B"  # Model repository URL on Hugging Face
# The following components assume that InternVL2-1B supports a processor (to jointly handle images and text)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True).to(device)
model.eval()

ValueError: Unrecognized configuration class <class 'transformers_modules.OpenGVLab.InternVL2-1B.0d75ccd166b1d0b79446ae6c5d1a4a667f1e6187.configuration_internvl_chat.InternVLChatConfig'> for this kind of AutoModel: AutoModelForSeq2SeqLM.
Model type should be one of BartConfig, BigBirdPegasusConfig, BlenderbotConfig, BlenderbotSmallConfig, EncoderDecoderConfig, FSMTConfig, GPTSanJapaneseConfig, LEDConfig, LongT5Config, M2M100Config, MarianConfig, MBartConfig, MT5Config, MvpConfig, NllbMoeConfig, PegasusConfig, PegasusXConfig, PLBartConfig, ProphetNetConfig, Qwen2AudioConfig, SeamlessM4TConfig, SeamlessM4Tv2Config, SwitchTransformersConfig, T5Config, UMT5Config, XLMProphetNetConfig.

In [None]:
def load_images_from_folder(folder_path):
    """
    Load all images (jpg, jpeg, png) from the specified folder.
    Returns a list of tuples: (filename, PIL.Image).
    """
    images = []
    for filename in os.listdir(folder_path):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.JPEG')):
            img_path = os.path.join(folder_path, filename)
            try:
                img = Image.open(img_path).convert('RGB')
                images.append((filename, img))
            except Exception as e:
                print(f"Could not load image {img_path}: {e}")
    return images

def apply_transformations(image):
    """
    Given a PIL Image, apply several symmetry transformations.
    Returns a dictionary mapping transformation names to transformed images.
    """
    transformations = {
        "original": image,
        "rotate_90": image.rotate(90, expand=True),
        "rotate_180": image.rotate(180, expand=True),
        "rotate_270": image.rotate(270, expand=True),
        "flip_horizontal": image.transpose(Image.FLIP_LEFT_RIGHT),
        "flip_vertical": image.transpose(Image.FLIP_TOP_BOTTOM),
        # Add more custom transformations if needed.
    }
    return transformations

In [None]:
def classify_image(image, prompt="What is in the image?"):
    """
    Given a PIL Image and a text prompt, processes the image using the model's processor,
    and generates a textual response from the model.
    """
    # The processor handles both image and text (if supported by the model)
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_new_tokens=50)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer


In [None]:
def compute_image_saliency(image, prompt="What is in the image?", target_token_position=0, baseline=None):
    """
    Computes a saliency map (pixel attributions) for a chosen output token of the generated answer,
    using Captum's Integrated Gradients.
    
    Parameters:
      image: PIL Image.
      prompt: Textual prompt provided to the model.
      target_token_position: The token position in the output whose logit is used for attribution.
      baseline: Baseline tensor for integrated gradients (if None, uses a zero tensor).
      
    Returns:
      attributions: Tensor of attributions (same shape as input image tensor).
      delta: Convergence delta from IntegratedGradients.
      
    Note: This example assumes that the model's processor produces a tensor called 'pixel_values'.
    """
    # Process the image and prompt to obtain model inputs.
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
    # Extract the pixel tensor (shape: [1, C, H, W])
    image_tensor = inputs["pixel_values"]
    
    # Ensure gradients are tracked
    image_tensor.requires_grad = True

    # Define a forward function that takes an image tensor and returns the chosen token logit.
    def forward_func(img_tensor):
        new_inputs = inputs.copy()
        new_inputs["pixel_values"] = img_tensor
        # Run the forward pass; here we use the logits output.
        outputs = model(**new_inputs)
        logits = outputs.logits  # shape: [batch, seq_len, vocab_size]
        # For simplicity, we choose the maximum logit value at the target token position.
        # (In practice, you might choose the logit for a specific token id.)
        target_logit = logits[0, target_token_position, :].max()
        return target_logit

    # Define a baseline (a black image). If provided, use it; otherwise, use zeros.
    if baseline is None:
        baseline = torch.zeros_like(image_tensor)
    
    # Initialize Integrated Gradients and compute attributions.
    ig = IntegratedGradients(forward_func)
    attributions, delta = ig.attribute(image_tensor, baseline, return_convergence_delta=True)
    return attributions, delta


In [None]:
def extract_cross_attention(image, prompt="What is in the image?"):
    """
    Registers forward hooks on cross-attention modules in the model’s decoder to extract the attention weights.
    Returns a list of attention outputs from all layers that have cross-attention.
    """
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
    cross_attn_weights = []  # List to store attention outputs

    def hook_fn(module, input, output):
        # For this example, we append the output directly.
        # (Depending on the model, you might need to extract specific tensors,
        # e.g., the raw attention weights.)
        cross_attn_weights.append(output)

    hooks = []
    # Register hooks on modules whose names indicate cross-attention.
    for name, module in model.named_modules():
        if "cross_attn" in name or "crossattention" in name:
            hooks.append(module.register_forward_hook(hook_fn))

    # Run a forward pass to collect attention weights.
    _ = model(**inputs)

    # Remove hooks
    for h in hooks:
        h.remove()
        
    return cross_attn_weights


In [None]:
def run_experiment(folder_path, prompt="What is in the image?"):
    """
    For each image in a folder, applies a set of symmetry transformations, classifies
    each transformed image using InternVL2-1B, computes a saliency map on the transformed images,
    and extracts cross-attention weights for the original image.
    """
    images = load_images_from_folder(folder_path)
    all_results = {}

    for filename, img in images:
        print(f"\n=== Processing Image: {filename} ===")
        transformation_dict = apply_transformations(img)
        img_results = {}
        
        for trans_name, timg in transformation_dict.items():
            print(f"\nTransformation: {trans_name}")
            # Classification (text generation) from the model:
            classification = classify_image(timg, prompt)
            print(f"Classification: {classification}")
            
            # Compute and visualize the saliency map (Integrated Gradients)
            attributions, delta = compute_image_saliency(timg, prompt, target_token_position=0)
            # Average the attributions across color channels to create a 2D saliency map.
            attr_np = attributions.squeeze().detach().cpu().numpy()  # shape: [C, H, W]
            if attr_np.ndim == 3:
                attr_np = np.mean(attr_np, axis=0)  # shape: [H, W]
            
            plt.figure(figsize=(5, 4))
            plt.imshow(attr_np, cmap="hot", interpolation="nearest")
            plt.title(f"Saliency Map - {trans_name}")
            plt.colorbar()
            plt.tight_layout()
            plt.show()
            
            # Save results for this transformation
            img_results[trans_name] = {
                "classification": classification,
                "saliency_delta": delta.item() if isinstance(delta, torch.Tensor) else delta
            }
        all_results[filename] = img_results
        
        # (Optional) Extract cross-attention weights for the original image
        ca_weights = extract_cross_attention(img, prompt)
        print("\nCross-attention weights (shapes) from decoder layers:")
        for i, weight in enumerate(ca_weights):
            # If weight is a tensor, print its shape; otherwise, show type info.
            if torch.is_tensor(weight):
                print(f" Layer {i}: shape {tuple(weight.shape)}")
            else:
                print(f" Layer {i}: type {type(weight)}")
    
    return all_results


In [None]:
if __name__ == "__main__":
    # Change this to your local folder containing images.
    folder_path = "images"  
    results = run_experiment(folder_path, prompt="What is in the image?")