# 画54点的网格

In [None]:
# 54 点的棋盘位置
RECT_POS = [                   [3,0],[4,0],[5,0],[6,0],
                        [2,1],[3,1],[4,1],[5,1],[6,1],[7,1],
                    [1,2],[2,2],[3,2],[4,2],[5,2],[6,2],[7,2],[8,2],
                    [1,3],[2,3],[3,3],[4,3],[5,3],[6,3],[7,3],[8,3],[9,3],
                    [1,4],[2,4],[3,4],[4,4],[5,4],[6,4],[7,4],[8,4],[9,4],
                    [1,5],[2,5],[3,5],[4,5],[5,5],[6,5],[7,5],[8,5],
                        [2,6],[3,6],[4,6],[5,6],[6,6],[7,6],
                                [3,7],[4,7],[5,7],[6,7]]

In [6]:
import numpy as np

def map_52_to_54(vec52, missing_idx_54=(26, 27)):
    """
    把长度52的阈值向量映射到长度54，按 missing_idx_54 这两个位置补 NaN。
    missing_idx_54 是在“54点顺序下”的两个下标（0-based）。
    """
    vec52 = np.asarray(vec52).reshape(-1)
    assert len(vec52) == 52, f"expected len=52, got {len(vec52)}"
    out = np.empty(54, dtype=np.float32)
    out[:] = np.nan
    # 先放入所有位置
    j = 0
    miss = set(missing_idx_54)
    for i in range(54):
        if i in miss:
            continue
        out[i] = vec52[j]
        j += 1
    assert j == 52, "index mapping error"
    return out


In [7]:
import matplotlib.pyplot as plt

def plot_index_grid(save_path="vf_index_grid.png", rect_size=54):
    fig, ax = plt.subplots(1,1, figsize=(6,5))
    for i,(x,y) in enumerate(RECT_POS):
        X = x*rect_size; Y = (7-y)*rect_size
        ax.fill([X,X+rect_size,X+rect_size,X],[Y,Y,Y+rect_size,Y], 
                color=(0.9,0.9,0.9), edgecolor='k', linewidth=0.5)
        ax.text(X+rect_size/2, Y+rect_size/2, str(i), ha='center', va='center', fontsize=9)
    ax.axis('scaled'); ax.axis('off')
    plt.tight_layout(); plt.savefig(save_path, dpi=200); plt.close(fig)
    print(f"[index grid] saved -> {save_path}")


In [8]:
plot_index_grid("vf_index_grid.png")

[index grid] saved -> vf_index_grid.png


# 52点可视化

In [None]:
# ==================== hgf_viz_52.py ====================
import os
from typing import Sequence, Tuple, Optional
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# ---------- 24-2 的 54 点网格（固定）----------
RECT_POS_54 = [
    [3,0],[4,0],[5,0],[6,0],
    [2,1],[3,1],[4,1],[5,1],[6,1],[7,1],
    [1,2],[2,2],[3,2],[4,2],[5,2],[6,2],[7,2],[8,2],
    [1,3],[2,3],[3,3],[4,3],[5,3],[6,3],[7,3],[8,3],[9,3],
    [1,4],[2,4],[3,4],[4,4],[5,4],[6,4],[7,4],[8,4],[9,4],
    [1,5],[2,5],[3,5],[4,5],[5,5],[6,5],[7,5],[8,5],
    [2,6],[3,6],[4,6],[5,6],[6,6],[7,6],
    [3,7],[4,7],[5,7],[6,7]
]

def get_rect_pos_52(missing_idx: Tuple[int,int]=(25,34)):
    """从 54 点里删除盲点两格，得到 52 点坐标表和保留的 54 索引。"""
    miss = set(missing_idx)
    pos_52, keep = [], []
    for i,(x,y) in enumerate(RECT_POS_54):
        if i in miss: 
            continue
        pos_52.append([x,y])
        keep.append(i)
    return pos_52, keep  # len=52

