In [7]:
import numpy as np
import matplotlib.pyplot as plt
import os
import json

# ========= 1. 数据读取 =========
def load_data(file_path: str):
    data = np.load(file_path)
    real_losses = data["real_losses"]
    topk_losses = data["topk_losses"]
    real_label = data["real_label"]
    topk_label = data["topk_label"]
    ground_truth_label = data["ground_truth_label"]
    return real_losses, topk_losses, real_label, topk_label, ground_truth_label

# ========= 2. 指标计算 =========
def compute_error_metrics(file_path, topk=256, t_p=0.6, t_h=2.0, eps=1e-8):
    real_losses, topk_losses, real_label, topk_label, gt = load_data(file_path)
    pred = real_label
    err_mask = pred != gt
    idx = np.where(err_mask)[0]

    if len(idx) == 0:
        raise ValueError("没有错误样本")

    K = min(topk, topk_label.shape[1])
    tk_lab = topk_label[idx, :K]
    tk_loss = topk_losses[idx, :K]
    gt_sub  = gt[idx]
    pr_sub  = pred[idx]

    # purity
    p_gt   = (tk_lab == gt_sub[:, None]).mean(axis=1)
    p_pred = (tk_lab == pr_sub[:, None]).mean(axis=1)

    # mode
    mode_lab = np.array([np.bincount(row).argmax() for row in tk_lab])
    p_max = np.array([(row == m).mean() for row, m in zip(tk_lab, mode_lab)])

    # first-hit ranks
    def first_rank(row, target):
        pos = np.where(row == target)[0]
        return (pos[0] + 1) if pos.size else np.inf
    r_gt   = np.array([first_rank(row, y) for row, y in zip(tk_lab, gt_sub)])
    r_pred = np.array([first_rank(row, y) for row, y in zip(tk_lab, pr_sub)])

    # min-loss per class
    def min_loss_for(row_lab, row_loss, c):
        mask = (row_lab == c)
        return row_loss[mask].min() if mask.any() else np.inf
    m_gt   = np.array([min_loss_for(L, S, y) for L, S, y in zip(tk_lab, tk_loss, gt_sub)])
    m_pred = np.array([min_loss_for(L, S, y) for L, S, y in zip(tk_lab, tk_loss, pr_sub)])
    g_gap  = m_pred - m_gt

    # hubness score
    def hubness(row_lab, row_loss, c):
        arr = row_loss[row_lab == c]
        if arr.size < 2: return 0.0
        arr = np.sort(arr)
        mad = np.median(np.abs(arr - np.median(arr))) + eps
        return (arr[1] - arr[0]) / mad
    h_score = np.array([hubness(L, S, y) for L, S, y in zip(tk_lab, tk_loss, pr_sub)])

    delta = p_pred - p_gt
    conf  = -real_losses[idx]

    # typing
    typ = np.full(idx.shape, "C", dtype=object)
    typ[(h_score >= t_h)] = "D"
    mask_A = (p_gt >= t_p) & (delta < 0) & (h_score < t_h)
    typ[mask_A] = "A"
    mask_B = (p_pred >= t_p) & (delta > 0) & (h_score < t_h)
    typ[mask_B] = "B"
    typ[(p_max < t_p) & (typ == "C")] = "C"

    return {
        "err_indices": idx.tolist(),
        "gt": gt_sub.tolist(),
        "pred": pr_sub.tolist(),
        "p_gt": p_gt.tolist(),
        "p_pred": p_pred.tolist(),
        "p_max": p_max.tolist(),
        "r_gt": r_gt.tolist(),
        "r_pred": r_pred.tolist(),
        "m_gt": m_gt.tolist(),
        "m_pred": m_pred.tolist(),
        "g_gap": g_gap.tolist(),
        "h": h_score.tolist(),
        "delta": delta.tolist(),
        "conf": conf.tolist(),
        "type": typ.tolist(),
        # 为 top-k 曲线准备
        "tk_lab": tk_lab.tolist(),
        "tk_loss": tk_loss.tolist()
    }

