In [48]:
import json
import os
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
from collections import defaultdict

def build_image(
        steps,
        groups,
        ours_alpha,
        title,
        ylabel,
        metric,
        file_name,
        x_range=[0, 50000],
        ticks=[0, 10000, 20000, 30000, 40000],
        tick_labels=["0", "10k", "20k", "30k", "40k"],
        otick=49900,
        otick_label="50k",
        width=800,
        height=600,
        showline=True,
        save_html=False,
        show_pic=False
    ):
    '''
    steps: list of int, (n_steps,)
    groups[alpha]: (n_runs, n_steps)
    '''
    # === 创建图 ===
    fig = go.Figure()
    # === 绘制均值和方差 ===
    for alpha in [0.5, ours_alpha]:
        runs = np.array(groups[alpha])  # shape: (n_runs, n_steps)
        mean_curve = runs.mean(axis=0)
        std_curve = runs.std(axis=0)

        # 阴影区域 (mean ± std)
        fig.add_trace(
            go.Scatter(
                x=steps.tolist() + steps[::-1].tolist(),
                y=(mean_curve + std_curve).tolist() + (mean_curve - std_curve)[::-1].tolist(),
                fill="toself",
                fillcolor="rgba(0,100,80,0.2)" if alpha == 0.5 else "rgba(200,30,30,0.2)",
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False
            )
        )

        # 均值曲线
        fig.add_trace(
            go.Scatter(
                x=steps,
                y=mean_curve,
                mode="lines",
                name=("Baseline(α=0.5)" if alpha == 0.5 else f"Ours(α={ours_alpha})"),
                line=dict(width=2, color="green" if alpha == 0.5 else "red")
                
            )
        )


    line_pos = 0.8 * x_range[1]
    # === 添加 40k 虚线 ===
    fig.add_vline(
        x=line_pos,
        line=dict(color="black", dash="dash", width=2),
        annotation=dict(text="", showarrow=False)  # 强制清空文字
    )
    
    # 在 x 轴下方添加额外ticklabel
    fig.add_annotation(
        x=otick,
        y=-0.01,
        yref="paper",    # y=0 表示 x 轴下边缘
        text=otick_label,
        showarrow=False,
        font=dict(size=32, color="black"),
        xanchor="center",
        yanchor="top"
    )

    # 在虚线附近加文字
    fig.add_annotation(
        x=int(line_pos/2), yref="paper", y=0,  # 左边
        text="Stage 1",
        showarrow=False,
        font=dict(size=28, color="black")
    )
    fig.add_annotation(
        x=int((line_pos+x_range[1])/2), yref="paper", y=0,  # 右边
        text="Stage 2",
        showarrow=False,
        font=dict(size=28, color="black")
    )

    legend_dict = dict(
            x=0.01, y=0.99,
            bgcolor="rgba(255,255,255,0.5)",  # 半透明白色背景
            xanchor="left", yanchor="top",
            font=dict(size=32, color="black")   # 调大 legend 字体
        )
    # if metric == "spearman_corr_test":
    #     legend_dict = dict(
    #         x=0.3, y=0.3,
            # bgcolor="rgba(255,255,255,0.5)",  # 半透明白色背景
    #         xanchor="left", yanchor="top",
    #         font=dict(size=28, color="black")   # 调大 legend 字体
    #     )
    # === 布局设置 ===
    fig.update_layout(
        width=width,
        height=height,
        title=dict(
            text=f"<b>{title}</b>",
            font=dict(size=40, color="black"),
            x=0.5,
            xanchor='center',
            y=0.99,
            yanchor='top'
        ),
        xaxis=dict(
            title="Steps",
            title_font=dict(size=40, color="black"),
            tickfont=dict(size=32, color="black"),
            range=x_range,
            tickmode="array",
            tickvals=ticks,
            ticktext=tick_labels,
            showline=showline,
            linecolor="black",
            linewidth=2,
            mirror=True
        ),
        yaxis=dict(
            title=ylabel,
            title_font=dict(size=40, color="black"),
            tickfont=dict(size=32, color="black"),
            # type="log",
            # range=[0.1, 3]  
            showline=showline,
            linecolor="black",
            linewidth=2,
            mirror=True
        ),
        legend=legend_dict,
        template="plotly_white",
        margin=dict(t=42, b=0, l=30, r=28)
    )

    # === 保存 ===
    os.makedirs('save', exist_ok=True)
    if save_html:
        fig.write_html(os.path.join('save', f"{file_name}.html"), include_mathjax="cdn")
    pio.write_image(fig, os.path.join('save', f"{file_name}.pdf"), width=width, height=height, scale=2)
    if show_pic:
        fig.show()


