In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from dystformer.utils import (
    apply_custom_style,
    make_box_plot,
    plot_all_metrics_by_prediction_length,
)

In [None]:
# Apply matplotlib style from config
apply_custom_style("../config/plotting.yaml")

In [None]:
figs_save_dir = os.path.join("../figures", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
data_split = "test_zeroshot"

run_metrics_dir_dict = {
    "Our Model": os.path.join(
        WORK_DIR,
        "eval_results",
        "patchtst",
        # "pft_stand_rff_only_pretrained-0",
        # "pft_chattn_noembed_pretrained_correct-0",
        "pft_chattn_emb_w_poly-0",
        # "pft_chattn_noembed_pretrained_correct-0",
        # "pft_linattnpolyemb_from_scratch-0",
        data_split,
    ),
    "Chronos 20M Finetune": os.path.join(
        WORK_DIR,
        "eval_results",
        "chronos",
        "chronos_bolt_mini-12",
        # "chronos_mini_ft-0",
        # "chronos_finetune_stand_updated-0",
        data_split,
    ),
    "Chronos 20M": os.path.join(
        WORK_DIR,
        "eval_results",
        "chronos",
        "chronos_mini_zeroshot",
        data_split,
    ),
    "Time MOE 50M": os.path.join(
        WORK_DIR,
        "eval_results",
        "timemoe",
        "timemoe-50m",
        data_split,
    ),
    "TimesFM 200M": os.path.join(
        WORK_DIR,
        "eval_results",
        "timesfm",
        "timesfm-200m",
        data_split,
    ),
    # "Mean": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "baselines",
    #     "mean",
    #     data_split,
    # ),
    # "Fourier": os.path.join(
    #     WORK_DIR,
    #     "eval_results",
    #     "baselines",
    #     "fourier",
    #     data_split,
    # ),
}

In [None]:
run_metrics_dir_dict.keys()

In [None]:
metrics_all_runs = defaultdict(dict)
for model_name, run_metrics_dir in run_metrics_dir_dict.items():
    print(model_name)
    if not os.path.exists(run_metrics_dir):
        print(f"Run metrics dir does not exist: {run_metrics_dir}")
        continue
    for file in sorted(
        os.listdir(run_metrics_dir),
        key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
    ):
        if file.endswith(".csv"):
            prediction_length = int(file.split("_pred")[1].split(".csv")[0])
            # print(prediction_length)
            with open(os.path.join(run_metrics_dir, file), "r") as f:
                metrics = pd.read_csv(f).to_dict()
                metrics_all_runs[model_name][prediction_length] = metrics

In [None]:
metrics_all_runs.keys()

In [None]:
metrics_all_runs["Our Model"][64].keys()

In [None]:
unrolled_metrics = defaultdict(dict)
for model_name, all_metrics_of_model in metrics_all_runs.items():
    print(model_name)
    for prediction_length, metrics in all_metrics_of_model.items():
        systems = metrics["system"]
        metrics_unrolled = {
            k: list(v.values()) for k, v in metrics.items() if k != "system"
        }
        # print(metrics_unrolled.keys())
        unrolled_metrics[model_name][prediction_length] = metrics_unrolled

In [None]:
unrolled_metrics["Our Model"].keys()

In [None]:
unrolled_metrics["Our Model"][64].keys()

In [None]:
len(unrolled_metrics["Our Model"][64]["smape"])

In [None]:
unrolled_metrics["Our Model"].keys()

In [None]:
n_runs = len(unrolled_metrics.keys())
print(n_runs)

In [None]:
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
default_colors = default_colors[: n_runs + 1]
default_colors = default_colors[:3] + default_colors[4:]
print(default_colors)

In [None]:
selected_metric = "smape"
legend_handles = make_box_plot(
    unrolled_metrics=unrolled_metrics,
    prediction_length=128,
    metric_to_plot=selected_metric,  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=default_colors,
    title=None,
    title_kwargs={"fontsize": 10},
    use_inv_spearman=True,
    order_by_metric="smape",
    save_path=f"baselines_figs/{selected_metric}_128.pdf",
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(25, 75),
    whisker_percentile_range=(5, 95),
    alpha_val=0.8,
    fig_kwargs={"figsize": (2, 4)},
    box_width=1.0,
)

In [None]:
plt.figure(figsize=(6, 1))

# Add the legend with the combined handles
legend = plt.legend(
    handles=legend_handles,
    loc="upper center",
    frameon=True,
    ncol=5,
    framealpha=1.0,
    fontsize=16,
)

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig(
    "baselines_figs/baselines_legend_horizontal_patches.pdf", bbox_inches="tight"
)
plt.show()
plt.close()

In [None]:
def get_summary_metrics_dict(
    unrolled_metrics: dict, metric_name: str
) -> dict[str, dict[str, list[float]]]:
    summary_metrics_dict = defaultdict(dict)
    for model_name, metrics_dict in unrolled_metrics.items():
        print(model_name)
        prediction_lengths = list(metrics_dict.keys())
        summary_metrics_dict[model_name]["prediction_lengths"] = prediction_lengths
        means = []
        medians = []
        stds = []
        for prediction_length in prediction_lengths:
            metric_val = metrics_dict[prediction_length][metric_name]
            means.append(np.nanmean(metric_val))
            medians.append(np.nanmedian(metric_val))
            stds.append(np.nanstd(metric_val))
        summary_metrics_dict[model_name]["means"] = means
        summary_metrics_dict[model_name]["medians"] = medians
        summary_metrics_dict[model_name]["stds"] = stds
    return summary_metrics_dict

In [None]:
smape_metrics_dict = get_summary_metrics_dict(unrolled_metrics, "smape")

In [None]:
smape_metrics_dict["Ours"].keys()

In [None]:
def plot_metrics_by_prediction_length(
    metrics_dict, metric_name, show_std_envelope=False
):
    plt.figure(figsize=(5, 4))
    for model_name, metrics in metrics_dict.items():
        plt.plot(
            metrics["prediction_lengths"],
            metrics["medians"],
            marker="o",
            label=model_name,
        )
        std_envelope = np.array(metrics["stds"])
        if show_std_envelope:
            plt.fill_between(
                metrics["prediction_lengths"],
                metrics["means"] - std_envelope,
                metrics["means"] + std_envelope,
                alpha=0.2,
            )
    plt.legend(loc="lower right")
    plt.xlabel("Prediction Length")
    plt.title(metric_name, fontweight="bold")

In [None]:
all_metrics_dict = {
    metrics_name: get_summary_metrics_dict(unrolled_metrics, metrics_name)
    for metrics_name in [
        "mse",
        "mae",
        "smape",
        "spearman",
    ]
}

Order model names by sMAPE

In [None]:
model_names_ordering = []  # sorted by median smape at 128
for model_name, data in all_metrics_dict["smape"].items():
    print(model_name)
    median_metrics_128 = data["medians"][1]
    model_names_ordering.append((model_name, median_metrics_128))
model_names_ordering = sorted(model_names_ordering, key=lambda x: x[1])
model_names_ordering = [x[0] for x in model_names_ordering]
print(model_names_ordering)

# Reorder all_metrics_dict according to model_names_ordering for each metric
reordered_metrics_dict = {}
for metric_name, metric_data in all_metrics_dict.items():
    reordered_metric_data = {}

    # Add models in the order specified by model_names_ordering
    for model_name in model_names_ordering:
        if model_name in metric_data:
            reordered_metric_data[model_name] = metric_data[model_name]
        else:
            raise ValueError(f"Model {model_name} not found in {metric_name}")

    reordered_metrics_dict[metric_name] = reordered_metric_data
all_metrics_dict = reordered_metrics_dict

In [None]:
all_metrics_dict.keys()

In [None]:
all_metrics_dict["mse"].keys()

In [None]:
legend_handles = plot_all_metrics_by_prediction_length(
    all_metrics_dict,
    ["mse", "mae", "smape", "spearman"],
    metrics_to_show_std_envelope=["smape", "pearson", "spearman"],
    n_cols=4,
    n_rows=1,
    save_path="baselines_figs/zeroshot_metrics_autoregressive_rollout_metrics.pdf",
    show_legend=False,
    legend_kwargs={"loc": "upper left", "frameon": True, "fontsize": 10},
    colors=default_colors,
    use_inv_spearman=True,
)

In [None]:
legend_handles

In [None]:
plt.figure(figsize=(6, 1))

# Add the legend with the combined handles
legend = plt.legend(
    handles=legend_handles,
    loc="upper center",
    frameon=True,
    ncol=5,
    framealpha=1.0,
    fontsize=16,
)

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig("baselines_figs/baselines_legend_horizontal.pdf", bbox_inches="tight")
plt.show()
plt.close()

In [None]:
plot_all_metrics_by_prediction_length(
    all_metrics_dict,
    ["smape"],
    metrics_to_show_std_envelope=["smape"],
    n_cols=1,
    n_rows=1,
    individual_figsize=(4, 4.5),
    ylim=(10, None),
    save_path="zeroshot_smape_autoregressive_rollout.pdf",
    legend_kwargs={"frameon": True, "fontsize": 10, "loc": "lower right"},
    colors=default_colors,
)

In [None]:
selected_metric = "spearman"
plot_all_metrics_by_prediction_length(
    all_metrics_dict,
    [selected_metric],
    metrics_to_show_std_envelope=["smape", "spearman"],
    n_cols=1,
    n_rows=1,
    individual_figsize=(4, 4),
    save_path=f"baselines_figs/baselines_{selected_metric}.pdf",
    show_legend=False,
    legend_kwargs={"frameon": True, "fontsize": 10, "loc": "lower right"},
    use_inv_spearman=True,
    colors=default_colors,
)

In [None]:
all_metrics_dict["smape"].items()