
# 视野预测可解释性工具（融合层 / 骨干 / Patch-PCA）— 中文版

本 Notebook 统一在顶部集中**配置**，并提供三类可视化：  
1. **融合层可解释性**：`gated-sum` 的门控 `g`、`attn` 的 2x2 注意力；  
2. **骨干（ViT）可解释性**：Attention Rollout 与 CLS→patch 显著图；  
3. **Patch 级 PCA**：对 rnflt 与 slab 使用**共享 PCA 空间**，用于“模态互补”论证。

> 说明：请将 `配置区` 中的路径改为你本机环境；模型与数据加载逻辑按你的训练脚本对齐。


## 0. 配置区（请按需修改）

In [1]:

# ====== 全局配置（集中写在这里）======
import os, math, warnings, types, json, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA


from utils import get_vit_transform, get_cnn_transform, get_albumentations_transform, get_imagenet_transform
from train import get_dataloader_Salsa, SalsaHGFAlignedDataset
from dino_model import DualDinoV3LateFusion52  

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# 路径与超参（与训练脚本保持一致）
CFG = dict(
    
    # model
    backbone      = "dinov3",                                 # 仅支持 dinov3
    dinov3_model  = "/mnt/sda/sijiali/GlaucomaCode/pretrained_weight/dinov3-vitb16-pretrain-lvd1689m",
    fusion        = "concat",                                  # "concat" / "gated-sum" / "sum" / "attn"
    vit_pool      = "cls",                                     # "cls" 或 "mean_patch"
    train_scope   = "all",                                    # "head" 或 "all"（只影响加载后是否冻结，不影响可视化）
    ckpt_path     = "/mnt/sda/sijiali/GlaucomaCode/Results_SALSA_multi/dinov3/dinov3_imagenet_all_cls_concat_lr5e-5/ckpts/best_model_epoch_302_rmse_1.0466.pth",   # 可视化的最优权重
    
    # data
    transform     = "imagenet",                                # "vit"/"cnn"/"albumentations"/"none"/"imagenet"
    modality_type = "rnflt+slab",                              # "rnflt" 或 "rnflt+slab"
    img_size      = 224,                                       # 输入尺寸
    data_root     = "/mnt/sda/sijiali/DataSet/harvardGF_unpacked",
    hgf_test_root = "/mnt/sda/sijiali/DataSet/Harvard-GF/Dataset/Test",
    
    # visualization
    batch_size    = 1,                                         # 可视化取 1 即可
    num_workers   = 2,                                         # DataLoader 线程数
    n_components  = 3,                                         # PCA 分量数
    save_root     = "./Results_visualization/SALSA_multi/_vis_exports_cn"      # 输出目录
)

# 构建与训练一致的 transform
mean = np.array([0.485, 0.456, 0.406])
std  = np.array([0.229, 0.224, 0.225])

if CFG["transform"] == "vit":
    base_transform = get_vit_transform(CFG["img_size"])
elif CFG["transform"] == "cnn":
    base_transform = get_cnn_transform(CFG["img_size"], mean, std)
elif CFG["transform"] == "albumentations":
    base_transform = get_albumentations_transform(CFG["img_size"])
elif CFG["transform"] == "none":
    base_transform = None
elif CFG["transform"] == "imagenet":
    base_transform = get_imagenet_transform(CFG["img_size"], with_slab=(CFG["modality_type"]=='rnflt+slab'))
else:
    raise ValueError("未知的 transform: %s" % CFG["transform"])

os.makedirs(CFG["save_root"], exist_ok=True)
print("配置完成。")

Device: cuda
配置完成。


## 1. 通用工具函数

In [3]:

def imshow_gray(img2d, title=None):
    """显示灰度图，img2d: H x W"""
    plt.figure(figsize=(4,4))
    plt.imshow(img2d, cmap="gray")
    if title: plt.title(title)
    plt.axis("off"); plt.show()

def save_pca_maps(pil_image: Image.Image, pca_img: np.ndarray, save_dir: str,
                  save_prefix: str, last_components_rgb: bool = True, resize=True):
    """
    保存 PCA 热图：每个分量单独保存为灰度图，最后 3 个分量组合为 RGB。
    pca_img: [Ht, Wt, C]
    """
    os.makedirs(save_dir, exist_ok=True)
    pil_image.save(os.path.join(save_dir, f"{save_prefix}_orig_img.png"))

    Ht, Wt, C = pca_img.shape
    for i in range(C):
        comp = pca_img[:, :, i]
        # 单图 min-max 归一化
        comp = (comp - comp.min()) / (comp.max() - comp.min() + 1e-8)
        im = Image.fromarray((comp * 255).astype(np.uint8))
        if resize:
            im = im.resize(pil_image.size, resample=Image.NEAREST).filter(ImageFilter.SHARPEN)
        im.save(os.path.join(save_dir, f"{save_prefix}_{i}.png"))

    if last_components_rgb and C >= 3:
        comp = pca_img[:, :, -3:]
        comp = (comp - comp.min(axis=(0,1), keepdims=True)) / (comp.ptp(axis=(0,1), keepdims=True) + 1e-8)
        im = Image.fromarray((comp * 255).astype(np.uint8))
        if resize:
            im = im.resize(pil_image.size, resample=Image.NEAREST).filter(ImageFilter.SHARPEN)
        im.save(os.path.join(save_dir, f"{save_prefix}_{C-3}_{C-2}_{C-1}_rgb.png"))


