In [31]:
import json
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import defaultdict
import plotly.io as pio

# === 读取json ===
with open("mols_run_summary.json", "r") as f:
    data = json.load(f)

# === 三个子图配置 ===
subplots_config = [
    {"metric": "spearman_corr_test", "title": r"$\text{Spearman Corr}$"},
    {"metric": "num_modes_eval", "title": r"$\text{Num Modes}$"},
    {"metric": "current_loss", "title": r"$\text{Loss}$"},
]

# 只绘制 alpha=0.5, 0.9
target_alphas = [0.5, 0.9]

fig = make_subplots(
    rows=1, cols=3,
    # subplot_titles=[cfg["title"] for cfg in subplots_config]
)

for col_idx, cfg in enumerate(subplots_config, start=1):
    groups = defaultdict(list)

    # === 聚合数据 ===
    for exp_name, record in data.items():
        if record["objective"] == 'db' and record["fl"] == False:
            alpha = record.get("alpha_init", None)
            if alpha in target_alphas:
                steps = record["step"]
                values = record[cfg["metric"]]
                # 数据间隔4步保留一个
                steps = steps[::4]
                values = values[::4]
                groups[alpha].append(values)

    # === 计算均值和标准差 ===
    for alpha in sorted(groups.keys()):
        runs = np.array(groups[alpha])  # shape: (n_runs, n_steps)
        print(f"Alpha: {alpha}, Metric: {cfg['metric']}, Runs: {runs.shape[0]}, Steps: {runs.shape[1]}")
        mean_curve = runs.mean(axis=0)
        std_curve = runs.std(axis=0)

        # 阴影区域 (mean ± std)
        fig.add_trace(
            go.Scatter(
                x=steps + steps[::-1],
                y=(mean_curve + std_curve).tolist() + (mean_curve - std_curve)[::-1].tolist(),
                fill="toself",
                fillcolor="rgba(0,100,80,0.2)" if alpha == 0.5 else "rgba(200,30,30,0.2)",
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False
            ),
            row=1, col=col_idx
        )

        # 均值曲线
        fig.add_trace(
            go.Scatter(
                x=steps,
                y=mean_curve,
                mode="lines",
                name=fr"$\alpha={alpha}$" if col_idx == 1 else None,
                line=dict(width=2, color="green" if alpha == 0.5 else "red"),
                showlegend=(col_idx == 1)  # 只在第一列显示图例
            ),
            row=1, col=col_idx
        )

# === 全局布局 ===
fig.update_layout(
    width=1800,   # 总画布宽度
    height=500,   # 高度保持不变
    title=dict(
        text=r'$\alpha=0.5,0.9$ 对比',
        font=dict(size=40, color="black"),
        x=0.5,
        xanchor='center'
    ),
    # xaxis=dict(
    #     title="Steps",
    #     title_font=dict(size=20, color="black"),
    #     tickfont=dict(size=16, color="black"),
    # ),
    # yaxis=dict(
    #     title="Metric Value",
    #     title_font=dict(size=20, color="black"),
    #     tickfont=dict(size=16, color="black")
    # ),
    legend=dict(
        x=0.01, y=0.99,
        xanchor="left", yanchor="top"
    ),
    template="plotly_white",
)

# # 设置所有子图的 x 轴范围一致
# for i in range(1, 4):
#     fig.update_xaxes(range=[0, 9000], row=1, col=i)

fig.write_html("mols_three_subplots.html", include_mathjax="cdn")
fig.show()
pio.write_image(fig, 'mols_three_subplots.pdf', width=1800, height=500)


Alpha: 0.5, Metric: spearman_corr_test, Runs: 5, Steps: 26
Alpha: 0.9, Metric: spearman_corr_test, Runs: 5, Steps: 26
Alpha: 0.5, Metric: num_modes_eval, Runs: 5, Steps: 26
Alpha: 0.9, Metric: num_modes_eval, Runs: 5, Steps: 26
Alpha: 0.5, Metric: current_loss, Runs: 5, Steps: 26
Alpha: 0.9, Metric: current_loss, Runs: 5, Steps: 26


In [5]:
import json
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from collections import defaultdict

# === 读取json ===
with open("mols_run_summary.json", "r") as f:
    data = json.load(f)

