In [None]:
import os
import re
import subprocess
import sys
import tempfile
import uuid

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

# 你可以在脚本最上面配置这些常量
BASE_OUTPUT_DIR = "/home/cy/nuist-lab/CcGAN-AVAR/output"
DATA_NAME = "UTKFace"
IMG_SIZE = 128
METRIC_NAME = "SFID"  # 或 "FID" "IS" 等

from matplotlib.lines import Line2D  # 顶部建议加一下（虽然这里其实用不到，也无妨）

def draw(experiments, title=None, verbose=False):
    """
    experiments: dict
        实验名(用于图例) -> (setting_suffix, server_idx)

    例如:
        experiments = {
            "128 GPUx1": ("cy_1_1", 1),
            "128 GPUx4": ("cy_1_0123", 1),
            # "128 GPUx4 (server83)": ("cy_1_0123", 3),
        }
    """
    if not experiments:
        print("请先在脚本里填写 e_map 再调用 draw。")
        sys.exit()

    all_exp_data = {}   # legend_name -> {iter: value}

    # ===== 先解析每个实验的 eval 结果 =====
    for legend_name, (setting_suffix, server_idx) in experiments.items():
        setting_folder = f"hav_{setting_suffix}"
        result_dir = os.path.join(
            BASE_OUTPUT_DIR,
            f"{DATA_NAME}_{IMG_SIZE}",
            setting_folder,
            "results"
        )
        summary_path = os.path.join(result_dir, "all_eval_results.txt")

        if verbose:
            print(f"[Info] 解析实验: {legend_name}")
            print(f"       setting_suffix = {setting_suffix}, server_idx = {server_idx}")
            print(f"       summary_path   = {summary_path}")

        iter_to_val = parse_all_eval_results(summary_path, METRIC_NAME, s_idx=server_idx, verbose=verbose)
        if not iter_to_val:
            print(f"  -> 没有从 {summary_path} 中解析到任何 {METRIC_NAME} 数据，该实验只在图例中显示，不画点。")
            # 关键：也把它存进去，只是值为空 dict
            all_exp_data[legend_name] = {}
            continue

        all_exp_data[legend_name] = iter_to_val

    # 即便所有实验都没数据，all_exp_data 也不为空了，所以这里不用再检查 not all_exp_data

    # =====================================================
    # 汇总所有 iter 和 所有值，计算 y 轴范围
    # =====================================================
    all_iters = set()
    all_values = []

    for _, iter_to_val in all_exp_data.items():
        all_iters.update(iter_to_val.keys())
        all_values.extend(iter_to_val.values())

    all_iters = sorted(all_iters)

    if all_values:
        y_min = min(all_values)
        y_max = max(all_values)
        y_range = y_max - y_min if y_max > y_min else 1.0
        margin = y_range * 0.1  # 原始设置
    else:
        # 所有实验都没有任何数值（纯占位图）
        y_min, y_max = 0.0, 1.0
        y_range = 1.0
        margin = 0.1

    # =====================================================
    # 画图（保持原来的大小/风格）
    # =====================================================
    plt.figure(figsize=(7, min(2 * len(experiments), 7)))

    # 用于之后文本错位计算的缓存
    annotations = []  # 每个元素: dict(x=..., y=..., text=..., exp_idx=...)

    # 先画曲线 + 记录每个点
    for exp_idx, (legend_name, iter_to_val) in enumerate(all_exp_data.items()):
        if not iter_to_val:
            # 没有任何点的数据：只画一个“空线”，用于 legend
            plt.plot(
                [], [],
                marker="o",
                linestyle="-",
                linewidth=1.5,
                markersize=4,
                label=legend_name,
            )
            continue

        iters_sorted = sorted(iter_to_val.keys())
        vals_sorted = [iter_to_val[i] for i in iters_sorted]

        # 画线
        plt.plot(
            iters_sorted,
            vals_sorted,
            marker="o",
            linestyle="-",
            linewidth=1.5,
            markersize=4,
            label=legend_name,
        )

        # 暂时不画文字，只记录下来
        for x, y in zip(iters_sorted, vals_sorted):
            annotations.append({
                "x": x,
                "y": y,
                "text": f"{y:.3f}",
                "exp_idx": exp_idx,
            })

    # =====================================================
    # 按 x（iteration）聚合点，做“接近 10%”的分组和上下错位
    # =====================================================
    points_by_x = {}
    for i, ann in enumerate(annotations):
        x = ann["x"]
        points_by_x.setdefault(x, []).append(i)

    base_offset_factor = 0.03  # 同一 iter 内的错位幅度

    y_text_positions = [None] * len(annotations)

    for x, idx_list in points_by_x.items():
        if len(idx_list) == 1:
            i = idx_list[0]
            y0 = annotations[i]["y"]
            y_text_positions[i] = y0  # 紧贴点
            continue

        idx_list_sorted = sorted(idx_list, key=lambda i: annotations[i]["y"])

        # 按“相差 <= 10%”划分子组
        groups = []
        current_group = [idx_list_sorted[0]]
        y_anchor = annotations[idx_list_sorted[0]]["y"]

        for i in idx_list_sorted[1:]:
            y_i = annotations[i]["y"]
            thr = 0.1 * max(abs(y_anchor), 1e-8)
            if abs(y_i - y_anchor) <= thr:
                current_group.append(i)
            else:
                groups.append(current_group)
                current_group = [i]
                y_anchor = y_i
        groups.append(current_group)

        # 对每个组做错位
        for group in groups:
            if len(group) == 1:
                j = group[0]
                y0 = annotations[j]["y"]
                y_text_positions[j] = y0 + y_range * 0.02
                continue

            k = len(group)
            rank_center = (k + 1) / 2.0
            step = base_offset_factor * y_range

            group_sorted = sorted(group, key=lambda j: annotations[j]["y"])
            for i, j in enumerate(group_sorted):
                rank = i + 1
                offset = (rank - rank_center) * step
                y_text_positions[j] = annotations[j]["y"] + offset

    # 真正画文字
    for ann, y_text in zip(annotations, y_text_positions):
        plt.text(
            ann["x"],
            y_text,
            ann["text"],
            ha="center",
            va="bottom",
            fontsize=10,
        )

    # 轴标签 & 外观
    plt.xlabel("Iteration", fontsize=12)
    plt.ylabel(METRIC_NAME, fontsize=12)
    if title and title.strip() != "":
        plt.title(title, fontsize=14)

    if all_values:
        plt.ylim(y_min - margin * 2, y_max + margin)

    plt.xticks(all_iters, rotation=45, fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.legend(fontsize=10)
    plt.tight_layout()

    # 保存图片
    if title and title.strip() != "":
        safe_title = title.strip()
        safe_title = re.sub(r"\s+", "_", safe_title)
        safe_title = re.sub(r"[^0-9A-Za-z_-]", "_", safe_title)
        safe_title = re.sub(r"_+", "_", safe_title)
        out_name = f"{safe_title}.png"
    else:
        rand_str = uuid.uuid4().hex[:8]
        out_name = f"{METRIC_NAME}_curve_{rand_str}.png"

    out_dir = os.path.join(os.getcwd(), "temp")
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, out_name)
    plt.savefig(out_path, dpi=200)
    print(f"\n[完成] 图像已保存为: {out_path}")

    plt.show()
