In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from collections import defaultdict

import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import colors as mcolors
from matplotlib.legend_handler import HandlerLine2D

from panda.utils.eval_utils import get_summary_metrics_dict
from panda.utils.plot_utils import (
    apply_custom_style,
    make_box_plot,
    plot_all_metrics_by_prediction_length,
)

apply_custom_style("../../config/plotting.yaml")

In [None]:
figs_save_dir = os.path.join("../../figures", "eval_metrics")
os.makedirs(figs_save_dir, exist_ok=True)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
# data_split = "final_skew40/test_zeroshot"
data_split = "test_zeroshot"  # includes test_zeroshot for skew and base systems, and z5_z10 splits as well

run_names_chattn = {
    # "Chattn + PolyEmbedLinAttn": "pft_linattnpolyemb_from_scratch-0",
    "Chattn + MLM + RFF": "pft_stand_rff_only_pretrained-0",
    # "Chattn + MLM + PolyEmbed": "pft_chattn_fullemb_pretrained-0",
    # "Chattn + MLM + PolyEmbed": "pft_chattn_fullemb_quartic_enc-0",
    "Chattn + MLM": "pft_chattn_noembed_pretrained_correct-0",
    "Chattn + RFF": "pft_rff496_proj-0",
    "Chattn + PolyEmbed": "pft_chattn_emb_w_poly-0",
    "Chattn": "pft_stand_chattn_noemb-0",
}

run_names_no_chattn = {
    "Univar (wider) + RFF": "pft_emb_equal_param_univariate_from_scratch-0",
    "Univar (wider)": "pft_noemb_equal_param_univariate_from_scratch-0",
    "Univar (deeper)": "pft_equal_param_deeper_univariate_from_scratch_noemb-0",
    "Univar + MLM + RFF": "pft_rff_univariate_pretrained-0",
    "Univar + MLM": "pft_vanilla_pretrained_correct-0",
}

run_names = {
    **run_names_chattn,
    **run_names_no_chattn,
}

run_metrics_dirs_all_groups = {
    "chattn": {
        run_abbrv: os.path.join(
            WORK_DIR,
            "eval_results",
            "patchtst",
            f"{run_name}",
            data_split,
        )
        for run_abbrv, run_name in run_names_chattn.items()
    },
    "no_chattn": {
        run_abbrv: os.path.join(
            WORK_DIR,
            "eval_results",
            "patchtst",
            f"{run_name}",
            data_split,
        )
        for run_abbrv, run_name in run_names_no_chattn.items()
    },
}

In [None]:
run_metrics_dirs_all_groups["no_chattn"].keys()

In [None]:
metrics_all = defaultdict(lambda: defaultdict(dict))
for run_group, run_metrics_dir_dict in run_metrics_dirs_all_groups.items():
    print(f"Run group: {run_group}")
    for run_abbrv, run_metrics_dir in run_metrics_dir_dict.items():
        if not os.path.exists(run_metrics_dir):
            print(f"Run metrics dir does not exist: {run_metrics_dir}")
            continue
        run_abbrv = str(run_abbrv)
        print(f"{run_abbrv}: {run_metrics_dir}")
        for file in sorted(
            os.listdir(run_metrics_dir),
            key=lambda x: int(x.split("_pred")[1].split(".csv")[0]),
        ):
            if file.endswith(".csv"):
                prediction_length = int(file.split("_pred")[1].split(".csv")[0])
                # print(f"Prediction length: {prediction_length} for {run_abbrv}")
                with open(os.path.join(run_metrics_dir, file)) as f:
                    metrics = pd.read_csv(f).to_dict()
                    metrics_all[run_group][run_abbrv][prediction_length] = metrics

In [None]:
metrics_all["no_chattn"].keys()

In [None]:
metrics_all.keys()

In [None]:
unrolled_metrics_all_groups = defaultdict(lambda: defaultdict(dict))
for run_group, all_metrics_of_run_group in metrics_all.items():
    print(run_group)
    for run_abbrv, all_metrics_of_run_abbrv in all_metrics_of_run_group.items():
        print(run_abbrv)
        for run_name, metrics in all_metrics_of_run_abbrv.items():
            print(run_name)
            systems = metrics.pop("system")
            # metrics_unrolled = {
            #     k: list(v.values()) for k, v in metrics.items() if k != "spearman"
            # }
            metrics_unrolled = {k: list(v.values()) for k, v in metrics.items()}
            print(metrics_unrolled.keys())
            unrolled_metrics_all_groups[run_group][run_abbrv][run_name] = metrics_unrolled