## 2. 融合层可解释性（gated-sum / attn）

In [4]:

@torch.no_grad()
def get_branch_features(model, rnflt, slab):
    """提取两路全局特征 fr/fs（与你模型 forward 一致的骨干输出）。"""
    fr = model.backbone_r(rnflt.to(device))
    fs = model.backbone_s(slab.to(device))
    return fr, fs  # [B,Dr], [B,Ds]

@torch.no_grad()
def fusion_introspect(model, fr, fs):
    """
    外部重现融合头的关键量：
    - gated-sum: 计算门控 g 以及融合后的向量
    - attn     : 计算 2x2 跨模态注意力矩阵
    其他 fusion（concat/sum）则给出 merge 前的表示（用于对比）。
    """
    head = model.head
    info = {}

    # 特征维度对齐
    proj_r = head.proj_r if hasattr(head, "proj_r") else nn.Identity()
    proj_s = head.proj_s if hasattr(head, "proj_s") else nn.Identity()
    fr_p = proj_r(fr); fs_p = proj_s(fs)

    if getattr(head, "fusion", "") == "gated-sum":
        cat = torch.cat([fr, fs], dim=1)      # [B,Dr+Ds]
        g = head.gate(cat)                    # [B,1]，已 sigmoid
        fused = g * fr_p + (1.0 - g) * fs_p   # [B,C]
        info["g"] = g.detach().cpu().numpy()
        info["fused"] = fused.detach().cpu().numpy()
        info["fr_p"] = fr_p.detach().cpu().numpy()
        info["fs_p"] = fs_p.detach().cpu().numpy()

    elif getattr(head, "fusion", "") == "attn":
        tokens = torch.stack([fr_p, fs_p], dim=1)   # [B,2,C]
        Q = head.q(tokens); K = head.k(tokens); V = head.v(tokens)
        attn = torch.softmax((Q @ K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5), dim=-1)  # [B,2,2]
        fused = (attn @ V).mean(dim=1)              # [B,C]
        info["attn_2x2"] = attn.detach().cpu().numpy()
        info["fused"] = fused.detach().cpu().numpy()

    elif getattr(head, "fusion", "") in ["concat", "sum"]:
        x = torch.cat([fr, fs], dim=1) if head.fusion == "concat" else (fr_p + fs_p)
        info["pre_merge"] = x.detach().cpu().numpy()

    return info

def plot_g_hist(g_values, title="门控 g 分布"):
    g_values = np.asarray(g_values).reshape(-1)
    plt.figure(figsize=(4,3))
    plt.hist(g_values, bins=30)
    plt.title(title); plt.xlabel("g"); plt.ylabel("数量"); plt.show()

def show_attn_matrix(attn_2x2, title="融合头 2x2 注意力矩阵"):
    A = np.asarray(attn_2x2)[0]  # 仅显示 batch=1 的情况
    fig, ax = plt.subplots(figsize=(3,3))
    im = ax.imshow(A, vmin=0, vmax=1)
    ax.set_xticks([0,1]); ax.set_xticklabels(["rnflt","slab"])
    ax.set_yticks([0,1]); ax.set_yticklabels(["rnflt","slab"])
    ax.set_title(title); plt.colorbar(im); plt.show()


## 3. 骨干（ViT）可解释性：Attention Rollout / CLS→patch

In [5]:
# ===== 稳健零侵入适配器（兼容 sdpa / eager）=====
import math, torch
import torch.nn.functional as F
from contextlib import contextmanager

