In [5]:
import json
import re
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import plotly.io as pio
import os

# === 读取json ===
with open("set_run_summary.json", "r") as f:
    data = json.load(f)

# === 正则解析实验名中的参数 ===
def parse_exp_name(name):
    m_match = re.search(r"_m\((.*?)\)", name)
    a_match = re.search(r"_a\((.*?)\)", name)
    sz_match = re.search(r"_sz\((.*?)\)", name)
    return {
        "m": m_match.group(1) if m_match else None,
        "alpha_init": float(a_match.group(1)) if a_match else None,
        "size": sz_match.group(1) if sz_match else None,
    }


y_key="forward_policy_entropy_eval"
y_name="Avg Forward Policy Entropy"
y_key="spearman_corr_test"
y_name="Avg Spearman Correlation"
if not os.path.exists(y_key):
    os.makedirs(y_key)

# === 聚合数据 ===
for method,fl in [("tb_gfn", False), ("db_gfn", False), ("db_gfn", True)]:
    groups = defaultdict(list)

    for exp_name, record in data.items():
        if record["method"] == method and record["size"] == 'large' and record["fl"] == fl:
            alpha = record["alpha_init"]
            if 0.4 <= alpha <= 0.9:
                steps = record["step"]
                values = record[y_key]
                groups[alpha].append(values)

    # === 计算均值 ===
    mean_results = {}
    for alpha, runs in groups.items():
        runs = np.array(runs)  # shape: (n_runs, n_steps)
        mean_results[alpha] = runs.mean(axis=0)

    # steps = record["step"]  # 假设所有实验的 step 相同
    # 1) 计算均值之后 & 使用前：把 steps 与曲线都截到 9k
    steps = np.asarray(record["step"])
    LIM = 9000
    mask = steps <= LIM

    for a in list(mean_results.keys()):
        mean_results[a] = np.asarray(mean_results[a])[mask]

    steps = steps[mask]  # 以后都用截断后的 steps

    # === Plotly 绘制折线图 ===
    fig = go.Figure()

    for alpha in sorted(mean_results.keys()):
        fig.add_trace(go.Scatter(
            x=steps,
            y=mean_results[alpha],
            mode="lines",
            name=f"α={alpha}"
        ))

    # —— 生成曲线略 —— #
    # 标题：用 Unicode 而不是 MathJax，保证能控制字号
    def method_title(method, fl):
        if method == "tb_gfn":
            return "TB"
        elif method == "db_gfn" and not fl:
            return "DB"
        else:
            return "FL-DB"

    title_text = method_title(method, fl)

    # 紧凑但不挤压的范围
    y_min = min(float(np.min(v)) for v in mean_results.values())
    y_max = max(float(np.max(v)) for v in mean_results.values())
    pad = 0.02 * (y_max - y_min if y_max > y_min else 1.0)
    y_range = [y_min - pad, y_max + pad]
    x_min, x_max = (min(steps), max(steps))

    fig.update_layout(
        width=800, height=600,
        # title=None,
        title=dict( text=title_text, font=dict(size=42, color="black"), x=0.5, xanchor="center", y=0.99, yanchor="top"), # 更靠近绘图区 pad=dict(t=0, b=0, l=0, r=0) # 去掉额外留白（关键
        margin=dict(l=80, r=22, t=42, b=70),
        legend=dict(
            x=0.99, y=0.01, xanchor="right", yanchor="bottom",
            orientation="v",
            font=dict(size=38),              # 图例字体
            bgcolor="rgba(255,255,255,0.6)"  # 可读性更好；想更紧凑可去掉
        ),
        template="plotly_white",
        paper_bgcolor="white",
        plot_bgcolor="white",
    )

    # X 轴：强制每 1000 一个刻度，并显示 1k/2k 格式
    fig.update_xaxes(
        range=[x_min, x_max],
        title=dict(text="Step", font=dict(size=40, color="black"), standoff=12),
        tickfont=dict(size=32, color="black"),
        ticks="outside",
        tickmode="linear",
        tick0=0,
        dtick=1000,
        tickformat="~s",
        automargin=True,
    )

    # Y 轴
    fig.update_yaxes(
        range=y_range,
        title=dict(text=y_name, font=dict(size=36, color="black"), standoff=12),
        tickfont=dict(size=32, color="black"),
        ticks="outside",
        automargin=True,
    )

    fig.update_layout(
        xaxis=dict(
            showline=True,
            linecolor="black",
            linewidth=1,
            mirror=True
        ),
        yaxis=dict(
            showline=True,
            linecolor="black",
            linewidth=1,
            mirror=True
        )
    )

    pio.write_image(fig, f'{y_key}/{method}-fl_{int(fl)}.pdf',
                    format='pdf', width=800, height=600, scale=2)