In [49]:
def calculate_objective(data:dict):
    for _, record in data.items():
        if "fl" in record.keys() and record["fl"]:
            if obj := record.get("objective", None):
                if 'db' in obj:
                    record["Objective"] = "FL-DB"
                elif 'subtb' in obj:
                    record["Objective"] = "FL-SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "FL-TB"
            elif obj := record.get("method", None):
                if 'db' in obj:
                    record["Objective"] = "FL-DB"
                elif 'subtb' in obj:
                    record["Objective"] = "FL-SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "FL-TB"
        else:
            if obj := record.get("objective", None):
                if 'db' in obj:
                    record["Objective"] = "DB"
                elif 'subtb' in obj:
                    record["Objective"] = "SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "TB"
            elif obj := record.get("method", None):
                if 'db' in obj:
                    record["Objective"] = "DB"
                elif 'subtb' in obj:
                    record["Objective"] = "SubTB(λ)"
                elif 'tb' in obj:
                    record["Objective"] = "TB"
    return data

In [54]:
# mols

# === 读取json ===
with open("mols_run_summary.json", "r") as f:
    data = calculate_objective(json.load(f))

# === 三个指标配置 ===
metrics_config = [
    {"metric": "num_modes_eval", "title": r"$\text{Num Modes}$", "ylabel": "Number of Modes"},
    {"metric": "top_100_avg_reward_eval", "title": r"$\text{Top-100 Avg Reward}$", "ylabel": "Top-100 Average Reward"},
    {"metric": "top_100_avg_similarity_eval", "title": r"$\text{Top-100 Avg Similarity}$", "ylabel": "Top-100 Average Similarity"},
    {"metric": "current_loss", "title": r"$\text{Loss}$", "ylabel": "Loss"},
    {"metric": "forward_policy_entropy_eval", "title": r"$\text{Policy Entropy}$", "ylabel": "Policy Entropy"},

    {"metric": "all_samples_avg_length_eval", "title": r"$\text{Avg Length}$", "ylabel": "Average Length"},
]

metrics_config2 = [
    {"metric": "spearman_corr_test", "title": r"$\text{Spearman Corr}$", "ylabel": "Spearman Correlation"},
]

filters = {
    'step': 49999,
    'random_action_prob': 0.05,
    'use_exp_weight_decay': False
}
objectives = ["DB", "FL-DB", "SubTB(λ)", "FL-SubTB(λ)", "TB"]
best_alpha = [0.6,   0.9,     0.6,        0.9,           0.6]
objectives2 = ["DB", "SubTB(λ)", "TB"]
best_alpha2 = [0.9,  0.9,        0.6]

for cfg in metrics_config:
    metric = cfg["metric"]
    title = cfg["title"]
    ylabel = cfg["ylabel"]
    print(f"Processing metric: {metric}")
    # === 聚合数据 ===
    for balpha, obj in zip(best_alpha, objectives):
        groups = defaultdict(list)
        title = obj
        print(f"  Objective: {obj}")
        for exp_name, record in data.items():
            if record.get("Objective", None) != obj:
                continue
            alpha = record.get("alpha_init", None)
            steps = np.array(record["step"])
            values = np.array([x if x is not None else np.nan for x in record[metric]])
            # filter NaN or None in values and broadcast to steps
            mask = np.isfinite(values)
            steps = steps[mask]
            values = values[mask]
            groups[alpha].append(values)
        
        build_image(
            steps,
            groups,
            ours_alpha=balpha,
            title=title,
            ylabel=ylabel,
            metric=metric,
            file_name=f"plot_mols_{obj.replace('-','_').removesuffix('(λ)')}_alpha{balpha}_{metric}"
        )

for cfg in metrics_config2:
    metric = cfg["metric"]
    title = cfg["title"]
    ylabel = cfg["ylabel"]
    print(f"Processing metric: {metric}")
    groups = defaultdict(list)
    # === 聚合数据 ===
    for alpha, obj in zip(best_alpha2, objectives2):
        title = obj
        print(f"  Objective: {obj}")
        for exp_name, record in data.items():
            if record.get("Objective", None) != obj:
                continue
            alpha = record.get("alpha_init", None)
            steps = np.array(record["step"])
            values = np.array([x if x is not None else np.nan for x in record[metric]])
            # filter NaN or None in values and broadcast to steps
            mask = np.isfinite(values)
            steps = steps[mask]
            values = values[mask]
            groups[alpha].append(values)
        
        build_image(
            steps,
            groups,
            ours_alpha=alpha,
            title=title,
            ylabel=ylabel,
            metric=metric,
            file_name=f"plot_mols_{obj.replace('-','_').removesuffix('(λ)')}_alpha{alpha}_{metric}"
        )



