In [None]:
from utils_visualization import *

# Process results

In [49]:
# process_all_subfolders('./results', compute_metrics, './results_aggregated')

# Global Visualization Plots

In [52]:
envs = ['ST', 'SCW', 'MO']
uf = ['mcd_world_model', 'prob_world_model', 'qnet_ensemble', 'rnd', 'random']
div = ['all', 'success', 'failure']
labels = ['MCD UE', 'prob. UE', 'ens. UE', 'RND UE', 'rand. UE']
dpi = 300

In [None]:
for e in envs:
    for d in div:
        p_num = 100

        # Build list of result files for each uncertainty filter
        path_list = [f'./results_aggregated/{e}_{u}_{d}.csv' for u in uf]

        # Load policy and policy+cbf baselines
        upper = pd.read_csv(f"./results_aggregated/{e}_base_{d}.csv").loc[0].to_dict()
        lower = pd.read_csv(f"./results_aggregated/{e}_base_{d}.csv").loc[1].to_dict()

        # Metrics to plot
        metrics = [
            'success_rate', "reward_mean", "length_mean", "collisions_mean",
            "velocity_mean", "goal_approach_mean", "ema_divergence_mean",
            "true_positive_rate", "false_positive_rate", "true_negative_rate", "false_negative_rate",
            "uf_activation_prop", "cbf_activation_prop"
        ]

        # Generate one plot per metric
        for m in metrics:
            fig = plot_metrics(
                path_list,
                labels,
                cbf_config=1,
                upper_bound=upper,
                lower_bound=lower,
                ncols=1,
                plot_std=False,
                metrics=[m],
                title=f"{e} environment"
            )

            # Save plot to folder
            os.makedirs(f"./img_2/{e}_{d}", exist_ok=True)
            fig.savefig(f"./img_2/{e}_{d}/1_{e}_{d}_{m}", dpi=dpi, bbox_inches="tight")
            plt.close(fig)
            p_num += 1


In [None]:
for e in envs:
    for d in div:
        p_num = 100

        # Build list of result files for each uncertainty filter
        path_list = [f'./results_aggregated/{e}_{u}_{d}.csv' for u in uf]

        # Load baseline rows: policy (row 0) and policy+cbf (row 1)
        upper = pd.read_csv(f"./results_aggregated/{e}_base_{d}.csv").loc[0].to_dict()
        lower = pd.read_csv(f"./results_aggregated/{e}_base_{d}.csv").loc[1].to_dict()

        # Metrics to plot
        metrics = [
            'success_rate', "reward_mean", "length_mean", "collisions_mean",
            "velocity_mean", "goal_approach_mean", "ema_divergence_mean",
            "true_positive_rate", "false_positive_rate", "true_negative_rate", "false_negative_rate",
            "uf_activation_prop", "cbf_activation_prop"
        ]

        # Generate one plot per metric
        for m in metrics:
            fig = plot_metrics(
                path_list,
                labels,
                cbf_config=1,
                upper_bound=upper,
                lower_bound=lower,
                ncols=1,
                plot_std=False,
                metrics=[m],
                title=f"{e} environment",
                legend=False  # hide legend for cleaner plots
            )

            # Save plot into "no_legend" subfolder
            os.makedirs(f"./img_2/{e}_{d}/no_legend", exist_ok=True)
            fig.savefig(f"./img_2/{e}_{d}/no_legend/1_{e}_{d}_{m}", dpi=dpi, bbox_inches="tight")
            plt.close(fig)
            p_num += 1