In [None]:
import json
import re
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import plotly.io as pio

# === 读取json ===
with open("set_run_summary.json", "r") as f:
    data = json.load(f)

# === 聚合数据 ===
groups = defaultdict(list)

for exp_name, record in data.items():
    if record["method"] == "tb_gfn" and record["size"] == "large" and record["fl"] == False:
        alpha = record["alpha_init"]
        if 0.4 <= alpha <= 0.9:
            steps = record["step"]
            values = record["forward_policy_entropy_eval"]
            groups[alpha].append(values)

# === 计算均值 ===
mean_results = {}
for alpha, runs in groups.items():
    runs = np.array(runs)  # shape: (n_runs, n_steps)
    mean_results[alpha] = runs.mean(axis=0)

steps = record["step"]  # 假设所有实验的 step 相同

# === Plotly 绘制折线图 ===
fig = go.Figure()

for alpha in sorted(mean_results.keys()):
    fig.add_trace(go.Scatter(
        x=steps,
        y=mean_results[alpha],
        mode="lines",
        name=fr"$\alpha={alpha}$"
    ))

fig.update_layout(
    width=800,   # 控制画布宽度
    height=500,  # 控制画布高度
    title=dict(
        text=r'$\alpha-\text{DB}$',
        font=dict(size=40, color="black"),   # 标题大且黑
        # 居中
        x=0.5,
        xanchor='center'
    ),
    xaxis=dict(
        title="Steps",
        title_font=dict(size=20, color="black"),  # x轴标题大且黑
        tickfont=dict(size=16, color="black")     # x轴刻度字体
    ),
    yaxis=dict(
        title="Average Forward Policy Entropy",
        title_font=dict(size=20, color="black"),  # y轴标题大且黑
        tickfont=dict(size=16, color="black")     # y轴刻度字体
    ),
    xaxis_title="Steps",
    yaxis_title="Average Forward Policy Entropy",
    legend=dict(
        # title="Alpha Init",
        x=0.79,   # 横向位置 (0=左, 1=右)
        y=0.01,   # 纵向位置 (0=底, 1=顶)
        xanchor="left",
        yanchor="bottom"
    ),
    template="plotly_white"
)
fig.update_xaxes(range=[0, 9000])
fig.write_html("test.html", include_mathjax="cdn")  # 或者 "require"
fig.show()
pio.write_image(fig, 'test4.pdf', width=800, height=500)

In [16]:
import json
import re
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import defaultdict

# === 读取json ===
with open("set_run_summary.json", "r") as f:
    data = json.load(f)

# === 正则解析实验名中的参数 ===
def parse_exp_name(name):
    m_match = re.search(r"_m\((.*?)\)", name)
    a_match = re.search(r"_a\((.*?)\)", name)
    sz_match = re.search(r"_sz\((.*?)\)", name)
    fl_match = re.search(r"fl", name)
    return {
        "m": m_match.group(1) if m_match else None,
        "alpha_init": float(a_match.group(1)) if a_match else None,
        "size": sz_match.group(1) if sz_match else None,
        "fl": True if fl_match else False
    }

# === 三个子图设置 ===
subplots_config = [
    {"objective": "db_gfn", "fl": False, "title": r"$\text{DB}$"},
    {"objective": "db_gfn", "fl": True,  "title": r"$\text{FL-DB}$"},
    {"objective": "tb_gfn", "fl": False, "title": r"$\text{TB}$"},
]

fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=[cfg["title"] for cfg in subplots_config]
)

