# DINOv2 Attention Rollout Demo (qkv-hook, reg tokens fixed)

一键运行：只需执行下方单个 Code Cell。包含：
- 在线图片下载
- timm 的 DINOv2 ViT-B/14 reg4 模型
- 在 qkv 上 hook 并复原注意力
- 最后 K 层 attention rollout + 叠加可视化


In [None]:
# ================== 一键依赖 ==================
!pip -q install timm requests pillow matplotlib

# ================== 导入 ==================
import io, math, requests, torch, timm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as T

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

# ================== 在线图片（可改） ==================
URLS = [
    "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&w=1280",
    "https://upload.wikimedia.org/wikipedia/commons/2/25/Universal_Robots_UR5.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/9/9f/KUKA_Industrial_Robot.jpg",
]
def load_first_ok(urls):
    hdr = {"User-Agent":"Mozilla/5.0"}
    for u in urls:
        try:
            r = requests.get(u, headers=hdr, timeout=20); r.raise_for_status()
            return Image.open(io.BytesIO(r.content)).convert("RGB"), u
        except Exception as e:
            print(f"[跳过] {u} -> {e}")
    raise RuntimeError("所有 URL 都不可用")

img_pil, used_url = load_first_ok(URLS)
print("[OK] 使用图片：", used_url)

# ================== 模型（timm DINOv2 ViT-B/14 reg4） ==================
model_id = 'vit_base_patch14_reg4_dinov2'
model = timm.create_model(model_id, pretrained=True).to(device).eval()

# 预处理（518×518 可被 14 整除 → 37×37 patch）
IMG_SIZE = 518
transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE), interpolation=T.InterpolationMode.BICUBIC),
    T.ToTensor(),
    T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
x = transform(img_pil).unsqueeze(0).to(device)

# ================== Hook qkv 并前向 ==================
qkv_per_block = []
def qkv_hook(module, inp, out):
    qkv_per_block.append(out.detach().cpu())

handles = [blk.attn.qkv.register_forward_hook(qkv_hook) for blk in model.blocks]
with torch.no_grad():
    _ = model(x)
for h in handles: h.remove()

assert len(qkv_per_block) == len(model.blocks), "没有捕获到 qkv，重跑单元试试"

# ================== 从 qkv 复原注意力 ==================
num_heads = model.blocks[0].attn.num_heads
embed_dim = model.blocks[0].attn.qkv.out_features // 3
head_dim = embed_dim // num_heads
scale = head_dim ** -0.5

attn_per_block = []
for qkv in qkv_per_block:
    B, N, _ = qkv.shape
    qkv = qkv.view(B, N, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)  # [3,B,H,N,Dh]
    q, k = qkv[0], qkv[1]  # [B,H,N,Dh]
    attn = torch.softmax((q @ k.transpose(-2, -1)) * scale, dim=-1)      # [B,H,N,N]
    attn_per_block.append(attn.cpu())

print(f"[Info] 捕获到 {len(attn_per_block)} 层注意力；heads={num_heads}, tokens(N)={attn_per_block[0].shape[-1]}")

# ================== 计算 patch 网格和“额外 token”数量（reg tokens） ==================
if hasattr(model.patch_embed, "grid_size"):
    Hp, Wp = model.patch_embed.grid_size
else:
    # 兜底：按 N-1 近似平方根
    N = attn_per_block[0].shape[-1]
    Hp = Wp = int(round(math.sqrt(N-1)))
patch_tokens = Hp * Wp

N = attn_per_block[0].shape[-1]
extra_tokens = N - (1 + patch_tokens)           # 例如 reg4 则为 4
assert extra_tokens >= 0, "token 数异常"
print(f"[Info] Patch网格: {Hp}×{Wp}, 额外token数: {extra_tokens}")

# 一个帮助函数：取“CLS→patches”的向量（自动跳过 reg tokens）
def cls_to_patches_vec(R):
    # R: [B,N,N]，第0行为 CLS→所有token
    start = 1 + extra_tokens                   # 跳过 CLS 和 reg tokens
    end = start + patch_tokens
    v = R[0, 0, start:end]                     # 长度应为 Hp*Wp
    assert v.numel() == patch_tokens
    return v

# ================== Rollout ==================
def attention_rollout(attn_list, last_k=None, head_fusion='mean', alpha=0.2):
    if last_k is not None:
        attn_list = attn_list[-last_k:]
    B, H, N, _ = attn_list[0].shape
    eye = torch.eye(N)[None].repeat(B,1,1)
    R = eye.clone()
    for A in attn_list:
        A_f = A.mean(dim=1) if head_fusion=='mean' else A.max(dim=1).values
        A_aug = (1 - alpha) * A_f + alpha * eye
        A_aug = A_aug / (A_aug.sum(dim=-1, keepdim=True) + 1e-8)
        R = A_aug @ R
    return R  # [B,N,N]

# ================== 可视化：逐层（使用最近 K 层聚焦） ==================
K = 6
fig, axes = plt.subplots(3, 4, figsize=(14,9))
fig.suptitle(f'Attention Rollout per Block (CLS→patches) [{model_id}]', fontsize=14)

for i in range(len(attn_per_block)):
    R = attention_rollout(attn_per_block[:i+1], last_k=min(K, i+1), head_fusion='mean', alpha=0.2)
    heat = cls_to_patches_vec(R).reshape(Hp, Wp).numpy()
    ax = axes[i//4, i%4]
    ax.imshow(heat, cmap='magma')
    ax.set_title(f'Block: {i}')
    ax.axis('off')

plt.tight_layout(); plt.show()

# ================== 叠加回原图（最后 K 层） ==================
R_last = attention_rollout(attn_per_block, last_k=K, head_fusion='mean', alpha=0.2)
heat_last = cls_to_patches_vec(R_last).reshape(Hp, Wp).numpy()
heat_last = (heat_last - heat_last.min()) / (heat_last.max() - heat_last.min() + 1e-6)

H = W = IMG_SIZE
from matplotlib import cm
img_resize = img_pil.resize((W,H))
heat_big = Image.fromarray((heat_last*255).astype(np.uint8)).resize((W,H), Image.BILINEAR)
heat_rgb = (cm.viridis(np.array(heat_big)/255.0)[:,:,:3]*255).astype(np.uint8)
overlay = Image.blend(img_resize, Image.fromarray(heat_rgb), alpha=0.45)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1); plt.title("Input"); plt.imshow(img_resize); plt.axis('off')
plt.subplot(1,2,2); plt.title(f"Overlay (last {K} layers)"); plt.imshow(overlay); plt.axis('off')
plt.tight_layout(); plt.show()