def draw1(experiments, title=None, verbose=False):
    """
    experiments: dict
        实验名(用于图例) -> (setting_suffix, server_idx)

    例如:
        experiments = {
            "128 GPUx1": ("cy_1_1", 1),
            "128 GPUx4": ("cy_1_0123", 1),
            # "128 GPUx4 (server83)": ("cy_1_0123", 3),
        }
    """
    if not experiments:
        print("请先在脚本里填写 e_map 再调用 draw。")
        sys.exit()

    all_exp_data = {}  # legend_name -> {iter: value}

    # ===== 先解析每个实验的 eval 结果 =====
    for legend_name, (setting_suffix, server_idx) in experiments.items():
        setting_folder = f"hav_{setting_suffix}"
        result_dir = os.path.join(
            BASE_OUTPUT_DIR,
            f"{DATA_NAME}_{IMG_SIZE}",
            setting_folder,
            "results"
        )
        summary_path = os.path.join(result_dir, "all_eval_results.txt")

        if verbose:
            print(f"[Info] 解析实验: {legend_name}")
            print(f"       setting_suffix = {setting_suffix}, server_idx = {server_idx}")
            print(f"       summary_path   = {summary_path}")

        iter_to_val = parse_all_eval_results(summary_path, METRIC_NAME, s_idx=server_idx, verbose=verbose)
        if not iter_to_val:
            print(f"  -> 没有从 {summary_path} 中解析到任何 {METRIC_NAME} 数据，跳过该实验。")
            continue

        all_exp_data[legend_name] = iter_to_val

    if not all_exp_data:
        print("所有实验都没有解析到数据，检查 e_map 配置、路径和 METRIC_NAME 是否正确。")
        sys.exit()

    # =====================================================
    # 汇总所有 iter 和 所有值，计算 y 轴范围
    # =====================================================
    all_iters = set()
    all_values = []

    for _, iter_to_val in all_exp_data.items():
        all_iters.update(iter_to_val.keys())
        all_values.extend(iter_to_val.values())

    all_iters = sorted(all_iters)
    y_min = min(all_values)
    y_max = max(all_values)
    y_range = y_max - y_min if y_max > y_min else 1.0
    margin = y_range * 0.1  # 你的原始设置

    # =====================================================
    # 画图（保持原来的大小/风格）
    # =====================================================
    plt.figure(figsize=(7, min(2 * len(experiments), 7)))

    # 用于之后文本错位计算的缓存
    annotations = []  # 每个元素: dict(x=..., y=..., text=..., exp_idx=...)

    # 先画曲线 + 记录每个点
    for exp_idx, (legend_name, iter_to_val) in enumerate(all_exp_data.items()):
        iters_sorted = sorted(iter_to_val.keys())
        vals_sorted = [iter_to_val[i] for i in iters_sorted]

        # 先画线
        plt.plot(
            iters_sorted,
            vals_sorted,
            marker="o",
            linestyle="-",
            linewidth=1.5,
            markersize=4,
            label=legend_name,
        )

        # 暂时不画文字，只记录下来
        for x, y in zip(iters_sorted, vals_sorted):
            annotations.append({
                "x": x,
                "y": y,
                "text": f"{y:.3f}",
                "exp_idx": exp_idx,
            })

    # =====================================================
    # 按 x（iteration）聚合点，做“接近 10%”的分组和上下错位
    # =====================================================
    # 构建：某个 iter 上的所有 annotation 索引
    points_by_x = {}
    for i, ann in enumerate(annotations):
        x = ann["x"]
        points_by_x.setdefault(x, []).append(i)

    # 每组内的上下错位尺度（相对整体 y_range）
    base_offset_factor = 0.03  # 想错位更明显就调大一点

    # 计算每个 annotation 的文本 y 坐标
    y_text_positions = [None] * len(annotations)

    for x, idx_list in points_by_x.items():
        if len(idx_list) == 1:
            # 只有一个点，直接用原逻辑：略微上移
            i = idx_list[0]
            y0 = annotations[i]["y"]
            # y_text_positions[i] = y0 + y_range * 0.02
            y_text_positions[i] = y0
            continue

        # 多个点：先按 y 排序
        idx_list_sorted = sorted(idx_list, key=lambda i: annotations[i]["y"])

        # 按“相差 <= 10%”划分子组
        groups = []
        current_group = [idx_list_sorted[0]]
        y_anchor = annotations[idx_list_sorted[0]]["y"]

        for i in idx_list_sorted[1:]:
            y_i = annotations[i]["y"]
            # 10% 阈值（相对 anchor）
            thr = 0.1 * max(abs(y_anchor), 1e-8)
            if abs(y_i - y_anchor) <= thr:
                current_group.append(i)
            else:
                groups.append(current_group)
                current_group = [i]
                y_anchor = y_i
        groups.append(current_group)

        # 对每个组做错位
        for group in groups:
            if len(group) == 1:
                # 组内只有一个点，不用错位
                j = group[0]
                y0 = annotations[j]["y"]
                y_text_positions[j] = y0 + y_range * 0.02
                continue

            # 组内多点：围绕均值 y_avg 上下排开
            k = len(group)
            rank_center = (k + 1) / 2.0  # rank从1开始算, 7的话center就是4, 8的话就是4.5
            step = base_offset_factor * y_range  # 每个点的间距

            # 按 y 排序的顺序，围绕 y_avg 对称错位
            group_sorted = sorted(group, key=lambda j: annotations[j]["y"])
            for i, j in enumerate(group_sorted):
                rank = i + 1
                offset = (rank - rank_center) * step  # 有正有负
                # 根据自身与中心点的排序比例,决定偏移多少
                y_text_positions[j] = annotations[j]["y"] + offset

    # 真正画文字
    for ann, y_text in zip(annotations, y_text_positions):
        plt.text(
            ann["x"],
            y_text,
            ann["text"],
            ha="center",
            va="bottom",
            fontsize=10,
        )

    # 轴标签 & 其他外观
    plt.xlabel("Iteration", fontsize=12)
    plt.ylabel(METRIC_NAME, fontsize=12)
    if title and title.strip() != "":
        plt.title(title, fontsize=14)

    # 不额外放大 y 轴，只用你原来的上下 margin 设置
    plt.ylim(y_min - margin * 2, y_max + margin)

    plt.xticks(all_iters, rotation=45, fontsize=10)
    plt.yticks(fontsize=10)
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.legend(fontsize=10)
    plt.tight_layout()

    # 保存图片
    if title and title.strip() != "":
        safe_title = title.strip()
        safe_title = re.sub(r"\s+", "_", safe_title)
        safe_title = re.sub(r"[^0-9A-Za-z_-]", "_", safe_title)
        safe_title = re.sub(r"_+", "_", safe_title)
        out_name = f"{safe_title}.png"
    else:
        rand_str = uuid.uuid4().hex[:8]
        out_name = f"{METRIC_NAME}_curve_{rand_str}.png"

    out_dir = os.path.join(os.getcwd(), "temp")
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, out_name)
    plt.savefig(out_path, dpi=200)
    print(f"\n[完成] 图像已保存为: {out_path}")

    plt.show()