Processing metric: num_modes_eval
  Objective: DB
  Objective: FL-DB
  Objective: SubTB(λ)
  Objective: FL-SubTB(λ)
  Objective: TB
Processing metric: top_100_avg_reward_eval
  Objective: DB
  Objective: FL-DB
  Objective: SubTB(λ)
  Objective: FL-SubTB(λ)
  Objective: TB
Processing metric: top_100_avg_similarity_eval
  Objective: DB
  Objective: FL-DB
  Objective: SubTB(λ)
  Objective: FL-SubTB(λ)
  Objective: TB
Processing metric: current_loss
  Objective: DB
  Objective: FL-DB
  Objective: SubTB(λ)
  Objective: FL-SubTB(λ)
  Objective: TB
Processing metric: forward_policy_entropy_eval
  Objective: DB
  Objective: FL-DB
  Objective: SubTB(λ)
  Objective: FL-SubTB(λ)
  Objective: TB
Processing metric: all_samples_avg_length_eval
  Objective: DB
  Objective: FL-DB
  Objective: SubTB(λ)
  Objective: FL-SubTB(λ)
  Objective: TB
Processing metric: spearman_corr_test
  Objective: DB
  Objective: SubTB(λ)
  Objective: TB


In [51]:
# set

# === 读取json ===
with open("set_run_summary.json", "r") as f:
    data = calculate_objective(json.load(f))

# === 指标配置 ===
metrics_config = [
    {"metric": "modes", "title": r"$\text{Num Modes}$", "ylabel": "Number of Modes"},
    {"metric": "mean_top_1000_R", "title": r"$\text{Top-1000 Avg Reward}$", "ylabel": "Top-1000 Average Reward"},
    {"metric": "spearman_corr_test", "title": r"$\text{Spearman Corr}$", "ylabel": "Spearman Correlation"},
    {"metric": "loss", "title": r"$\text{Loss}$", "ylabel": "Loss"},
    {"metric": "forward_policy_entropy_eval", "title": r"$\text{Policy Entropy}$", "ylabel": "Policy Entropy"}
]

filters = {
    'step': 9999,
    'training_mode': 'online',
    'use_alpha_scheduler': True,
    'use_grad_clip': False,
    'use_exp_weight_decay': 1
}

objectives = ["DB", "FL-DB", "TB"]

def get_best_alpha(obj, size):
    if size == 'small':
        if obj == 'DB':
            return 0.9
        elif obj == 'FL-DB':
            return 0.9
        elif obj == 'TB':
            return 0.7
    elif size == 'medium':
        if obj == 'DB':
            return 0.9
        elif obj == 'FL-DB':
            return 0.9
        elif obj == 'TB':
            return 0.8
    elif size == 'large':
        if obj == 'DB':
            return 0.9
        elif obj == 'FL-DB':
            return 0.9
        elif obj == 'TB':
            return 0.7
    raise ValueError(f"Unknown objective or size: {obj}, {size}")

for sz in ['small', 'medium', 'large']:
    for cfg in metrics_config:
        metric = cfg["metric"]
        title = cfg["title"]
        ylabel = cfg["ylabel"]
        print(f"Processing metric: {metric}")
        # === 聚合数据 ===
        for obj in objectives:
            groups = defaultdict(list)
            best_alpha = get_best_alpha(obj, sz)
            title = obj
            print(f"  Objective: {obj}")
            for exp_name, record in data.items():
                if record.get("Objective", None) != obj:
                    continue
                if record.get("size", None) != sz:
                    continue
                alpha = record.get("alpha_init", None)
                steps = np.array(record["step"])
                run_metric = record[metric] if isinstance(record[metric], list) else [record[metric]] * steps.shape[0]
                values = np.array([x if x is not None else np.nan for x in run_metric])
                # filter NaN or None in values and broadcast to steps
                mask = np.isfinite(values)
                steps = steps[mask]
                values = values[mask]
                groups[alpha].append(values)
            
            build_image(
                steps,
                groups,
                ours_alpha=best_alpha,
                title=title,
                ylabel=ylabel,
                metric=metric,
                x_range=[0, 10000],
                ticks=[0, 2000, 4000, 6000, 8000],
                tick_labels=["0", "2k", "4k", "6k", "8k"],
                otick=9900,
                otick_label="10k",
                file_name=f"plot_set_{obj.replace('-','_').removesuffix('(λ)')}_alpha{best_alpha}_{sz}_{metric}"
            )


Processing metric: modes
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: mean_top_1000_R
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: spearman_corr_test
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: loss
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: forward_policy_entropy_eval
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: modes
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: mean_top_1000_R
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: spearman_corr_test
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: loss
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: forward_policy_entropy_eval
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: modes
  Objective: DB
  Objective: FL-DB
  Objective: TB
