In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
serif = False
render_format = "pdf"
if serif:
    dir_postfix = ""
    plt.rcParams["font.family"] = "serif"
else:
    dir_postfix = "sans"
    plt.rcParams["font.family"] = "Liberation Sans"
plt.rcParams["font.size"] = 10
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42

from results import *

# PushT Domain Analysis

## Main Result

In [2]:
pusht_exp_keys = []

# Generate temporal consistency experiment keys.
pusht_exp_keys += get_temporal_consistency_exp_keys(
    pred_horizons=[16],
    sample_sizes=[256],
    error_fns=["mmd_rbf_all_median", "kde_kl_all_rev_eig", "kde_kl_all_for_eig", "mse_all"],
    aggr_fns=["min"],
)

# Generate loss function experiment keys.
pusht_exp_keys += get_loss_function_exp_keys(
    loss_fns=["noise_pred_all", "temporal_noise_pred_all"],
    sample_sizes=[10],
)

# Generate reconstruction experiment keys.
pusht_exp_keys += get_loss_function_exp_keys(
    loss_fns=["action_rec_all", "temporal_action_rec_all"],
    sample_sizes=[4],
)

# Generate embedding experiment keys.
pusht_exp_keys += get_embedding_exp_keys(
    embeddings=["encoder_feat", "clip_feat", "resnet_feat"],
    score_fns=["mahal"],
)

# Generate ensemble experiment keys.
pusht_exp_keys += get_ensemble_exp_keys(
    pred_horizons=[16],
    sample_sizes=[256],
    action_spaces=["all"],
)

# Seed 0.
pusht_8_metrics_0 = compile_metrics(
    domain="0525_pusht_8",
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys,
    return_test_data=True,
)
pusht_8_aggr_metrics_0 = aggregate_metrics(
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys,
    data=pusht_8_metrics_0,
)
pusht_scores_main_0 = extract_metric_dict(
    exp_keys=pusht_exp_keys,
    data=pusht_8_aggr_metrics_0,
    metric="Balanced Accuracy",
)

# Seed 2.
pusht_8_metrics_1 = compile_metrics(
    domain="0526_pusht_8",
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys,
    return_test_data=True,
)
pusht_8_aggr_metrics_1 = aggregate_metrics(
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys,
    data=pusht_8_metrics_1,
)
pusht_scores_main_1 = extract_metric_dict(
    exp_keys=pusht_exp_keys,
    data=pusht_8_aggr_metrics_1,
    metric="Balanced Accuracy"
)

# Seed 3.
pusht_8_metrics_2 = compile_metrics(
    domain="0527_pusht_8",
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys,
    return_test_data=True,
)
pusht_8_aggr_metrics_2 = aggregate_metrics(
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys,
    data=pusht_8_metrics_2,
)
pusht_scores_main_2 = extract_metric_dict(
    exp_keys=pusht_exp_keys,
    data=pusht_8_aggr_metrics_2,
    metric="Balanced Accuracy"
)

In [3]:
pusht_exp_keys_left = [
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median", 
    "loss_fn_noise_pred_all_sample_size_10", 
    "loss_fn_action_rec_all_sample_size_4", 
    "pred_horizon_16_sample_size_256_action_space_all",
    "embedding_encoder_feat_score_fn_mahal",
]

pusht_exp_keys_mid = [
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median",
    "pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min", 
    "loss_fn_temporal_noise_pred_all_sample_size_10", 
    "loss_fn_noise_pred_all_sample_size_10", 
    "loss_fn_temporal_action_rec_all_sample_size_4", 
    "loss_fn_action_rec_all_sample_size_4", 
]

pusht_exp_keys_right = [
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_for_eig",
    "pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min",
]

pusht_exp_labels = {
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median": "STAC MMD (Ours)",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig": "STAC Rev. KL (Ours)",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_for_eig": "STAC For. KL (Ours)",
    "pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min": "Temporal Non-Distr. Min.",
    "loss_fn_noise_pred_all_sample_size_10": "DDPM Loss",
    "loss_fn_temporal_noise_pred_all_sample_size_10": "Temporal DDPM Loss",
    "loss_fn_action_rec_all_sample_size_4": "Diffusion Reconstruction",
    "loss_fn_temporal_action_rec_all_sample_size_4": "Temporal Diffusion Recon.",
    "pred_horizon_16_sample_size_256_action_space_all": "Diffusion Output Variance",
    "embedding_encoder_feat_score_fn_mahal": "Policy Encoder Embedding",
    "embedding_clip_feat_score_fn_mahal": "CLIP Embedding",
    "embedding_resnet_feat_score_fn_mahal": "ResNet Embedding",
}