@contextmanager
def _enable_outputs(hf_model, attn=True, hidd=True):
    """
    临时打开 output_attentions / output_hidden_states；如当前是 sdpa，则尽力切到 eager。
    若切换失败，则关闭 attentions（后续自动用 hidden-states 兜底），保证不报错。
    """
    cfg = getattr(hf_model, "config", None)
    state = {"attn_ok": False}  # 是否安全地启用了 attentions

    # 记录原状态
    orig_attn = getattr(cfg, "output_attentions", False) if cfg else False
    orig_hidd = getattr(cfg, "output_hidden_states", False) if cfg else False
    # 记录原注意力实现（模型或 config 可能持有）
    orig_impl_model = getattr(hf_model, "_attn_implementation", None)
    orig_impl_cfg   = getattr(cfg, "_attn_implementation", None) if cfg else None
    orig_impl_cfg2  = getattr(cfg, "attn_implementation", None) if cfg else None

    try:
        # 先开 hidden states（这个不会和 sdpa 冲突）
        if cfg and hidd:
            cfg.output_hidden_states = True

        # 试图启用 attentions
        if attn and cfg:
            # 先把实现切到 eager（优先用官方入口）
            switched = False
            if hasattr(hf_model, "set_attn_implementation"):
                try:
                    hf_model.set_attn_implementation("eager")
                    switched = True
                except Exception:
                    switched = False
            # 若没有入口则尝试直接改 config（有些模型读 config）
            if not switched and cfg is not None:
                try:
                    if hasattr(cfg, "_attn_implementation"):
                        cfg._attn_implementation = "eager"
                        switched = True
                    if hasattr(cfg, "attn_implementation"):
                        cfg.attn_implementation = "eager"
                        switched = True
                except Exception:
                    switched = False

            # 切换实现成功后再打开 attentions；否则让它保持 False，后续走兜底
            if switched:
                cfg.output_attentions = True
                state["attn_ok"] = True
            else:
                # 保证不触发 sdpa 的限制
                if hasattr(cfg, "output_attentions"):
                    cfg.output_attentions = False
                state["attn_ok"] = False

        yield state

    finally:
        # 恢复实现
        try:
            if hasattr(hf_model, "set_attn_implementation") and orig_impl_model is not None:
                hf_model.set_attn_implementation(orig_impl_model)
        except Exception:
            pass
        # 恢复 config 中实现字段
        if cfg is not None:
            try:
                if orig_impl_cfg is not None and hasattr(cfg, "_attn_implementation"):
                    cfg._attn_implementation = orig_impl_cfg
                if orig_impl_cfg2 is not None and hasattr(cfg, "attn_implementation"):
                    cfg.attn_implementation  = orig_impl_cfg2
            except Exception:
                pass
            # 恢复输出开关
            cfg.output_attentions     = orig_attn
            cfg.output_hidden_states  = orig_hidd


def _get_reg(backbone):
    """读取 DINOv3 的 register token 数（没有则为 0）。"""
    return int(getattr(backbone.net.config, "num_register_tokens", 0))

def _maybe_norm(backbone, x):
    """与训练 forward 对齐：仅在 apply_imagenet_norm=True 时做归一化。"""
    if getattr(backbone, "apply_imagenet_norm", False):
        return backbone._imagenet_norm(x)
    return x


In [6]:
# ===== 同名：ViT 多层注意力 Rollout（自动兼容 sdpa / 无注意力兜底）=====
from typing import Union

@torch.no_grad()
def vit_attention_rollout(
    backbone, x: torch.Tensor,
    layers: Union[str, list[int]] = "lastk:6",   # "all" / "lastk:6" / [层索引...]
    head_fuse: str = "mean",                     # "mean" / "max"
    add_residual: bool = True,
    residual_alpha: float = 0.5,
    upsample_to_input: bool = True,
    return_grid: bool = False,
    eps: float = 1e-8
):
    assert not getattr(backbone, "is_convnext", False), "此可视化仅适用于 ViT 分支（ConvNeXt 不支持）。"

    x = _maybe_norm(backbone, x)
    B, H_in, W_in = x.shape[0], x.shape[-2], x.shape[-1]

    with _enable_outputs(backbone.net, attn=True, hidd=True) as st:
        out = backbone.net(
            pixel_values=x,
            output_attentions=bool(st["attn_ok"]),
            output_hidden_states=True,
            return_dict=True
        )
        atts = out.attentions if st["attn_ok"] else None

        if atts is None:
            # 兜底：最后一层 patch token 的 L2 范数
            reg = _get_reg(backbone)
            hs = out.hidden_states[-1]        # [B, T, D]
            patch = hs[:, 1+reg:, :]          # [B, N, D]
            score = patch.norm(dim=-1)        # [B, N]
            N = score.shape[-1]; S = int(math.sqrt(N)); assert S*S == N, f"N={N} 不是方形网格"
            M = score.view(B,1,S,S)
            M = (M - M.amin((-2,-1),True)) / (M.amax((-2,-1),True)-M.amin((-2,-1),True)+eps)
            if upsample_to_input:
                M = F.interpolate(M, size=(H_in, W_in), mode="bilinear", align_corners=False)
            return M if not return_grid else (M, score.view(B,1,S,S))

        # 选择层
        if isinstance(layers, str):
            if layers == "all":
                use = range(len(atts))
            elif layers.startswith("lastk:"):
                k = int(layers.split(":")[1]); use = range(len(atts)-k, len(atts))
            else:
                raise ValueError("layers 仅支持 'all' / 'lastk:K' / [list]")
        else:
            use = layers

        A_roll = None
        for li in use:
            A = atts[li]                      # [B, heads, T, T]
            if head_fuse == "mean":
                A = A.mean(dim=1)             # [B, T, T]
            elif head_fuse == "max":
                A, _ = A.max(dim=1)
            else:
                raise ValueError("head_fuse 仅支持 mean/max")
            A = A / (A.sum(dim=-1, keepdim=True) + eps)
            if add_residual:
                I = torch.eye(A.shape[-1], device=A.device, dtype=A.dtype).unsqueeze(0)
                A = (A + residual_alpha * I) / (1.0 + residual_alpha)
            A_roll = A if A_roll is None else A_roll @ A

        reg = _get_reg(backbone)
        cls_to_patch = A_roll[:, 0, 1+reg:]  # [B, N_patch]
        N = cls_to_patch.shape[-1]; S = int(math.sqrt(N)); assert S*S == N, f"N={N} 不是方形网格"
        M = cls_to_patch.view(B,1,S,S)
        M = (M - M.amin((-2,-1),True)) / (M.amax((-2,-1),True)-M.amin((-2,-1),True)+eps)

        if upsample_to_input:
            M_up = F.interpolate(M, size=(H_in, W_in), mode="bilinear", align_corners=False)
            return M_up if not return_grid else (M_up, M)
        return M