In [None]:
unrolled_metrics_all_combined = {
    **unrolled_metrics_all_groups["chattn"],
    **unrolled_metrics_all_groups["no_chattn"],
}

In [None]:
unrolled_metrics_all_groups["no_chattn"].keys()

In [None]:
run_metrics_dirs_all_groups.keys()

In [None]:
metric_names_chosen = ["mse", "mae", "smape", "spearman"]

In [None]:
all_metrics_dict = defaultdict(dict)

for run_group in run_metrics_dirs_all_groups.keys():
    all_metrics_dict[run_group] = {
        metrics_name: get_summary_metrics_dict(unrolled_metrics_all_groups[run_group], metrics_name)[0]
        for metrics_name in metric_names_chosen
    }

In [None]:
default_colors = plt.cm.tab10.colors

In [None]:
n_runs_chattn = len(run_metrics_dirs_all_groups["chattn"].keys())
n_runs_no_chattn = len(run_metrics_dirs_all_groups["no_chattn"].keys())

In [None]:
bar_colors_chattn = plt.cm.Reds(np.linspace(0.75, 0.1, n_runs_chattn)).tolist()
print(bar_colors_chattn)
bar_colors_no_chattn = plt.cm.Greys(np.linspace(0.75, 0.1, n_runs_no_chattn)).tolist()
print(bar_colors_no_chattn)
bar_colors = bar_colors_chattn + bar_colors_no_chattn

In [None]:
selected_pred_length = 128

In [None]:
legend_handles = make_box_plot(
    unrolled_metrics=unrolled_metrics_all_combined,
    prediction_length=selected_pred_length,
    metric_to_plot="smape",  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=bar_colors,
    # title=rf"$L_{{pred}}$ = {selected_pred_length}",
    title_kwargs={"fontsize": 10},
    save_path=f"ablations_figs/smape_{selected_pred_length}.pdf",
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(40, 60),
    whisker_percentile_range=(30, 70),
    alpha_val=0.8,
    fig_kwargs={"figsize": (3.2, 5)},
)

In [None]:
plt.figure(figsize=(3, 5))

# Group the legend handles and labels
channel_attention_handles = legend_handles[:n_runs_chattn]  # First 4 handles for channel attention
univariate_handles = legend_handles[n_runs_chattn:]  # Last 4 handles for univariate

# Create section headers with bold text and visible lines
channel_attention_header = mlines.Line2D([0], [0], color="black", label="Multivariate", linewidth=0)
univariate_header = mlines.Line2D([0], [0], color="black", label="Univariate", linewidth=0)

# Create an empty line for spacing
spacer = mlines.Line2D([0], [0], color="none", label=" ", linewidth=0)

# Combine headers and handles with spacer
all_handles = (
    [channel_attention_header]
    + channel_attention_handles
    + [spacer]  # Add spacer between sections
    + [univariate_header]
    + univariate_handles
)

# Add the legend with the combined handles
legend = plt.legend(
    handles=all_handles,
    loc="upper center",
    frameon=True,
    ncol=1,
    framealpha=1.0,
    fontsize=14,
    handler_map={
        channel_attention_header: HandlerLine2D(),
        univariate_header: HandlerLine2D(),
        spacer: HandlerLine2D(),
    },
)

# Make section headers bold and add underlines
for text in legend.get_texts():
    if text.get_text() in ["Multivariate", "Univariate"]:
        text.set_fontweight("bold")
        text.set_ha("left")  # Align text to the left
        text.set_position((0, 0))  # Remove indentation

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig(
    f"ablations_figs/ablations_legend_vertical_{selected_pred_length}.pdf",
    bbox_inches="tight",
)
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(5, 2))

# Group the legend handles and labels
channel_attention_handles = legend_handles[:n_runs_chattn]  # First 4 handles for channel attention
univariate_handles = legend_handles[n_runs_chattn:]  # Last 4 handles for univariate

# Create section headers with bold text and visible lines
channel_attention_header = mlines.Line2D([0], [0], color="black", label="Multivariate", linewidth=0)
univariate_header = mlines.Line2D([0], [0], color="black", label="Univariate", linewidth=0)

# Create an empty line for spacing
spacer = mlines.Line2D([0], [0], color="none", label=" ", linewidth=0)