# Color code 1.
pusht_exp_colors = {
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median": "#d94801",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig": "#f16913",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_for_eig": "#fd8d3c",
    "pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min": "white",
    "loss_fn_noise_pred_all_sample_size_10": "#6baed6",
    "loss_fn_temporal_noise_pred_all_sample_size_10": "#2171b5",
    "loss_fn_action_rec_all_sample_size_4": "#9e9ac8",
    "loss_fn_temporal_action_rec_all_sample_size_4": "#6a51a3",
    "pred_horizon_16_sample_size_256_action_space_all": "#969696",
    "embedding_encoder_feat_score_fn_mahal": "#7fcdbb",
    # "embedding_clip_feat_score_fn_mahal": "#41ab5d",
    # "embedding_resnet_feat_score_fn_mahal": "#74c476",
}

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 3.5), sharey=True)

ymin = 0.5
ymax = 1.0

# Left Plot: Families of OOD.
bar_width = 1.0
method_spacing_factor = 1.0
category_spacing = 1.25

index = np.arange(len(pusht_exp_keys_left))
positions = index * bar_width * method_spacing_factor
positions += index * category_spacing
for i, k in enumerate(pusht_exp_keys_left):
    h0 = pusht_scores_main_0[k]
    h1 = pusht_scores_main_1[k]
    h2 = pusht_scores_main_2[k]
    h = sum([h0, h1, h2]) / 3
    hmax = max([h0, h1, h2])
    hmin = min([h0, h1, h2])
    p = positions[i]
    l = pusht_exp_labels[k]
    c = pusht_exp_colors[k]
    if c == "white":
        ax1.bar(p, h, bar_width, label=l, color=c, edgecolor='black', linewidth=1.5)
    else:
        ax1.bar(p, h, bar_width, label=l, color=c)
    ax1.errorbar(p, h, yerr=[[h - hmin], [hmax - h]], fmt='none', ecolor='black', elinewidth=1.75, capsize=0)

ax1.set_title("PushT: Failure Detection Methods", fontsize=17)
ax1.set_axisbelow(True)
ax1.set_ylim(ymin, ymax)
ax1.set_ylabel("Balanced Accuracy", fontsize=16)
ax1.yaxis.grid(True, linestyle='-', linewidth=0.8, color="gray")
ax1.tick_params(axis='y', labelsize=12)

xticks = positions
xticklabels = [
    "Temporal\nConsistency",
    "Empirical\nLoss",
    "Diffusion\nRecon.",
    "Diffusion\nVariance",
    "Embedding\nSimilarity",
]
ax1.set_xticks(xticks)
ax1.set_xticklabels(xticklabels, fontsize=12)

ax1.spines['top'].set_linewidth(1.5)
ax1.spines['right'].set_linewidth(1.5)
ax1.spines['left'].set_linewidth(1.5)
ax1.spines['bottom'].set_linewidth(2.5)
# ax1.legend(loc='upper right', fancybox=True, framealpha=0.7)

# Mid Plot: Temporal consistency.
bar_width = 0.5
method_spacing_factor = 1.0
category_spacing = 1.0

index = np.arange(len(pusht_exp_keys_mid))
positions = index * bar_width * method_spacing_factor
positions[2:] += category_spacing
positions[4:] += category_spacing
for i, k in enumerate(pusht_exp_keys_mid):
    h0 = pusht_scores_main_0[k]
    h1 = pusht_scores_main_1[k]
    h2 = pusht_scores_main_2[k]
    h = sum([h0, h1, h2]) / 3
    hmax = max([h0, h1, h2])
    hmin = min([h0, h1, h2])
    p = positions[i]
    l = pusht_exp_labels[k]
    c = pusht_exp_colors[k]
    if c == "white":
        ax2.bar(p, h, bar_width, label=l, color=c, edgecolor='black', linewidth=1.5,  hatch='\\')
    else:
        ax2.bar(p, h, bar_width, label=l, color=c)
    ax2.errorbar(p, h, yerr=[[h - hmin], [hmax - h]], fmt='none', ecolor='black', elinewidth=1.75, capsize=0)

