In [1]:
import os, sys, json, torch, random
os.environ['CUDA_VISIBLE_DEVICES'] = '5'
import numpy as np
from pathlib import Path
from PIL import Image

# 固定随机种子
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

sys.path.append("/home/czj/llava15_test/LLaVA")

from llava.model.builder import load_pretrained_model
from llava.data.gqa_loader import GQALoader
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN




  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


[2025-10-28 22:15:06,332] [INFO] [real_accelerator.py:161:get_accelerator] Setting ds_accelerator to cuda (auto detect)




In [2]:
from gqa_eval.config import get_args
from gqa_eval.seed_utils import set_seed
from gqa_eval.model_loader import load_llava_model
from llava.data.gqa_loader import GQALoader

In [3]:
import numpy as np
from tqdm import tqdm
from collections import Counter
from gqa_eval.pred_gt_match import compute_match_batch
from gqa_eval.prompt import build_multimodal_batch_inputs

In [4]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
import torch
import numpy as np
import os

In [14]:
import argparse

args = argparse.Namespace(
    model_path="/home/czj/llava15_test/llava-v1.5-7b",
    model_base=None,
    image_file="/home/czj/kunkun.png",
    query="What is this image?",
    conv_mode=None,
    sep=",",
    temperature=0.0,
    top_p=None,
    num_beams=1,
    max_new_tokens=64,
    layers="0-31"
)
args


Namespace(model_path='/home/czj/llava15_test/llava-v1.5-7b', model_base=None, image_file='/home/czj/kunkun.png', query='What is this image?', conv_mode=None, sep=',', temperature=0.0, top_p=None, num_beams=1, max_new_tokens=64, layers='0-31')

In [10]:
from numpy.random import f
import argparse, re, os, torch
from io import BytesIO
from PIL import Image
import requests
from llava.constants import (
    IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_PLACEHOLDER
)
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path


def disable_torch_init():
    import torch
    import math
    torch.set_grad_enabled(False)
    torch.manual_seed(1234)
    torch.cuda.manual_seed_all(1234)

# ============== 基础工具 ==============

def load_image(image_file):
    if image_file.startswith("http"):
        resp = requests.get(image_file)
        return Image.open(BytesIO(resp.content)).convert("RGB")
    return Image.open(image_file).convert("RGB")

def build_prompt(query, model, conv_mode=None):
    model_name = model.config._name_or_path if hasattr(model.config, "_name_or_path") else ""
    if conv_mode is None:
        if "llama-2" in model_name.lower():
            conv_mode = "llava_llama_2"
        elif "mistral" in model_name.lower():
            conv_mode = "mistral_instruct"
        elif "v1.6-34b" in model_name.lower():
            conv_mode = "chatml_direct"
        elif "v1" in model_name.lower():
            conv_mode = "llava_v1"
        else:
            conv_mode = "llava_v0"

    qs = query
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in qs:
        qs = re.sub(IMAGE_PLACEHOLDER, image_token_se if model.config.mm_use_im_start_end else DEFAULT_IMAGE_TOKEN, qs)
    else:
        qs = (image_token_se if model.config.mm_use_im_start_end else DEFAULT_IMAGE_TOKEN) + "\n" + qs

    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    return prompt, conv_mode

def encode_images_for_len(model, images):
    with torch.inference_mode():
        feats = model.encode_images(
            process_images(images, model.get_vision_tower().image_processor, model.config)
            .to(model.device, dtype=torch.float16)
        )
    # 兼容不同返回结构
    if isinstance(feats, list):
        feats = feats[0]
    if feats.dim() == 4:  # [B, N, H, D]
        num_tokens = feats.shape[1] * feats.shape[2]
    elif feats.dim() == 3:  # [B, N, D]
        num_tokens = feats.shape[1]
    elif feats.dim() == 2:  # [N, D]
        num_tokens = feats.shape[0]
    else:
        raise ValueError(f"Unexpected feats shape: {feats.shape}")
    return num_tokens