def parse_all_eval_results(summary_path, metric_name, s_idx=1, verbose=False):
    """
    从 all_eval_results.txt 中解析出：iter -> metric_value

    summary 文件中格式类似：
        Checkpoint: ckpt_niter_5000.pth (iter=5000)
          SFID: 0.624 (0.165)
          LS: 10.124 (7.674)
          Diversity: 1.134 (0.093)
          FID: 0.624
          IS: 2.998 (0.021)

    s_idx:
        1 -> 本地文件
        3 -> 从 cy@192.168.100.83:summary_path 拿
        4 -> 从 cy@192.168.100.81:summary_path 拿
    """

    # ========== 根据 s_idx 决定是否走远程 ==========
    if s_idx in (3, 4):
        # 映射服务器
        if s_idx == 3:
            host = "192.168.100.83"
        else:  # s_idx == 4
            host = "192.168.100.81"

        remote_spec = f"cy@{host}:{summary_path}"

        # 本地缓存目录：/tmp/remote_eval_cache
        cache_dir = os.path.join(tempfile.gettempdir(), "remote_eval_cache")
        os.makedirs(cache_dir, exist_ok=True)

        # 为了避免不同服务器间同名文件互相覆盖，前面加 s_idx 前缀
        local_fname = f"s{s_idx}_" + os.path.basename(summary_path)
        local_path = os.path.join(cache_dir, local_fname)

        if verbose:
            print(f"[Info] 从远程 {remote_spec} 拷贝到本地缓存: {local_path}")

        try:
            # 这里使用 sshpass + scp，如果你还没装：
            #   sudo apt-get install sshpass
            #
            # 警告：密码明文写在脚本里很不安全，仅用于你自己实验环境。
            subprocess.run(
                [
                    "sshpass", "-p", "dx666666",
                    "scp",
                    "-o", "StrictHostKeyChecking=no",
                    remote_spec,
                    local_path,
                ],
                check=True,
            )
        except FileNotFoundError:
            print("[错误] sshpass 或 scp 命令不存在，请先安装：sudo apt-get install sshpass openssh-client")
            return {}
        except subprocess.CalledProcessError as e:
            print(f"[错误] scp 远程文件失败：{e}")
            return {}

        # 后面统一用本地缓存的路径来解析
        summary_path = local_path

    # ========== 从（本地）summary_path 解析指标 ==========
    if not os.path.isfile(summary_path):
        print(f"[警告] 未找到文件: {summary_path}")
        return {}

    iter_to_value = {}
    current_iter = None

    with open(summary_path, "r") as f:
        for line in f:
            line_stripped = line.strip()

            # 匹配：Checkpoint: ckpt_niter_5000.pth (iter=5000)
            if line_stripped.startswith("Checkpoint:"):
                m_iter = re.search(r"iter\s*=\s*(\d+)", line_stripped)
                if m_iter:
                    current_iter = int(m_iter.group(1))
                else:
                    current_iter = None
                continue

            if current_iter is None:
                continue

            # 匹配指标行，比如：
            #   SFID: 0.624 (0.165)
            #   FID: 0.624
            if line_stripped.startswith(metric_name + ":"):
                m_val = re.search(rf"{metric_name}:\s*([0-9eE+\-\.]+)", line_stripped)
                if m_val:
                    val = float(m_val.group(1))
                    iter_to_value[current_iter] = val
                continue

    return iter_to_value