ax2.set_title("Temporal Score Function Ablation", fontsize=17)
ax2.set_axisbelow(True)
ax2.set_ylim(ymin, ymax)
ax2.yaxis.grid(True, linestyle='-', linewidth=0.8, color="gray")

xticks = [
    positions[:2].mean().item(),
    positions[2:4].mean().item(),
    positions[4:6].mean().item(),
]
xticklabels = [
    "Temporal\nConsistency",
    "Empirical\nLoss",
    "Diffusion\nRecon.",
]
ax2.set_xticks(xticks)
ax2.set_xticklabels(xticklabels, fontsize=12)

ax2.spines['top'].set_linewidth(1.5)
ax2.spines['right'].set_linewidth(1.5)
ax2.spines['left'].set_linewidth(1.5)
ax2.spines['bottom'].set_linewidth(2.5)
# ax2.legend(loc='upper right', fancybox=True, framealpha=0.7)

# Right Plot: Divergence metrics.
bar_width = 0.5
method_spacing_factor = 1.0
category_spacing = 1.0

index = np.arange(len(pusht_exp_keys_right))
positions = index * bar_width * method_spacing_factor
positions += index * category_spacing
for i, k in enumerate(pusht_exp_keys_right):
    h0 = pusht_scores_main_0[k]
    h1 = pusht_scores_main_1[k]
    h2 = pusht_scores_main_2[k]
    h = sum([h0, h1, h2]) / 3
    hmax = max([h0, h1, h2])
    hmin = min([h0, h1, h2])
    p = positions[i]
    l = pusht_exp_labels[k]
    c = pusht_exp_colors[k]
    if c == "white":
        ax3.bar(p, h, bar_width, label=l, color=c, edgecolor='black', linewidth=1.5, hatch='\\')
    else:
        ax3.bar(p, h, bar_width, label=l, color=c)
    # ax3.errorbar(p, h, yerr=[[h - hmin], [hmax - h]], fmt='none', ecolor='black', capsize=8)
    ax3.errorbar(p, h, yerr=[[h - hmin], [hmax - h]], fmt='none', ecolor='black', elinewidth=1.75, capsize=0)

ax3.set_title("Statistical Distance Function Ablation", fontsize=17)
ax3.set_axisbelow(True)
ax3.set_ylim(ymin, ymax)
ax3.yaxis.grid(True, linestyle='-', linewidth=0.8, color="gray")

xticks = positions
xticklabels = [
    "STAC\nMMD",
    "STAC\nRev. KL",
    "STAC\nFor. KL",
    "Non-Distr.\nMin.",
]
ax3.set_xticks(xticks)
ax3.set_xticklabels(xticklabels, fontsize=12)

ax3.spines['top'].set_linewidth(1.5)
ax3.spines['right'].set_linewidth(1.5)
ax3.spines['left'].set_linewidth(1.5)
ax3.spines['bottom'].set_linewidth(2.5)
# ax3.legend(loc='upper right', fancybox=True, framealpha=0.7)

h1, l1 = ax1.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
h3, l3 = ax3.get_legend_handles_labels()
legend_exp_keys = [
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_for_eig",
    "pred_horizon_16_sample_size_256_action_space_all",
    "loss_fn_temporal_noise_pred_all_sample_size_10",
    "loss_fn_noise_pred_all_sample_size_10",
    "loss_fn_temporal_action_rec_all_sample_size_4",
    "loss_fn_action_rec_all_sample_size_4",
    "embedding_encoder_feat_score_fn_mahal",
    "pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min",
]

labels = []
handles = []
for k in legend_exp_keys:
    l = pusht_exp_labels[k]
    for _l, _h in zip(l1 + l2 + l3, h1 + h2 + h3):
        if l == _l and l not in labels:
            labels.append(l)
            handles.append(_h)

fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.025), ncol=5, fancybox=True, fontsize=13)
plt.subplots_adjust(bottom=0.2, wspace=1.0)
plt.tight_layout()

save_path = CWD / ".." / f"figures_{dir_postfix}" / f"pusht-result.{render_format}"
plt.savefig(save_path, format=render_format, dpi=300, bbox_inches='tight', transparent=True)
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 4), sharey=True)

