In [1]:
import json
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import plotly.io as pio

# === 读取 json ===
with open("mols_run_summary.json", "r") as f:
    data = json.load(f)

# === 聚合数据：对相同 alpha 的实验取平均 ===
groups_length = defaultdict(list)
groups_entropy = defaultdict(list)

for record in data.values():
    if record["objective"] == "subtb" and record["fl"] is True:
        # 规整到一位小数，避免浮点误差
        alpha = round(float(record["alpha_init"]), 1)
        if 0.1 <= alpha <= 0.9:
            avg_length = np.mean(record["all_samples_avg_length_eval"])
            avg_entropy = np.mean(record["forward_policy_entropy_eval"])
            groups_length[alpha].append(avg_length)
            groups_entropy[alpha].append(avg_entropy)

# 计算每个 alpha 的均值
alpha_list = sorted(groups_length.keys())
mean_length_list = [np.mean(groups_length[a]) for a in alpha_list]
mean_entropy_list = [np.mean(groups_entropy[a]) for a in alpha_list]

# === x 轴刻度：强制显示 0.1-0.9 的 9 个数字 ===
x_ticks = [round(0.1 * i, 1) for i in range(1, 10)]
x_ticktext = [f"{v:.1f}" for v in x_ticks]

# === 绘图 ===
fig = go.Figure()

# 左轴：平均长度
fig.add_trace(go.Scatter(
    x=alpha_list,
    y=mean_length_list,
    mode="lines+markers",
    name="Avg Sample Length",
    line=dict(color="blue"),
    yaxis="y",
    cliponaxis=False   # 防止边缘点被裁剪
))

# 右轴：前向策略熵
fig.add_trace(go.Scatter(
    x=alpha_list,
    y=mean_entropy_list,
    mode="lines+markers",
    name="Forward Policy Entropy",
    line=dict(color="red"),
    yaxis="y2",
    cliponaxis=False
))

# === 布局（压缩上下留白）===
fig.update_layout(
    width=450,
    height=500,
    showlegend=False,
    template="plotly_white",
    margin=dict(t=0, b=0, l=28, r=28),  # 上下 0；左右可按需再调小
    yaxis=dict(
        title="Average Sample Length",
        title_font=dict(size=28, color="blue"),
        tickfont=dict(size=16, color="blue"),
        side="left",
        automargin=True
    ),
    yaxis2=dict(
        title="Average Forward Policy Entropy",
        title_font=dict(size=28, color="red"),
        tickfont=dict(size=16, color="red"),
        overlaying="y",
        side="right",
        automargin=True
    )
)

# === x 轴：强制刻度 + 轻微放宽范围，避免两端点被吞 ===
pad = 0.02  # 如仍被裁，可调到 0.03~0.05
xmin = x_ticks[0] - pad   # 0.08
xmax = x_ticks[-1] + pad  # 0.92
fig.update_xaxes(
    title="α",
    title_font=dict(size=20, color="black"),
    tickfont=dict(size=16, color="black"),
    tickmode="array",
    tickvals=x_ticks,
    ticktext=x_ticktext,
    range=[xmin, xmax],
    automargin=True
)

# === 输出 ===
# fig.show()
pio.write_image(fig, "alpha_dual_axis_avg.pdf", width=500, height=500, scale=2)