In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
if os.path.exists("../custom_style.mplstyle"):
    plt.style.use(["ggplot", "../custom_style.mplstyle"])

In [None]:
figs_save_dir = os.path.join("../figures", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
data_split = "final_skew40/test_zeroshot"

run_metrics_dir_dict = {
    "Our Model": os.path.join(
        WORK_DIR,
        "eval_results",
        "patchtst",
        "pft_stand_rff_only_pretrained-0",
        # "pft_chattn_emb_w_poly-0"
        # "pft_chattn_noembed_pretrained_correct-0",
        data_split,
    ),
    "Chronos 20M Finetune": os.path.join(
        WORK_DIR,
        "eval_results",
        "chronos",
        "chronos_bolt_mini-12",
        # "chronos_mini_ft-0",
        # "chronos_finetune_stand_updated-0",
        data_split,
    ),
    "Chronos 20M": os.path.join(
        WORK_DIR,
        "eval_results",
        "chronos",
        "chronos_mini_zeroshot",
        data_split,
    ),
    "Time MOE 50M": os.path.join(
        WORK_DIR,
        "eval_results",
        "timemoe",
        "timemoe-50m",
        data_split,
    ),
    "TimesFM 200M": os.path.join(
        WORK_DIR,
        "eval_results",
        "timesfm",
        "timesfm-200m",
        data_split,
    ),
    "Mean": os.path.join(
        WORK_DIR,
        "eval_results",
        "baselines",
        "mean",
        data_split,
    ),
    "Fourier": os.path.join(
        WORK_DIR,
        "eval_results",
        "baselines",
        "fourier",
        data_split,
    ),
}

In [None]:
run_metrics_dir_dict.keys()

In [None]:
metrics_all_runs = defaultdict(dict)
for model_name, run_metrics_dir in run_metrics_dir_dict.items():
    print(model_name)
    for file in sorted(
        os.listdir(run_metrics_dir),
        key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
    ):
        if file.endswith(".csv"):
            prediction_length = int(file.split("_pred")[1].split(".csv")[0])
            # print(prediction_length)
            with open(os.path.join(run_metrics_dir, file), "r") as f:
                metrics = pd.read_csv(f).to_dict()
                metrics_all_runs[model_name][prediction_length] = metrics

In [None]:
metrics_all_runs.keys()

In [None]:
metrics_all_runs["Our Model"][64].keys()

In [None]:
unrolled_metrics = defaultdict(dict)
for model_name, all_metrics_of_model in metrics_all_runs.items():
    print(model_name)
    for prediction_length, metrics in all_metrics_of_model.items():
        systems = metrics["system"]
        metrics_unrolled = {
            k: list(v.values()) for k, v in metrics.items() if k != "system"
        }
        print(metrics_unrolled.keys())
        unrolled_metrics[model_name][prediction_length] = metrics_unrolled

In [None]:
np.median(unrolled_metrics["Our Model"][128]["r2_score"])

In [None]:
unrolled_metrics["Our Model"].keys()

In [None]:
unrolled_metrics["Our Model"][64].keys()

In [None]:
len(unrolled_metrics["Our Model"][64]["smape"])

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


def make_box_plot(
    unrolled_metrics: dict[str, dict[int, dict[str, list[float]]]],
    prediction_length: int,
    run_names: list[str],
    ylim=(1e-5, 1e5),
    verbose: bool = False,
    metrics_to_exclude: list[str] = [],
    use_rescaled_smape: bool = False,
    run_names_to_exclude: list[str] = [],
    title: str | None = None,
    title_kwargs: dict = {},
    figsize: tuple[int, int] = (6, 4),
    legend_kwargs: dict = {},
    show_xticks: bool = True,
    ylabel: str = "",
    yticks_kwargs: dict = {},
    box_saturation: float = 0.5,
    whisker_percentile_range: tuple[int, int] = (0, 90),
    show_legend: bool = True,
    save_path: str | None = None,
):
    # Extract metrics data for the given prediction_length and run_names
    metrics_by_run_name = {
        run_name: unrolled_metrics[run_name][prediction_length]
        for run_name in run_names
    }
    metric_names = list(metrics_by_run_name[run_names[0]].keys())
    metric_names = [name for name in metric_names if name not in metrics_to_exclude]
    run_names = [name for name in run_names if name not in run_names_to_exclude]

    metric_names_title = []
    for name in metric_names:
        # Create pretty titles for x-axis tick labels
        if name in ["mse", "mae", "rmse", "mape"]:
            name = name.upper()
        elif name == "smape":
            name = "sMAPE"
        else:
            name = name.capitalize()
        metric_names_title.append(name)

    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.figure(figsize=figsize)

    plot_data = []
    median_data = []
    for i, metric_name in enumerate(metric_names):
        metric_title = metric_names_title[i]
        for run_name in run_names:
            values = metrics_by_run_name[run_name][metric_name]
            if metric_name == "smape" and use_rescaled_smape:
                values = [x / 100 for x in values]
            median_value = np.nanmedian(values)
            plot_data.extend([(metric_title, v, run_name) for v in values])
            median_data.append((metric_title, median_value, run_name))
            if verbose:
                print(f"{metric_title} - {run_name} median: {median_value}")

    # Create DataFrame for use with seaborn
    df = pd.DataFrame(plot_data, columns=["Metric", "Value", "Run"])

    # Plot box plot
    ax = sns.boxplot(
        data=df,
        x="Metric",
        y="Value",
        hue="Run",
        dodge=True,
        width=0.8,
        fliersize=0,  # Don't show outlier points
        palette=colors[: len(run_names)],
        saturation=box_saturation,
        whis=whisker_percentile_range,  # type: ignore
    )

    # Get the center x-coordinate for each metric category
    x_positions = {
        label.get_text(): pos for pos, label in enumerate(ax.get_xticklabels())
    }

    # Assume each category uses a total width of 0.8 (default in many seaborn categorical plots)
    dodge_width = 0.8
    n = len(run_names)
    # The space allocated per hue (run) within the category:
    group_width = dodge_width / n
    # Let the median line be slightly wider than the group width
    median_line_width = group_width * 0.8

    # Draw horizontal lines for medians centered exactly on the corresponding dodge positions
    for metric, median_value, run in median_data:
        x_center = x_positions[metric]
        run_index = run_names.index(run)
        # Compute the dodge offset so that the hues are evenly spaced
        offset = (run_index - (n - 1) / 2) * group_width
        x_line_center = x_center + offset
        plt.hlines(
            y=median_value,
            xmin=x_line_center - median_line_width / 2,
            xmax=x_line_center + median_line_width / 2,
            color="black",
            linewidth=2,
            zorder=3,
        )

    # plt.yscale("log")
    plt.ylim(ylim)
    plt.ylabel(ylabel, fontweight="bold")
    plt.yticks(**yticks_kwargs)
    if show_legend:
        # don't show run either for boxplot
        plt.legend(**legend_kwargs)
    else:
        ax.legend().remove()
    if title is not None:
        plt.title(title, **title_kwargs)
    if show_xticks:
        plt.xticks(
            rotation=15, fontweight="bold"
        )  # Optional: rotates x-tick labels for readability
    else:
        plt.xticks([])
    plt.xlabel("")
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

In [None]:
# max_smape_ours = np.max(unrolled_metrics["Our Model"][128]["smape"])
# max_smape_chronos = np.max(unrolled_metrics["Chronos Finetune"][128]["smape"])
# max_smape_val = max(max_smape_ours, max_smape_chronos)
# print(max_smape_val)

In [None]:
# median_r2score_ours = np.median(unrolled_metrics["Our Model"][128]["r2_score"])
# median_r2score_chronos = np.median(
#     unrolled_metrics["Chronos Finetune"][128]["r2_score"]
# )
# print(median_r2score_ours, median_r2score_chronos)

In [None]:
plt.rcParams["path.simplify_threshold"] = 0.0

In [None]:
unrolled_metrics["Fourier"].keys()

In [None]:
make_box_plot(
    unrolled_metrics,
    128,
    run_names=[
        "Our Model",
        "Chronos 20M Finetune",
        "Chronos 20M",
        "Time MOE 50M",
        "TimesFM 200M",
        "Mean",
        "Fourier",
    ],
    ylim=(-0.1, 2),
    metrics_to_exclude=["r2_score"],
    run_names_to_exclude=["pft_stand_pretrained_vanilla"],
    use_rescaled_smape=True,
    title="Metrics for Zeroshot Systems",
    legend_kwargs={"frameon": True, "fontsize": 7},
    title_kwargs={"fontweight": "bold"},
    box_saturation=0.8,
    save_path="zeroshot_metrics_strip_bar_plot.pdf",
)

In [None]:
make_box_plot(
    unrolled_metrics,
    128,
    run_names=[
        "Our Model",
        "Chronos 20M Finetune",
        "Chronos 20M",
        "Time MOE 50M",
        "TimesFM 200M",
    ],
    ylim=(0, 140),
    metrics_to_exclude=["r2_score", "spearman", "mae", "mse"],
    run_names_to_exclude=["pft_stand_pretrained_vanilla"],
    use_rescaled_smape=False,
    title=None,
    figsize=(2, 3.5),
    show_legend=False,
    legend_kwargs={"frameon": True, "fontsize": 6, "loc": "upper left"},
    ylabel="sMAPE",
    yticks_kwargs={"fontsize": 5},
    show_xticks=False,
    # title_kwargs={"fontweight": "bold", "fontsize": 9, "y": 1.05},
    box_saturation=0.8,
    whisker_percentile_range=(1, 91),
    save_path="zeroshot_smape_metrics_comparison.pdf",
)

In [None]:
def get_summary_metrics_dict(
    unrolled_metrics: dict, metric_name: str
) -> dict[str, dict[str, list[float]]]:
    summary_metrics_dict = defaultdict(dict)
    for model_name, metrics_dict in unrolled_metrics.items():
        print(model_name)
        prediction_lengths = list(metrics_dict.keys())
        summary_metrics_dict[model_name]["prediction_lengths"] = prediction_lengths
        means = []
        medians = []
        stds = []
        for prediction_length in prediction_lengths:
            metric_val = metrics_dict[prediction_length][metric_name]
            means.append(np.nanmean(metric_val))
            medians.append(np.nanmedian(metric_val))
            stds.append(np.nanstd(metric_val))
        summary_metrics_dict[model_name]["means"] = means
        summary_metrics_dict[model_name]["medians"] = medians
        summary_metrics_dict[model_name]["stds"] = stds
    return summary_metrics_dict

In [None]:
smape_metrics_dict = get_summary_metrics_dict(unrolled_metrics, "smape")

In [None]:
smape_metrics_dict["Ours"].keys()

In [None]:
def plot_metrics_by_prediction_length(
    metrics_dict, metric_name, show_std_envelope=False
):
    plt.figure(figsize=(5, 4))
    for model_name, metrics in metrics_dict.items():
        plt.plot(
            metrics["prediction_lengths"],
            metrics["medians"],
            marker="o",
            label=model_name,
        )
        std_envelope = np.array(metrics["stds"])
        if show_std_envelope:
            plt.fill_between(
                metrics["prediction_lengths"],
                metrics["means"] - std_envelope,
                metrics["means"] + std_envelope,
                alpha=0.2,
            )
    plt.legend(loc="lower right")
    plt.xlabel("Prediction Length")
    plt.title(metric_name, fontweight="bold")

In [None]:
# plot_metrics_by_prediction_length(smape_metrics_dict, "sMAPE", show_std_envelope=True)

In [None]:
all_metrics_dict = {
    metrics_name: get_summary_metrics_dict(unrolled_metrics, metrics_name)
    for metrics_name in [
        "mse",
        "mae",
        "smape",
        "r2_score",
        "spearman",
    ]
}

In [None]:
all_metrics_dict["spearman"]["Our Model"]

In [None]:
# plot_metrics_by_prediction_length(
#     all_metrics_dict["spearman"], "Spearman", show_std_envelope=False
# )

In [None]:
all_metrics_dict.keys()

In [None]:
all_metrics_dict["mse"].keys()

In [None]:
markers = ["o", "s", "v", "D", "X"]


def plot_all_metrics_by_prediction_length(
    all_metrics_dict: dict[str, dict[str, dict[str, list[float]]]],
    metric_names: list[str],
    metrics_to_show_std_envelope: list[str],
    n_rows: int = 2,
    n_cols: int = 3,
    individual_figsize: tuple[int, int] = (4, 4),
    save_path: str | None = None,
    ylim: tuple[float | None, float | None] = (None, None),
    legend_kwargs: dict = {},
):
    num_metrics = len(metric_names)
    fig, axes = plt.subplots(
        nrows=n_rows,
        ncols=n_cols,
        figsize=(individual_figsize[0] * n_cols, individual_figsize[1] * n_rows),
    )
    # Handle the case where axes might be a single element or already a list
    if n_rows == 1 and n_cols == 1:
        axes = [axes]
    elif hasattr(axes, "flatten"):  # Check if axes has flatten method
        axes = axes.flatten()  # Flatten the axes array for easy iteration
    for i, (ax, metric_name) in enumerate(zip(axes, metric_names)):
        metrics_dict = all_metrics_dict[metric_name]
        for j, (model_name, metrics) in enumerate(metrics_dict.items()):
            # print(model_name)
            # print(metrics.keys())
            ax.plot(
                metrics["prediction_lengths"],
                metrics["medians"],
                marker=markers[j],
                label=model_name,
                markersize=6,
                # alpha=0.8,
            )
            # std_envelope = np.array(metrics["stds"])
            se_envelope = np.array(metrics["stds"]) / np.sqrt(len(metrics["stds"]))
            if metric_name in metrics_to_show_std_envelope:
                ax.fill_between(
                    metrics["prediction_lengths"],
                    metrics["means"] - se_envelope,
                    metrics["means"] + se_envelope,
                    alpha=0.1,
                )
        if i == 0:
            ax.legend(**legend_kwargs)
        ax.set_xlabel("Prediction Length", fontweight="bold", fontsize=12)
        ax.set_xticks(metrics["prediction_lengths"])
        name = metric_name.replace("_", " ")
        if name in ["mse", "mae", "rmse", "mape"]:
            name = name.upper()
        elif name == "smape":
            name = "sMAPE"
        else:
            name = name.capitalize()
        ax.set_title(name, fontweight="bold", fontsize=16)

    # Hide any unused subplots
    for ax in axes[num_metrics:]:
        ax.set_visible(False)
    if ylim is not None:
        for ax in axes:
            ax.set_ylim(ylim)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

In [None]:
plot_all_metrics_by_prediction_length(
    all_metrics_dict,
    ["mse", "mae", "smape", "spearman"],
    metrics_to_show_std_envelope=["smape", "pearson", "spearman"],
    n_cols=4,
    n_rows=1,
    save_path="zeroshot_metrics_autoregressive_rollout_metrics.pdf",
    legend_kwargs={"loc": "upper left", "frameon": True, "fontsize": 10},
)

In [None]:
plot_all_metrics_by_prediction_length(
    all_metrics_dict,
    ["smape"],
    metrics_to_show_std_envelope=["smape"],
    n_cols=1,
    n_rows=1,
    individual_figsize=(4, 4.5),
    ylim=(20, None),
    save_path="zeroshot_smape_autoregressive_rollout.pdf",
    legend_kwargs={"frameon": True, "fontsize": 10, "loc": "lower right"},
)

In [None]:
# # make bar plot of mse and smape for each model at prediction length 128
# # make bar plot of mse and smape for each model at prediction length 128
# # smape_metrics_dict = get_summary_metrics_dict(unrolled_metrics, "smape")

# colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
# plt.figure(figsize=(5, 3))
# plt.bar(
#     ["Chronos Finetune", "Ours"],
#     [
#         np.median(unrolled_metrics["Chronos 20M Finetune"][128]["smape"]),
#         np.median(unrolled_metrics["Our Model"][128]["smape"]),
#     ],
#     color=[colors[0], colors[1]],
# )
# plt.show()