ymin = 0.5
ymax = 1.0

# Left Plot: Families of OOD.
bar_width = 1.0
method_spacing_factor = 1.0
category_spacing = 1.25

index = np.arange(len(pusht_exp_keys_left))
positions = index * bar_width * method_spacing_factor
positions += index * category_spacing
for i, k in enumerate(pusht_exp_keys_left):
    h0 = pusht_scores_main_0[k]
    h1 = pusht_scores_main_1[k]
    h2 = pusht_scores_main_2[k]
    h = sum([h0, h1, h2]) / 3
    hmax = max([h0, h1, h2])
    hmin = min([h0, h1, h2])
    p = positions[i]
    l = pusht_exp_labels[k]
    c = pusht_exp_colors[k]
    if c == "white":
        ax1.bar(p, h, bar_width, label=l, color=c, edgecolor='black', linewidth=1.5)
    else:
        ax1.bar(p, h, bar_width, label=l, color=c)
    ax1.errorbar(p, h, yerr=[[h - hmin], [hmax - h]], fmt='none', ecolor='black', elinewidth=1.75, capsize=0)

ax1.set_title("PushT: Failure Detection Methods", fontsize=24)
ax1.set_axisbelow(True)
ax1.set_ylim(ymin, ymax)
ax1.set_ylabel("Balanced Accuracy", fontsize=22)
ax1.yaxis.grid(True, linestyle='-', linewidth=0.8, color="gray")
ax1.tick_params(axis='y', labelsize=14)

xticks = positions
xticklabels = [
    "Temporal\nConsistency",
    "Empirical\nLoss",
    "Diffusion\nRecon.",
    "Diffusion\nVariance",
    "Embedding\nSimilarity",
]
ax1.set_xticks(xticks)
ax1.set_xticklabels(xticklabels, fontsize=14)

ax1.spines['top'].set_linewidth(1.5)
ax1.spines['right'].set_linewidth(1.5)
ax1.spines['left'].set_linewidth(1.5)
ax1.spines['bottom'].set_linewidth(2.5)
# ax1.legend(loc='upper right', fancybox=True, framealpha=0.7)

# Mid Plot: Temporal consistency.
bar_width = 0.5
method_spacing_factor = 1.0
category_spacing = 1.0

index = np.arange(len(pusht_exp_keys_mid))
positions = index * bar_width * method_spacing_factor
positions[2:] += category_spacing
positions[4:] += category_spacing
for i, k in enumerate(pusht_exp_keys_mid):
    h0 = pusht_scores_main_0[k]
    h1 = pusht_scores_main_1[k]
    h2 = pusht_scores_main_2[k]
    h = sum([h0, h1, h2]) / 3
    hmax = max([h0, h1, h2])
    hmin = min([h0, h1, h2])
    p = positions[i]
    l = pusht_exp_labels[k]
    c = pusht_exp_colors[k]
    if c == "white":
        ax2.bar(p, h, bar_width, label=l, color=c, edgecolor='black', linewidth=1.5,  hatch='\\')
    else:
        ax2.bar(p, h, bar_width, label=l, color=c)
    ax2.errorbar(p, h, yerr=[[h - hmin], [hmax - h]], fmt='none', ecolor='black', elinewidth=1.75, capsize=0)

ax2.set_title("Temporal Score Function Ablation", fontsize=24)
ax2.set_axisbelow(True)
ax2.set_ylim(ymin, ymax)
ax2.yaxis.grid(True, linestyle='-', linewidth=0.8, color="gray")

xticks = [
    positions[:2].mean().item(),
    positions[2:4].mean().item(),
    positions[4:6].mean().item(),
]
xticklabels = [
    "Temporal\nConsistency",
    "Empirical\nLoss",
    "Diffusion\nRecon.",
]
ax2.set_xticks(xticks)
ax2.set_xticklabels(xticklabels, fontsize=14)

ax2.spines['top'].set_linewidth(1.5)
ax2.spines['right'].set_linewidth(1.5)
ax2.spines['left'].set_linewidth(1.5)
ax2.spines['bottom'].set_linewidth(2.5)
# ax2.legend(loc='upper right', fancybox=True, framealpha=0.7)