In [None]:
for p in [['mcd_world_model', 'random'], ['prob_world_model', 'random'], ['qnet_ensemble', 'random'], ['rnd', 'random']]:
    for e in envs:
        for d in div:
            p_num = 200

            # Build list of result files for the two uncertainty estimators to compare
            path_list = [f'./results_aggregated/{e}_{u}_{d}.csv' for u in p]

            # Load baseline rows: policy (row 0) and policy+cbf (row 1)
            upper = pd.read_csv(f"./results_aggregated/{e}_base_{d}.csv").loc[0].to_dict()
            lower = pd.read_csv(f"./results_aggregated/{e}_base_{d}.csv").loc[1].to_dict()
                
            # Metric pairs to plot with plot_metric_series_multi
            pair_metrics = [
                ["collisions_mean", "length_mean"],
                ["collisions_mean", "velocity_mean"],
                ["velocity_mean", "goal_approach_mean"],
                ["velocity_mean", "ema_divergence_mean"],
            ]

            # Loop over metric pairs
            for x_metric, y_metric in pair_metrics:
                fig = plot_metric_series_multi(
                    path_list,
                    path_list,
                    x_metric,
                    y_metric,
                    labels,
                    'inferno',
                    'percentile',
                    upper,
                    upper,
                    lower,
                    lower,
                    1,
                    f"{e} environment"
                )

                # Save plots
                os.makedirs(f"./img_2/{e}_{d}", exist_ok=True)
                fig.savefig(f"./img_2/{e}_{d}/2_{p[0]}_{e}_{d}_{x_metric}_{y_metric}", dpi=dpi, bbox_inches="tight")
                plt.close(fig)
                p_num += 1

            # Special case: all_tests → success vs other metrics
            if d == "all":
                success_pairs = [
                    ["success_rate", "collisions_mean"],
                    ["success_rate", "velocity_mean"],
                ]

                path_list_y = [f'./results_aggregated/{e}_{u}_success.csv' for u in p]
                upper_y = pd.read_csv(f"./results_aggregated/{e}_base_success.csv").loc[0].to_dict()
                lower_y = pd.read_csv(f"./results_aggregated/{e}_base_success.csv").loc[1].to_dict()
                
                # Loop over success-based pairs
                for x_metric, y_metric in success_pairs:
                    fig = plot_metric_series_multi(
                        path_list,
                        path_list_y,
                        x_metric,
                        y_metric,
                        labels,
                        'inferno',
                        'percentile',
                        upper,
                        upper_y,
                        lower,
                        lower_y,
                        1,
                        f"{e} environment"
                    )
                    os.makedirs(f"./img_2/{e}_{d}", exist_ok=True)
                    fig.savefig(f"./img_2/{e}_{d}/2_{p[0]}_{e}_{d}_{x_metric}_{y_metric}", dpi=dpi, bbox_inches="tight")
                    plt.close(fig)
                    p_num += 1
                
                # Extra case: false positive vs true positive
                success_pairs = [['false_positive_rate', 'true_positive_rate']]
                path_list_2 = [f'./results_aggregated/{e}_{u}_{d}.csv' for u in p]
                
                for x_metric, y_metric in success_pairs:
                    fig = plot_metric_series_multi(
                        path_list_2,
                        path_list_2,
                        x_metric,
                        y_metric,
                        labels,
                        'inferno',
                        'percentile',
                        upper,
                        upper,
                        lower,
                        lower,
                        1,
                        f"{e} environment"
                    )
                    os.makedirs(f"./img_2/{e}_{d}", exist_ok=True)
                    fig.savefig(f"./img_2/{e}_{d}/2_{p[0]}_{e}_{d}_{x_metric}_{y_metric}", dpi=dpi, bbox_inches="tight")
                    plt.close(fig)
                    p_num += 1


# Local Visualization Plot

In [None]:
rnd = load_stats('ST_rnd_85pctl_cbf1_6423448', './results/ST_rnd')['stats']
rnd_th = 0.9697459936141968
prob = load_stats('ST_prob_world_model_85pctl_cbf1_6350506', './results/ST_prob_world_model')['stats']
prob_th = 3.568255662918091
qens = load_stats('ST_qnet_ensemble_85pctl_cbf1_6401558', './results/ST_qnet_ensemble')['stats']
qens_th = 0.21788763999938965
mcd = load_stats('ST_mcd_world_model_85pctl_cbf1_6375200', './results/ST_mcd_world_model')['stats']
mcd_th = 1.091325044631958

In [None]:
for mod, th, name, title in zip(
    [qens, mcd, rnd, prob],
    [qens_th, mcd_th, rnd_th, prob_th],
    ['qens', 'mcd', 'rnd', 'prob'],
    ['Ensemble', 'MCD', 'RND', 'Probabilistic']
):
    # Metric pairs to explore correlation with uncertainty
    pairs = [
        ('u_e', 'dist_goal'),
        ('u_e', 'angle_goal'),
        ('u_e', 'dist_ema'),
        ('u_e', 'angle_ema'),
        ('u_e', 'f_action'),
        ('u_e', 'r_action'),
        ('u_e', 'cbf_mean_change'),
        ('u_e', 'f_velocity'),
        ('u_e', 'r_velocity'),
        ('u_e', 'ray_mean'),
        ('u_e', 'ray_std'),
    ]

    for key_x, key_y in pairs:
        # Plot histogram/heatmap for each pair
        ax = plot_stats(
            mod,
            (key_x, key_y),
            filter_fn=lambda x: x['total_success'] == 1,  # only successful episodes
            v_line=th  # vertical line for τ threshold
        )
        fig = ax.figure  # retrieve figure from axis

        # Add a title to each subplot
        ax.set_title(f'ST environment - tau = 85 percentile', fontsize=14)

        # Save to disk
        os.makedirs(f"./img_2/additional/{name}", exist_ok=True)
        fig.savefig(f"./img_2/additional/{name}/{name}_{key_x}_{key_y}.png", dpi=dpi, bbox_inches="tight")
        plt.close(fig)