# ===== 同名：最后一层 CLS→patch 显著图（自动兼容 sdpa / 无注意力兜底）=====
@torch.no_grad()
def vit_cls_saliency(
    backbone, x: torch.Tensor,
    head_fuse: str = "mean",          # "mean" / "max"
    upsample_to_input: bool = True,
    return_grid: bool = False,
    eps: float = 1e-8
):
    assert not getattr(backbone, "is_convnext", False), "此可视化仅适用于 ViT 分支。"

    x = _maybe_norm(backbone, x)
    B, H_in, W_in = x.shape[0], x.shape[-2], x.shape[-1]

    with _enable_outputs(backbone.net, attn=True, hidd=True) as st:
        out = backbone.net(
            pixel_values=x,
            output_attentions=bool(st["attn_ok"]),
            output_hidden_states=True,
            return_dict=True
        )
        reg = _get_reg(backbone)

        if not st["attn_ok"] or out.attentions is None:
            # 兜底：CLS 与 patch 的余弦相似度
            hs = out.hidden_states[-1]           # [B, T, D]
            cls = F.normalize(hs[:, 0:1, :], dim=-1)
            patch = F.normalize(hs[:, 1+reg:, :], dim=-1)
            score = (patch * cls).sum(dim=-1)    # [B, N]
            N = score.shape[-1]; S = int(math.sqrt(N)); assert S*S == N
            M = score.view(B,1,S,S)
            M = (M - M.amin((-2,-1),True)) / (M.amax((-2,-1),True)-M.amin((-2,-1),True)+eps)
            if upsample_to_input:
                M = F.interpolate(M, size=(H_in, W_in), mode="bilinear", align_corners=False)
            return M if not return_grid else (M, score.view(B,1,S,S))

        # 正常分支：最后一层注意力
        att = out.attentions[-1]                 # [B, heads, T, T]
        if head_fuse == "mean":
            A = att.mean(dim=1)
        elif head_fuse == "max":
            A, _ = att.max(dim=1)
        else:
            raise ValueError("head_fuse 仅支持 mean/max")
        A = A / (A.sum(dim=-1, keepdim=True) + eps)

        cls_to_patch = A[:, 0, 1+reg:]          # [B, N]
        N = cls_to_patch.shape[-1]; S = int(math.sqrt(N)); assert S*S == N
        M = cls_to_patch.view(B,1,S,S)
        M = (M - M.amin((-2,-1),True)) / (M.amax((-2,-1),True)-M.amin((-2,-1),True)+eps)

        if upsample_to_input:
            M_up = F.interpolate(M, size=(H_in, W_in), mode="bilinear", align_corners=False)
            return M_up if not return_grid else (M_up, M)
        return M


## 4. Patch 级 PCA（共享 PCA 空间）

In [7]:

