In [22]:
import json
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import plotly.io as pio
import plotly.express as px


def calculate_objective(data:dict):
    for _, record in data.items():
        if "fl" in record.keys() and record["fl"]:
            if obj := record.get("objective", None):
                if 'db' in obj:
                    record["Objective"] = "FL-DB"
                elif 'subtb' in obj:
                    record["Objective"] = "FL-SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "FL-TB"
            elif obj := record.get("method", None):
                if 'db' in obj:
                    record["Objective"] = "FL-DB"
                elif 'subtb' in obj:
                    record["Objective"] = "FL-SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "FL-TB"
        else:
            if obj := record.get("objective", None):
                if 'db' in obj:
                    record["Objective"] = "DB"
                elif 'subtb' in obj:
                    record["Objective"] = "SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "TB"
            elif obj := record.get("method", None):
                if 'db' in obj:
                    record["Objective"] = "DB"
                elif 'subtb' in obj:
                    record["Objective"] = "SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "TB"
    return data

# === 读取 json ===
with open("mols_run_summary.json", "r") as f:
    data = calculate_objective(json.load(f))

# === 聚合数据：objective → alpha → avg_length ===
groups = defaultdict(lambda: defaultdict(list))

for record in data.values():
    obj = record["Objective"]
    alpha = round(float(record["alpha_init"]), 1)
    if 0.1 <= alpha <= 0.9:
        avg_length = record["all_samples_avg_length_eval"][-1]
        print(f"Objective: {obj}, alpha: {alpha}, Avg Length: {avg_length}")
        groups[obj][alpha].append(avg_length)

# === x 轴刻度 ===
x_ticks = [round(0.1 * i, 1) for i in range(1, 10)]
x_ticktext = [f"{v:.1f}" for v in x_ticks]

# === 针对每个 objective 画图 ===
for obj, alpha_dict in groups.items():
    alpha_list = sorted(alpha_dict.keys())
    mean_length_list = [np.mean(alpha_dict[a]) for a in alpha_list]
    std_length_list = [np.std(alpha_dict[a]) for a in alpha_list]

    # === y 范围 = (min(mean - std), max(mean + std))，再扩展 10% margin ===
    y_min = min(m - s for m, s in zip(mean_length_list, std_length_list))
    y_max = max(m + s for m, s in zip(mean_length_list, std_length_list))
    y_range = y_max - y_min
    ymin = y_min - 0.1 * y_range
    ymax = y_max + 0.1 * y_range

    # === x 范围：在 0.1 和 0.9 两端留余量 ===
    pad = 0.05
    xmin = x_ticks[0] - pad
    xmax = x_ticks[-1] + pad

    # === 不同 alpha 染不同色 ===
    colorscale = px.colors.qualitative.Set1
    color_map = {a: colorscale[i % len(colorscale)] for i, a in enumerate(alpha_list)}

    fig = go.Figure()

    for a, mean, std in zip(alpha_list, mean_length_list, std_length_list):
        fig.add_trace(go.Bar(
            x=[a],
            y=[mean],
            error_y=dict(type="data", array=[std], visible=True),
            marker=dict(color=color_map[a]),
            name=f"α={a:.1f}"
        ))

    # if obj == 'tb':
    #     title = "<b>TB</b>"
    # elif obj == 'subtb':
    #     title = "<b>SubTB(λ)</b>"
    # elif obj == 'db':
    #     title = "<b>DB</b>"
    title = f"<b>{obj}</b>"

    # === 布局 ===
    fig.update_layout(
        width=600,
        height=500,
        template="plotly_white",
        margin=dict(t=32, b=30, l=28, r=28),
        showlegend=False,
        title=dict(
            # text=f"Objective: {obj}",
            text=title,
            font=dict(size=40, color="black"),
            x=0.5, xanchor="center",
            y=0.99, yanchor="top"
        ),
        yaxis=dict(
            title="Average Sample Length",
            title_font=dict(size=35, color="black"),
            tickfont=dict(size=27, color="black"),
            range=[ymin, ymax],
            automargin=True
        ),
        xaxis=dict(
            title="α",
            title_font=dict(size=40, color="black"),
            tickfont=dict(size=27, color="black"),
            tickmode="array",
            tickvals=x_ticks,
            ticktext=x_ticktext,
            range=[xmin, xmax],
            automargin=True
        )
    )

    # === 输出为 PDF ===
    outname = f"figs/Length/length_vs_alpha_{obj}.pdf"
    pio.write_image(fig, outname, width=600, height=500, scale=2)
    print(f"Saved {outname}")
    fig.show()


