In [None]:
import os
from collections import defaultdict
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from dystformer.utils import apply_custom_style, make_box_plot

In [None]:
apply_custom_style("../config/plotting.yaml")

In [None]:
figs_save_dir = os.path.join("../figs", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
scalinglaw_splits = [2**i for i in range(0, 8)]

In [None]:
scalinglaw_splits

In [None]:
# data_split = "final_skew40/test_zeroshot"
data_split = "test_zeroshot"

run_names_chattn = {
    "ic1": "pft_chattn_noembed_pretrained_correct-0",
    "ic2": "pft_chattn_mlm_sys10490_ic2-0",
    "ic4": "pft_chattn_mlm_sys5245_ic4-0",
    "ic8": "pft_chattn_mlm_sys2623_ic8-0",
    "ic16": "pft_chattn_mlm_sys1312_ic16-0",
    "ic32": "pft_chattn_mlm_sys656_ic32-0",
    "ic64": "pft_chattn_mlm_sys328_ic64-0",
    "ic128": "pft_chattn_mlm_sys164_ic128-0",
}

run_metrics_dirs_all_groups = {
    "chattn": {
        run_abbrv: os.path.join(
            WORK_DIR,
            "eval_results",
            "patchtst",
            run_name,
            data_split,
        )
        for run_abbrv, run_name in run_names_chattn.items()
    },
}

In [None]:
run_metrics_dirs_all_groups

In [None]:
metrics_all = defaultdict(lambda: defaultdict(dict))
for run_group, run_metrics_dir_dict in run_metrics_dirs_all_groups.items():
    print(f"Run group: {run_group}")
    for run_abbrv, run_metrics_dir in run_metrics_dir_dict.items():
        if not os.path.exists(run_metrics_dir):
            print(
                f"Run metrics directory does not exist for {run_abbrv}: {run_metrics_dir}"
            )
            continue
        run_abbrv = str(run_abbrv)
        print(f"{run_abbrv}: {run_metrics_dir}")
        for file in sorted(
            os.listdir(run_metrics_dir),
            key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
        ):
            if file.endswith(".csv"):
                prediction_length = int(file.split("_pred")[1].split(".csv")[0])
                # print(f"Prediction length: {prediction_length} for {run_abbrv}")
                with open(os.path.join(run_metrics_dir, file), "r") as f:
                    metrics = pd.read_csv(f).to_dict()
                    metrics_all[run_group][run_abbrv][prediction_length] = metrics

In [None]:
metrics_all.keys()

In [None]:
unrolled_metrics_all_groups = defaultdict(lambda: defaultdict(dict))
for run_group, all_metrics_of_run_group in metrics_all.items():
    # print(run_group)
    for run_abbrv, all_metrics_of_run_abbrv in all_metrics_of_run_group.items():
        # print(run_abbrv)
        for run_name, metrics in all_metrics_of_run_abbrv.items():
            # print(run_name)
            systems = metrics.pop("system")
            metrics_unrolled = {k: list(v.values()) for k, v in metrics.items()}
            # print(metrics_unrolled.keys())
            unrolled_metrics_all_groups[run_group][run_abbrv][run_name] = (
                metrics_unrolled
            )

In [None]:
unrolled_metrics_all_combined = {
    **unrolled_metrics_all_groups["chattn"],
}

In [None]:
def get_summary_metrics_dict(unrolled_metrics, metric_name):
    summary_metrics_dict = defaultdict(dict)
    for model_name, metrics_dict in unrolled_metrics.items():
        prediction_lengths = list(metrics_dict.keys())
        summary_metrics_dict[model_name]["prediction_lengths"] = prediction_lengths
        means = []
        medians = []
        stds = []
        for prediction_length in prediction_lengths:
            metric_val = metrics_dict[prediction_length][metric_name]
            means.append(np.nanmean(metric_val))
            medians.append(np.nanmedian(metric_val))
            stds.append(np.nanstd(metric_val))
        summary_metrics_dict[model_name]["means"] = means
        summary_metrics_dict[model_name]["medians"] = medians
        summary_metrics_dict[model_name]["stds"] = stds
    return summary_metrics_dict

In [None]:
def plot_metrics_by_prediction_length(
    metrics_dict, metric_name, show_std_envelope=False
):
    plt.figure(figsize=(5, 4))
    for model_name, metrics in metrics_dict.items():
        plt.plot(
            metrics["prediction_lengths"],
            metrics["medians"],
            marker="o",
            label=model_name,
        )
        std_envelope = np.array(metrics["stds"])
        if show_std_envelope:
            plt.fill_between(
                metrics["prediction_lengths"],
                metrics["means"] - std_envelope,
                metrics["means"] + std_envelope,
                alpha=0.2,
            )
    plt.legend(loc="lower right")
    plt.xlabel("Prediction Length")
    plt.title(metric_name, fontweight="bold")

In [None]:
run_metrics_dirs_all_groups.keys()

In [None]:
metric_names_chosen = [
    "mse",
    "mae",
    "smape",
    "spearman",
]

In [None]:
all_metrics_dict = defaultdict(dict)

for run_group in run_metrics_dirs_all_groups.keys():
    all_metrics_dict[run_group] = {
        metrics_name: get_summary_metrics_dict(
            unrolled_metrics_all_groups[run_group], metrics_name
        )
        for metrics_name in metric_names_chosen
    }

In [None]:
default_colors = plt.cm.tab10.colors

In [None]:
all_metrics_dict["chattn"]["mse"].keys()

In [None]:
unrolled_metrics_all_groups["chattn"].keys()

In [None]:
unrolled_metrics_all_combined.keys()

In [None]:
n_runs = len(unrolled_metrics_all_combined)
print(n_runs)

In [None]:
bar_colors = plt.cm.Blues(np.linspace(1.0, 0.1, n_runs)).tolist()
print(len(bar_colors))

In [None]:
selected_pred_length = 512

In [None]:
unrolled_metrics_all_combined.keys()

In [None]:
unrolled_metrics_all_combined["ic2"].keys()

In [None]:
unrolled_metrics_all_combined["ic2"][128].keys()

In [None]:
ic_to_n_systems = {
    "ic1": 20979,
    "ic2": 10490,
    "ic4": 5245,
    "ic8": 2623,
    "ic16": 1312,
    "ic32": 656,
    "ic64": 328,
    "ic128": 164,
}

In [None]:
def make_scaling_plot(
    unrolled_metrics: dict,
    prediction_length: int = 128,
    metric_to_plot: str = "smape",
    stat_to_plot: Literal["median", "mean"] = "median",
    colormap: str = "Blues",
    legend_kwargs: dict = {},
    figsize: tuple = (4, 4),
    save_path: str | None = None,
    use_inv_spearman: bool = True,
    show_legend: bool = True,
    title: str | None = None,
) -> None:
    if metric_to_plot == "smape":
        metric_to_plot_title = "sMAPE"
    elif metric_to_plot == "spearman" and use_inv_spearman:
        metric_to_plot_title = "1 - Spearman"
    else:
        metric_to_plot_title = metric_to_plot.upper()

    metric_at_predlength = defaultdict(list)
    for ic_split, metrics_by_predlength_dict in unrolled_metrics.items():
        n_systems = ic_to_n_systems[ic_split]
        metric_at_predlength[n_systems] = metrics_by_predlength_dict[prediction_length][
            metric_to_plot
        ]
    # sort metric_at_predlength by n_systems
    metric_at_predlength = dict(sorted(metric_at_predlength.items()))
    # make line plot of medians of metric_at_predlength
    colors = plt.cm.get_cmap(colormap)(np.linspace(0, 1.0, len(metric_at_predlength)))
    plt.figure(figsize=figsize)
    for i, (n_systems, metric_vals) in enumerate(metric_at_predlength.items()):
        metric_vals = np.array(metric_vals)
        # get rid of nan values
        metric_vals = metric_vals[~np.isnan(metric_vals)]
        if metric_to_plot == "spearman" and use_inv_spearman:
            metric_vals = 1 - metric_vals

        if stat_to_plot == "median":
            median_vals = np.median(metric_vals)
            # Create a custom boxplot similar to make_box_plot function
            box_percentile_range = (40, 60)
            whisker_percentile_range = (25, 75)
            box_width = 0.5 * n_systems  # NOTE: this assumes x-axis is log scale
            alpha_val = 0.8

            # Calculate the percentiles
            lower_box, upper_box = np.percentile(metric_vals, box_percentile_range)
            lower_whisker, upper_whisker = np.percentile(
                metric_vals, whisker_percentile_range
            )

            # Box width and spacing parameters
            box_half_width = box_width / 2
            whisker_cap_width = box_half_width * 0.5
            # Box
            box = plt.Rectangle(
                (n_systems - box_half_width, lower_box),
                box_width,
                upper_box - lower_box,
                fill=True,
                facecolor=colors[i],
                alpha=alpha_val,
                linewidth=1,
                edgecolor="black",
                zorder=5,
                label=rf"$N_{{sys}}={n_systems}$",
            )
            plt.gca().add_patch(box)

            # Median line
            plt.hlines(
                median_vals,
                n_systems - box_half_width,
                n_systems + box_half_width,
                colors="black",
                linewidth=2.5,
                zorder=10,
            )

            # Whiskers
            plt.vlines(
                n_systems,
                lower_box,
                lower_whisker,
                colors="black",
                linestyle="-",
                linewidth=1,
                zorder=5,
            )
            plt.vlines(
                n_systems,
                upper_box,
                upper_whisker,
                colors="black",
                linestyle="-",
                linewidth=1,
                zorder=5,
            )

            # Caps on whiskers
            plt.hlines(
                lower_whisker,
                n_systems - whisker_cap_width,
                n_systems + whisker_cap_width,
                colors="black",
                linewidth=1,
                zorder=5,
            )
            plt.hlines(
                upper_whisker,
                n_systems - whisker_cap_width,
                n_systems + whisker_cap_width,
                colors="black",
                linewidth=1,
                zorder=5,
            )
        elif stat_to_plot == "mean":
            mean_vals = np.mean(metric_vals)
            std_vals = np.std(metric_vals)
            ste_vals = std_vals / np.sqrt(len(metric_vals))

            plt.scatter(
                n_systems,
                mean_vals,
                s=36,  # equivalent to markersize=6 squared
                edgecolors="black",
                linewidths=0.2,
                label=rf"$N_{{sys}}={n_systems}$",
                color=colors[i],
            )
            plt.errorbar(
                n_systems,
                mean_vals,
                yerr=ste_vals,
                fmt="none",
                color=colors[i],
                capsize=5,  # Add T-shaped caps to the error bars
            )
        else:
            raise ValueError(f"Invalid stat_to_plot: {stat_to_plot}")
    if show_legend:
        plt.legend(**legend_kwargs)
    if title is not None:
        plt.title(title, fontweight="bold")
    plt.xlabel("Number of Systems", fontweight="bold")
    plt.ylabel(metric_to_plot_title, fontweight="bold")
    plt.xscale("log", base=2)
    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

In [None]:
metric_to_plot = "smape"
prediction_length = 512
stat_to_plot = "median"
make_scaling_plot(
    unrolled_metrics_all_combined,
    metric_to_plot=metric_to_plot,
    stat_to_plot=stat_to_plot,
    prediction_length=prediction_length,
    colormap="cividis_r",
    show_legend=False,
    title=rf"$L_{{pred}}={prediction_length}$",
    legend_kwargs={"loc": "upper right", "frameon": True, "ncol": 1, "fontsize": 8},
    save_path=f"scalinglaw_figs/{metric_to_plot}_{prediction_length}_{stat_to_plot}.pdf",
)

In [None]:
metric_to_plot = "smape"
prediction_length = 512
stat_to_plot = "mean"
make_scaling_plot(
    unrolled_metrics_all_combined,
    metric_to_plot=metric_to_plot,
    stat_to_plot=stat_to_plot,
    prediction_length=prediction_length,
    colormap="cividis_r",
    show_legend=False,
    title=rf"$L_{{pred}}={prediction_length}$",
    legend_kwargs={"loc": "upper right", "frameon": True, "ncol": 1, "fontsize": 8},
    save_path=f"scalinglaw_figs/{metric_to_plot}_{prediction_length}_{stat_to_plot}.pdf",
)

In [None]:
def make_scaling_plot_v2(
    unrolled_metrics: dict,
    prediction_lengths: list[int] = [128, 256, 512],
    metric_to_plot: str = "smape",
    colormap: str = "Blues",
    legend_kwargs: dict = {},
    figsize: tuple = (4, 4),
    save_path: str | None = None,
    use_inv_spearman: bool = True,
    show_legend: bool = True,
    ylim: tuple | None = None,
) -> None:
    if metric_to_plot == "smape":
        metric_to_plot_title = "sMAPE"
    elif metric_to_plot == "spearman" and use_inv_spearman:
        metric_to_plot_title = "1 - Spearman"
    else:
        metric_to_plot_title = metric_to_plot.upper()

    mean_vals_dict = defaultdict(lambda: defaultdict(list))
    std_vals_dict = defaultdict(lambda: defaultdict(list))
    ste_vals_dict = defaultdict(lambda: defaultdict(list))
    for ic_split, metrics_by_predlength_dict in unrolled_metrics.items():
        n_systems = int(ic_to_n_systems[ic_split])
        for prediction_length in prediction_lengths:
            metric_vals = metrics_by_predlength_dict[prediction_length][metric_to_plot]
            # get rid of nan values
            # Handle case where metric_vals might not be a numpy array
            if isinstance(metric_vals, (list, tuple)):
                metric_vals = np.array(metric_vals)
            # Filter out NaN values
            if len(metric_vals) > 0:
                mask = ~np.isnan(metric_vals)
                metric_vals = metric_vals[mask]
            if metric_to_plot == "spearman" and use_inv_spearman:
                metric_vals = 1 - metric_vals
            mean_vals_dict[prediction_length][n_systems] = np.nanmean(metric_vals)
            std_vals_dict[prediction_length][n_systems] = np.nanstd(metric_vals)
            ste_vals_dict[prediction_length][n_systems] = std_vals_dict[
                prediction_length
            ][n_systems] / np.sqrt(len(metric_vals))
    # sort metric_at_predlength by n_systems
    mean_vals_dict = dict(sorted(mean_vals_dict.items()))
    std_vals_dict = dict(sorted(std_vals_dict.items()))
    ste_vals_dict = dict(sorted(ste_vals_dict.items()))
    # make line plot of medians of metric_at_predlength
    colors = plt.cm.get_cmap(colormap)(np.linspace(0, 0.8, len(mean_vals_dict)))
    plt.figure(figsize=figsize)
    for i, (prediction_length, metrics_dict_by_n_systems) in enumerate(
        mean_vals_dict.items()
    ):
        n_systems = list(metrics_dict_by_n_systems.keys())
        mean_vals = np.array(list(metrics_dict_by_n_systems.values()))
        ste_vals = np.array(list(ste_vals_dict[prediction_length].values()))
        plt.plot(
            n_systems,
            mean_vals,
            marker="o",
            linestyle="-",
            label=rf"$L_{{pred}}={prediction_length}$",
            color=colors[i],
            alpha=0.8,
        )
        plt.fill_between(
            n_systems,
            mean_vals - ste_vals,
            mean_vals + ste_vals,
            alpha=0.2,
            color=colors[i],
        )
    if show_legend:
        plt.legend(**legend_kwargs)
    plt.xlabel("Number of Systems", fontweight="bold")
    plt.ylabel(metric_to_plot_title, fontweight="bold")
    plt.xscale("log", base=2)
    if ylim is not None:
        plt.ylim(ylim)
    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

In [None]:
all_pred_lengths = list(unrolled_metrics_all_combined["ic2"].keys())
print(all_pred_lengths)

In [None]:
metric_to_plot = "smape"
prediction_lengths = [128, 256, 512]
stat_to_plot = "mean"
make_scaling_plot_v2(
    unrolled_metrics_all_combined,
    metric_to_plot=metric_to_plot,
    prediction_lengths=all_pred_lengths,
    colormap="cividis",
    show_legend=True,
    ylim=(18, None),
    legend_kwargs={"loc": "lower center", "frameon": True, "ncol": 4, "fontsize": 5},
    # save_path=f"scalinglaw_figs/{metric_to_plot}_combined.pdf",
)

In [None]:
legend_handles = make_box_plot(
    unrolled_metrics=unrolled_metrics_all_combined,
    prediction_length=selected_pred_length,
    metric_to_plot="smape",  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=bar_colors,
    title=None,
    title_kwargs={"fontsize": 10},
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(40, 60),
    whisker_percentile_range=(25, 75),
    alpha_val=0.8,
    show_legend=True,
    legend_kwargs={"loc": "lower right", "frameon": True, "ncol": 1, "framealpha": 1.0},
    save_path="scalinglaw_figs/smape_128.pdf",
)

In [None]:
plt.figure(figsize=(4, 0.6))
# Add the legend
plt.legend(
    handles=legend_handles,
    loc="center",
    frameon=True,
    ncol=3,
    framealpha=1.0,
)
plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
# plt.savefig("ablations_figs/ablations_legend.pdf", bbox_inches="tight")
plt.show()
plt.close()