@torch.no_grad()
def extract_patch_tokens(backbone, x: torch.Tensor, layer_idx: int = -1):
    """
    提取指定层的 patch tokens（去掉 CLS 和可选的 register），返回 tokens[B,N,D] 以及网格大小 (Ht, Wt)。
    """
    if getattr(backbone, "apply_imagenet_norm", False):
        x = backbone._imagenet_norm(x)
    out = backbone.net(pixel_values=x, output_hidden_states=True, return_dict=True)
    hs = out.hidden_states[layer_idx]              # [B, T, D]
    reg = int(getattr(backbone.net.config, "num_register_tokens", 0))
    tokens = hs[:, 1+reg:, :]                      # [B, N, D] 去 CLS(+register)
    B, N, D = tokens.shape
    S = int(math.sqrt(N)); assert S*S == N, f"patch 网格不是方形：N={N}"
    return tokens, S, S

def tokens_to_pca_maps(tokens_list, grids, n_components=3, pca_fit=None):
    """
    将若干样本/模态的 patch tokens 合并拟合一个 PCA（共享空间），
    然后分别 transform 成各自的 PCA 热图 [Ht, Wt, C]。
    """
    mats = [t.reshape(-1, t.shape[-1]) for t in tokens_list]   # [Ni, D]
    X = np.concatenate([m.cpu().numpy() for m in mats], axis=0)
    pca = pca_fit or PCA(n_components=n_components).fit(X)

    outs = []
    for t, (Ht,Wt) in zip(tokens_list, grids):
        Z = pca.transform(t.reshape(-1, t.shape[-1]).cpu().numpy())  # [N, C]
        Z = Z.reshape(Ht, Wt, n_components)
        outs.append(Z)
    return pca, outs


## 5. 数据与模型加载

In [8]:

# 5.1 构建验证 DataLoader（仅用于可视化，不训练）
try:
    val_loader = get_dataloader_Salsa(
        image_root=CFG["data_root"],
        hgf_test_root=CFG["hgf_test_root"],
        modality_type=CFG["modality_type"],
        batch_size=CFG["batch_size"],
        shuffle=False,
        num_workers=CFG["num_workers"],
        transform=base_transform,
        ids=None,   # 可指定子集ID列表；None 表示全量
    )
except Exception as e:
    print("get_dataloader_Salsa 导入失败，回退到直接实例化 Dataset：", e)
    ds = SalsaHGFAlignedDataset(
        image_root=CFG["data_root"],
        hgf_test_root=CFG["hgf_test_root"],
        modality_type=CFG["modality_type"],
        transform=base_transform
    )
    val_loader = DataLoader(ds, batch_size=CFG["batch_size"], shuffle=False, num_workers=CFG["num_workers"], pin_memory=True)

print("验证集长度：", len(val_loader))

# 5.2 构建模型并加载权重
model = DualDinoV3LateFusion52(
    rnflt_model_name=CFG["dinov3_model"],
    rnflt_vit_pool=CFG["vit_pool"],
    slab_model_name=CFG["dinov3_model"],
    slab_vit_pool=CFG["vit_pool"],
    fusion=CFG["fusion"],
    head_hidden_dim=512,
    head_dropout=0.1,
    out_dim=52,
).to(device)

def set_train_scope(m, scope="head"):
    """与训练保持一致：仅用于控制是否冻结参数（不影响可视化）。"""
    if scope == "head":
        for p in m.parameters():
            p.requires_grad = False
        for p in m.head.parameters():
            p.requires_grad = True
    elif scope == "all":
        for p in m.parameters():
            p.requires_grad = True
    else:
        raise ValueError("未知的 train_scope")

set_train_scope(model, CFG["train_scope"])

def load_ckpt_safely(m, path):
    """兼容 DataParallel 前缀的稳健加载。"""
    if path and os.path.exists(path):
        sd = torch.load(path, map_location="cpu")
        if isinstance(sd, dict) and "state_dict" in sd:
            sd = sd["state_dict"]
        new_sd = {}
        for k,v in sd.items():
            nk = k[7:] if k.startswith("module.") else k
            new_sd[nk] = v
        missing, unexpected = m.load_state_dict(new_sd, strict=False)
        print("已加载权重：", path)
        print("缺失键：", len(missing), " 额外键：", len(unexpected))
    else:
        print("未加载权重，请在 CFG['ckpt_path'] 中填写权重路径。")

load_ckpt_safely(model, CFG["ckpt_path"])
model.eval()
print("模型就绪。")


验证集长度： 696


已加载权重： /mnt/sda/sijiali/GlaucomaCode/Results_SALSA_multi/dinov3/dinov3_imagenet_all_cls_concat_lr5e-5/ckpts/best_model_epoch_302_rmse_1.0466.pth
缺失键： 0  额外键： 0
模型就绪。


## 6. 端到端示例：对一个样本生成三类可视化

In [9]:
# 取一个 batch
batch = next(iter(val_loader))
sid = batch.get("id", ["unknown"])[0]
print("样本 ID：", sid)