# Combine headers and handles with spacer
all_handles = [channel_attention_header] + channel_attention_handles + [univariate_header] + univariate_handles

# Add the legend with the combined handles
legend = plt.legend(
    handles=all_handles,
    loc="upper center",
    frameon=True,
    ncol=2,
    framealpha=1.0,
    fontsize=14,
    handler_map={
        channel_attention_header: HandlerLine2D(),
        univariate_header: HandlerLine2D(),
        spacer: HandlerLine2D(),
    },
)

# Make section headers bold and add underlines
for text in legend.get_texts():
    if text.get_text() in ["Multivariate", "Univariate"]:
        text.set_fontweight("bold")
        text.set_ha("left")  # Align text to the left
        text.set_position((0, 0))  # Remove indentation

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig(
    f"ablations_figs/ablations_legend_horizontal_{selected_pred_length}.pdf",
    bbox_inches="tight",
)
plt.show()
plt.close()

In [None]:
plt.figure(figsize=(5, 2))

# Group the legend handles and labels
channel_attention_handles = legend_handles[:n_runs_chattn]  # First 4 handles for channel attention
univariate_handles = legend_handles[n_runs_chattn:]  # Last 4 handles for univariate

# Create section headers with bold text and visible lines
channel_attention_header = mlines.Line2D([0], [0], color="black", label="Multivariate", linewidth=0)
univariate_header = mlines.Line2D([0], [0], color="black", label="Univariate", linewidth=0)

# Create an empty line for spacing
spacer = mlines.Line2D([0], [0], color="none", label=" ", linewidth=0)

# Combine headers and handles with spacer
all_handles = [channel_attention_header] + channel_attention_handles + [univariate_header] + univariate_handles

# Add the legend with the combined handles
legend = plt.legend(
    handles=all_handles,
    loc="upper center",
    frameon=True,
    ncol=4,
    framealpha=1.0,
    fontsize=14,
    handler_map={
        channel_attention_header: HandlerLine2D(),
        univariate_header: HandlerLine2D(),
        spacer: HandlerLine2D(),
    },
)

# Make section headers bold and add underlines
for text in legend.get_texts():
    if text.get_text() in ["Multivariate", "Univariate"]:
        text.set_fontweight("bold")
        text.set_ha("left")  # Align text to the left
        text.set_position((0, 0))  # Remove indentation

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig(
    f"ablations_figs/ablations_legend_horizontal_{selected_pred_length}_v2.pdf",
    bbox_inches="tight",
)
plt.show()
plt.close()

In [None]:
for metric_to_plot in metric_names_chosen:
    make_box_plot(
        unrolled_metrics=unrolled_metrics_all_combined,
        prediction_length=selected_pred_length,
        metric_to_plot=metric_to_plot,  # Specify which metric to plot
        sort_runs=True,  # Optionally sort runs by their metric values
        colors=bar_colors,
        title=None,
        title_kwargs={"fontsize": 10},
        order_by_metric="smape",
        save_path=f"ablations_figs/{metric_to_plot}_{selected_pred_length}.pdf",
        ylabel_fontsize=12,
        show_xlabel=False,
        show_legend=False,
        legend_kwargs={
            "loc": "upper left",
            "frameon": True,
            "ncol": 1,
            "framealpha": 0.8,
            # "prop": {"weight": "bold", "size": 5},
            "prop": {"size": 6.8},
        },
        box_percentile_range=(40, 60),
        whisker_percentile_range=(25, 75),
        alpha_val=0.8,
        fig_kwargs={"figsize": (3.2, 5)},
        use_inv_spearman=True,
    )

In [None]:
custom_colors_dict = {}
for i, patch in enumerate(legend_handles):
    color = patch.get_facecolor()
    hex_color = mcolors.rgb2hex(color)
    run_name = patch.get_label()
    print(run_name, hex_color)
    custom_colors_dict[run_name] = hex_color

In [None]:
all_metrics_dict.keys()

In [None]:
all_metrics_dict["chattn"]["smape"]

In [None]:
all_metrics_dict_all = {
    metrics_name: {
        **all_metrics_dict["chattn"][metrics_name],
        **all_metrics_dict["no_chattn"][metrics_name],
    }
    for metrics_name in metric_names_chosen
}

In [None]:
all_metrics_dict_all["smape"].keys()

In [None]:
custom_colors_dict

In [None]:
all_metrics_dict_all["smape"].keys()

