In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

from dystformer.utils import (
    apply_custom_style,
    make_box_plot,
    plot_all_metrics_by_prediction_length,
)

apply_custom_style("../config/plotting.yaml")

In [None]:
DEFAULT_COLORS = list(plt.rcParams["axes.prop_cycle"].by_key()["color"])
figs_save_dir = os.path.join("../figures", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
data_split = "test_zeroshot"

run_metrics_dir_dict = {
    "Panda": os.path.join(
        WORK_DIR,
        "eval_results",
        "patchtst",
        # "pft_stand_rff_only_pretrained-0",
        # "pft_chattn_noembed_pretrained_correct-0",
        "pft_chattn_emb_w_poly-0",
        # "pft_linattnpolyemb_from_scratch-0",
        data_split,
    ),
    "Chronos 20M SFT": os.path.join(
        WORK_DIR,
        "eval_results",
        "chronos",
        # "chronos_bolt_mini-12",
        "chronos_mini_ft-0",
        # "chronos_finetune_stand_updated-0",
        data_split,
    ),
    "Chronos 20M": os.path.join(
        WORK_DIR,
        "eval_results",
        "chronos",
        "chronos_mini_zeroshot",
        data_split,
    ),
    "Time MOE 50M": os.path.join(
        WORK_DIR,
        "eval_results",
        "timemoe",
        "timemoe-50m",
        data_split,
    ),
    "TimesFM 200M": os.path.join(
        WORK_DIR,
        "eval_results",
        "timesfm",
        "timesfm-200m",
        data_split,
    ),
    # "Mean": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "baselines",
    #     "mean",
    #     data_split,
    # ),
    # "Fourier": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "baselines",
    #     "fourier",
    #     data_split,
    # ),
}

In [None]:
metrics_all_runs = defaultdict(dict)
for model_name, run_metrics_dir in run_metrics_dir_dict.items():
    if not os.path.exists(run_metrics_dir):
        print(f"Run metrics dir does not exist: {run_metrics_dir}")
        continue
    for file in sorted(
        os.listdir(run_metrics_dir),
        key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
    ):
        if file.endswith(".csv"):
            prediction_length = int(file.split("_pred")[1].split(".csv")[0])
            with open(os.path.join(run_metrics_dir, file), "r") as f:
                metrics = pd.read_csv(f).to_dict()
                metrics_all_runs[model_name][prediction_length] = metrics

In [None]:
metrics_all_runs.keys()

In [None]:
unrolled_metrics = defaultdict(dict)
for model_name, all_metrics_of_model in metrics_all_runs.items():
    for prediction_length, metrics in all_metrics_of_model.items():
        systems = metrics["system"]
        metrics_unrolled = {
            k: list(v.values()) for k, v in metrics.items() if k != "system"
        }
        unrolled_metrics[model_name][prediction_length] = metrics_unrolled

In [None]:
n_runs = len(unrolled_metrics.keys())
print(n_runs)

In [None]:
default_colors = DEFAULT_COLORS[: n_runs + 1]
default_colors = default_colors[:3] + default_colors[4:]
print(default_colors)

In [None]:
selected_metric = "smape"
legend_handles = make_box_plot(
    unrolled_metrics=unrolled_metrics,
    prediction_length=128,
    metric_to_plot=selected_metric,  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=default_colors,
    title=None,
    title_kwargs={"fontsize": 10},
    use_inv_spearman=True,
    order_by_metric="smape",
    save_path=f"{figs_save_dir}/{selected_metric}_128.pdf",
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(25, 75),
    whisker_percentile_range=(5, 95),
    alpha_val=0.8,
    fig_kwargs={"figsize": (2, 4)},
    box_width=1.0,
)

In [None]:
plt.figure(figsize=(6, 1))

# Add the legend with the combined handles
legend = plt.legend(
    handles=legend_handles,
    loc="upper center",
    frameon=True,
    ncol=5,
    framealpha=1.0,
    fontsize=16,
)

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig(
    f"{figs_save_dir}/baselines_legend_horizontal_patches.pdf", bbox_inches="tight"
)
plt.show()
plt.close()

In [None]:
def get_summary_metrics_dict(
    unrolled_metrics: dict, metric_name: str
) -> tuple[dict[str, dict[str, list[float]]], dict[str, bool]]:
    """
    Get the summary metrics for a given metric name.

    Returns a tuple of the summary metrics dictionary and a boolean indicating whether there are NaNs in the metrics.
    """
    summary_metrics_dict = defaultdict(dict)
    has_nans = defaultdict(bool)
    for model_name, metrics_dict in unrolled_metrics.items():
        prediction_lengths = list(metrics_dict.keys())
        summary_metrics_dict[model_name]["prediction_lengths"] = prediction_lengths
        num_vals = len(metrics_dict[prediction_lengths[0]][metric_name])
        summary_metrics_dict[model_name]["num_vals"] = num_vals
        means = []
        medians = []
        stds = []
        stes = []
        all_vals = []
        for prediction_length in tqdm(
            prediction_lengths, desc=f"Computing {metric_name} for {model_name}"
        ):
            metric_vals = np.array(metrics_dict[prediction_length][metric_name])
            if np.isnan(metric_vals).any():
                has_nans[model_name] = True
            mean = np.nanmean(metric_vals)
            median = np.nanmedian(metric_vals)
            std = np.nanstd(metric_vals)
            ste = std / np.sqrt(len(metric_vals))

            means.append(mean)
            medians.append(median)
            stds.append(std)
            stes.append(ste)
            all_vals.append(metric_vals)

        if has_nans[model_name]:
            print(f"NaNs in {model_name} for {metric_name}")

        summary_metrics_dict[model_name]["means"] = means
        summary_metrics_dict[model_name]["medians"] = medians
        summary_metrics_dict[model_name]["stds"] = stds
        summary_metrics_dict[model_name]["stes"] = stes
        summary_metrics_dict[model_name]["all_vals"] = all_vals

    return summary_metrics_dict, has_nans

In [None]:
smape_metrics_dict, has_nans = get_summary_metrics_dict(unrolled_metrics, "smape")

In [None]:
metrics = ["mse", "mae", "smape"]
metrics_dicts, has_nans = zip(
    *[get_summary_metrics_dict(unrolled_metrics, metric) for metric in metrics]
)
all_metrics_dict = {m: metrics_dicts[i] for i, m in enumerate(metrics)}
has_nans_dict = {m: has_nans[i] for i, m in enumerate(metrics)}

In [None]:
has_nans_dict

Order model names by sMAPE

In [None]:
model_names_ordering = []  # sorted by median smape at 128
for model_name, data in all_metrics_dict["smape"].items():
    median_metrics_128 = data["medians"][1]
    model_names_ordering.append((model_name, median_metrics_128))
model_names_ordering = sorted(model_names_ordering, key=lambda x: x[1])
model_names_ordering = [x[0] for x in model_names_ordering]
print(model_names_ordering)

# Reorder all_metrics_dict according to model_names_ordering for each metric
reordered_metrics_dict = {}
for metric_name, metric_data in all_metrics_dict.items():
    reordered_metric_data = {}

    # Add models in the order specified by model_names_ordering
    for model_name in model_names_ordering:
        if model_name in metric_data:
            reordered_metric_data[model_name] = metric_data[model_name]
        else:
            raise ValueError(f"Model {model_name} not found in {metric_name}")

    reordered_metrics_dict[metric_name] = reordered_metric_data
all_metrics_dict = reordered_metrics_dict

In [None]:
legend_handles = plot_all_metrics_by_prediction_length(
    all_metrics_dict,
    ["mse", "mae", "smape"],
    metrics_to_show_envelope=["mae", "smape"],
    n_cols=4,
    n_rows=1,
    save_path=f"{figs_save_dir}/zeroshot_metrics_autoregressive_rollout_metrics.pdf",
    show_legend=False,
    legend_kwargs={"loc": "upper left", "frameon": True, "fontsize": 10},
    colors=default_colors,
    use_inv_spearman=True,
    percentile_range=(40, 60),
    has_nans=has_nans_dict,
)

In [None]:
plt.figure(figsize=(6, 1))

# Add the legend with the combined handles
legend = plt.legend(
    handles=legend_handles,
    loc="upper center",
    frameon=True,
    ncol=5,
    framealpha=1.0,
    fontsize=16,
)

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig(f"{figs_save_dir}/baselines_legend_horizontal.pdf", bbox_inches="tight")
plt.show()
plt.close()

In [None]:
plot_all_metrics_by_prediction_length(
    all_metrics_dict,
    ["smape"],
    metrics_to_show_envelope=["smape"],
    n_cols=1,
    n_rows=1,
    individual_figsize=(4, 4.5),
    ylim=(10, None),
    save_path=f"{figs_save_dir}/zeroshot_smape_autoregressive_rollout_metrics.pdf",
    legend_kwargs={"frameon": True, "fontsize": 10, "loc": "lower right"},
    colors=default_colors,
    percentile_range=(35, 65),
    has_nans=has_nans_dict,
)