def compute_spans(prompt, tokenizer, img_feat_len):
    """
    计算 system / image / question 的 token 区间（在“插入图像特征之后”的序列上）。
    做法：先找出 <image> 占位符位置前后的文本 token 数，再把 image 特征长度插进去。
    """
    # 让 tokenizer_image_token 在文本里保留 IMAGE_TOKEN_INDEX 占位，便于找到它的位置
    toks = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=None, padding_side="right")
    ids = toks[0] if isinstance(toks, list) else toks
    # 找到占位符下标
    try:
        image_pos = ids.index(IMAGE_TOKEN_INDEX)
    except ValueError:
        # 没有图片占位符，就当没有图像（极少见）
        image_pos = len(ids)

    sys_len = image_pos               # image 占位符之前的 token 数
    q_len  = len(ids) - (image_pos + 1)  # image 占位符之后的文本 token 数
    # 插入图像特征后，最终序列分段：
    # [ 0, sys_len-1 ] -> system/text-before-image
    # [ sys_len, sys_len+img_feat_len-1 ] -> image patch tokens（实际是特征）
    # [ sys_len+img_feat_len, sys_len+img_feat_len+q_len-1 ] -> question/text-after-image
    sys_span = (0, sys_len-1) if sys_len>0 else None
    img_span = (sys_len, sys_len+img_feat_len-1) if img_feat_len>0 else None
    q_start  = sys_len + img_feat_len
    q_span   = (q_start, q_start+q_len-1) if q_len>0 else None
    base_len = sys_len + img_feat_len + q_len
    spans = {
        "sys": sys_span,
        "image": img_span,
        "question": q_span,
        "base_len": base_len  # 生成开始的下标
    }
    return spans

def parse_layers_arg(layers_str, num_layers):
    if not layers_str:
        return list(range(num_layers))
    out = []
    for part in layers_str.split(","):
        part = part.strip()
        if "-" in part:
            a,b = part.split("-")
            out.extend(range(int(a), int(b)+1))
        else:
            out.append(int(part))
    out = sorted(set([i for i in out if 0 <= i < num_layers]))
    return out

# ============== Hook ==============

def register_selected_layer_hooks(model, selected_layers):
    """
    仅在 LLaMA 文本层的指定层注册 hook，抓取多头注意力。
    结果结构：attn_trace[step][layer] = [Tensor_of_shape(B,H,S,S)]
    """
    attn_trace = {}
    step_ref = {"step": -1}  # 可变引用，循环里更新

    def save_attn_hook(module, inp, out):
        if not (isinstance(out, tuple) and len(out)>1 and out[1] is not None):
            return
        layer_idx = module.layer_idx
        attn = out[1].detach().cpu()  # [B, num_heads, S, S]
        if step_ref["step"] >= 0:
            attn_trace.setdefault(step_ref["step"], {})[layer_idx] = attn

    num_hooks = 0
    for name, m in model.model.named_modules():
        mname = re.match(r"layers\.(\d+)\.self_attn$", name)
        if mname:
            lid = int(mname.group(1))
            if lid in selected_layers:
                m.layer_idx = lid
                m.register_forward_hook(save_attn_hook)
                num_hooks += 1

    return attn_trace, step_ref, num_hooks

# ============== 主逻辑：手写解码 + 收集注意力 + 分段 ==============