# ========= 3. 绘图 =========
def plot_all(metrics, output_dir="plots"):
    os.makedirs(output_dir, exist_ok=True)
    p_gt = np.array(metrics["p_gt"])
    p_pred = np.array(metrics["p_pred"])
    delta = np.array(metrics["delta"])
    conf = np.array(metrics["conf"])
    r_gt = np.array(metrics["r_gt"])
    r_pred = np.array(metrics["r_pred"])
    g_gap = np.array(metrics["g_gap"])
    h = np.array(metrics["h"])

    # 1. purity 散点
    plt.figure(figsize=(6,6))
    plt.scatter(p_gt, p_pred, c=delta, cmap="coolwarm", alpha=0.7)
    plt.xlabel("p_gt"); plt.ylabel("p_pred"); plt.title("Neighbor Purity Scatter")
    plt.colorbar(label="delta=p_pred-p_gt")
    plt.savefig(os.path.join(output_dir, "scatter_purity.png")); plt.close()

    # 2. first-hit rank 直方
    plt.figure(figsize=(6,4))
    plt.hist(np.array(r_gt)[np.isfinite(r_gt)], bins=50, alpha=0.6, label="r_gt")
    plt.hist(np.array(r_pred)[np.isfinite(r_pred)], bins=50, alpha=0.6, label="r_pred")
    plt.yscale("log"); plt.legend(); plt.xlabel("Rank"); plt.ylabel("Count")
    plt.title("First-hit rank")
    plt.savefig(os.path.join(output_dir, "hist_rank.png")); plt.close()

    # 3. min-loss gap
    plt.figure(figsize=(6,4))
    plt.hist(np.array(g_gap)[np.isfinite(g_gap)], bins=50, alpha=0.7)
    plt.xlabel("g = m_pred - m_gt"); plt.ylabel("Count"); plt.title("Min-loss gap distribution")
    plt.savefig(os.path.join(output_dir, "hist_gap.png")); plt.close()

    # 4. hubness score
    plt.figure(figsize=(6,4))
    plt.scatter(h, delta, c=conf, cmap="viridis", alpha=0.7)
    plt.xlabel("hubness score h"); plt.ylabel("delta"); plt.title("Hubness vs Delta")
    plt.colorbar(label="Confidence proxy")
    plt.savefig(os.path.join(output_dir, "scatter_hubness.png")); plt.close()

    # 5. 错误类别热力图
    from collections import Counter
    gt = metrics["gt"]; pred = metrics["pred"]
    pairs = list(zip(gt, pred))
    cnt = Counter(pairs)
    gts = sorted(set(gt)); preds = sorted(set(pred))
    mat = np.zeros((len(gts), len(preds)))
    for (g,p),v in cnt.items():
        i = gts.index(g); j = preds.index(p); mat[i,j]=v
    plt.figure(figsize=(6,6))
    plt.imshow(mat, cmap="Blues"); plt.colorbar()
    plt.xticks(range(len(preds)), preds); plt.yticks(range(len(gts)), gts)
    plt.xlabel("Pred"); plt.ylabel("GT"); plt.title("Confusion Heatmap on Errors")
    plt.savefig(os.path.join(output_dir, "heatmap_confusion.png")); plt.close()

    # 6. 置信度 vs 一致性
    plt.figure(figsize=(6,4))
    plt.scatter(conf, np.abs(delta), c=p_gt, cmap="plasma", alpha=0.7)
    plt.xlabel("Confidence (-loss)"); plt.ylabel("|delta|"); plt.title("Confidence vs Consistency gap")
    plt.colorbar(label="p_gt")
    plt.savefig(os.path.join(output_dir, "scatter_confidence.png")); plt.close()

    # 7. top-k 最优真值邻居曲线
    tk_lab = np.array(metrics["tk_lab"])
    tk_loss = np.array(metrics["tk_loss"])
    gt = np.array(metrics["gt"])
    kmax = min(20, tk_lab.shape[1])
    vals = []
    for k in range(1,kmax+1):
        min_losses = []
        for row_lab,row_loss,y in zip(tk_lab, tk_loss, gt):
            mask = row_lab[:k]==y
            if mask.any():
                min_losses.append(row_loss[:k][mask].min())
        if len(min_losses)>0:
            vals.append(np.mean(min_losses))
        else:
            vals.append(np.nan)
    plt.figure(figsize=(6,4))
    plt.plot(range(1,kmax+1), vals, marker="o")
    plt.xlabel("k'"); plt.ylabel("avg min loss of gt neighbor")
    plt.title("Top-k' best ground-truth neighbor curve")
    plt.savefig(os.path.join(output_dir, "curve_topk.png")); plt.close()

# ========= 4. 主函数 =========
if __name__ == "__main__":
    file_path = "./EU_N_train_4096_N_valid_4096_losses.npz"  # 替换为你的数据文件路径
    output_dir = "plots"

    metrics = compute_error_metrics(file_path)
    plot_all(metrics, output_dir=output_dir)

    # 同时导出 JSON 供分析
    with open(os.path.join(output_dir, "metrics.json"), "w") as f:
        json.dump(metrics, f, indent=2)
    print(f"分析结果已保存到 {output_dir}/metrics.json 和图像文件")


  g_gap  = m_pred - m_gt


分析结果已保存到 plots/metrics.json 和图像文件