h1, l1 = ax1.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
legend_exp_keys = [
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_for_eig",
    "pred_horizon_16_sample_size_256_action_space_all",
    "loss_fn_temporal_noise_pred_all_sample_size_10",
    "loss_fn_noise_pred_all_sample_size_10",
    "loss_fn_temporal_action_rec_all_sample_size_4",
    "loss_fn_action_rec_all_sample_size_4",
    "embedding_encoder_feat_score_fn_mahal",
    "pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min",
]

labels = []
handles = []
for k in legend_exp_keys:
    l = pusht_exp_labels[k]
    for _l, _h in zip(l1 + l2, h1 + h2):
        if l == _l and l not in labels:
            labels.append(l)
            handles.append(_h)

# fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.025), ncol=4, fancybox=True, fontsize=13)
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.025), ncol=4, fancybox=True, fontsize=13)
plt.subplots_adjust(bottom=0.2, wspace=1.0)
plt.tight_layout()

save_path = CWD / ".." / f"figures_{dir_postfix}" / f"pusht-result-small.{render_format}"
plt.savefig(save_path, format=render_format, dpi=300, bbox_inches='tight', transparent=True)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 4), sharey=True)

ymin = 0.5
ymax = 1.0

# Plot: Distance functions.
bar_width = 0.5
method_spacing_factor = 1.0
category_spacing = 1.0

index = np.arange(len(pusht_exp_keys_right))
positions = index * bar_width * method_spacing_factor
positions += index * category_spacing
for i, k in enumerate(pusht_exp_keys_right):
    h0 = pusht_scores_main_0[k]
    h1 = pusht_scores_main_1[k]
    h2 = pusht_scores_main_2[k]
    h = sum([h0, h1, h2]) / 3
    hmax = max([h0, h1, h2])
    hmin = min([h0, h1, h2])
    p = positions[i]
    l = pusht_exp_labels[k]
    c = pusht_exp_colors[k]
    if c == "white":
        ax.bar(p, h, bar_width, label=l, color=c, edgecolor='black', linewidth=1.5, hatch='\\')
    else:
        ax.bar(p, h, bar_width, label=l, color=c)
    ax.errorbar(p, h, yerr=[[h - hmin], [hmax - h]], fmt='none', ecolor='black', elinewidth=1.75, capsize=0)

ax.set_title("Statistical Distance Function Ablation", fontsize=18)
ax.set_axisbelow(True)
ax.set_ylim(ymin, ymax)
ax.set_ylabel("Balanced Accuracy", fontsize=16)
ax.yaxis.grid(True, linestyle='-', linewidth=0.8, color="gray")
ax.tick_params(axis='y', labelsize=13)

xticks = positions
xticklabels = [
    "STAC\nMMD",
    "STAC\nRev. KL",
    "STAC\nFor. KL",
    "Non-Distr.\nMin.",
]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, fontsize=13)

ax.spines['top'].set_linewidth(1.5)
ax.spines['right'].set_linewidth(1.5)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(2.5)
# ax.legend(loc='upper right', fancybox=True, framealpha=0.7)

h3, l3 = ax.get_legend_handles_labels()
legend_exp_keys = [
    "pred_horizon_16_sample_size_256_error_fn_mmd_rbf_all_median",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_rev_eig",
    "pred_horizon_16_sample_size_256_error_fn_kde_kl_all_for_eig",
    "pred_horizon_16_sample_size_256_error_fn_mse_all_aggr_fn_min",
]

labels = []
handles = []
for k in legend_exp_keys:
    l = pusht_exp_labels[k]
    for _l, _h in zip(l3, h3):
        if l == _l and l not in labels:
            labels.append(l)
            handles.append(_h)

fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.025), ncol=2, fancybox=True, fontsize=13)
plt.subplots_adjust(bottom=0.2, wspace=1.0)
plt.tight_layout()

save_path = CWD / ".." / f"figures_{dir_postfix}" / f"pusht-result-functions.{render_format}"
plt.savefig(save_path, format=render_format, dpi=300, bbox_inches='tight', transparent=True)
plt.show()

## Ablation Result: Prediction & Execution Horizon

In [7]:
pusht_exp_keys_abl = []

# Generate temporal consistency experiment keys.
pusht_exp_keys_abl += get_temporal_consistency_exp_keys(
    pred_horizons=[8, 12, 16],
    sample_sizes=[256],
    error_fns=["mmd_rbf_all_median"],
    aggr_fns=[None],
)