# Tabular Results

In [82]:
summary_global, summary_by_env, summary_best_per_env, summary_st_as_ref = analyze_folder('./results_aggregated')

In [None]:
# Round all summaries to 5 decimal places for consistency
summary_global = summary_global.round(5)
summary_by_env = summary_by_env.round(5)
summary_best_per_env = summary_best_per_env.round(5)
summary_st_as_ref = summary_st_as_ref.round(5)

# Export summaries to markdown text files (human-readable tables)
summary_global.to_markdown('./results_aggregated/summary_global.txt', index=False)
summary_by_env.to_markdown('./results_aggregated/summary_by_env.txt', index=False)
summary_best_per_env.to_markdown('./results_aggregated/summary_best_per_env.txt', index=False)
summary_st_as_ref.to_markdown('./results_aggregated/summary_st_as_ref.txt', index=False)


In [89]:
summary_global

Unnamed: 0,model,$\tau$,score,success rate,collisions,velocity,ep. length
0,MCD UE,85.0,0.747,0.855,0.564,0.637,273.391
1,Random UE,50.0,0.678,0.841,0.714,0.629,261.174
2,Prob. UE,90.0,0.657,0.838,0.866,0.65,251.807
3,Ens. UE,65.0,0.616,0.838,0.659,0.624,279.779
4,RND UE,80.0,0.582,0.856,0.731,0.617,286.687
5,Policy,-1.0,0.0,0.827,1.095,0.673,252.118
6,Policy + CBF,-1.0,0.0,0.823,0.317,0.572,354.587


In [90]:
summary_by_env

Unnamed: 0,model,env.,$\tau$,score,success rate,collisions,velocity,ep. length
0,MCD UE,MO,85.0,0.746,0.631,0.51,0.652,159.241
1,MCD UE,SCW,85.0,0.669,0.955,0.881,0.562,405.011
2,MCD UE,ST,85.0,0.825,0.98,0.302,0.696,255.922
3,Policy,MO,-1.0,0.0,0.561,0.835,0.661,152.49
4,Policy,SCW,-1.0,0.0,0.942,1.491,0.643,340.733
5,Policy,ST,-1.0,0.0,0.977,0.958,0.716,263.13
6,Policy + CBF,MO,-1.0,0.0,0.579,0.447,0.611,181.038
7,Policy + CBF,SCW,-1.0,0.0,0.933,0.451,0.474,534.404
8,Policy + CBF,ST,-1.0,0.0,0.958,0.053,0.63,348.32
9,Prob. UE,MO,90.0,0.609,0.575,0.665,0.652,152.552


In [91]:
summary_best_per_env

Unnamed: 0,model,env.,$\tau$,score,success rate,collisions,velocity,ep. length
10,MCD UE,MO,80.0,0.75,0.594,0.48,0.654,160.941
15,Policy,MO,-1.0,0.0,0.561,0.835,0.661,152.49
16,Policy + CBF,MO,-1.0,0.0,0.579,0.447,0.611,181.038
28,Prob. UE,MO,85.0,0.616,0.596,0.567,0.643,156.282
39,Ens. UE,MO,65.0,0.677,0.577,0.569,0.651,155.888
52,Random UE,MO,50.0,0.594,0.566,0.623,0.639,156.45
70,RND UE,MO,70.0,0.573,0.618,0.551,0.638,166.706
88,MCD UE,SCW,85.0,0.669,0.955,0.881,0.562,405.011
92,Policy,SCW,-1.0,0.0,0.942,1.491,0.643,340.733
93,Policy + CBF,SCW,-1.0,0.0,0.933,0.451,0.474,534.404


In [92]:
summary_st_as_ref

Unnamed: 0,model,env.,$\tau$,score,success rate,collisions,velocity,ep. length
0,MCD UE,MO,85.0,0.746,0.631,0.51,0.652,159.241
1,MCD UE,SCW,85.0,0.669,0.955,0.881,0.562,405.011
2,MCD UE,ST,85.0,0.825,0.98,0.302,0.696,255.922
3,MCD UE,ALL,85.0,0.747,0.855,0.564,0.637,273.391
4,MCD UE,OOD,85.0,0.708,0.793,0.696,0.607,282.126
5,Policy,MO,-1.0,0.0,0.561,0.835,0.661,152.49
6,Policy,SCW,-1.0,0.0,0.942,1.491,0.643,340.733
7,Policy,ST,-1.0,0.0,0.977,0.958,0.716,263.13
8,Policy,ALL,-1.0,0.0,0.827,1.095,0.673,252.118
9,Policy,OOD,-1.0,0.0,0.752,1.163,0.652,246.612
