In [2]:
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image

model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
max_length = 14
num_beams = 1
gen_kwargs = {
  "max_length": max_length, 
  "num_beams": num_beams, 
  "output_attentions": True,
  "return_dict_in_generate": True,
}

def predict_step(image_paths):
  images = []
  for image_path in image_paths:
    i_image = Image.open(image_path)
    if i_image.mode != "RGB":
      i_image = i_image.convert(mode="RGB")

    images.append(i_image)

  pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
  pixel_values = pixel_values

  output = model.generate(pixel_values, **gen_kwargs)
  output_ids = output.sequences


  preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  preds = [pred.strip() for pred in preds]
  return preds,output



In [15]:
print(output.sequences.shape)

torch.Size([1, 11])


In [6]:
import numpy as np
from timm.data import create_transform
from typing import List, Tuple, Dict
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

In [7]:
def apply_mask(image: np.ndarray, mask: np.ndarray, color: Tuple[float, float, float], alpha: float = 0.5) -> np.ndarray:
    # Ensure mask and image have the same shape
    mask = mask[:, :, np.newaxis]
    mask = np.repeat(mask, 3, axis=2)
    
    # Convert color to numpy array
    color = np.array(color)
    
    # Apply mask
    masked_image = image * (1 - alpha * mask) + alpha * mask * color[np.newaxis, np.newaxis, :] * 255
    return masked_image.astype(np.uint8)


def rollout(attentions, discard_ratio, head_fusion, num_prefix_tokens=1):
    # based on https://github.com/jacobgil/vit-explain/blob/main/vit_rollout.py
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention in attentions:
            if head_fusion.startswith('mean'):
                # mean_std fusion doesn't appear to make sense with rollout
                attention_heads_fused = attention.mean(dim=0)
            elif head_fusion == "max":
                attention_heads_fused = attention.amax(dim=0)
            elif head_fusion == "min":
                attention_heads_fused = attention.amin(dim=0)
            else:
                raise ValueError("Attention head fusion type Not supported")

            # Discard the lowest attentions, but don't discard the prefix tokens
            flat = attention_heads_fused.view(-1)
            _, indices = flat.topk(int(flat.size(-1 )* discard_ratio), -1, False)
            indices = indices[indices >= num_prefix_tokens]
            flat[indices] = 0

            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0 * I) / 2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    # Look at the total attention between the prefix tokens (usually class tokens)
    # and the image patches    
    # FIXME this is token 0 vs non-prefix right now, need to cover other cases (> 1 prefix, no prefix, etc)
    mask = result[0, num_prefix_tokens:]
    width = int(mask.size(-1) ** 0.5)
    mask = mask.reshape(width, width).numpy()
    mask = mask / np.max(mask)
    return mask


def visualize_attention(
        attention_maps:any,
        image:Image.Image,
        head_fusion: str,
        discard_ratio: float,
) -> Tuple[List[Image.Image], Image.Image]:
    """Visualize attention maps and rollout for the given image and model."""
    
    
    # FIXME handle wider range of models that may not have num_prefix_tokens attr
    num_prefix_tokens = getattr(model, 'num_prefix_tokens', 1)  # Default to 1 class token if not specified
    # Convert PIL Image to numpy array
    image_np = np.array(image)

    # Create visualizations
    visualizations = []
    attentions_for_rollout = []
    for layer_idx, attn_map in enumerate(attention_maps):
        print(f"Attention map shape for {layer_idx}: {attn_map.shape}")
        attn_map = attn_map[0]  # Remove batch dimension

        attentions_for_rollout.append(attn_map)

        attn_map = attn_map[:, :, num_prefix_tokens:]  # Remove prefix tokens for visualization

        if head_fusion == 'mean_std':                
            attn_map = attn_map.mean(0) / attn_map.std(0)
        elif head_fusion == 'mean':
            attn_map = attn_map.mean(0)
        elif head_fusion == 'max':
            attn_map = attn_map.amax(0)
        elif head_fusion == 'min':
            attn_map = attn_map.amin(0)
        else:
            raise ValueError(f"Invalid head fusion method: {head_fusion}")

        # Use the first token's attention (usually the class token)
        # FIXME handle different prefix token scenarios
        attn_map = attn_map[0]

        # Reshape the attention map to 2D
        num_patches = int(attn_map.shape[-1] ** 0.5)
        attn_map = attn_map.reshape(num_patches, num_patches)

        # Interpolate to match image size
        attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
        attn_map = F.interpolate(attn_map, size=(image_np.shape[0], image_np.shape[1]), mode='bilinear', align_corners=False)
        attn_map = attn_map.squeeze().cpu().numpy()

        # Normalize attention map
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())

        # Create visualization
        fig, ax = plt.subplots(figsize=(20, 10))

        # Attention map overlay
        masked_image = apply_mask(image_np, attn_map, color=(1, 0, 0))  # Red mask
        ax.imshow(masked_image)
        ax.set_title(f'Attention Map for {layer_idx}')
        ax.axis('off')

        plt.tight_layout()

        # Convert plot to image
        fig.canvas.draw()
        vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
        visualizations.append(vis_image)
        plt.close(fig)

    # Calculate rollout
    rollout_mask = rollout(attentions_for_rollout, discard_ratio, head_fusion, num_prefix_tokens)

    # Create rollout visualization
    fig, ax = plt.subplots(figsize=(20, 10))

    # Rollout overlay
    rollout_mask_pil = Image.fromarray((rollout_mask * 255).astype(np.uint8))
    rollout_mask_resized = np.array(rollout_mask_pil.resize((image_np.shape[1], image_np.shape[0]), Image.BICUBIC)) / 255.0
    masked_image = apply_mask(image_np, rollout_mask_resized, color=(1, 0, 0))  # Red mask
    ax.imshow(masked_image)
    ax.set_title('Attention Rollout')
    ax.axis('off')

    plt.tight_layout()

    # Convert plot to image
    fig.canvas.draw()
    rollout_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
    plt.close(fig)

    return visualizations, rollout_image