# Execution horizon = 2
pusht_2_metrics_abl = compile_metrics(
    domain="0525_pusht_2",
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys_abl,
)
pusht_2_aggr_metrics_abl = aggregate_metrics(
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys_abl,
    data=pusht_2_metrics_abl,
)

# Execution horizon = 4
pusht_4_metrics_abl = compile_metrics(
    domain="0525_pusht_4",
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys_abl,
)
pusht_4_aggr_metrics_abl = aggregate_metrics(
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys_abl,
    data=pusht_4_metrics_abl,
)

# Execution horizon = 8
pusht_8_metrics_abl = compile_metrics(
    domain="0525_pusht_8",
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys_abl,
)
pusht_8_aggr_metrics_abl = aggregate_metrics(
    splits=["na", "hh"],
    exp_keys=pusht_exp_keys_abl,
    data=pusht_8_metrics_abl,
)

pusht_scores_abl = []
for data in [pusht_2_aggr_metrics_abl, pusht_4_aggr_metrics_abl, pusht_8_aggr_metrics_abl]:
    pusht_scores_abl += extract_metric_list(
        exp_keys=pusht_exp_keys_abl,
        data=data,
        metric="Balanced Accuracy"
    ).tolist()

In [None]:
# Color settings.
orange_plot = False

pusht_exp_labels_abl = [
    "$h = 8$",
    "$h = 12$",
    "$h = 16$",
    "$h = 8$",
    "$h = 12$",
    "$h = 16$",
    "$h = 12$",
    "$h = 16$",
]

if orange_plot:
    pusht_exp_colors_abl = [
        "#fdd0a2",
        "#fd8d3c",
        "#d94801",
        "#fdd0a2",
        "#fd8d3c",
        "#d94801",
        "#fd8d3c",
        "#d94801",
    ]
else:
    pusht_exp_colors_abl = [
        "#c6dbef",
        "#6baed6",
        "#2171b5",
        "#dadaeb",
        "#9e9ac8",
        "#6a51a3",
        "#fd8d3c",
        "#d94801",
    ]
    grayscale_colors = [
        "#cccccc",
        "#969696", 
        "#636363",
    ]

# Plot settings
ymin = 0.5
ymax = 1.0
bar_width = 1.0
method_spacing_factor = 1.0
category_spacing = 2.5

index = np.arange(len(pusht_exp_labels_abl))
positions = index * bar_width * method_spacing_factor
positions[3:] += category_spacing
positions[6:] += category_spacing

fig, ax = plt.subplots(figsize=(7, 5))
for i in range(len(pusht_exp_labels_abl)):
    p = positions[i]
    h = pusht_scores_abl[i]
    l = pusht_exp_labels_abl[i]
    c = pusht_exp_colors_abl[i]
    bar = ax.bar(p, h, bar_width, label=l, color=c)

ax.set_ylim(ymin, ymax)
ax.set_title("Pred. and Exec. Horizon Ablation", fontsize=24)
ax.set_ylabel("Balanced Accuracy", fontsize=20)
ax.tick_params(axis="y", labelsize=16)
ax.set_axisbelow(True)
ax.yaxis.grid(True, linestyle='-', linewidth=0.8, color="gray")

xticks = [
    positions[0:3].mean().item(),
    positions[3:6].mean().item(),
    positions[6:8].mean().item(),
]
xticklabels = [
    "Execution\nHorizon 2",
    "Execution\nHorizon 4",
    "Execution\nHorizon 8",
]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, fontsize=20)

ax.spines['top'].set_linewidth(1.5)
ax.spines['right'].set_linewidth(1.5)
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(2)

if orange_plot:
    labels = []
    handles = []
    for _h, _l in zip(*ax.get_legend_handles_labels()):
        if _l not in labels:
            labels.append(_l)
            handles.append(_h)
else:
    labels = ["$h = 8$", "$h = 12$", "$h = 16$"]
    handles = [plt.bar([0], [0], color=grayscale_colors[i]) for i in range(len(grayscale_colors))]

fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.54, 0.025), ncol=3, fancybox=True, fontsize=18)

plt.tight_layout()
save_path = CWD / ".." / f"figures_{dir_postfix}" / f"pusht-abl-result.{render_format}"
plt.savefig(save_path, format=render_format, dpi=300, bbox_inches='tight', transparent=True)
plt.show()