# ------------------ 基础绘制：52 点灰阶网格 ------------------
def draw_vf_52(
    ax,
    values_52: Sequence[float],
    title: Optional[str]=None,
    rect_size: int = 54,
    vmin: float = -50, vmax: float = 50.0,
    text: bool = True,
    missing_idx: Tuple[int,int]=(26,35),
):
    """把长度 52 的阈值向量画在 24-2 网格上（去掉盲点两格）。"""
    vals = np.asarray(values_52).reshape(-1)
    assert len(vals) == 52, f"expect 52 values, got {len(vals)}"
    pos_52, _ = get_rect_pos_52(missing_idx)

    for i,(x,y) in enumerate(pos_52):
        X = x*rect_size
        Y = (7 - y)*rect_size
        vf = float(vals[i])
        bg = (vf - vmin) / (vmax - vmin)
        bg = max(0.0, min(1.0, bg))
        txt_color = 'white' if bg < 0.5 else 'black'
        ax.fill([X, X+rect_size, X+rect_size, X],
                [Y, Y, Y+rect_size, Y], color=(bg,bg,bg), edgecolor='none')
        if text:
            ax.text(X+rect_size/2, Y+rect_size/2, f"{vf:.1f}",
                    ha='center', va='center', color=txt_color, fontsize=9)

    # 把盲点两格画成深色占位（可改成留空）
    for j in missing_idx:
        x,y = RECT_POS_54[j]
        X = x*rect_size; Y = (7 - y)*rect_size
        ax.fill([X, X+rect_size, X+rect_size, X],
                [Y, Y, Y+rect_size, Y], color='black', edgecolor='none', alpha=0.9)

    ax.axis('scaled'); ax.axis('off')
    if title: ax.set_title(title, fontsize=12)

# ---------------------- Panel A：单样本三联图 ----------------------
def panel_A_hgf_52(
    truth_52: Sequence[float],
    pred_52:  Sequence[float],
    *,
    rnfl_map: Optional[np.ndarray] = None,   # Harvard-GF 的 npz['rnflt'] 矩阵（可选）
    rnfl_path: Optional[str] = None,         # 若你用 SALSA 的 jpg 也可传路径（可选）
    save_path: Optional[str] = None,
    suptitle: Optional[str] = None,
    vmin: float = -50.0, vmax: float = 50.0,
    rect_size: int = 54,
    missing_idx: Tuple[int,int] = (25,34),
    rnfl_cmap: str = "jet"
):
    """论文 Panel-A：Input (可选) + Ground truth + Prediction。"""
    # 判断左列是否画输入
    has_img = False
    img_mode = None  # 'array' | 'path'
    if rnfl_map is not None:
        has_img = True; img_mode = 'array'
    elif rnfl_path and os.path.exists(rnfl_path):
        has_img = True; img_mode = 'path'

    ncols = 3 if has_img else 2
    fig, axes = plt.subplots(1, ncols, figsize=(4*ncols, 4))
    col = 0

    # 左：输入
    if has_img:
        if img_mode == 'array':
            if rnfl_map.ndim == 2:
                axes[col].imshow(rnfl_map, cmap=rnfl_cmap)
            else:
                axes[col].imshow(rnfl_map)
        else:
            axes[col].imshow(Image.open(rnfl_path))
        axes[col].set_title("Input (RNFLT)", fontsize=12)
        axes[col].axis('off')
        col += 1

    # 中 & 右：GT / Pred
    draw_vf_52(axes[col], truth_52, title="Ground truth",
               rect_size=rect_size, vmin=vmin, vmax=vmax, missing_idx=missing_idx); col += 1
    draw_vf_52(axes[col], pred_52,  title="Prediction",
               rect_size=rect_size, vmin=vmin, vmax=vmax, missing_idx=missing_idx)

    if suptitle: fig.suptitle(suptitle, fontsize=13)
    plt.tight_layout()
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=300)
        print(f"[Panel A] saved -> {save_path}")
        plt.close(fig)
    else:
        plt.show()

# ---------------------- Panel B：逐点 MAE 与改进 ----------------------
def pointwise_mae_52(truth_mat_52: np.ndarray, pred_mat_52: np.ndarray) -> np.ndarray:
    """
    输入：
      truth_mat_52: (N,52)
      pred_mat_52:  (N,52)
    返回：
      mae_52: (52,) —— 每个检测点的 MAE 平均
    """
    truth_mat_52 = np.asarray(truth_mat_52)
    pred_mat_52  = np.asarray(pred_mat_52)
    assert truth_mat_52.shape == pred_mat_52.shape and truth_mat_52.shape[1] == 52
    return np.mean(np.abs(pred_mat_52 - truth_mat_52), axis=0)

