In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
if os.path.exists("../custom_style.mplstyle"):
    plt.style.use(["ggplot", "../custom_style.mplstyle"])

In [None]:
figs_save_dir = os.path.join("../figs", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
data_split = "final_skew40/test_zeroshot"

run_names_chattn = {
    "mlm + chattn + embedding": "pft_stand_rff_only_pretrained",
    "mlm + chattn": "pft_chattn_noembed_pretrained_correct",
    "chattn + embedding": "pft_fullyfeat_from_scratch",
    "chattn": "pft_stand_chattn_noemb",
}

run_names_no_chattn = {
    "vanilla (wider) + embedding": "pft_emb_equal_param_univariate_from_scratch",
    "vanilla (wider)": "pft_noemb_equal_param_univariate_from_scratch",
    "vanilla (deeper)": "pft_equal_param_deeper_univariate_from_scratch_noemb",
    "vanilla + mlm + embedding": "pft_rff_univariate_pretrained",
    "vanilla + mlm": "pft_vanilla_pretrained_correct",
}

run_names = {
    **run_names_chattn,
    **run_names_no_chattn,
}

run_metrics_dirs_all_groups = {
    "chattn": {
        run_abbrv: os.path.join(
            "../eval_results",
            "patchtst",
            f"{run_name}-0",
            data_split,
        )
        for run_abbrv, run_name in run_names_chattn.items()
    },
    "no_chattn": {
        run_abbrv: os.path.join(
            "../eval_results",
            "patchtst",
            f"{run_name}-0",
            data_split,
        )
        for run_abbrv, run_name in run_names_no_chattn.items()
    },
}

In [None]:
run_metrics_dirs_all_groups["no_chattn"].keys()

In [None]:
metrics_all = defaultdict(lambda: defaultdict(dict))
for run_group, run_metrics_dir_dict in run_metrics_dirs_all_groups.items():
    print(f"Run group: {run_group}")
    for run_abbrv, run_metrics_dir in run_metrics_dir_dict.items():
        run_abbrv = str(run_abbrv)
        print(f"{run_abbrv}: {run_metrics_dir}")
        for file in sorted(
            os.listdir(run_metrics_dir),
            key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
        ):
            if file.endswith(".csv"):
                prediction_length = int(file.split("_pred")[1].split(".csv")[0])
                # print(f"Prediction length: {prediction_length} for {run_abbrv}")
                with open(os.path.join(run_metrics_dir, file), "r") as f:
                    metrics = pd.read_csv(f).to_dict()
                    metrics_all[run_group][run_abbrv][prediction_length] = metrics

In [None]:
metrics_all["no_chattn"].keys()

In [None]:
metrics_all.keys()

In [None]:
unrolled_metrics_all_groups = defaultdict(lambda: defaultdict(dict))
for run_group, all_metrics_of_run_group in metrics_all.items():
    print(run_group)
    for run_abbrv, all_metrics_of_run_abbrv in all_metrics_of_run_group.items():
        print(run_abbrv)
        for run_name, metrics in all_metrics_of_run_abbrv.items():
            print(run_name)
            systems = metrics.pop("system")
            metrics_unrolled = {k: list(v.values()) for k, v in metrics.items()}
            print(metrics_unrolled.keys())
            unrolled_metrics_all_groups[run_group][run_abbrv][run_name] = (
                metrics_unrolled
            )

In [None]:
unrolled_metrics_all_combined = {
    **unrolled_metrics_all_groups["chattn"],
    **unrolled_metrics_all_groups["no_chattn"],
}

In [None]:
unrolled_metrics_all_groups["no_chattn"].keys()

In [None]:
def get_summary_metrics_dict(unrolled_metrics, metric_name):
    summary_metrics_dict = defaultdict(dict)
    for model_name, metrics_dict in unrolled_metrics.items():
        print(model_name)
        prediction_lengths = list(metrics_dict.keys())
        summary_metrics_dict[model_name]["prediction_lengths"] = prediction_lengths
        means = []
        medians = []
        stds = []
        for prediction_length in prediction_lengths:
            metric_val = metrics_dict[prediction_length][metric_name]
            means.append(np.nanmean(metric_val))
            medians.append(np.nanmedian(metric_val))
            stds.append(np.nanstd(metric_val))
        summary_metrics_dict[model_name]["means"] = means
        summary_metrics_dict[model_name]["medians"] = medians
        summary_metrics_dict[model_name]["stds"] = stds
    return summary_metrics_dict

In [None]:
def plot_metrics_by_prediction_length(
    metrics_dict, metric_name, show_std_envelope=False
):
    plt.figure(figsize=(5, 4))
    for model_name, metrics in metrics_dict.items():
        plt.plot(
            metrics["prediction_lengths"],
            metrics["medians"],
            marker="o",
            label=model_name,
        )
        std_envelope = np.array(metrics["stds"])
        if show_std_envelope:
            plt.fill_between(
                metrics["prediction_lengths"],
                metrics["means"] - std_envelope,
                metrics["means"] + std_envelope,
                alpha=0.2,
            )
    plt.legend(loc="lower right")
    plt.xlabel("Prediction Length")
    plt.title(metric_name, fontweight="bold")

In [None]:
run_metrics_dirs_all_groups.keys()

In [None]:
metric_names_chosen = [
    "mse",
    "mae",
    "smape",
    "r2_score",
    "spearman",
]

In [None]:
all_metrics_dict = defaultdict(dict)

for run_group in run_metrics_dirs_all_groups.keys():
    all_metrics_dict[run_group] = {
        metrics_name: get_summary_metrics_dict(
            unrolled_metrics_all_groups[run_group], metrics_name
        )
        for metrics_name in metric_names_chosen
    }

In [None]:
colors = plt.cm.tab10.colors


def plot_all_metrics_by_prediction_length(
    all_metrics_dict: dict[str, dict[str, dict[str, list[float]]]],
    metric_names: list[str],
    runs_to_exclude: list[str],
    metrics_to_show_std_envelope: list[str],
    n_rows: int = 2,
    n_cols: int = 3,
    limit_num_prediction_lengths: int | None = None,
    title: str = "",
):
    num_metrics = len(metric_names)
    fig, axes = plt.subplots(
        nrows=n_rows, ncols=n_cols, figsize=(4 * n_cols, 4 * n_rows)
    )
    axes = axes.flatten()  # Flatten the axes array for easy iteration

    for i, (ax, metric_name) in enumerate(zip(axes, metric_names)):
        metrics_dict = all_metrics_dict[metric_name]
        for j, (model_name, metrics) in enumerate(metrics_dict.items()):
            if model_name in runs_to_exclude:
                continue
            # print(model_name)
            # print(metrics.keys())
            prediction_lengths = metrics["prediction_lengths"][
                :limit_num_prediction_lengths
            ]
            medians = metrics["medians"][:limit_num_prediction_lengths]
            means = metrics["means"][:limit_num_prediction_lengths]
            stds = metrics["stds"][:limit_num_prediction_lengths]
            ax.plot(
                prediction_lengths,
                medians,
                marker="o",
                label=model_name,
                color=colors[j],
                alpha=0.7,
            )
            std_envelope = np.array(stds)
            if metric_name in metrics_to_show_std_envelope:
                ax.fill_between(
                    prediction_lengths,
                    means - std_envelope,
                    means + std_envelope,
                    alpha=0.2,
                    color=colors[j],
                )
        if i == 0:
            ax.legend(loc="lower right", fontsize=10, frameon=True)
        ax.set_xlabel("Prediction Length")
        ax.set_xticks(prediction_lengths)
        name = metric_name.replace("_", " ")
        if name in ["mse", "mae", "rmse", "mape", "smape"]:
            name = name.upper()
        else:
            name = name.capitalize()
        ax.set_title(name, fontweight="bold")

    # Hide any unused subplots
    for ax in axes[num_metrics:]:
        ax.set_visible(False)

    if title:
        plt.suptitle(title, fontweight="bold", fontsize=16, y=1.05)
    plt.tight_layout()
    plt.show()

In [None]:
plot_all_metrics_by_prediction_length(
    all_metrics_dict["no_chattn"],
    ["mse", "mae", "smape", "spearman"],
    # metrics_to_show_std_envelope=["smape", "pearson", "spearman"],
    runs_to_exclude=[],
    metrics_to_show_std_envelope=[],
    limit_num_prediction_lengths=None,
    title="No Channel Attention",
    n_cols=4,
)

In [None]:
plot_all_metrics_by_prediction_length(
    all_metrics_dict["chattn"],
    ["mse", "mae", "smape", "spearman"],
    # metrics_to_show_std_envelope=["smape", "pearson", "spearman"],
    runs_to_exclude=[],
    metrics_to_show_std_envelope=[],
    limit_num_prediction_lengths=None,
    title="Channel Attention",
    n_cols=4,
)

In [None]:
all_metrics_dict_all = {
    metrics_name: {
        **all_metrics_dict["chattn"][metrics_name],
        **all_metrics_dict["no_chattn"][metrics_name],
    }
    for metrics_name in metric_names_chosen
}
plot_all_metrics_by_prediction_length(
    all_metrics_dict_all,
    ["mse", "mae", "smape", "spearman"],
    # metrics_to_show_std_envelope=["smape", "pearson", "spearman"],
    runs_to_exclude=[],
    metrics_to_show_std_envelope=[],
    limit_num_prediction_lengths=None,
    title="All Models",
    n_cols=4,
)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

colors = plt.cm.tab10.colors


def make_strip_plot(
    unrolled_metrics: dict[str, dict[int, dict[str, list[float]]]],
    prediction_length: int,
    selected_run_names: list[str] | None = None,
    ylim: tuple[float, float] = (1e-5, 1e5),
    verbose: bool = False,
    metrics_to_exclude: list[str] = [],
    use_rescaled_smape: bool = False,
    run_names_to_exclude: list[str] = [],
    plot_violin: bool = False,
    title: str = "Metrics",
):
    # Extract metrics data for the given prediction_length and run_names
    if selected_run_names is None:
        selected_run_names = list(unrolled_metrics.keys())

    metrics_by_run_name = {
        run_name: unrolled_metrics[run_name][prediction_length]
        for run_name in selected_run_names
    }

    run_names = list(metrics_by_run_name.keys())
    metric_names = list(metrics_by_run_name[run_names[0]].keys())
    metric_names = [name for name in metric_names if name not in metrics_to_exclude]
    run_names = [name for name in run_names if name not in run_names_to_exclude]

    print(f"run_names: {run_names}")
    print(f"metric_names: {metric_names}")

    # Create pretty titles for x-axis tick labels
    metric_names_title = [
        name.replace("_", " ").upper()
        if name in ["mse", "mae", "rmse", "mape", "smape"]
        else name.replace("_", " ").capitalize()
        for name in metric_names
    ]

    # colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.figure(figsize=(6, 4))

    plot_data = []
    median_data = []
    for i, metric_name in enumerate(metric_names):
        metric_title = metric_names_title[i]
        for run_name in run_names:
            values = metrics_by_run_name[run_name][metric_name]
            if metric_name == "smape" and use_rescaled_smape:
                values = [x / 100 for x in values]
            median_value = np.nanmedian(values)
            plot_data.extend([(metric_title, v, run_name) for v in values])
            median_data.append((metric_title, median_value, run_name))
            if verbose:
                print(f"{metric_title} - {run_name} median: {median_value}")

    # Create DataFrame for use with seaborn
    df = pd.DataFrame(plot_data, columns=["Metric", "Value", "Run"])

    if plot_violin:
        # Plot violin plot
        ax = sns.violinplot(
            data=df,
            x="Metric",
            y="Value",
            hue="Run",
            dodge=True,
            alpha=0.7,
            palette=colors[: len(run_names)],
            scale="width",
            cut=0,
        )
    else:
        # Plot box plot
        ax = sns.boxplot(
            data=df,
            x="Metric",
            y="Value",
            hue="Run",
            dodge=True,
            width=0.8,
            fliersize=0,  # Don't show outlier points
            palette=colors[: len(run_names)],
            saturation=0.5,
        )

        # Plot the strip plot using hue (with dodge enabled)
        ax = sns.stripplot(
            data=df,
            x="Metric",
            y="Value",
            hue="Run",
            dodge=True,  # enables side-by-side grouping for each metric
            alpha=0.5,
            size=3,
            jitter=0.15,
            palette=colors[: len(run_names)],
            order=metric_names_title,
            legend=False,
            zorder=0,
        )

    # Get the center x-coordinate for each metric category
    x_positions = {
        label.get_text(): pos for pos, label in enumerate(ax.get_xticklabels())
    }

    # Assume each category uses a total width of 0.8 (default in many seaborn categorical plots)
    dodge_width = 0.8
    n = len(run_names)
    # The space allocated per hue (run) within the category:
    group_width = dodge_width / n
    # Let the median line be slightly wider than the group width
    median_line_width = group_width * 0.8

    # Draw horizontal lines for medians centered exactly on the corresponding dodge positions
    for metric, median_value, run in median_data:
        x_center = x_positions[metric]
        run_index = run_names.index(run)
        # Compute the dodge offset so that the hues are evenly spaced
        offset = (run_index - (n - 1) / 2) * group_width
        x_line_center = x_center + offset
        plt.hlines(
            y=median_value,
            xmin=x_line_center - median_line_width / 2,
            xmax=x_line_center + median_line_width / 2,
            color="black",
            linewidth=2,
            zorder=3,
        )

    # plt.yscale("log")
    plt.ylim(ylim)
    plt.ylabel("")
    plt.legend(frameon=True, fontsize=4, loc="lower right")
    plt.title(title, fontweight="bold")
    plt.xticks(rotation=15)  # Optional: rotates x-tick labels for readability
    plt.show()

In [None]:
run_names

In [None]:
make_strip_plot(
    unrolled_metrics_all_groups["no_chattn"],
    128,
    selected_run_names=None,
    ylim=(0, 2),
    metrics_to_exclude=["r2_score"],
    # run_names_to_exclude=["pft_stand_pretrained_vanilla"],
    use_rescaled_smape=True,
    plot_violin=False,
    title="No Channel Attention",
)

In [None]:
unrolled_metrics_all_groups["chattn"].keys()

In [None]:
make_strip_plot(
    unrolled_metrics_all_groups["chattn"],
    128,
    selected_run_names=None,
    ylim=(0, 2),
    metrics_to_exclude=["r2_score"],
    use_rescaled_smape=True,
    plot_violin=False,
    title="Channel Attention",
)

In [None]:
unrolled_metrics_all_combined.keys()

In [None]:
make_strip_plot(
    unrolled_metrics_all_combined,
    128,
    selected_run_names=None,
    ylim=(0, 1),
    metrics_to_exclude=["r2_score"],
    use_rescaled_smape=True,
    plot_violin=False,
    title="All Models",
)