def run(args):
    disable_torch_init()
    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(
        args.model_path, args.model_base, model_name
    )
    model.eval().to(model.device)
    model.config.output_attentions = True  # 让 attention 参与计算

    # 构建 prompt
    prompt, conv_mode = build_prompt(args.query, model, args.conv_mode)

    # 图像
    image = load_image(args.image_file)
    images = [image]
    image_sizes = [image.size]
    image_tensors = process_images(images, image_processor, model.config).to(model.device, dtype=torch.float16)

    # 计算图像特征长度（用于分段）
    img_feat_len = encode_images_for_len(model, images)
    spans = compute_spans(prompt, tokenizer, img_feat_len)

    # 文本 token
    tok = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt", padding_side="right")
    input_ids = tok["input_ids"].to(model.device)
    attention_mask = tok.get("attention_mask", None)
    if attention_mask is not None:
        attention_mask = attention_mask.to(model.device)

    # 只在「指定层」注册 hook
    num_layers = model.config.num_hidden_layers
    selected_layers = parse_layers_arg(args.layers, num_layers)
    attn_trace, step_ref, num_hooks = register_selected_layer_hooks(model, selected_layers)
    print(f"✅ Registered hooks on layers: {selected_layers} (total {num_hooks})")
    print(f"Spans: {spans}  (base_len={spans['base_len']})")

    # 自回归解码
    max_new = args.max_new_tokens
    eos_token_id = getattr(tokenizer, "eos_token_id", 2)
    past_key_values = None

    generated = input_ids.clone()      # 文本 token 序列（不含图像 patch）
    current_input = input_ids          # 第 0 步喂完整 prompt
    base_len_text = input_ids.shape[1] # ✅ 文本解码起点（用于最终 decode）

    with torch.inference_mode():
        for t in range(max_new):
            step_ref["step"] = t  # 给 hook 标记时间步

            outputs = model(
                input_ids=current_input,
                images=image_tensors if t == 0 else None,   # 只在第 0 步喂图像
                image_sizes=image_sizes if t == 0 else None,
                past_key_values=past_key_values,
                use_cache=True,
                output_attentions=True,
                return_dict=True,
            )

            # 限定到 tokenizer 词表长度，避免越界 token
            logits = outputs.logits[:, -1, :tokenizer.vocab_size]
            next_token = torch.argmax(logits, dim=-1)   # [B]

            # 追加到生成序列（文本 token 侧）
            generated = torch.cat([generated, next_token.unsqueeze(-1)], dim=-1)

            # 终止条件：至少生成 1 个 token 后才允许停
            if eos_token_id is not None and t > 0 and (next_token == eos_token_id).all():
                break

            # 增量解码：下一步只喂“上一步生成”的 token，并带上 KV cache
            current_input = next_token.unsqueeze(-1)    # 形状 [B, 1]
            past_key_values = outputs.past_key_values

    print("\n=== Attn ")
    for step, layer_dict in attn_trace.items():
        print(f"\nStep {step}:")
        for layer, attn in layer_dict.items():
            print(f"Layer {layer}:")
            print(f"Shape: {attn.shape}")
            # print(f"attn: {attn}")

    # 输出文本与索引范围
    gen_only = generated[:, base_len_text:]    # ✅ 用文本长度作为分界

    final_text = tokenizer.batch_decode(gen_only, skip_special_tokens=True)[0].strip()

    base_len = spans["base_len"]
    gen_positions = (spans["base_len"], spans["base_len"] + gen_only.shape[1] - 1)
    print("\n=== Decode Done ===")
    print(f"generated.shape: {generated.shape}")
    print("Prediction:", final_text)
    print("System span:", spans["sys"], "Image span:", spans["image"], "Question span:", spans["question"])
    print("Gen positions:", gen_positions)
    print(f"Collected steps: {len(attn_trace)} (each has layers {sorted(attn_trace.get(0, {}).keys())})")

    # 你可以在此把注意力落盘
    # torch.save({"attn_trace": attn_trace, "spans": spans, "gen_positions": gen_positions}, "attn_trace.pt")


    return final_text, attn_trace, spans, gen_positions, selected_layers

final_text, attn_trace, spans, gen_positions, selected_layers = run(args)


You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


✅ Registered hooks on layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] (total 32)
Spans: {'sys': (0, 34), 'image': (35, 610), 'question': (611, 622), 'base_len': 623}  (base_len=623)

=== Attn 

Step 0:
Layer 0:
Shape: torch.Size([1, 32, 623, 623])
Layer 1:
Shape: torch.Size([1, 32, 623, 623])
Layer 2:
Shape: torch.Size([1, 32, 623, 623])
Layer 3:
Shape: torch.Size([1, 32, 623, 623])
Layer 4:
Shape: torch.Size([1, 32, 623, 623])
Layer 5:
Shape: torch.Size([1, 32, 623, 623])
Layer 6:
Shape: torch.Size([1, 32, 623, 623])
Layer 7:
Shape: torch.Size([1, 32, 623, 623])
Layer 8:
Shape: torch.Size([1, 32, 623, 623])
Layer 9:
Shape: torch.Size([1, 32, 623, 623])
Layer 10:
Shape: torch.Size([1, 32, 623, 623])
Layer 11:
Shape: torch.Size([1, 32, 623, 623])
Layer 12:
Shape: torch.Size([1, 32, 623, 623])
Layer 13:
Shape: torch.Size([1, 32, 623, 623])
Layer 14:
Shape: torch.Size([1, 32, 623, 623])
Layer 15:
Shape: torch.S