def heatmap_grid_52(
    ax, vec_52: Sequence[float], title: str,
    cmap='viridis', vmin=None, vmax=None,
    marker_size_scale: float = 2200.0,
    missing_idx: Tuple[int,int]=(25,34)
):
    """把 52 点向量画成“离散方块热力图”（与 Panel B 风格一致）。"""
    pos_52, _ = get_rect_pos_52(missing_idx)
    xs = [p[0] for p in pos_52]
    ys = [p[1] for p in pos_52]
    sc = ax.scatter(xs, ys, c=vec_52, s=marker_size_scale, marker='s',
                    cmap=cmap, vmin=vmin, vmax=vmax)
    ax.invert_yaxis()
    ax.set_aspect('equal')
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_title(title, fontsize=12)
    plt.colorbar(sc, ax=ax, fraction=0.046, pad=0.02)

def panel_B_hgf_52(
    truth_mat_52: np.ndarray,
    pred_mat_52:  np.ndarray,
    *,
    baseline_pred_mat_52: Optional[np.ndarray] = None,  # 若给基线，则画“改进图”
    save_path: Optional[str] = None,
    missing_idx: Tuple[int,int]=(25,34),
    mae_cmap='magma', improve_cmap='coolwarm'
):
    """
    论文 Panel-B：
      左上（或单图）：Pointwise MAE (52)
      若提供 baseline_pred_mat_52，则右上画 MAE 改善图（+ 正 = 改善）
    """
    mae = pointwise_mae_52(truth_mat_52, pred_mat_52)

    if baseline_pred_mat_52 is not None:
        base_mae = pointwise_mae_52(truth_mat_52, baseline_pred_mat_52)
        improve = base_mae - mae  # 正数=改善
        fig, axes = plt.subplots(1, 2, figsize=(8,4))
        heatmap_grid_52(axes[0], mae,     "Pointwise MAE (52)", cmap=mae_cmap,    missing_idx=missing_idx)
        heatmap_grid_52(axes[1], improve, "MAE improvement (52)", cmap=improve_cmap, missing_idx=missing_idx)
    else:
        fig, ax = plt.subplots(1, 1, figsize=(4,4))
        heatmap_grid_52(ax, mae, "Pointwise MAE (52)", cmap=mae_cmap, missing_idx=missing_idx)

    plt.tight_layout()
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=300)
        print(f"[Panel B] saved -> {save_path}")
        plt.close(fig)
    else:
        plt.show()
# ==================== END ====================


## A

In [None]:
import numpy as np
import pandas as pd

# 读你的推理结果
dfp = pd.read_excel("/mnt/sda/sijiali/Results_infer_HGF/test_pred.xlsx")
dft = pd.read_excel("/mnt/sda/sijiali/Results_infer_HGF/test_tar.xlsx")

row = 0  # 选要展示的样本
pred_52 = dfp.loc[row, [c for c in dfp.columns if c.startswith("pred_")]].to_numpy()
tar_52  = dft.loc[row, [c for c in dft.columns  if c.startswith("tar_")]].to_numpy()

# 从 Harvard-GF 读取 RNFLT 矩阵（可选）
sid = str(dfp.loc[row, "test_id"])          # 你的 infer 保存的 id 列名
npz_path = f"/mnt/sda/sijiali/DataSet/Harvard-GF/Dataset/Test/{sid}.npz"
rnflt = np.load(npz_path, allow_pickle=True)["rnflt"]  # (H,W)

panel_A_hgf_52(
    truth_52=tar_52,
    pred_52=pred_52,
    rnfl_map=rnflt,                       
    save_path=f"./figs/hgf_panelA_{sid}.png",
    suptitle=f"Sample {sid}",
    missing_idx=(26,35)                 
)


[Panel A] saved -> ./figs/hgf_panelA_data_2401.png


## B

In [16]:
# import numpy as np
# import pandas as pd

# # 读你的推理结果/真值（全体样本）
# dfp = pd.read_excel("./Results_HGF/test_pred.xlsx")
# dft = pd.read_excel("./Results_HGF/test_tar.xlsx")

# pred_mat = dfp[[c for c in dfp.columns if c.startswith("pred_")]].to_numpy()  # (N,52)
# tar_mat  = dft[[c for c in dft.columns  if c.startswith("tar_")]].to_numpy()  # (N,52)

# # （可选）读基线模型的预测
# # dfb = pd.read_excel("./Results_HGF/baseline_pred.xlsx")
# # base_mat = dfb[[c for c in dfb.columns if c.startswith("pred_")]].to_numpy()

# panel_B_hgf_52(
#     truth_mat_52=tar_mat,
#     pred_mat_52=pred_mat,
#     # baseline_pred_mat_52=base_mat,  # 若有基线，解注释
#     save_path="./figs/hgf_panelB.png",
#     missing_idx=(26,35)
# )