# 从 batch 中取出 rnflt / slab（与你 transform 的输出键一致）
x_r = batch.get("image")   # [B,3,H,W]，你的 Dataset 中 key 为 'image'
x_s = batch.get("slab") if "slab" in batch else None

if x_s is None and CFG["modality_type"] == "rnflt+slab":
    raise RuntimeError("期望 batch 中存在 'slab' 键，请检查 transform / 数据管线。")

x_r = x_r.to(device).float()
x_s = x_s.to(device).float() if x_s is not None else None

# ========== 6.1 融合层可解释性 ==========
fr = model.backbone_r(x_r)
fs = model.backbone_s(x_s) if x_s is not None else torch.zeros_like(fr)
info = fusion_introspect(model, fr, fs)
if "g" in info:
    print("g 示例：", info["g"].ravel()[:5])
    plot_g_hist(info["g"], title=f"门控 g（{sid}）")
if "attn_2x2" in info:
    show_attn_matrix(info["attn_2x2"], title=f"融合 2x2 注意力（{sid}）")

# ========== 6.2 骨干（ViT）可解释性 ==========
rollout_r = vit_attention_rollout(model.backbone_r, x_r)
imshow_gray(rollout_r[0,0].detach().cpu().numpy(), f"RNFLT rollout（{sid}）")
if x_s is not None:
    rollout_s = vit_attention_rollout(model.backbone_s, x_s)
    imshow_gray(rollout_s[0,0].detach().cpu().numpy(), f"SLAB rollout（{sid}）")

# ========== 6.3 Patch 级 PCA（共享空间） ==========
tok_r, Hr, Wr = extract_patch_tokens(model.backbone_r, x_r, layer_idx=-1)
tokens_list = [tok_r.squeeze(0)]
grids = [(Hr,Wr)]
if x_s is not None:
    tok_s, Hs, Ws = extract_patch_tokens(model.backbone_s, x_s, layer_idx=-1)
    tokens_list.append(tok_s.squeeze(0))
    grids.append((Hs,Ws))

pca, Zs = tokens_to_pca_maps(tokens_list=tokens_list, grids=grids, n_components=CFG["n_components"])

# 保存到磁盘（使用 batch 里的图重建一个 PIL，仅用于预览与对齐尺寸）
os.makedirs(os.path.join(CFG["save_root"], "rnflt"), exist_ok=True)
rnflt_np = x_r[0].detach().cpu().permute(1,2,0).numpy()
rnflt_np = (rnflt_np - rnflt_np.min()) / (rnflt_np.max() - rnflt_np.min() + 1e-8)
rnflt_pil = Image.fromarray((rnflt_np*255).astype("uint8"))
save_pca_maps(rnflt_pil, Zs[0], os.path.join(CFG["save_root"], "rnflt"), save_prefix=str(sid))

if x_s is not None:
    os.makedirs(os.path.join(CFG["save_root"], "slab"), exist_ok=True)
    slab_np = x_s[0].detach().cpu().permute(1,2,0).numpy()
    slab_np = (slab_np - slab_np.min()) / (slab_np.max() - slab_np.min() + 1e-8)
    slab_pil = Image.fromarray((slab_np*255).astype("uint8"))
    save_pca_maps(slab_pil, Zs[1], os.path.join(CFG["save_root"], "slab"), save_prefix=str(sid))

print("PCA 热图已保存至：", CFG["save_root"])


ERROR:tornado.general:SEND Error: Host unreachable


KeyboardInterrupt: 

## 7. 整合可视化图像

In [None]:
import os
from pathlib import Path
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# =========================
# 配置（按需修改）
# =========================
rnflt_dir = Path("./Results_visualization/SALSA_multi/_vis_exports_cn/rnflt")  # 你的 rnflt PCA 目录
slab_dir  = Path("./Results_visualization/SALSA_multi/_vis_exports_cn/slab")   # 你的 slab  PCA 目录
out_dir   = Path("./Results_visualization/SALSA_multi/_vis_exports_cn/panels") # 面板输出目录
out_dir.mkdir(parents=True, exist_ok=True)

# 叠加可视化参数
overlay_alpha_gray = 0.45   # comp0/1/2 灰度叠加透明度
overlay_alpha_rgb  = 0.65   # RGB 合成叠加透明度
rgb_overlay_as_gray = False # True 则把 RGB 合成转灰度后再叠加；False 保持彩色叠加



In [None]:

# =========================
# 工具函数
# =========================
def _load_png(path, mode=None):
    im = Image.open(path)
    if mode is not None:
        im = im.convert(mode)
    return im