for col_idx, cfg in enumerate(subplots_config, start=1):
    # === 聚合数据 ===
    groups = defaultdict(list)
    for exp_name, record in data.items():
        if record["method"] == cfg["objective"] and record["fl"] == cfg["fl"] and record["size"] == "large":
            alpha = record["alpha_init"]
            if 0.4 <= alpha <= 0.9:
                steps = record["step"]
                values = record["spearman_corr_test"]
                groups[alpha].append(values)
    
    # === 计算均值 ===
    mean_results = {}
    for alpha, runs in groups.items():
        runs = np.array(runs)  # shape: (n_runs, n_steps)
        mean_results[alpha] = runs.mean(axis=0)
    
    # === 添加到子图 ===
    for alpha in sorted(mean_results.keys()):
        fig.add_trace(
            go.Scatter(
                x=steps,
                y=mean_results[alpha],
                mode="lines",
                name=fr"$\alpha={alpha}$",
                showlegend=(col_idx==1)  # 只在第一列显示图例
            ),
            row=1, col=col_idx
        )

# === 全局布局 ===
fig.update_layout(
    width=1800,   # 总画布宽度
    height=500,   # 高度保持不变
    title=dict(
        text=r'$\alpha-\text{DB}$',
        font=dict(size=40, color="black"),
        x=0.5,
        xanchor='center'
    ),
    xaxis=dict(
        title="Steps",
        title_font=dict(size=20, color="black"),
        tickfont=dict(size=16, color="black"),
        # side="bottom"
    ),
    yaxis=dict(
        title="Average Forward Policy Entropy",
        title_font=dict(size=20, color="black"),
        tickfont=dict(size=16, color="black")
    ),
    legend=dict(
        x=0.66, y=0.01,
        xanchor="left", yanchor="bottom"
    ),
    template="plotly_white",
)

# 设置所有子图的 x 轴范围一致
fig.update_xaxes(range=[0, 9000], row=1, col=1)
fig.update_xaxes(range=[0, 9000], row=1, col=2)
fig.update_xaxes(range=[0, 9000], row=1, col=3)

fig.write_html("test_three_subplots.html", include_mathjax="cdn")
fig.show()
pio.write_image(fig, 'test4.pdf', width=800, height=300)

In [2]:
import json
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import plotly.io as pio

# === 读取 json ===
with open("mols_run_summary.json", "r") as f:
    data = json.load(f)

# === 聚合数据：对相同 alpha 的实验取平均 ===
groups_length = defaultdict(list)
groups_entropy = defaultdict(list)

for record in data.values():
    if record["objective"] == "subtb" and record["fl"] is True:
        # 规整到一位小数，避免浮点误差
        alpha = round(float(record["alpha_init"]), 1)
        if 0.1 <= alpha <= 0.9:
            avg_length = np.mean(record["all_samples_avg_length_eval"])
            avg_entropy = np.mean(record["forward_policy_entropy_eval"])
            groups_length[alpha].append(avg_length)
            groups_entropy[alpha].append(avg_entropy)

# 计算每个 alpha 的均值
alpha_list = sorted(groups_length.keys())
mean_length_list = [np.mean(groups_length[a]) for a in alpha_list]
mean_entropy_list = [np.mean(groups_entropy[a]) for a in alpha_list]

# === x 轴刻度：强制显示 0.1-0.9 的 9 个数字 ===
x_ticks = [round(0.1 * i, 1) for i in range(1, 10)]
x_ticktext = [f"{v:.1f}" for v in x_ticks]

# === 绘图 ===
fig = go.Figure()

# 左轴：平均长度
fig.add_trace(go.Scatter(
    x=alpha_list,
    y=mean_length_list,
    mode="lines+markers",
    name="Avg Sample Length",
    line=dict(color="blue"),
    yaxis="y",
    cliponaxis=False   # 防止边缘点被裁剪
))

# 右轴：前向策略熵
fig.add_trace(go.Scatter(
    x=alpha_list,
    y=mean_entropy_list,
    mode="lines+markers",
    name="Forward Policy Entropy",
    line=dict(color="red"),
    yaxis="y2",
    cliponaxis=False
))

# === 布局（压缩上下留白）===
fig.update_layout(
    width=450,
    height=500,
    showlegend=False,
    template="plotly_white",
    margin=dict(t=0, b=0, l=28, r=28),  # 上下 0；左右可按需再调小
    yaxis=dict(
        title="Average Sample Length",
        title_font=dict(size=28, color="blue"),
        tickfont=dict(size=16, color="blue"),
        side="left",
        automargin=True
    ),
    yaxis2=dict(
        title="Average Forward Policy Entropy",
        title_font=dict(size=28, color="red"),
        tickfont=dict(size=16, color="red"),
        overlaying="y",
        side="right",
        automargin=True
    )
)

