# Robustness Analysis

Analyze policy robustness across environment perturbations.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="whitegrid")

In [None]:
RESULTS_CSV = "/content/drive/MyDrive/results_combined_new.csv"
ROBUST_PER_SEED = "/content/drive/MyDrive/robustness_deltas_per_seed.csv"
ROBUST_AGG = "/content/drive/MyDrive/robustness_deltas_agg.csv"

In [None]:
df = pd.read_csv(RESULTS_CSV)
df_final = df[df["phase"] == "final"].copy()
df_final["env_clean"] = df_final["env"].str.strip()

In [None]:
grouped = df_final.groupby(["baseline", "env_clean", "seed"], as_index=False).agg(
    mean_return=("mean_return", "mean")
)

pivot = grouped.pivot_table(
    index=["baseline", "seed"],
    columns="env_clean",
    values="mean_return"
).reset_index()

for col in ["Noise", "Delay", "Combo"]:
    pivot[f"delta_{col}"] = pivot[col] - pivot["No Noise"]

pivot.to_csv(ROBUST_PER_SEED, index=False)

In [None]:
agg = pivot.groupby("baseline").agg({
    "delta_Noise": ["mean", "std"],
    "delta_Delay": ["mean", "std"],
    "delta_Combo": ["mean", "std"],
})

agg.columns = ["_".join(col) for col in agg.columns]
agg = agg.reset_index()
agg.to_csv(ROBUST_AGG, index=False)

print(agg.round(2))

In [None]:
delta_cols = ["delta_Noise", "delta_Delay", "delta_Combo"]
plot_df = pivot.melt(
    id_vars=["baseline", "seed"],
    value_vars=delta_cols,
    var_name="perturbation",
    value_name="delta"
)

plt.figure(figsize=(12, 6))
sns.barplot(data=plot_df, x="baseline", y="delta", hue="perturbation")
plt.axhline(0, color="black", linestyle="--", alpha=0.5)
plt.ylabel("Return Delta (vs No Noise)")
plt.xlabel("Baseline")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
heatmap_data = agg.set_index("baseline")[["delta_Noise_mean", "delta_Delay_mean", "delta_Combo_mean"]]
heatmap_data.columns = ["Noise", "Delay", "Combo"]

plt.figure(figsize=(8, 6))
sns.heatmap(heatmap_data, annot=True, fmt=".1f", cmap="RdYlGn", center=0)
plt.title("Robustness Delta (higher = more robust)")
plt.tight_layout()
plt.show()