In [18]:
preds,output=predict_step(['val2017/000000463527.jpg'])

In [19]:
image=Image.open("val2017/000000463527.jpg")

In [20]:
print(preds)

['a tray of food with a sandwich and a cup of coffee']


In [None]:
for i in range(len(output.cross_attentions)):
    _ ,roll= visualize_attention(output.cross_attentions[i],image=image,head_fusion="mean",discard_ratio=0.8)
    roll.save(f"my_token_attn_{i}.png")

Attention map shape for 0: torch.Size([1, 12, 1, 197])
Attention map shape for 1: torch.Size([1, 12, 1, 197])
Attention map shape for 2: torch.Size([1, 12, 1, 197])
Attention map shape for 3: torch.Size([1, 12, 1, 197])


  attn_map = torch.tensor(attn_map).unsqueeze(0).unsqueeze(0)
  vis_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())


Attention map shape for 4: torch.Size([1, 12, 1, 197])
Attention map shape for 5: torch.Size([1, 12, 1, 197])
Attention map shape for 6: torch.Size([1, 12, 1, 197])
Attention map shape for 7: torch.Size([1, 12, 1, 197])
Attention map shape for 8: torch.Size([1, 12, 1, 197])
Attention map shape for 9: torch.Size([1, 12, 1, 197])
Attention map shape for 10: torch.Size([1, 12, 1, 197])
Attention map shape for 11: torch.Size([1, 12, 1, 197])


  rollout_image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())


Attention map shape for 0: torch.Size([1, 12, 1, 197])
Attention map shape for 1: torch.Size([1, 12, 1, 197])
Attention map shape for 2: torch.Size([1, 12, 1, 197])
Attention map shape for 3: torch.Size([1, 12, 1, 197])
Attention map shape for 4: torch.Size([1, 12, 1, 197])
Attention map shape for 5: torch.Size([1, 12, 1, 197])
Attention map shape for 6: torch.Size([1, 12, 1, 197])
Attention map shape for 7: torch.Size([1, 12, 1, 197])
Attention map shape for 8: torch.Size([1, 12, 1, 197])
Attention map shape for 9: torch.Size([1, 12, 1, 197])
Attention map shape for 10: torch.Size([1, 12, 1, 197])
Attention map shape for 11: torch.Size([1, 12, 1, 197])
Attention map shape for 0: torch.Size([1, 12, 1, 197])
Attention map shape for 1: torch.Size([1, 12, 1, 197])
Attention map shape for 2: torch.Size([1, 12, 1, 197])
Attention map shape for 3: torch.Size([1, 12, 1, 197])
Attention map shape for 4: torch.Size([1, 12, 1, 197])
Attention map shape for 5: torch.Size([1, 12, 1, 197])
Attentio