In [None]:
import argparse
import io
import math
import os
import pathlib
import sys
from typing import List, Tuple

import torch
import requests
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from transformers import AutoImageProcessor, AutoConfig, ViTForImageClassification

### Using a Pre-trained ViT for Image Classification

Loading Model

In [None]:
MODEL_NAME = "facebook/deit-small-patch16-224"

config = AutoConfig.from_pretrained(MODEL_NAME)
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(MODEL_NAME)
model.eval()

if torch.cuda.is_available():
    model.to("cuda")
    device = "cuda"

Loading Images

In [None]:
DEFAULT_IMAGE_URLS = [
    "https://images.unsplash.com/photo-1518791841217-8f162f1e1131",  # cat
    "https://images.unsplash.com/photo-1518020382113-a7e8fc38eac9",  # dog
    "https://images.unsplash.com/photo-1519681393784-d120267933ba",  # mountains
]

def is_url(s: str) -> bool:
    return s.startswith("http://") or s.startswith("https://")

def load_image_from_url(url: str, timeout: int = 10) -> Image.Image:
    r = requests.get(url, timeout=timeout)
    r.raise_for_status()
    return Image.open(io.BytesIO(r.content)).convert("RGB")

def load_image_from_path(p: str) -> Image.Image:
    return Image.open(p).convert("RGB")

def load_images(sources: List[str]) -> List[Tuple[str, Image.Image]]:
    images = []
    for src in sources:
        try:
            if is_url(src):
                img = load_image_from_url(src)
                name = src.split("/")[-1] or "image"
            else:
                img = load_image_from_path(src)
                name = pathlib.Path(src).stem
            images.append((name, img))
        except Exception as e:
            print(f"[WARN] Failed to load '{src}': {e}", file=sys.stderr)
    return images


sources = DEFAULT_IMAGE_URLS
items = load_images(sources)

pil_images = [img for (_, img) in items]

inputs = processor(images=pil_images, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

Inference

In [None]:
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True, return_dict=True)
    logits = outputs.logits                
    attentions = outputs.attentions  

probs = logits.softmax(dim=-1)
top1_ids = probs.argmax(dim=-1).tolist()

### Visualizing Patch Attention

In [None]:
def normalize_to_01(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    mn, mx = x.min(), x.max()
    return (x - mn) / (mx - mn + eps)

def get_cls_to_patches_attention(
    attentions: List[torch.Tensor]
) -> torch.Tensor:
    """
    Extract CLS -> patch attentions from the last layer and average across heads.

    attentions: list[length=num_layers] of tensors with shape (B, H, L, L)
    Returns: tensor with shape (B, num_patches), values in [0,1] (normalized per image).
    """
    last = attentions[-1]              # (B, H, L, L)
    B, H, L, _ = last.shape
    # token 0 is CLS. We take CLS -> others (dim=-1 indexes the "to" tokens).
    # Exclude CLS->CLS by slicing 1: for patches only.
    cls_to_all = last[:, :, 0, 1:]     # (B, H, L-1)
    # Average across heads:
    cls_avg = cls_to_all.mean(dim=1)   # (B, L-1) = (B, num_patches)

    # Normalize per sample for nicer visualization:
    out = torch.stack([normalize_to_01(cls_avg[b]) for b in range(B)], dim=0)
    return out  # (B, num_patches) in [0,1]


cls_patch_attn = get_cls_to_patches_attention(attentions)  # (B, num_patches)

In [None]:
def ensure_dir(p: str):
    pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)
    
def attention_map_to_heatmap(
    attn_vec: torch.Tensor,
    ref_img_size: Tuple[int, int]
) -> np.ndarray:
    """
    Convert a flat attention vector (num_patches) to an upsampled heatmap (H,W).
    Assumes a square patch grid (e.g., 14x14 for 224 with 16x16 patches).
    """
    num_patches = attn_vec.numel()
    grid = int(math.sqrt(num_patches))
    assert grid * grid == num_patches, f"num_patches={num_patches} is not a perfect square."

    # Reshape to grid
    attn_grid = attn_vec.reshape(grid, grid).detach().cpu().numpy()

    # Upsample to image size using PIL
    attn_img = Image.fromarray((attn_grid * 255).astype(np.uint8), mode="L")
    attn_img = attn_img.resize(ref_img_size, resample=Image.BICUBIC)
    return np.array(attn_img).astype(np.float32) / 255.0  # back to [0,1]


def overlay_heatmap_on_image(
    image: Image.Image,
    heatmap: np.ndarray,
    alpha: float = 0.45,
    cmap: str = "jet",
    save_path: str = None,
    title: str = None
):
    """
    Show and optionally save an overlay of the heatmap on the image.
    """
    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.imshow(heatmap, cmap=cmap, alpha=alpha, interpolation="bilinear")
    if title:
        plt.title(title)
    plt.axis("off")

    if save_path:
        ensure_dir(save_path)
        plt.savefig(save_path, bbox_inches="tight", dpi=200)
        print(f"[INFO] Saved attention overlay → {save_path}")
    plt.close()