In [None]:
METRIC_NAME = "SFID"
title = "linux-1, batch_size = 128"
e_map = {
    "b128 GPUx1": ("1_0", 1),
    "b128 GPUx2 (todo)": ("1_01", 1),
    "b128 GPUx3": ("1_123", 1),
    "b128 GPUx4": ("1_0123", 1),
    # "b128 GPUx1 new code": ("cy_1_1", 1),
    # "b128 GPUx4 new code": ("cy_1_0123", 1),
}
draw(e_map, title)

In [None]:
METRIC_NAME = "SFID"
title = "linux-3,4, batch_size = 64"
e_map = {
    "b64 GPUx1 accx2": ("3_4_acc2", 3),
    "b64 GPUx2": ("3_01", 3),
    "b64 GPUx2 sync_bn": ("3_23_sync", 3),
    "b64 GPUx4": ("3_0123", 3),
    "b64 GPUx4 sync_bn (todo)": ("3_0123_sync", 3),

    # "b64 GPUx5": ("4_01234", 4),
    "b64 GPUx5 (todo)": ("4_01234", 4),
    "b64 GPUx5 sync_bn": ("4_01234_sync", 4),

    # "b64 GPUx1 new code": ("cy_3_4", 3),
    # "b64 GPUx1 accx2 new code": ("cy_3_2_acc2", 3),
    # "b64 GPUx1 accx8 new code": ("cy_3_3_acc8", 3),
    # "b64 GPUx2 new code": ("cy_3_01", 3),
    # "b64 GPUx2 sync_bn new code": ("cy_3_01_sync", 3),
}
draw(e_map, title)