def _paths_for_id_with_rgb(mod_dir, sid):
    """返回：原图、comp0/1/2、rgb合成 的路径"""
    p_orig = mod_dir / f"{sid}_orig_img.png"
    p0 = mod_dir / f"{sid}_0.png"
    p1 = mod_dir / f"{sid}_1.png"
    p2 = mod_dir / f"{sid}_2.png"
    prgb = mod_dir / f"{sid}_0_1_2_rgb.png"
    return p_orig, [p0, p1, p2], prgb

def _collect_ids(rnflt_dir, slab_dir):
    ids_r = {p.name.replace("_orig_img.png", "") for p in rnflt_dir.glob("*_orig_img.png")}
    ids_s = {p.name.replace("_orig_img.png", "") for p in slab_dir.glob("*_orig_img.png")}
    return sorted(list(ids_r & ids_s))

def _resize_like(img, ref, resample=Image.NEAREST):
    if img.size != ref.size:
        return img.resize(ref.size, resample=resample)
    return img

def _overlay_colormap(base_pil, comp_gray_pil, alpha=0.5, cmap="jet"):
    """把灰度 PCA component 渲染成伪彩色，然后叠加到 base 图上"""
    comp_gray_pil = _resize_like(comp_gray_pil.convert("L"), base_pil)
    g = np.array(comp_gray_pil, dtype=np.float32)
    g = (g - g.min()) / (g.max() - g.min() + 1e-8)  # normalize 0-1
    rgba = cm.get_cmap(cmap)(g)  # (H,W,4)
    rgb = (rgba[..., :3] * 255).astype(np.uint8)
    comp_color = Image.fromarray(rgb, mode="RGB")
    return Image.blend(base_pil.convert("RGB"), comp_color, alpha=alpha)

def _overlay_rgb(orig_rgb_pil, rgb_pil, alpha=0.45, as_gray=False):
    """RGB 合成图叠加；as_gray=True 时先转灰度再叠加"""
    rgb_pil = _resize_like(rgb_pil.convert("RGB"), orig_rgb_pil)
    if as_gray:
        g = rgb_pil.convert("L")
        rgb_pil = Image.merge("RGB", [g, g, g])
    return Image.blend(orig_rgb_pil.convert("RGB"), rgb_pil, alpha=alpha)

def _to_gray_rgb(img_rgb_pil: Image.Image) -> Image.Image:
    """把彩色图转成3通道灰度（用于叠加的底图）。"""
    g = img_rgb_pil.convert("L")
    return Image.merge("RGB", [g, g, g])