In [None]:
id2label = {int(k): v for k, v in model.config.id2label.items()}
target_size = pil_images[0].size  # (W, H) for overlay sizing

for i, ((name, pil_img), idx) in enumerate(zip(items, top1_ids)):
    label = id2label.get(idx, f"class_{idx}")
    conf = probs[i, idx].item()
    print(f"[RESULT] {name}: Top-1 = {label} (p={conf:.3f})")

    # Build upsampled heatmap for this image
    heatmap = attention_map_to_heatmap(cls_patch_attn[i], (pil_img.size[0], pil_img.size[1]))

    # Overlay & save
    out_name = f"{name}__attn.png"
    overlay_title = f"{label} (p={conf:.2f}) — CLS attention"
    overlay_heatmap_on_image(pil_img, heatmap, alpha=0.45, cmap="jet", save_path=out_name, title=overlay_title)



In [None]:
#!/usr/bin/env python3
# vit_attn_from_medium.py
# ------------------------------------------------------------
# Implements:
# (1) Top-1 classification with a small ImageNet-pretrained ViT-family model
# (2) Patch attention visualization using ATTENTION ROLLOUT exactly as in the linked post:
#     - fuse heads by mean
#     - add identity (A + I)
#     - row-normalize
#     - multiply across layers
#     - extract a CLS↔patch vector, reshape to grid, normalize, blur, and overlay
#
# Usage:
#   python vit_attn_from_medium.py --images cat.jpg https://.../dog.png
#   python vit_attn_from_medium.py                  # uses 3 default URLs
#
# Output:
#   - Prints Top-1 predictions (class + probability)
#   - Saves `<name>__attn.png` heatmap overlays
# ------------------------------------------------------------

import argparse
import io
import math
import os
import pathlib
import sys
from typing import List, Tuple

import numpy as np
import requests
import torch
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt

from transformers import (
    AutoConfig,
    AutoImageProcessor,
    ViTForImageClassification,
)

# You can switch to "google/vit-base-patch16-224" if you prefer canonical ViT.
MODEL_NAME = "facebook/deit-small-patch16-224"

DEFAULT_IMAGE_URLS = [
    "https://images.unsplash.com/photo-1518791841217-8f162f1e1131",  # cat
    "https://images.unsplash.com/photo-1518020382113-a7e8fc38eac9",  # dog
    "https://images.unsplash.com/photo-1519681393784-d120267933ba",  # coffee
]


# ----------------------- IO helpers -----------------------

def is_url(s: str) -> bool:
    return s.startswith("http://") or s.startswith("https://")

def load_image_from_url(url: str, timeout: int = 12) -> Image.Image:
    r = requests.get(url, timeout=timeout)
    r.raise_for_status()
    return Image.open(io.BytesIO(r.content)).convert("RGB")

def load_image_from_path(p: str) -> Image.Image:
    return Image.open(p).convert("RGB")

def load_images(sources: List[str]) -> List[Tuple[str, Image.Image]]:
    images = []
    for src in sources:
        try:
            img = load_image_from_url(src) if is_url(src) else load_image_from_path(src)
            name = (src.split("/")[-1] or "image") if is_url(src) else pathlib.Path(src).stem
            images.append((name, img))
        except Exception as e:
            print(f"[WARN] Failed to load '{src}': {e}", file=sys.stderr)
    return images

def ensure_dir(p: str):
    pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)


# ----------------------- Attention rollout (from the article) -----------------------

@torch.no_grad()
def attention_rollout(attentions: List[torch.Tensor]) -> torch.Tensor:
    """
    Implements the "rolling attention" described in the article:
      - Fuse heads with mean
      - Add identity (A + I)
      - Row-normalize
      - Multiply across layers

    attentions: list of length L_layers, each tensor (B, H, T, T)
    returns: rollout (B, T, T)
    """
    # Start with identity
    T = attentions[0].size(-1)
    device = attentions[0].device
    rollout = torch.eye(T, device=device).unsqueeze(0)  # (1, T, T)

    for A in attentions:
        # mean over heads -> (B, T, T)
        A_fused = A.mean(dim=1)
        # add identity
        I = torch.eye(T, device=A_fused.device).unsqueeze(0)  # (1, T, T)
        A_aug = A_fused + I
        # row-normalize
        A_norm = A_aug / A_aug.sum(dim=-1, keepdim=True)
        # multiply (broadcast batch): (B,T,T) = (B,T,T) @ (B,T,T) via per-sample matmul
        rollout = rollout @ A_norm

    return rollout  # (B, T, T)


def cls_to_patches_from_rollout(rollout: torch.Tensor, follow_medium_indexing: bool = True) -> torch.Tensor:
    """
    Extract a CLS↔patch vector from rollout.

    The Medium post indexes like: rollout[0, 1:, 0]
    We'll support that *and* the more common 'row=CLS -> cols=patches' (rollout[0, 0, 1:]).
    By default, we follow the Medium post to honor the user's request.
    """
    # rollout: (B, T, T), where T = 1 + num_patches
    if follow_medium_indexing:
        vec = rollout[:, 1:, 0]  # (B, num_patches)
        # The article inverted then normalized; we keep that behavior:
        vec = 1.0 - vec
    else:
        vec = rollout[:, 0, 1:]  # (B, num_patches)

    # Normalize each sample to [0,1] for visualization
    vec = (vec - vec.amin(dim=1, keepdim=True)) / (vec.amax(dim=1, keepdim=True) - vec.amin(dim=1, keepdim=True) + 1e-6)
    return vec  # (B, num_patches) in [0,1]