# === 三个指标配置 ===
metrics_config = [
    {"metric": "spearman_corr_test", "title": r"$\text{Spearman Corr}$", "ylabel": "Spearman Correlation"},
    {"metric": "num_modes_eval", "title": r"$\text{Num Modes}$", "ylabel": "Number of Modes"},
    {"metric": "current_loss", "title": r"$\text{Loss}$", "ylabel": "Loss"},
]

# 只绘制 alpha=0.5, 0.9
target_alphas = [0.5, 0.9]

for cfg in metrics_config:
    metric = cfg["metric"]

    groups = defaultdict(list)

    # === 聚合数据 ===
    for exp_name, record in data.items():
        if record["objective"] == 'subtb' and record["fl"] == True:
            alpha = record.get("alpha_init", None)
            if alpha in target_alphas:
                steps = record["step"]
                values = record[metric]
                steps = steps[::4]
                values = values[::4]
                groups[alpha].append(values)

    # === 创建图 ===
    fig = go.Figure()

    # === 绘制均值和方差 ===
    for alpha in sorted(groups.keys()):
        runs = np.array(groups[alpha])  # shape: (n_runs, n_steps)
        mean_curve = runs.mean(axis=0)
        std_curve = runs.std(axis=0)

        # 阴影区域 (mean ± std)
        fig.add_trace(
            go.Scatter(
                x=steps + steps[::-1],
                y=(mean_curve + std_curve).tolist() + (mean_curve - std_curve)[::-1].tolist(),
                fill="toself",
                fillcolor="rgba(0,100,80,0.2)" if alpha == 0.5 else "rgba(200,30,30,0.2)",
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False
            )
        )

        # 均值曲线
        fig.add_trace(
            go.Scatter(
                x=steps,
                y=mean_curve,
                mode="lines",
                name=("Baseline(α=0.5)" if alpha == 0.5 else "Ours(α=0.9)"),
                line=dict(width=2, color="green" if alpha == 0.5 else "red")
                
            )
        )

    # === 添加 40k 虚线 ===
    fig.add_vline(
        x=40000,
        line=dict(color="black", dash="dash", width=2),
        # annotation_text=None,
        annotation=dict(text="", showarrow=False)  # 强制清空文字
    )

    # 在虚线附近加文字
    fig.add_annotation(
        x=20000, yref="paper", y=0,  # 左边
        text="Stage 1",
        showarrow=False,
        font=dict(size=28, color="black")
    )
    fig.add_annotation(
        x=45000, yref="paper", y=0,  # 右边
        text="Stage 2",
        showarrow=False,
        font=dict(size=28, color="black")
    )

    legend_dict = dict(
            x=0.05, y=0.99,
            # x=0.3, y=0.3,
            xanchor="left", yanchor="top",
            font=dict(size=32, color="black")   # 调大 legend 字体
        )
    if cfg["metric"] == "spearman_corr_test":
        legend_dict = dict(
            x=0.3, y=0.3,
            xanchor="left", yanchor="top",
            font=dict(size=28, color="black")   # 调大 legend 字体
        )
    # === 布局设置 ===
    fig.update_layout(
        width=800,
        height=600,
        # title=dict(
        #     text=cfg["title"],
        #     font=dict(size=40, color="black"),
        #     x=0.5,
        #     xanchor='center',
        #     y=0.8,
        #     yanchor='top'
        # ),
        xaxis=dict(
            title="Steps",
            title_font=dict(size=40, color="black"),
            tickfont=dict(size=32, color="black"),
            range=[0, 50001],
            tickvals=list(range(0, 51000, 10000)),  # 另一种方式：手动指定
            # tickmode="linear",
            showline=True,
            linecolor="black",
            linewidth=2,
            mirror=True
        ),
        yaxis=dict(
            title=cfg["ylabel"],
            title_font=dict(size=40, color="black"),
            tickfont=dict(size=32, color="black"),
            # type="log",
            # range=[0.1, 3]  
            showline=True,
            linecolor="black",
            linewidth=2,
            mirror=True
        ),
        legend=legend_dict,
        template="plotly_white",
        margin=dict(t=4, b=0, l=28, r=28)
    )

    # === 保存 ===
    html_name = f"{metric}_plot.html"
    pdf_name = f"{metric}_plot.pdf"
    fig.write_html(html_name, include_mathjax="cdn")
    pio.write_image(fig, pdf_name, width=800, height=600, scale=2)
    fig.show()
