In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from panda.utils.plot_utils import apply_custom_style, make_box_plot

In [None]:
apply_custom_style("../config/plotting.yaml")

In [None]:
figs_save_dir = os.path.join("../figs", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
# data_split = "final_skew40/test_zeroshot"
data_split = "test_zeroshot"

run_names_panda = {
    # "Panda 72M": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "patchtst",
    #     "panda_nh12_dmodel768_mixedp-4",
    #     data_split,
    # ),
    # "Panda 42M": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "patchtst",
    #     "panda_nh10_dmodel640-1",
    #     data_split,
    # ),
    # "Panda 21M": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "patchtst",
    #     "polyembed_21M_iter400k_dataimproved-2",
    #     data_split,
    # ),
    "Panda 21M": os.path.join(
        WORK_DIR,
        "eval_results",
        "patchtst",
        "pft_chattn_emb_w_poly-0",
        data_split,
    ),
}

run_names_chronos_zs = {
    # "Chronos 20M": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "chronos",
    #     "chronos_mini_zeroshot",
    #     data_split,
    # ),
    # "Chronos 46M": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     # "chronos_nondeterministic",
    #     "chronos",
    #     "chronos_small_zeroshot",
    #     data_split,
    # ),
    "Chronos 200M": os.path.join(
        WORK_DIR,
        "eval_results",
        # "chronos",
        "chronos_nondeterministic",
        "chronos_base_zeroshot",
        data_split,
    ),
}

run_names_chronos_sft = {
    # "Chronos 46M SFT": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "chronos",
    #     # "chronos_nondeterministic",
    #     "chronos_small_ft-4",
    #     data_split,
    # ),
    "Chronos 20M SFT": os.path.join(
        WORK_DIR,
        "eval_results",
        # "chronos",
        "chronos_nondeterministic",
        "chronos_t5_mini_ft-0",
        data_split,
    ),
}

run_metrics_dirs_all_groups = {
    "panda": run_names_panda,
    "chronos_zs": run_names_chronos_zs,
    "chronos_sft": run_names_chronos_sft,
}

In [None]:
run_metrics_dirs_all_groups

In [None]:
metrics_all = defaultdict(lambda: defaultdict(dict))
for run_group, run_metrics_dir_dict in run_metrics_dirs_all_groups.items():
    print(f"Run group: {run_group}")
    for run_abbrv, run_metrics_dir in run_metrics_dir_dict.items():
        print(run_abbrv)
        if not os.path.exists(run_metrics_dir):
            continue
        run_abbrv = str(run_abbrv)
        print(f"{run_abbrv}: {run_metrics_dir}")
        csv_files = [file for file in os.listdir(run_metrics_dir) if file.endswith(".csv")]
        for file in sorted(
            csv_files,
            key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
        ):
            prediction_length = int(file.split("_pred")[1].split(".csv")[0])
            # print(f"Prediction length: {prediction_length} for {run_abbrv}")
            with open(os.path.join(run_metrics_dir, file)) as f:
                metrics = pd.read_csv(f).to_dict()
                metrics_all[run_group][run_abbrv][prediction_length] = metrics

In [None]:
metrics_all.keys()

In [None]:
unrolled_metrics_all_groups = defaultdict(lambda: defaultdict(dict))
for run_group, all_metrics_of_run_group in metrics_all.items():
    print(run_group)
    for run_abbrv, all_metrics_of_run_abbrv in all_metrics_of_run_group.items():
        print(run_abbrv)
        for run_name, metrics in all_metrics_of_run_abbrv.items():
            print(run_name)
            systems = metrics.pop("system")
            metrics_unrolled = {k: list(v.values()) for k, v in metrics.items()}
            print(metrics_unrolled.keys())
            unrolled_metrics_all_groups[run_group][run_abbrv][run_name] = metrics_unrolled

In [None]:
unrolled_metrics_all_combined = {
    **unrolled_metrics_all_groups["panda"],
    **unrolled_metrics_all_groups["chronos_zs"],
    **unrolled_metrics_all_groups["chronos_sft"],
}

In [None]:
def get_summary_metrics_dict(unrolled_metrics, metric_name):
    summary_metrics_dict = defaultdict(dict)
    for model_name, metrics_dict in unrolled_metrics.items():
        print(model_name)
        prediction_lengths = list(metrics_dict.keys())
        summary_metrics_dict[model_name]["prediction_lengths"] = prediction_lengths
        means = []
        medians = []
        stds = []
        for prediction_length in prediction_lengths:
            metric_val = metrics_dict[prediction_length][metric_name]
            means.append(np.nanmean(metric_val))
            medians.append(np.nanmedian(metric_val))
            stds.append(np.nanstd(metric_val))
        summary_metrics_dict[model_name]["means"] = means
        summary_metrics_dict[model_name]["medians"] = medians
        summary_metrics_dict[model_name]["stds"] = stds
    return summary_metrics_dict

In [None]:
def plot_metrics_by_prediction_length(metrics_dict, metric_name, show_std_envelope=False):
    plt.figure(figsize=(5, 4))
    for model_name, metrics in metrics_dict.items():
        plt.plot(
            metrics["prediction_lengths"],
            metrics["medians"],
            marker="o",
            label=model_name,
        )
        std_envelope = np.array(metrics["stds"])
        if show_std_envelope:
            plt.fill_between(
                metrics["prediction_lengths"],
                metrics["means"] - std_envelope,
                metrics["means"] + std_envelope,
                alpha=0.2,
            )
    plt.legend(loc="lower right")
    plt.xlabel("Prediction Length")
    plt.title(metric_name, fontweight="bold")

In [None]:
run_metrics_dirs_all_groups.keys()

In [None]:
metric_names_chosen = [
    # "mse",
    "mae",
    "smape",
    "spearman",
]

In [None]:
all_metrics_dict = defaultdict(dict)

for run_group in run_metrics_dirs_all_groups.keys():
    all_metrics_dict[run_group] = {
        metrics_name: get_summary_metrics_dict(unrolled_metrics_all_groups[run_group], metrics_name)
        for metrics_name in metric_names_chosen
    }

In [None]:
default_colors = plt.cm.tab10.colors

In [None]:
unrolled_metrics_all_combined.keys()

In [None]:
n_runs = len(unrolled_metrics_all_combined)
print(n_runs)

In [None]:
n_runs_panda = len(run_names_panda)
n_runs_chronos_zs = len(run_names_chronos_zs)
n_runs_chronos_sft = len(run_names_chronos_sft)

# panda_colors = [plt.cm.tab20.colors[6], plt.cm.tab20.colors[7]]
# chronos_zs_colors = [plt.cm.tab20.colors[8], plt.cm.tab20.colors[9]]
# chronos_sft_colors = [plt.cm.tab20.colors[0]]
panda_colors = [plt.cm.tab20.colors[6]]
chronos_zs_colors = [plt.cm.tab20.colors[8]]
chronos_sft_colors = [plt.cm.tab20.colors[0]]

bar_colors = panda_colors + chronos_sft_colors + chronos_zs_colors

In [None]:
selected_pred_length = 128

In [None]:
legend_handles = make_box_plot(
    unrolled_metrics=unrolled_metrics_all_combined,
    prediction_length=selected_pred_length,
    metric_to_plot="smape",  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=bar_colors,
    title=None,
    title_kwargs={"fontsize": 10},
    save_path=f"scaled_figs/smape_{selected_pred_length}.pdf",
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(40, 60),
    whisker_percentile_range=(25, 75),
    alpha_val=0.8,
    fig_kwargs={"figsize": (2.5, 5)},
    box_width=0.8,
    verbose=True,
)

In [None]:
# plt.figure(figsize=(4, 0.6))
plt.figure(figsize=(5, 1))

# Add the legend
plt.legend(
    handles=legend_handles,
    loc="center",
    frameon=True,
    ncol=5,
    framealpha=1.0,
    fontsize=12,
)
plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig("scaled_figs/ablations_legend_horizontal.pdf", bbox_inches="tight")
plt.show()
plt.close()

In [None]:
selected_pred_length = 512

In [None]:
make_box_plot(
    unrolled_metrics=unrolled_metrics_all_combined,
    prediction_length=selected_pred_length,
    metric_to_plot="spearman",  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=bar_colors,
    title=None,
    title_kwargs={"fontsize": 10},
    use_inv_spearman=True,
    order_by_metric="smape",
    save_path=f"scaled_figs/spearman_{selected_pred_length}.pdf",
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(40, 60),
    whisker_percentile_range=(25, 75),
    alpha_val=0.8,
    fig_kwargs={"figsize": (2.5, 5)},
    box_width=0.8,
    verbose=True,
)

In [None]:
metric_name = "mae"

make_box_plot(
    unrolled_metrics=unrolled_metrics_all_combined,
    prediction_length=selected_pred_length,
    metric_to_plot=metric_name,  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=bar_colors,
    title=None,
    title_kwargs={"fontsize": 10},
    use_inv_spearman=True,
    order_by_metric="smape",
    save_path=f"scaled_figs/{metric_name}_{selected_pred_length}.pdf",
    ylabel_fontsize=12,
    show_xlabel=False,
    show_legend=False,
    legend_kwargs={
        "loc": "upper left",
        "frameon": True,
        "ncol": 1,
        "framealpha": 0.8,
        # "prop": {"weight": "bold", "size": 5},
        "prop": {"size": 6.8},
    },
    box_percentile_range=(40, 60),
    whisker_percentile_range=(25, 75),
    alpha_val=0.8,
    fig_kwargs={"figsize": (2.5, 5)},
    box_width=0.8,
)

In [None]:
unrolled_metrics_all_groups.keys()

In [None]:
all_metrics_dict.keys()