Objective: SubTB(λ), alpha: 0.7, Avg Length: 7.978755
Objective: DB, alpha: 0.8, Avg Length: 7.981265
Objective: SubTB(λ), alpha: 0.7, Avg Length: 7.98716
Objective: SubTB(λ), alpha: 0.1, Avg Length: 7.987755
Objective: FL-SubTB(λ), alpha: 0.2, Avg Length: 2.263975
Objective: FL-DB, alpha: 0.1, Avg Length: 3.025695
Objective: FL-SubTB(λ), alpha: 0.2, Avg Length: 2.7601
Objective: DB, alpha: 0.4, Avg Length: 7.98673
Objective: FL-DB, alpha: 0.1, Avg Length: 3.022645
Objective: SubTB(λ), alpha: 0.9, Avg Length: 7.983835
Objective: SubTB(λ), alpha: 0.1, Avg Length: 7.986785
Objective: DB, alpha: 0.2, Avg Length: 7.98744
Objective: DB, alpha: 0.7, Avg Length: 7.984615
Objective: SubTB(λ), alpha: 0.3, Avg Length: 7.98741
Objective: DB, alpha: 0.1, Avg Length: 7.98743
Objective: DB, alpha: 0.8, Avg Length: 7.980865
Objective: SubTB(λ), alpha: 0.8, Avg Length: 7.986275
Objective: FL-DB, alpha: 0.6, Avg Length: 3.30086
Objective: FL-DB, alpha: 0.4, Avg Length: 3.090235
Objective: SubTB(λ), alp

Saved figs/Length/length_vs_alpha_DB.pdf


Saved figs/Length/length_vs_alpha_FL-SubTB(λ).pdf


Saved figs/Length/length_vs_alpha_FL-DB.pdf


Saved figs/Length/length_vs_alpha_TB.pdf


In [4]:
data['m(subtb)_a(0.7)_s(0)_v(False)']

{'step': [0,
  500,
  1000,
  1500,
  2000,
  2500,
  3000,
  3500,
  4000,
  4500,
  5000,
  5500,
  6000,
  6500,
  7000,
  7500,
  8000,
  8500,
  9000,
  9500,
  10000,
  10500,
  11000,
  11500,
  12000,
  12500,
  13000,
  13500,
  14000,
  14500,
  15000,
  15500,
  16000,
  16500,
  17000,
  17500,
  18000,
  18500,
  19000,
  19500,
  20000,
  20500,
  21000,
  21500,
  22000,
  22500,
  23000,
  23500,
  24000,
  24500,
  25000,
  25500,
  26000,
  26500,
  27000,
  27500,
  28000,
  28500,
  29000,
  29500,
  30000,
  30500,
  31000,
  31500,
  32000,
  32500,
  33000,
  33500,
  34000,
  34500,
  35000,
  35500,
  36000,
  36500,
  37000,
  37500,
  38000,
  38500,
  39000,
  39500,
  40000,
  40500,
  41000,
  41500,
  42000,
  42500,
  43000,
  43500,
  44000,
  44500,
  45000,
  45500,
  46000,
  46500,
  47000,
  47500,
  48000,
  48500,
  49000,
  49500,
  49999],
 'num_modes_eval': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1