In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

poster = True
figscale = 1.75 if poster else 1.0

sns.set_theme(
    context="poster" if poster else "paper",
    style="whitegrid",
    rc={
        "figure.figsize": (figscale * 3.5, figscale * 3.5),
        "figure.dpi": 150,
        "savefig.dpi": 1000,
        # "figure.constrained_layout.use": True,
        "pdf.fonttype": 42,
    },
)

In [None]:
df_cnn = pd.read_csv("../data/out/generalization/e7ivqipk.csv")
df_cnn["filter"] = "cnn"
df_lmb = pd.read_csv("../data/out/lmb_summary.csv")
df_glmb = pd.read_csv("../data/out/glmb_summary.csv")
df_combined = pd.concat([df_cnn, df_lmb, df_glmb]).reset_index(drop=True)

In [None]:
(
    df_combined.groupby(["scale", "filter"])["ospa"]
    .agg(["mean", "std"])
    .style.format("{:.0f}")
)

In [None]:
data = (
    df_combined.groupby(["filter", "scale", "simulation_idx"], as_index=False)
    .mean()
    .replace({"cnn": "CNN", "lmb": "LMB", "glmb": "GLMB"})
)

plt.figure()
sns.barplot(data=data, x="scale", y="mse", hue="filter", errorbar="ci")
plt.legend(
    loc="center",
    bbox_to_anchor=(1.1, -0.3, -1.2, 0.1),
    ncols=3,
    title=None,
    fontsize="x-small" if poster else None,
)
plt.ylabel("MSE" if poster else "Mean Squared Error")
plt.xlabel("Window width (km)")
if poster:
    plt.title("MTT Transfer Learning")
plt.savefig("../figures/mtt_mse.pdf", bbox_inches="tight")
plt.savefig("../figures/mtt_mse.png", bbox_inches="tight")
plt.show()

plt.figure()
sns.barplot(data=data, x="scale", y="ospa", hue="filter", errorbar="ci")
plt.legend(
    loc="center",
    bbox_to_anchor=(1.1, -0.3, -1.2, 0.1),
    ncols=3,
    title=None,
    fontsize="x-small" if poster else None,
)
plt.ylabel("OSPA (m)" if poster else "Optimal Sub-pattern Assignment")
plt.xlabel("Window width (km)")
if poster:
    plt.title("MTT Transfer Learning")
plt.savefig("../figures/mtt_ospa.pdf", bbox_inches="tight")
plt.savefig("../figures/mtt_ospa.png", bbox_inches="tight")
plt.show()

# plt.figure()
# sns.barplot(data=data.query("filter != 'CNN'"), x="scale", y="ospa1", hue="filter", errorbar="ci")
# plt.legend(loc="lower center", bbox_to_anchor=(0.5, 1.05), ncol=3, title=None)
# plt.ylabel("Optimal Sub-pattern Assignment (p=1)")
# plt.xlabel("Window width (km)")
# plt.savefig("../figures/ospa1.pdf")
# plt.savefig("../figures/ospa1.png")
# plt.show()

# CNN Runtime

In [None]:
from scipy.stats import zscore

df = pd.read_csv("../data/out/generalization/e7ivqipk_runtime.csv")
# filter out columns with zscore greater than 3
df = df[df.groupby("scale").forward.transform(zscore).abs() < 3]
df.groupby("scale").agg(["mean", "std"]).apply(lambda x: (x * 1000).round(2))