# ----------------------- Viz helpers -----------------------

def upsample_attention_to_image(attn_vec: torch.Tensor, target_size: Tuple[int, int]) -> Image.Image:
    """
    attn_vec: (num_patches,) torch tensor in [0,1]
    target_size: (W, H)
    returns a PIL Image (mode 'L') resized to target_size
    """
    num_patches = attn_vec.numel()
    grid = int(math.sqrt(num_patches))
    assert grid * grid == num_patches, f"num_patches={num_patches} is not a square"
    arr = (attn_vec.detach().cpu().numpy().reshape(grid, grid) * 255).astype(np.uint8)
    heat = Image.fromarray(arr, mode="L")
    heat = heat.resize(target_size, resample=Image.BICUBIC)
    # Optional smoothing as in the article
    heat = heat.filter(ImageFilter.GaussianBlur(radius=2))
    return heat  # grayscale (L)

def overlay_heatmap_rgba(image: Image.Image, heat_gray: Image.Image, alpha: int = 100, save_path: str = None, title: str = None):
    """
    Builds an RGBA overlay from grayscale (alpha channel = intensity) and blends it on 'image'.
    """
    # Convert to RGBA with alpha proportional to heat (as in the article)
    heat_arr = np.array(heat_gray.convert("L"))
    rgba = np.stack([heat_arr, heat_arr, heat_arr, heat_arr], axis=-1)  # grayscale in RGB, same as alpha seed
    heat_rgba = Image.fromarray(rgba, mode="RGBA")
    heat_rgba.putalpha(alpha)  # control opacity

    # Compose
    base = image.convert("RGBA")
    composed = Image.alpha_composite(base, heat_rgba)

    # Matplotlib save with optional title (keeps things simple)
    plt.figure(figsize=(6, 6))
    plt.imshow(composed)
    if title:
        plt.title(title)
    plt.axis("off")
    if save_path:
        ensure_dir(save_path)
        plt.savefig(save_path, bbox_inches="tight", dpi=200)
        print(f"[INFO] Saved → {save_path}")
    plt.close()


# ----------------------- Main -----------------------

@torch.no_grad()
def main(args):
    # Load model + processor
    print(f"[INFO] Loading model: {MODEL_NAME}")
    processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
    try:
        model = ViTForImageClassification.from_pretrained(MODEL_NAME, attn_implementation="eager")
    except TypeError:
        # Older transformers versions may not accept this kwarg
        model = ViTForImageClassification.from_pretrained(MODEL_NAME)

    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    model.to(device).eval()

    # Prepare images
    sources = args.images if args.images else DEFAULT_IMAGE_URLS
    items = load_images(sources)
    if not items:
        print("[ERROR] No images loaded.", file=sys.stderr)
        sys.exit(1)

    pil_images = [img for (_, img) in items]
    inputs = processor(images=pil_images, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Forward with attentions
    outputs = model(**inputs, output_attentions=True, return_dict=True)
    logits = outputs.logits                         # (B, 1000)
    attentions = outputs.attentions                 # list[L_layers] of (B, H, T, T)

    # Top-1 predictions
    probs = logits.softmax(dim=-1)
    top1_ids = probs.argmax(dim=-1).tolist()
    id2label = {int(k): v for k, v in model.config.id2label.items()}

    for i, ((name, img), cls_id) in enumerate(zip(items, top1_ids)):
        label = id2label.get(cls_id, f"class_{cls_id}")
        conf = probs[i, cls_id].item()
        print(f"[RESULT] {name}: Top-1 = {label} (p={conf:.3f})")

    # Attention rollout (per article)
    rollout = attention_rollout(attentions)         # (B, T, T)

    # CLS↔patch vector (follow the article’s indexing by default)
    cls_vecs = cls_to_patches_from_rollout(rollout, follow_medium_indexing=True)  # (B, num_patches)

    # Build overlays
    for i, (name, img) in enumerate(items):
        heat_L = upsample_attention_to_image(cls_vecs[i], (img.size[0], img.size[1]))  # grayscale 'L'
        title = f"{id2label.get(top1_ids[i], str(top1_ids[i]))} (p={probs[i, top1_ids[i]].item():.2f}) — attention rollout"
        out_path = f"{name}__attn.png"
        overlay_heatmap_rgba(img, heat_L, alpha=100, save_path=out_path, title=title)

    print("\n[NOTE] Check the saved *_attn.png files; do the highlighted regions match the object?")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ViT classification + attention visualization (Medium-style rollout).")
    parser.add_argument("--images", nargs="*", help="Paths or URLs to 1–3 images.")
    parser.add_argument("--cpu", action="store_true", help="Force CPU.")
    args = parser.parse_args()
    main(args)