def make_panels_for_id(
    sid, rnflt_dir, slab_dir, out_dir,
    overlay_alpha_gray=0.45, overlay_alpha_rgb=0.45,
    rgb_overlay_as_gray=False,        # True: RGB map -> gray before overlay
    base_gray_overlays_rnflt=True,    # True: RNFLT overlay uses gray base
    base_gray_overlays_slab=True,     # True: SLAB  overlay uses gray base
    dpi=200,
    cmap_comp="jet"                   # colormap for PC overlays
):
    # ---------- paths ----------
    r_orig_p, r_comp_ps, r_rgb_p = _paths_for_id_with_rgb(rnflt_dir, sid)
    s_orig_p, s_comp_ps, s_rgb_p = _paths_for_id_with_rgb(slab_dir,  sid)

    need = [r_orig_p, s_orig_p, r_rgb_p, s_rgb_p, *r_comp_ps, *s_comp_ps]
    for p in need:
        if not Path(p).exists():
            print(f"[WARN] missing file: {p}")
            return

    # ---------- load ----------
    r_orig = _load_png(r_orig_p, "RGB")
    s_orig = _load_png(s_orig_p, "RGB")
    r_comps = [_load_png(p, "L") for p in r_comp_ps]
    s_comps = [_load_png(p, "L") for p in s_comp_ps]
    r_rgb = _load_png(r_rgb_p, "RGB")
    s_rgb = _load_png(s_rgb_p, "RGB")

    # =========================================================
    # Panel A (2×5): Image + PC0 + PC1 + PC2 + PC0/1/2 RGB
    # =========================================================
    figA, axA = plt.subplots(2, 5, figsize=(16, 7))
    titlesA = ["Image", "PC0", "PC1", "PC2", "PC0/1/2 RGB"]

    # RNFLT row
    axA[0, 0].imshow(r_orig); axA[0, 0].set_title(f"RNFLT {titlesA[0]}"); axA[0, 0].axis("off")
    for j in range(3):
        axA[0, j+1].imshow(r_comps[j], cmap="gray", vmin=0, vmax=255)
        axA[0, j+1].set_title(f"RNFLT {titlesA[j+1]}"); axA[0, j+1].axis("off")
    axA[0, 4].imshow(r_rgb); axA[0, 4].set_title(f"RNFLT {titlesA[4]}"); axA[0, 4].axis("off")

    # SLAB row
    axA[1, 0].imshow(s_orig); axA[1, 0].set_title(f"SLAB {titlesA[0]}"); axA[1, 0].axis("off")
    for j in range(3):
        axA[1, j+1].imshow(s_comps[j], cmap="gray", vmin=0, vmax=255)
        axA[1, j+1].set_title(f"SLAB {titlesA[j+1]}"); axA[1, j+1].axis("off")
    axA[1, 4].imshow(s_rgb); axA[1, 4].set_title(f"SLAB {titlesA[4]}"); axA[1, 4].axis("off")

    figA.suptitle(f"{sid} — PCA Panel A: Image + PC0/1/2 + PC0/1/2 RGB", fontsize=14, y=1.02)
    # figA.tight_layout(rect=[0, 0, 1, 0.96])  # 预留顶部避免重叠
    out_A = out_dir / f"{sid}_panelA_2x5_image_pc012_rgb.png"
    figA.savefig(out_A, dpi=dpi, bbox_inches="tight")
    plt.close(figA)

    # =========================================================
    # Panel B (2×5): Image + PC0/1/2 overlay (colored) + RGB overlay
    # =========================================================
    # bases for overlay (gray base as requested)
    r_base = _to_gray_rgb(r_orig) if base_gray_overlays_rnflt else r_orig
    s_base = _to_gray_rgb(s_orig) if base_gray_overlays_slab  else s_orig

    # colored overlays for PCs
    r_over_comp = [_overlay_colormap(r_base, r_comps[j], alpha=overlay_alpha_gray, cmap=cmap_comp) for j in range(3)]
    s_over_comp = [_overlay_colormap(s_base, s_comps[j], alpha=overlay_alpha_gray, cmap=cmap_comp) for j in range(3)]

    # RGB overlay (can force gray map if wanted)
    r_over_rgb  = _overlay_rgb(r_base, r_rgb, alpha=overlay_alpha_rgb, as_gray=rgb_overlay_as_gray)
    s_over_rgb  = _overlay_rgb(s_base, s_rgb, alpha=overlay_alpha_rgb, as_gray=rgb_overlay_as_gray)

    figB, axB = plt.subplots(2, 5, figsize=(18, 7))
    titlesB = ["Image", "PC0 overlay", "PC1 overlay", "PC2 overlay", "RGB overlay"]

    # RNFLT row
    axB[0, 0].imshow(r_orig); axB[0, 0].set_title(f"RNFLT {titlesB[0]}"); axB[0, 0].axis("off")
    for j in range(3):
        axB[0, j+1].imshow(r_over_comp[j]); axB[0, j+1].set_title(f"RNFLT {titlesB[j+1]}"); axB[0, j+1].axis("off")
    axB[0, 4].imshow(r_over_rgb); axB[0, 4].set_title(f"RNFLT {titlesB[4]}"); axB[0, 4].axis("off")

    # SLAB row
    axB[1, 0].imshow(s_orig); axB[1, 0].set_title(f"SLAB {titlesB[0]}"); axB[1, 0].axis("off")
    for j in range(3):
        axB[1, j+1].imshow(s_over_comp[j]); axB[1, j+1].set_title(f"SLAB {titlesB[j+1]}"); axB[1, j+1].axis("off")
    axB[1, 4].imshow(s_over_rgb); axB[1, 4].set_title(f"SLAB {titlesB[4]}"); axB[1, 4].axis("off")

    figB.suptitle(f"{sid} — PCA Panel B: Image + PC overlays + RGB overlay", fontsize=14, y=1.02)
    figB.tight_layout(rect=[0, 0, 1, 0.96])
    out_B = out_dir / f"{sid}_panelB_2x5_image_pc_overlays_rgb.png"
    figB.savefig(out_B, dpi=dpi, bbox_inches="tight")
    plt.close(figB)

    print(f"[OK] saved: {out_A.name} | {out_B.name}")

In [None]:
# =========================
# 批量生成（对 rnflt/slab 都存在的 ID）
# =========================
ids = _collect_ids(rnflt_dir, slab_dir)
print(f"发现 {len(ids)} 个样本：", ids[:5], "..." if len(ids) > 5 else "")

for sid in ids:
    make_panels_for_id(
    sid=sid,
    rnflt_dir=rnflt_dir,
    slab_dir=slab_dir,
    out_dir=out_dir,
    overlay_alpha_gray=overlay_alpha_gray,   # 灰度热图叠加透明度
    overlay_alpha_rgb=overlay_alpha_gray,    # RGB热图叠加透明度
    rgb_overlay_as_gray=rgb_overlay_as_gray  # True→RGB热图先转灰度再叠加
    )