In [None]:
## weird bug here?
plot_all_metrics_by_prediction_length(
    all_metrics_dict["chattn"],
    ["mse", "mae", "smape"],
    # metrics_to_show_std_envelope=["smape", "spearman"],
    n_rows=1,
    n_cols=4,
    colors=custom_colors_dict,
    show_legend=False,
)

### Now make ablations box plot for pred length 512 (rollout) using the same legend handles as 128

In [None]:
legend_handles[0].get_facecolor()

In [None]:
legend_handles[0].get_label()

In [None]:
# write custom colors dict mapping key (label of legend handle) to color (facecolor of legend handle)
custom_colors_dict = {}
for i, patch in enumerate(legend_handles):
    color = patch.get_facecolor()
    hex_color = mcolors.rgb2hex(color)
    run_name = patch.get_label()
    print(run_name, hex_color)
    custom_colors_dict[run_name] = hex_color

In [None]:
custom_colors_dict.keys()

In [None]:
selected_pred_length = 512

legend_handles_rollout = make_box_plot(
    unrolled_metrics=unrolled_metrics_all_combined,
    prediction_length=selected_pred_length,
    metric_to_plot="smape",  # Specify which metric to plot
    sort_runs=True,  # Optionally sort runs by their metric values
    colors=custom_colors_dict,
    # title=rf"$L_{{pred}}$ = {selected_pred_length}",
    title_kwargs={"fontsize": 10},
    save_path=f"ablations_figs/smape_{selected_pred_length}.pdf",
    ylabel_fontsize=12,
    show_xlabel=False,
    box_percentile_range=(40, 60),
    whisker_percentile_range=(30, 70),
    alpha_val=0.8,
    fig_kwargs={"figsize": (3.2, 5)},
)

In [None]:
# for metric_to_plot in metric_names_chosen:
#     make_box_plot(
#         unrolled_metrics=unrolled_metrics_all_combined,
#         prediction_length=selected_pred_length,
#         metric_to_plot=metric_to_plot,  # Specify which metric to plot
#         sort_runs=True,  # Optionally sort runs by their metric values
#         colors=custom_colors_dict,
#         # title=rf"$L_{{pred}}$ = {selected_pred_length}",
#         title_kwargs={"fontsize": 10},
#         save_path=f"ablations_figs/{metric_to_plot}_{selected_pred_length}.pdf",
#         ylabel_fontsize=12,
#         show_xlabel=False,
#         box_percentile_range=(40, 60),
#         whisker_percentile_range=(30, 70),
#         alpha_val=0.8,
#         fig_kwargs={"figsize": (3.2, 5)},
#         use_inv_spearman=True,
#     )

In [None]:
plt.figure(figsize=(3, 5))

# Group the legend handles and labels
channel_attention_handles = legend_handles_rollout[:n_runs_chattn]  # First 4 handles for channel attention
univariate_handles = legend_handles_rollout[n_runs_chattn:]  # Last 4 handles for univariate

# Create section headers with bold text and visible lines
channel_attention_header = mlines.Line2D([0], [0], color="black", label="Multivariate", linewidth=0)
univariate_header = mlines.Line2D([0], [0], color="black", label="Univariate", linewidth=0)

# Create an empty line for spacing
spacer = mlines.Line2D([0], [0], color="none", label=" ", linewidth=0)

# Combine headers and handles with spacer
all_handles = (
    [channel_attention_header]
    + channel_attention_handles
    + [spacer]  # Add spacer between sections
    + [univariate_header]
    + univariate_handles
)

# Add the legend with the combined handles
legend = plt.legend(
    handles=all_handles,
    loc="upper center",
    frameon=True,
    ncol=1,
    framealpha=1.0,
    fontsize=14,
    handler_map={
        channel_attention_header: HandlerLine2D(),
        univariate_header: HandlerLine2D(),
        spacer: HandlerLine2D(),
    },
)

# Make section headers bold and add underlines
for text in legend.get_texts():
    if text.get_text() in ["Multivariate", "Univariate"]:
        text.set_fontweight("bold")
        text.set_ha("left")  # Align text to the left
        text.set_position((0, 0))  # Remove indentation

plt.xticks([])
plt.yticks([])
plt.tight_layout(pad=0)
plt.savefig(
    f"ablations_figs/ablations_legend_rollout_vertical_{selected_pred_length}.pdf",
    bbox_inches="tight",
)
plt.show()
plt.close()