Processing metric: mean_top_1000_R
  Objective: DB
  Objective: FL-DB
  Obje

In [52]:
# bit

# === 读取json ===
with open("bit_run_summary.json", "r") as f:
    data = calculate_objective(json.load(f))

# === 指标配置 ===
metrics_config = [
    {"metric": "modes", "title": r"$\text{Num Modes}$", "ylabel": "Number of Modes"},
    {"metric": "spearman_corr_test", "title": r"$\text{Spearman Corr}$", "ylabel": "Spearman Correlation"},
    {"metric": "loss", "title": r"$\text{Loss}$", "ylabel": "Loss"},
    {"metric": "forward_policy_entropy_eval", "title": r"$\text{Policy Entropy}$", "ylabel": "Policy Entropy"}
]

filters = {
    'step': 49999,
    'grad_clip_norm': 20,
    'use_exp_weight_decay': False
}

# objectives = ["DB", "FL-DB", "SubTB(λ)", "FL-SubTB(λ)", "TB"]
objectives = ["DB", "SubTB(λ)", "TB"]

def get_bit_best_alpha(obj, k):
    if k == 4:
        if obj == 'DB':
            return 0.9
        elif obj == 'SubTB(λ)':
            return 0.2
        elif obj == 'TB':
            return 0.8
    raise ValueError(f"Unknown objective or size: {obj}, k={k}")

for k in [4]:
    for cfg in metrics_config:
        metric = cfg["metric"]
        title = cfg["title"]
        ylabel = cfg["ylabel"]
        print(f"Processing metric: {metric}")
        # === 聚合数据 ===
        for obj in objectives:
            groups = defaultdict(list)
            best_alpha = get_bit_best_alpha(obj, k)
            title = obj
            print(f"  Objective: {obj}")
            for exp_name, record in data.items():
                if record.get("Objective", None) != obj:
                    continue
                if record.get("k", None) != k:
                    continue
                alpha = record.get("alpha_init", None)
                steps = np.array(record["step"])
                run_metric = record[metric] if isinstance(record[metric], list) else [record[metric]] * steps.shape[0]
                values = np.array([x if x is not None else np.nan for x in run_metric])
                # filter NaN or None in values and broadcast to steps
                mask = np.isfinite(values)
                steps = steps[mask]
                values = values[mask]
                groups[alpha].append(values)
            # print(groups[0.5])
            # raise
            build_image(
                steps,
                groups,
                ours_alpha=best_alpha,
                title=title,
                ylabel=ylabel,
                metric=metric,
                # x_range=[0, 10000],
                # ticks=[0, 2000, 4000, 6000, 8000],
                # tick_labels=["0", "2k", "4k", "6k", "8k"],
                # otick=9900,
                # otick_label="10k",
                file_name=f"plot_bit_{obj.replace('-','_').removesuffix('(λ)')}_alpha{best_alpha}_k{k}_{metric}"
            )


Processing metric: modes
  Objective: DB
  Objective: SubTB(λ)
  Objective: TB
Processing metric: spearman_corr_test
  Objective: DB
  Objective: SubTB(λ)
  Objective: TB
Processing metric: loss
  Objective: DB
  Objective: SubTB(λ)
  Objective: TB
Processing metric: forward_policy_entropy_eval
  Objective: DB
  Objective: SubTB(λ)
  Objective: TB


In [53]:
len(data)

for _,run in data.items():
    if run['alpha_init'] == 0.2:
        if run['k'] == 4:
            if run['objective'] == 'subtb':
                print(run['modes'])
                print(run['step'])
    # raise

[0, 0, 0, 0, 0, 0, 0, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 6, 7, 7, 7, 8, 9, 10, 12, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 17, 17, 18, 20, 21, 21, 21, 23, 23, 23, 23, 23, 24, 24, 25, 25, 25, 26, 26, 26, 26, 28, 28, 28, 28, 30, 32, 32, 33, 33, 33, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 37, 37, 37, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 40, 40, 40, 40, 41, 41]
[0, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, 10500, 11000, 11500, 12000, 12500, 13000, 13500, 14000, 14500, 15000, 15500, 16000, 16500, 17000, 17500, 18000, 18500, 19000, 19500, 20000, 20500, 21000, 21500, 22000, 22500, 23000, 23500, 24000, 24500, 25000, 25500, 26000, 26500, 27000, 27500, 28000, 28500, 29000, 29500, 30000, 30500, 31000, 31500, 32000, 32500, 33000, 33500, 34000, 34500, 35000, 35500, 36000, 36500, 37000, 37500, 38000, 38500, 39000, 39500, 40000, 40500, 41000, 41500, 42000, 42500, 43000, 43500, 44000, 44500, 45000, 4550