# === x 轴：强制刻度 + 轻微放宽范围，避免两端点被吞 ===
pad = 0.02  # 如仍被裁，可调到 0.03~0.05
xmin = x_ticks[0] - pad   # 0.08
xmax = x_ticks[-1] + pad  # 0.92
fig.update_xaxes(
    title="α",
    title_font=dict(size=20, color="black"),
    tickfont=dict(size=16, color="black"),
    tickmode="array",
    tickvals=x_ticks,
    ticktext=x_ticktext,
    range=[xmin, xmax],
    automargin=True
)

# === 输出 ===
# fig.show()
pio.write_image(fig, "alpha_dual_axis_avg.pdf", width=500, height=500, scale=2)

In [9]:
import json
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import plotly.io as pio

# === 读取 json ===
with open("mols_run_summary.json", "r") as f:
    data = json.load(f)

# === 聚合数据：对相同 alpha 的实验取平均 ===
groups_length = defaultdict(list)
groups_entropy = defaultdict(list)

for record in data.values():
    if record["objective"] == "subtb" and record["fl"] is True:
        # 规整到一位小数，避免浮点误差
        alpha = round(float(record["alpha_init"]), 1)
        if 0.1 <= alpha <= 0.9:
            avg_length = np.mean(record["all_samples_avg_length_eval"])
            avg_entropy = np.mean(record["forward_policy_entropy_eval"])
            groups_length[alpha].append(avg_length)
            groups_entropy[alpha].append(avg_entropy)

# 计算每个 alpha 的均值
alpha_list = sorted(groups_length.keys())
mean_length_list = [np.mean(groups_length[a]) for a in alpha_list]
mean_entropy_list = [np.mean(groups_entropy[a]) for a in alpha_list]

# === x 轴刻度：强制显示 0.1-0.9 的 9 个数字 ===
x_ticks = [round(0.1 * i, 1) for i in range(1, 10)]
x_ticktext = [f"{v:.1f}" for v in x_ticks]

# === 绘图 ===
fig = go.Figure()

# 左轴：平均长度
fig.add_trace(go.Scatter(
    x=alpha_list,
    y=mean_length_list,
    mode="lines+markers",
    name="Avg Sample Length",                        # 保留原始 title
    line=dict(color="rgb(0, 76, 153)", width=3),
    marker=dict(symbol="circle", size=8),
    yaxis="y",
    cliponaxis=False
))

# 右轴：前向策略熵
fig.add_trace(go.Scatter(
    x=alpha_list,
    y=mean_entropy_list,
    mode="lines+markers",
    name="Avg Forward Policy Entropy",                   # 保留原始 title
    line=dict(color="rgb(204, 51, 0)", width=3),
    marker=dict(symbol="square", size=8),
    yaxis="y2",
    cliponaxis=False
))

# === 布局优化 ===
fig.update_layout(
    width=500,
    height=500,
    template="simple_white",
    margin=dict(t=20, b=20, l=60, r=60),
    legend=dict(                  # 修改：legend 放左上角
        x=0.02, y=0.98,           # 位置：靠左上
        xanchor="left", yanchor="top",
        orientation="v",          # 垂直排布
        font=dict(size=18)
    ),
    yaxis=dict(
        title="Average Sample Length",
        title_font=dict(size=22, color="rgb(0, 76, 153)"),
        tickfont=dict(size=16, color="rgb(0, 76, 153)"),
        side="left",
        showgrid=True, gridcolor="lightgray",  # 保留横线
        zeroline=False,
        automargin=True
    ),
    yaxis2=dict(
        title="Average Forward Policy Entropy",
        title_font=dict(size=22, color="rgb(204, 51, 0)"),
        tickfont=dict(size=16, color="rgb(204, 51, 0)"),
        overlaying="y",
        side="right",
        showgrid=False,            # 右轴不重复画网格
        zeroline=False,
        automargin=True
    ),
    xaxis=dict(
        title="α",
        title_font=dict(size=22, color="black"),
        tickfont=dict(size=16, color="black"),
        tickmode="array",
        tickvals=x_ticks,
        ticktext=x_ticktext,
        range=[x_ticks[0] - 0.02, x_ticks[-1] + 0.02],
        showgrid=True, gridcolor="lightgray",  # 修改：保留竖线
        zeroline=False,
        automargin=True
    )
)

# === 输出 ===
# fig.show()
pio.write_image(fig, "alpha_dual_axis_avg.pdf", width=500, height=500, scale=3)