In [None]:

final_text, spans, gen_positions, selected_layers

('The image features a cartoon character with a distinctive appearance, resembling a cross between a cat and a human. The character is wearing a black and white outfit, and it appears to be a young man with a beard. The character is sitting down, possibly on a chair, and looking at',
 {'sys': (0, 34), 'image': (35, 610), 'question': (611, 622), 'base_len': 623},
 (623, 686),
 [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31])

: 

In [11]:
## save png
import os
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

def visualize_full_attn_logscale(
    attn_trace,
    spans,
    selected_layers,
    save_dir="attn_vis_logscale",
    avg_heads=True,
    vmin_mode="1e-4",   # 可选："1e-4" 或 "1e-3"
    cmap="magma"
):
    """
    注意力可视化（Log 版）
    - 拼接完整注意力三角矩阵
    - LogNorm 色标，vmin 可选 1e-3 或 1e-4
    - 坐标轴标注语义区块（sys / image / question / gen）
    """

    os.makedirs(save_dir, exist_ok=True)
    total_steps = len(attn_trace)
    if total_steps == 0:
        print("⚠️ attn_trace 为空")
        return

    def find_first_with_layer(layer_idx):
        for s in range(total_steps):
            if layer_idx in attn_trace.get(s, {}):
                return s
        return None

    # === 分区标签 ===
    sections = []
    for name in ["sys", "image", "question"]:
        if name in spans:
            start, end = spans[name]
            sections.append((name, start, end))
    if "base_len" in spans:
        sections.append(("gen", spans["base_len"], None))  # 生成部分

    vmin = 1e-4 if vmin_mode == "1e-4" else 1e-3
    print(f"🎨 使用 LogNorm 可视化, vmin={vmin}")

    for layer_idx in selected_layers:
        s0 = find_first_with_layer(layer_idx)
        if s0 is None:
            print(f"⚠️ Layer {layer_idx} 无注意力数据，跳过。")
            continue

        attn0 = attn_trace[s0][layer_idx][0]
        attn0_mean = attn0.mean(0) if avg_heads else attn0[0]
        q0, k0 = attn0_mean.shape
        gen_steps = total_steps - (s0 + 1)
        total_len = k0 + gen_steps

        full_mat = torch.zeros((total_len, total_len), dtype=attn0_mean.dtype)
        q_init = min(q0, k0)
        full_mat[:q_init, :k0] = attn0_mean[:q_init, :k0]

        row_cursor = k0
        for s in range(s0 + 1, total_steps):
            if layer_idx not in attn_trace[s]:
                continue
            attn_s = attn_trace[s][layer_idx][0]
            attn_s_mean = attn_s.mean(0) if avg_heads else attn_s[0]
            last_row = attn_s_mean[-1]
            k_s = last_row.shape[0]
            full_mat[row_cursor, :k_s] = last_row
            row_cursor += 1
            if row_cursor >= total_len:
                break

        fm = full_mat.numpy()
        fm /= (fm.max() + 1e-8)

        # ---- Log 范围动态调节 ----
        vmax = np.percentile(fm, 99.5)
        norm = LogNorm(vmin=vmin, vmax=max(vmax, vmin * 1e3))

        # ---- 绘图 ----
        fig, ax = plt.subplots(figsize=(7, 7))
        sns.heatmap(fm, cmap=cmap, norm=norm, ax=ax, square=True, cbar=True)
        ax.set_title(f"Layer {layer_idx} | LogNorm(vmin={vmin})", fontsize=11)
        ax.set_xlabel("Key tokens")
        ax.set_ylabel("Query tokens")

        # === 添加语义刻度 ===
        xticks, xticklabels = [], []
        yticks, yticklabels = [], []

        for name, start, end in sections:
            if end is None:
                end = total_len - 1
            mid = (start + end) // 2
            label = f"{name} [{start}-{end}]"
            xticks.append(mid)
            yticks.append(mid)
            xticklabels.append(label)
            yticklabels.append(label)

            # 区间边界线
            ax.axvline(start, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
            ax.axvline(end, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
            ax.axhline(start, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
            ax.axhline(end, color="white", linestyle="--", linewidth=0.7, alpha=0.7)

        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, rotation=45, ha="right", fontsize=7)
        ax.set_yticks(yticks)
        ax.set_yticklabels(yticklabels, rotation=0, fontsize=7)

        out_path = os.path.join(save_dir, f"layer{layer_idx}_log_vmin{vmin_mode}.png")
        plt.tight_layout()
        plt.savefig(out_path, bbox_inches="tight", dpi=300)
        plt.close(fig)
        print(f"✅ 保存注意力矩阵: {out_path}")


# 示例调用
visualize_full_attn_logscale(
    attn_trace=attn_trace,
    spans=spans,
    selected_layers=selected_layers,
    save_dir="kunkun_png",
    avg_heads=True,
    vmin_mode="1e-4",   # 或 "1e-3"
)


🎨 使用 LogNorm 可视化, vmin=0.0001
✅ 保存注意力矩阵: kunkun_png/layer0_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer1_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer2_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer3_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer4_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer5_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer6_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer7_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer8_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer9_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer10_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer11_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer12_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer13_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer14_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer15_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer16_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer17_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer18_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer19_log_vmin1e-4.png
✅ 保存注意力矩阵: kunkun_png/layer20_log_vmin1e

In [12]:
## save every step per layer as mp4
import os
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import imageio.v2 as imageio

def visualize_full_attn_logscale_animation(
    attn_trace,
    spans,
    selected_layers,
    save_dir="attn_vis_logscale_anim",
    avg_heads=True,
    vmin_mode="1e-4",   # 可选："1e-4" 或 "1e-3"
    cmap="magma",
    fps=4,
    as_mp4=True,
):
    """
    动态注意力可视化（保留原版布局）
    - 每步拼接注意力行生成一帧
    - 与原版 LogNorm 一致的布局 / 标签 / 分区
    - 输出 MP4 或 GIF
    """

    os.makedirs(save_dir, exist_ok=True)
    total_steps = len(attn_trace)
    if total_steps == 0:
        print("⚠️ attn_trace 为空")
        return

    def find_first_with_layer(layer_idx):
        for s in range(total_steps):
            if layer_idx in attn_trace.get(s, {}):
                return s
        return None

    sections = []
    for name in ["sys", "image", "question"]:
        if name in spans:
            start, end = spans[name]
            sections.append((name, start, end))
    if "base_len" in spans:
        sections.append(("gen", spans["base_len"], None))

    vmin = 1e-4 if vmin_mode == "1e-4" else 1e-3
    print(f"🎬 使用 LogNorm 可视化, vmin={vmin}")

    for layer_idx in selected_layers:
        s0 = find_first_with_layer(layer_idx)
        if s0 is None:
            print(f"⚠️ Layer {layer_idx} 无注意力数据，跳过。")
            continue

        attn0 = attn_trace[s0][layer_idx][0]
        attn0_mean = attn0.mean(0) if avg_heads else attn0[0]
        q0, k0 = attn0_mean.shape
        gen_steps = total_steps - (s0 + 1)
        total_len = k0 + gen_steps

        full_mat = torch.zeros((total_len, total_len), dtype=attn0_mean.dtype)
        q_init = min(q0, k0)
        full_mat[:q_init, :k0] = attn0_mean[:q_init, :k0]

        frames = []

        # === 动态帧生成 ===
        row_cursor = k0
        for s in range(s0 + 1, total_steps + 1):
            if s < total_steps and layer_idx in attn_trace[s]:
                attn_s = attn_trace[s][layer_idx][0]
                attn_s_mean = attn_s.mean(0) if avg_heads else attn_s[0]
                last_row = attn_s_mean[-1]
                k_s = last_row.shape[0]
                full_mat[row_cursor, :k_s] = last_row
                row_cursor += 1
                if row_cursor >= total_len:
                    break

            # === 每一步都画一帧 ===
            fm = full_mat.numpy()
            fm /= (fm.max() + 1e-8)
            vmax = np.percentile(fm, 99.5)
            norm = LogNorm(vmin=vmin, vmax=max(vmax, vmin * 1e3))

            fig, ax = plt.subplots(figsize=(7, 7))
            sns.heatmap(fm, cmap=cmap, norm=norm, ax=ax, square=True, cbar=True)
            ax.set_title(f"Layer {layer_idx} | Step {s} | LogNorm(vmin={vmin})", fontsize=11)
            ax.set_xlabel("Key tokens")
            ax.set_ylabel("Query tokens")

            # === 添加语义刻度 ===
            xticks, xticklabels = [], []
            yticks, yticklabels = [], []
            for name, start, end in sections:
                if end is None:
                    end = total_len - 1
                mid = (start + end) // 2
                label = f"{name} [{start}-{end}]"
                xticks.append(mid)
                yticks.append(mid)
                xticklabels.append(label)
                yticklabels.append(label)
                ax.axvline(start, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
                ax.axvline(end, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
                ax.axhline(start, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
                ax.axhline(end, color="white", linestyle="--", linewidth=0.7, alpha=0.7)

            ax.set_xticks(xticks)
            ax.set_xticklabels(xticklabels, rotation=45, ha="right", fontsize=7)
            ax.set_yticks(yticks)
            ax.set_yticklabels(yticklabels, rotation=0, fontsize=7)

            plt.tight_layout()
            frame_path = os.path.join(save_dir, f"tmp_layer{layer_idx}_step{s}.png")
            plt.savefig(frame_path, bbox_inches="tight", dpi=200)
            plt.close(fig)
            frames.append(frame_path)

        # === 合成为动画 ===
        if not frames:
            print(f"⚠️ Layer {layer_idx} 没有帧数据")
            continue

        gif_path = os.path.join(save_dir, f"layer{layer_idx}_log_vmin{vmin_mode}.{ 'mp4' if as_mp4 else 'gif' }")
        imgs = [imageio.imread(f) for f in frames]
        if as_mp4:
            imageio.mimwrite(gif_path, imgs, fps=fps, codec='libx264', quality=8)
        else:
            imageio.mimsave(gif_path, imgs, fps=fps)

        print(f"✅ 动态注意力动画保存: {gif_path}  ({len(frames)} 帧)")

        # 清理临时帧
        for f in frames:
            os.remove(f)
visualize_full_attn_logscale_animation(
    attn_trace=attn_trace,
    spans=spans,
    selected_layers=[0, 15, 31],
    save_dir="kunkun_step_wise",
    avg_heads=True,
    vmin_mode="1e-4",
    fps=3,
    as_mp4=True
)



🎬 使用 LogNorm 可视化, vmin=0.0001


[rawvideo @ 0x3820b240] Stream #0: not enough frames to estimate rate; consider increasing probesize


✅ 动态注意力动画保存: kunkun_step_wise/layer0_log_vmin1e-4.mp4  (62 帧)


[rawvideo @ 0x15774240] Stream #0: not enough frames to estimate rate; consider increasing probesize


✅ 动态注意力动画保存: kunkun_step_wise/layer15_log_vmin1e-4.mp4  (62 帧)


[rawvideo @ 0x2aaf8240] Stream #0: not enough frames to estimate rate; consider increasing probesize


✅ 动态注意力动画保存: kunkun_step_wise/layer31_log_vmin1e-4.mp4  (62 帧)


In [13]:
## save every layer per step as mp4
import os
import numpy as np
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import imageio.v2 as imageio
from PIL import Image

def _clamp(x, lo, hi):
    return max(lo, min(hi, x))

def _find_first_step_with_layer(attn_trace, layer_idx):
    for s in range(len(attn_trace)):
        if layer_idx in attn_trace.get(s, {}):
            return s
    return None

def _build_full_matrix_upto_step(attn_trace, layer_idx, upto_step, avg_heads=True):
    """
    把某层在 step<=upto_step 的注意力行都“攒起来”，重建一个当前时刻的完整方阵：
    - 取首次出现该层的 step=s0 的 (q0 x k0)；(通常 q0≈k0=base_len，或 q0<=k0)
    - s>s0 时每步只有最后一行(q_len=1,k_len增长)，把每一步的最后一行依次拼到下三角。
    - 返回 full_mat (total_len x total_len)，以及 total_len, k0。
    """
    s0 = _find_first_step_with_layer(attn_trace, layer_idx)
    if s0 is None or s0 > upto_step:
        return None, None, None  # 本层在 upto_step 之前都没有数据

    # 第一次有该层注意力的矩阵（通常是完整多行）：
    attn0 = attn_trace[s0][layer_idx][0]            # (1, H, Q0, K0)
    attn0_mean = attn0.mean(0) if avg_heads else attn0[0]  # (Q0, K0)
    q0, k0 = attn0_mean.shape

    # 到当前 step 的生成步数：
    gen_steps = max(0, upto_step - s0)   # 包含 s0+1,...,upto_step 这些“单行”
    total_len = k0 + gen_steps           # 当前 keys 的长度（也是我们要构造的方阵边长）

    # 初始化方阵
    full_mat = torch.zeros((total_len, total_len), dtype=attn0_mean.dtype)

    # 先放下第 0 步已有的“多行块”
    q_init = min(q0, k0)                 # 只放到方阵可见范围
    full_mat[:q_init, :k0] = attn0_mean[:q_init, :k0]

    # 逐步补齐 s0+1 .. upto_step 的“最后一行”
    row_cursor = k0
    for s in range(s0 + 1, upto_step + 1):
        if layer_idx not in attn_trace[s]:
            # 某些步可能没存到（例如你只抓了指定 layer 的步），跳过
            continue
        attn_s = attn_trace[s][layer_idx][0]               # (1, H, 1, Ks)
        attn_s_mean = attn_s.mean(0) if avg_heads else attn_s[0]  # (1, Ks)
        last_row = attn_s_mean[-1]                         # (Ks,)
        Ks = last_row.shape[0]
        # 写入当前行（注意：列只能写到 Ks）
        if row_cursor < total_len:
            full_mat[row_cursor, :Ks] = last_row
            row_cursor += 1
        else:
            break

    return full_mat, total_len, k0

def visualize_layerwise_per_step_multi_fixed(
    attn_trace,
    spans,
    selected_layers,
    selected_steps,            # list[int] 或 int
    save_dir="attn_layerwise_anim",
    avg_heads=True,
    vmin_mode="1e-4",          # "1e-4" / "1e-3"
    cmap="magma",
    fps=3,
    as_mp4=True,
    highlight_gen=True,
    figsize=(7,7),
    dpi=300
):
    """
    按“固定 step、layer 为时间轴”输出动画。
    与此前不同：先把该 step 之前的所有注意力“聚合成方阵”，再绘图。
    这样不会出现单行图像导致“整张空白”的问题。
    """
    os.makedirs(save_dir, exist_ok=True)
    if isinstance(selected_steps, int):
        selected_steps = [selected_steps]

    vmin = 1e-4 if vmin_mode == "1e-4" else 1e-3

    for step in selected_steps:
        if step not in attn_trace:
            print(f"⚠️ step={step} 不存在（max step={len(attn_trace)-1}），跳过。")
            continue

        # 过滤出该 step 你关心且确实存在的层
        layers = [L for L in selected_layers if L in attn_trace[step]]
        if not layers:
            print(f"⚠️ step={step} 无匹配层数据，跳过。")
            continue

        # 先为这个 step 里的所有层，分别构建出“完整方阵”，并顺便统计全局 vmax（保证同一步内颜色范围一致）
        mats = {}
        vals_for_vmax = []
        for Lidx in layers:
            full_mat, total_len, k0 = _build_full_matrix_upto_step(attn_trace, Lidx, step, avg_heads=avg_heads)
            if full_mat is None:
                print(f"  └─ Layer {Lidx}: 在 step<={step} 无数据，跳过。")
                continue
            fm = full_mat.detach().cpu().numpy()
            # 归一化后再汇总
            fm = fm / (fm.max() + 1e-8)
            mats[Lidx] = (fm, total_len, k0)
            vals_for_vmax.append(fm.flatten())

        if not mats:
            print(f"⚠️ step={step} 所有指定层都无可用矩阵。")
            continue

        all_vals = np.concatenate(vals_for_vmax)
        vmax = np.percentile(all_vals, 99.5)
        vmax = max(vmax, vmin * 1e3)  # 颜色范围至少拉开 3 个数量级

        # —— 逐层绘图，生成帧 —— 
        frame_paths = []
        for Lidx in layers:
            fm, total_len, k0 = mats[Lidx]
            norm = LogNorm(vmin=vmin, vmax=vmax)

            fig, ax = plt.subplots(figsize=figsize)
            sns.heatmap(fm, cmap=cmap, norm=norm, ax=ax, square=True, cbar=True)

            ax.set_title(f"Step {step} | Layer {Lidx} | LogNorm(vmin={vmin})", fontsize=11)
            ax.set_xlabel("Key tokens")
            ax.set_ylabel("Query tokens")

            # —— 分区线：对“当前 total_len”裁剪 —— 
            sections = []
            for name in ["sys", "image", "question"]:
                if name in spans:
                    s, e = spans[name]
                    s = _clamp(s, 0, total_len-1)
                    e = _clamp(e, 0, total_len-1)
                    if s <= e:
                        sections.append((name, s, e))
            # gen = [base_len, total_len-1]
            if "base_len" in spans:
                s = _clamp(spans["base_len"], 0, total_len-1)
                e = total_len-1
                if s <= e:
                    sections.append(("gen", s, e))

            # 语义刻度
            xticks, xticklabels, yticks, yticklabels = [], [], [], []
            for name, s, e in sections:
                mid = (s + e) // 2
                lab = f"{name} [{s}-{e}]"
                xticks.append(mid); yticks.append(mid)
                xticklabels.append(lab); yticklabels.append(lab)

                ax.axvline(s, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
                ax.axvline(e, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
                ax.axhline(s, color="white", linestyle="--", linewidth=0.7, alpha=0.7)
                ax.axhline(e, color="white", linestyle="--", linewidth=0.7, alpha=0.7)

            ax.set_xticks(xticks)
            ax.set_xticklabels(xticklabels, rotation=45, ha="right", fontsize=7)
            ax.set_yticks(yticks)
            ax.set_yticklabels(yticklabels, rotation=0, fontsize=7)

            # 高亮当前生成 token（裁剪）
            if highlight_gen and "base_len" in spans:
                gen_pos = _clamp(spans["base_len"] + step, 0, total_len-1)
                ax.axhline(gen_pos, color='cyan', linewidth=1.0, alpha=0.85)
                ax.axvline(gen_pos, color='cyan', linewidth=1.0, alpha=0.85)

            # 固定边距，避免像素抖动
            fig.subplots_adjust(left=0.12, right=0.98, bottom=0.18, top=0.90)

            frame_path = os.path.join(save_dir, f"tmp_step{step}_layer{Lidx}.png")
            fig.savefig(frame_path, dpi=dpi)
            plt.close(fig)
            frame_paths.append(frame_path)

        # —— 合成：强制统一分辨率（兜底） —— 
        from PIL import Image
        imgs = [Image.open(p).convert("RGB") for p in frame_paths]
        W = max(im.width for im in imgs)
        H = max(im.height for im in imgs)
        imgs_resized = [im.resize((W, H), Image.Resampling.LANCZOS) for im in imgs]

        out_path = os.path.join(save_dir, f"step{step}_layersweep.{'mp4' if as_mp4 else 'gif'}")
        if as_mp4:
            imageio.mimwrite(out_path, imgs_resized, fps=3, codec="libx264", quality=8)
        else:
            imageio.mimsave(out_path, imgs_resized, fps=3)

        print(f"✅ 保存动画：{out_path}（{len(imgs_resized)} 帧）")

        for p in frame_paths:
            try:
                os.remove(p)
            except:
                pass
            
visualize_layerwise_per_step_multi_fixed(
    attn_trace=attn_trace,
    spans=spans,                         # 你构建的 {sys:(s,e), image:(s,e), question:(s,e), base_len:int}
    selected_layers=selected_layers,  # 只看这些层
    selected_steps=[0, 30, 60],          # 这些步各出一段视频
    save_dir="kunkun_layer_wise",
    avg_heads=True,
    vmin_mode="1e-4",
    as_mp4=True
)



[rawvideo @ 0x690d240] Stream #0: not enough frames to estimate rate; consider increasing probesize


✅ 保存动画：kunkun_layer_wise/step0_layersweep.mp4（32 帧）


[rawvideo @ 0x39901240] Stream #0: not enough frames to estimate rate; consider increasing probesize


✅ 保存动画：kunkun_layer_wise/step30_layersweep.mp4（32 帧）


[rawvideo @ 0x8337240] Stream #0: not enough frames to estimate rate; consider increasing probesize


✅ 保存动画：kunkun_layer_wise/step60_layersweep.mp4（32 帧）
