In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
from sklearn.metrics import r2_score
from tqdm import tqdm

from panda.utils.plot_utils import apply_custom_style

apply_custom_style("../config/plotting.yaml")

In [None]:
DEFAULT_COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

### Process Saved Metrics

In [None]:
def get_sorted_metric_fnames(save_dir):
    fnames = [f for f in os.listdir(save_dir) if f.endswith(".json") and "distributional_metrics" and "all" in f]

    def extract_window(fname):
        m = re.search(r"window-(\d+)", fname)
        return int(m.group(1)) if m else float("inf")

    return sorted(fnames, key=extract_window)


panda_metrics_save_dir = f"{WORK_DIR}/eval_results/panda/pft_chattn_emb_w_poly-0/test_zeroshot"
# NOTE: we also have for chronos_nondeterministic, replace "chronos" with "chronos_nondeterministic" in the paths below
chronos_sft_metrics_save_dir = f"{WORK_DIR}/eval_results/chronos_nondeterministic/chronos_t5_mini_ft-0/test_zeroshot"
chronos_zs_metrics_save_dir = f"{WORK_DIR}/eval_results/chronos_nondeterministic/chronos_mini_zeroshot/test_zeroshot"

panda_metrics_fnames = get_sorted_metric_fnames(panda_metrics_save_dir)
chronos_sft_metrics_fnames = get_sorted_metric_fnames(chronos_sft_metrics_save_dir)
chronos_zs_metrics_fnames = get_sorted_metric_fnames(chronos_zs_metrics_save_dir)

print(f"Found {len(panda_metrics_fnames)} panda metrics files: {panda_metrics_fnames}")
print(f"Found {len(chronos_sft_metrics_fnames)} chronos sft metrics files: {chronos_sft_metrics_fnames}")
print(f"Found {len(chronos_zs_metrics_fnames)} chronos zs metrics files: {chronos_zs_metrics_fnames}")

In [None]:
# For accumulating values across all files, for both panda and chronos_sft metrics


def filter_none(values):
    """Remove None values from a list."""
    return [v for v in values if v is not None]


def accumulate_metrics(metrics_fnames, metrics_save_dir):
    avg_hellinger_accum = {
        "pred": defaultdict(lambda: defaultdict(list)),
        "pred_with_context": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    kld_accum = {
        "pred": defaultdict(lambda: defaultdict(list)),
        "pred_with_context": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    gpdim_accum = {
        "gt": defaultdict(lambda: defaultdict(list)),
        "pred": defaultdict(lambda: defaultdict(list)),
        "gt_with_context": defaultdict(lambda: defaultdict(list)),
        "pred_with_context": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    max_lyap_accum = {
        "gt": defaultdict(lambda: defaultdict(list)),
        "pred": defaultdict(lambda: defaultdict(list)),
        "gt_with_context": defaultdict(lambda: defaultdict(list)),
        "pred_with_context": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    prediction_time_accum = defaultdict(list)

    for fname in metrics_fnames:
        with open(os.path.join(metrics_save_dir, fname), "rb") as f:
            metrics = json.load(f)
        n_pred_intervals = len(metrics)
        print(f"number of prediction intervals in {fname}: {n_pred_intervals}")
        for pred_interval in metrics:
            print(pred_interval)
            data = metrics[pred_interval]
            for system_name, system_entry in tqdm(data, desc=f"Processing {pred_interval}"):
                avg_hellinger_accum["pred"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["avg_hellinger_distance"]
                )
                avg_hellinger_accum["pred_with_context"][pred_interval][system_name].append(
                    system_entry["full_trajectory"]["avg_hellinger_distance"]
                )
                kld_accum["full"][pred_interval][system_name].append(system_entry["full_trajectory"]["kl_divergence"])
                max_lyap_accum["gt"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["max_lyap_gt"]
                )
                max_lyap_accum["pred"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["max_lyap_pred"]
                )
                max_lyap_accum["full"][pred_interval][system_name].append(
                    system_entry["full_trajectory"]["max_lyap_full_traj"]
                )
                max_lyap_accum["gt_with_context"][pred_interval][system_name].append(
                    system_entry["pred_with_context"]["max_lyap_gt_with_context"]
                )
                max_lyap_accum["pred_with_context"][pred_interval][system_name].append(
                    system_entry["pred_with_context"]["max_lyap_pred_with_context"]
                )
                gpdim_accum["gt"][pred_interval][system_name].append(system_entry["prediction_horizon"]["gpdim_gt"])
                gpdim_accum["pred"][pred_interval][system_name].append(system_entry["prediction_horizon"]["gpdim_pred"])
                gpdim_accum["gt_with_context"][pred_interval][system_name].append(
                    system_entry["pred_with_context"]["gpdim_gt_with_context"]
                )
                gpdim_accum["pred_with_context"][pred_interval][system_name].append(
                    system_entry["pred_with_context"]["gpdim_pred_with_context"]
                )
                gpdim_accum["full"][pred_interval][system_name].append(
                    system_entry["full_trajectory"]["gpdim_full_traj"]
                )
                pred_time = system_entry["prediction_time"]
                prediction_time_accum[system_name].append(pred_time)

    # Now, take the mean across all files for each metric, skipping None values
    avg_hellinger = {k: defaultdict(dict) for k in avg_hellinger_accum.keys()}
    kld = {k: defaultdict(dict) for k in kld_accum.keys()}
    max_lyap = {k: defaultdict(dict) for k in max_lyap_accum.keys()}
    gpdim = {k: defaultdict(dict) for k in gpdim_accum.keys()}
    prediction_time = {}

    for key in ["pred", "pred_with_context", "full"]:
        for metric_accum, metric in [
            (avg_hellinger_accum, avg_hellinger),
            (kld_accum, kld),
        ]:
            for pred_interval in metric_accum[key]:
                for system_name, values in metric_accum[key][pred_interval].items():
                    filtered = filter_none(values)
                    metric[key][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None

    for key in ["gt", "gt_with_context", "pred", "pred_with_context", "full"]:
        for metric_accum, metric in [(gpdim_accum, gpdim), (max_lyap_accum, max_lyap)]:
            for pred_interval in metric_accum[key]:
                for system_name, values in metric_accum[key][pred_interval].items():
                    filtered = filter_none(values)
                    metric[key][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None

    for system_name, times in prediction_time_accum.items():
        times_arr = np.array(filter_none(times))
        prediction_time[system_name] = np.mean(times_arr) if len(times_arr) > 0 else None

    return {
        "avg_hellinger": avg_hellinger,
        "kld": kld,
        "max_lyap": max_lyap,
        "gpdim": gpdim,
        "prediction_time": prediction_time,
    }


# Accumulate metrics for both panda and chronos_sft
print("Accumulating panda metrics...")
panda_metrics = accumulate_metrics(panda_metrics_fnames, panda_metrics_save_dir)
print("Accumulating chronos_sft metrics...")
chronos_sft_metrics = accumulate_metrics(chronos_sft_metrics_fnames, chronos_sft_metrics_save_dir)
print("Accumulating chronos_zs metrics...")
chronos_zs_metrics = accumulate_metrics(chronos_zs_metrics_fnames, chronos_zs_metrics_save_dir)

In [None]:
panda_metrics.keys()

In [None]:
panda_metrics["max_lyap"].keys()

In [None]:
models = ["panda", "chronos_sft", "chronos_zs"]
metrics = {
    m: {k: eval(f"{m}_metrics")[k] for k in ["avg_hellinger", "kld", "gpdim", "max_lyap", "prediction_time"]}
    for m in models
}

### Inference Time Comparison

In [None]:
# NOTE: the reason for this is for some reason the first prediction takes lot more time, prob spinning up overhead
first_system = list(metrics["panda"]["prediction_time"].keys())[0]
metrics["panda"]["prediction_time"].pop(first_system)
metrics["chronos_sft"]["prediction_time"].pop(first_system)
metrics["chronos_zs"]["prediction_time"].pop(first_system)
print(metrics["panda"]["prediction_time"])
print(metrics["chronos_sft"]["prediction_time"])
print(metrics["chronos_zs"]["prediction_time"])

In [None]:
# Print prediction time mean and std for both panda and chronos_sft

for model_name in ["panda", "chronos_sft", "chronos_zs"]:
    prediction_times = list(metrics[model_name]["prediction_time"].values())
    prediction_time_mean = np.mean(prediction_times)
    prediction_time_std = np.std(prediction_times)
    print(f"{model_name} prediction time mean:", prediction_time_mean)
    print(f"{model_name} prediction time std:", prediction_time_std)

## Max Lyapunov Exponent Comparison

In [None]:
metrics["panda"]["max_lyap"].keys()

In [None]:
# Choose the prediction interval (pred_length) of 512
pred_length = "128"
model_type = "panda"
use_full_traj_gt = True
use_context_with_preds = True
show_figure = True

model_type_title = model_type.replace("_", " ").title()
if model_type == "chronos_zs":
    model_type_title = "Chronos"
elif model_type == "chronos_sft":
    model_type_title = "Chronos SFT"
elif model_type == "panda":
    model_type_title = "Panda"
else:
    raise ValueError(f"Invalid model type: {model_type}")

if use_full_traj_gt:
    gt_key = "full"
    gt_key_name = "Full Trajectory"
else:
    gt_key = "gt"
    gt_key_name = "Ground Truth"

if use_context_with_preds:
    pred_key = "pred_with_context"
    pred_key_name = "Context + Prediction"
else:
    pred_key = "pred"
    pred_key_name = "Prediction"

# Get the dictionaries for gtcontext and predcontext at pred_length 512 for model_type
gt_dict = metrics[model_type]["max_lyap"][gt_key].get(pred_length, {})
pred_dict = metrics[model_type]["max_lyap"][pred_key].get(pred_length, {})

# Find the intersection of system names present in both
system_names = set(gt_dict.keys()) & set(pred_dict.keys())

# Prepare x and y data for scatter plot
x_raw = [gt_dict[sys] for sys in system_names]
y_raw = [pred_dict[sys] for sys in system_names]

# Filter out pairs where either value is nan or inf
x = []
y = []
num_invalid = 0
for xi, yi in zip(x_raw, y_raw):
    if np.isfinite(xi) and np.isfinite(yi) and not np.isnan(xi) and not np.isnan(yi):
        x.append(xi)
        y.append(yi)
    else:
        num_invalid += 1

# Compute R^2 score
if len(x) > 0 and len(y) > 0:
    r2 = r2_score(x, y)
else:
    r2 = float("nan")

print(f"Filtered out {num_invalid} invalid (nan/inf) pairs from {len(x_raw)} total.")

print(f"{model_type_title}: {pred_key_name} vs {gt_key_name} at L_pred={pred_length}, R^2={r2:.3f}")

if show_figure:
    plt.figure(figsize=(4, 4))
    plt.scatter(x, y, color="black", s=5, alpha=0.1, label=None)
    plt.xlabel(gt_key_name, fontweight="bold")
    plt.ylabel(pred_key_name, fontweight="bold")
    plt.title(
        rf"{model_type_title} $\lambda_{{\max}}$ ($L_{{\mathrm{{pred}}}} = {pred_length}$)",
        fontweight="bold",
    )

    # Prepare handles and labels for legend
    handles = []
    labels = []

    # Plot y=x line in red dashed, but do NOT add to legend yet
    if len(x) > 0 and len(y) > 0:
        y_eq_x_min = min(x + y)
        y_eq_x_max = max(x + y)
    else:
        y_eq_x_min = 0
        y_eq_x_max = 1
    (h1,) = plt.plot([y_eq_x_min, y_eq_x_max], [y_eq_x_min, y_eq_x_max], "r--", label=r"$y=x$")

    # Plot line of best fit as solid red line and prepare equation+R2 for legend
    eqn_r2_label = None
    if len(x) > 1 and len(y) > 1:
        # Fit line: y = m*x + b
        m, b = np.polyfit(x, y, 1)
        x_fit = np.array([y_eq_x_min, y_eq_x_max])
        y_fit = m * x_fit + b
        # Format equation for label (to be shown in legend with R^2)
        if abs(b) < 1e-10:
            eqn_str = rf"$y = {m:.2f}x$"
        else:
            sign = "+" if b >= 0 else "-"
            eqn_str = rf"$y = {m:.2f}x {sign} {abs(b):.2f}$"
        if not (r2 != r2):  # check for nan
            eqn_r2_label = rf"{eqn_str}  $(R^2 = {r2:.3f})$"
        else:
            eqn_r2_label = eqn_str
        (h2,) = plt.plot(x_fit, y_fit, color="red", linestyle="-", linewidth=1.5, label=eqn_r2_label)
        # Add best fit line first, then y=x line, to put y=x below in legend
        handles.append(h2)
        labels.append(eqn_r2_label)
        handles.append(h1)
        labels.append(r"$y=x$")
    else:
        # If no best fit, just add y=x
        handles.append(h1)
        labels.append(r"$y=x$")

    # Show legend in lower right, showing both lines and their labels
    ax = plt.gca()
    ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
    ax.ticklabel_format(style="sci", axis="both", scilimits=(0, 0))

    if len(handles) > 0:
        ax.legend(handles=handles, labels=labels, loc="lower right", fontsize=8, frameon=True)

    plt.tight_layout()
    # plt.savefig(
    #     os.path.join(
    #         "../figures",
    #         f"max_lyap_r_full_pred_{pred_length}_{model_type}.pdf",
    #     ),
    #     bbox_inches="tight",
    # )

    plt.show()

In [None]:
# Choose the prediction interval (pred_length) of 512
pred_lengths = ["128", "512"]
model_types = ["panda", "chronos_sft", "chronos_zs"]
use_full_traj_gt = True
use_context_with_preds = True

if use_full_traj_gt:
    gt_key = "full"
    gt_key_name = "Full Trajectory"
else:
    gt_key = "gt"
    gt_key_name = "Ground Truth"

if use_context_with_preds:
    pred_key = "pred_with_context"
    pred_key_name = "Context + Prediction"
else:
    pred_key = "pred"
    pred_key_name = "Prediction"

print(f"({pred_key_name}) vs {gt_key_name}")

for model_type in model_types:
    model_type_title = model_type.replace("_", " ").title()
    if model_type == "chronos_zs":
        model_type_title = "Chronos"
    elif model_type == "chronos_sft":
        model_type_title = "Chronos SFT"
    elif model_type == "panda":
        model_type_title = "Panda"
    else:
        raise ValueError(f"Invalid model type: {model_type}")

    print(f"Model Type: {model_type}")

    for pred_length in pred_lengths:
        print(f"Prediction Length L_pred = {pred_length}")

        # Get the dictionaries for gtcontext and predcontext at pred_length 512 for model_type
        gt_dict = metrics[model_type]["max_lyap"][gt_key].get(pred_length, {})
        pred_dict = metrics[model_type]["max_lyap"][pred_key].get(pred_length, {})

        # Find the intersection of system names present in both
        system_names = set(gt_dict.keys()) & set(pred_dict.keys())

        # Prepare x and y data for scatter plot
        x_raw = [gt_dict[sys] for sys in system_names]
        y_raw = [pred_dict[sys] for sys in system_names]

        # Filter out pairs where either value is nan or inf
        x = []
        y = []
        num_invalid = 0
        for xi, yi in zip(x_raw, y_raw):
            if np.isfinite(xi) and np.isfinite(yi) and not np.isnan(xi) and not np.isnan(yi):
                x.append(xi)
                y.append(yi)
            else:
                num_invalid += 1

        # Compute R^2 score
        if len(x) > 0 and len(y) > 0:
            r2 = r2_score(x, y)
        else:
            r2 = float("nan")

        print(f"Filtered out {num_invalid} invalid (nan/inf) pairs from {len(x_raw)} total.")

        print(f"R^2={r2:.3f}")

In [None]:
# Load the Rosenstein Lyapunov Exponents of the full trajectory
full_traj_lyap_r_fpath = os.path.join(WORK_DIR, "eval_results", "dataset", "max_lyap_r_test_zeroshot.json")
full_traj_lyap_r_lst = json.load(open(full_traj_lyap_r_fpath))["4096"]

In [None]:
full_traj_lyap_r_dict = {entry[0]: entry[1]["max_lyap_rosenstein"] for entry in full_traj_lyap_r_lst}

In [None]:
full_traj_lyap_r_dict.keys()

In [None]:
# Choose the prediction interval (pred_length) of 512
pred_lengths = ["128", "512"]
model_types = ["panda", "chronos_sft", "chronos_zs"]
use_full_traj_gt = True
use_context_with_preds = False

if use_full_traj_gt:
    gt_key = "full"
    gt_key_name = "Full Trajectory"
else:
    gt_key = "gt"
    gt_key_name = "Ground Truth"

if use_context_with_preds:
    pred_key = "pred_with_context"
    pred_key_name = "Context + Prediction"
else:
    pred_key = "pred"
    pred_key_name = "Prediction"

print(f"({pred_key_name}) vs {gt_key_name}")

for model_type in model_types:
    model_type_title = model_type.replace("_", " ").title()
    if model_type == "chronos_zs":
        model_type_title = "Chronos"
    elif model_type == "chronos_sft":
        model_type_title = "Chronos SFT"
    elif model_type == "panda":
        model_type_title = "Panda"
    else:
        raise ValueError(f"Invalid model type: {model_type}")

    print(f"Model Type: {model_type}")

    for pred_length in pred_lengths:
        print(f"Prediction Length L_pred = {pred_length}")

        # Get the dictionaries for gtcontext and predcontext at pred_length 512 for model_type
        gt_dict = metrics[model_type]["max_lyap"][gt_key].get(pred_length, {})
        # gt_dict = full_traj_lyap_r_dict
        pred_dict = metrics[model_type]["max_lyap"][pred_key].get(pred_length, {})

        # Find the intersection of system names present in both
        system_names = set(gt_dict.keys()) & set(pred_dict.keys())

        # Prepare x and y data for scatter plot
        x_raw = [gt_dict[sys] for sys in system_names]
        y_raw = [pred_dict[sys] for sys in system_names]

        # Filter out pairs where either value is nan or inf
        x = []
        y = []
        num_invalid = 0
        for xi, yi in zip(x_raw, y_raw):
            if np.isfinite(xi) and np.isfinite(yi) and not np.isnan(xi) and not np.isnan(yi):
                x.append(xi)
                y.append(yi)
            else:
                num_invalid += 1

        # Compute R^2 score
        # NOTE: also can swap out with pearsonr here to compute pearson correlation
        if len(x) > 0 and len(y) > 0:
            r2 = r2_score(x, y)
        else:
            r2 = float("nan")

        # print(
        #     f"Filtered out {num_invalid} invalid (nan/inf) pairs from {len(x_raw)} total."
        # )

        print(f"R^2={r2:.3f}")

In [None]:
metrics["chronos_sft"]["max_lyap"]["pred"].keys()

In [None]:
import pandas as pd
from scipy.stats import pearsonr

full_lyaps = [full_traj_lyap_r_dict[sys] for sys in full_traj_lyap_r_dict.keys()]

In [None]:
print(f"gt key: {gt_key}")
print(f"pred key: {pred_key}")

In [None]:
lyaps_per_model = {model_name: metrics[model_name]["max_lyap"][pred_key] for model_name in metrics.keys()}

In [None]:
results = defaultdict(dict)
for model in lyaps_per_model.keys():
    lyaps_for_model = lyaps_per_model[model]
    for pred_length in lyaps_for_model.keys():
        if pred_length == "128":  # TODO: remove when we have 128 predictions
            continue
        model_lyaps = lyaps_for_model[pred_length]
        model_lyaps = [model_lyaps[sys] for sys in model_lyaps.keys()]
        assert len(model_lyaps) == len(full_lyaps)

        # measure correlation and wilcoxon signed rank test statistics and pvalues
        result = pearsonr(full_lyaps, model_lyaps)
        results[model][pred_length] = {
            "corr": float(f"{result.statistic:.3f}"),
            "pval": float(f"{result.pvalue:.3e}"),
        }


pd.DataFrame(results).T

In [None]:
import pandas as pd

results = defaultdict(dict)
for model in lyaps_per_model.keys():
    lyaps_for_model = lyaps_per_model[model]
    for pred_length in lyaps_for_model.keys():
        if pred_length == "128":  # TODO: remove when we have 128 predictions
            continue
        model_lyaps = lyaps_for_model[pred_length]
        model_lyaps = [model_lyaps[sys] for sys in model_lyaps.keys()]
        assert len(model_lyaps) == len(full_lyaps)

        # measure R2 score
        r2 = r2_score(full_lyaps, model_lyaps)
        results[model][pred_length] = {
            "r2": float(f"{r2:.3f}"),
        }

pd.DataFrame(results).T

## Distributional Metrics

In [None]:
def filter_nans(values):
    # Convert dict_values to list, filter out None, then filter out NaN
    arr = []
    for v in list(values):
        if v is not None and not (isinstance(v, float) and np.isnan(v)):
            arr.append(float(v))
    return np.array(arr, dtype=float)

In [None]:
def plot_metric_histograms(
    metric_name,
    pred_length,
    horizon_name,
    model_labels,
    model_keys,
    bins=25,
    log_scale=False,
    xlabel=None,
    ylabel="Count",
    title=None,
    filename=None,
    alpha_val=0.6,
    legend_loc="upper right",
):
    # Gather data for each model
    metric_data = []
    for model in model_keys:
        values = metrics[model][metric_name][horizon_name][pred_length].values()
        if metric_name == "avg_hellinger":
            arr = filter_nans(values)
        else:
            arr = np.array([v for v in values if v is not None], dtype=float)
        metric_data.append(arr)

    # Concatenate all for bin calculation
    if log_scale:
        all_vals = np.concatenate([d[d > 0] for d in metric_data])
        if len(all_vals) > 0:
            min_val = np.min(all_vals)
            max_val = np.max(all_vals)
            bins = np.logspace(np.log10(min_val), np.log10(max_val), bins)
        else:
            print("No positive values found")
    else:
        all_vals = np.concatenate(metric_data)
        bins = np.histogram_bin_edges(all_vals, bins=bins)

    plt.figure(figsize=(4, 4))
    for i, (arr, label) in enumerate(zip(metric_data, model_labels)):
        color = (
            DEFAULT_COLORS[i]
            if i < len(DEFAULT_COLORS)
            else f"tab:{['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'][i % 10]}"
        )
        edgecolor = color
        arr_plot = arr
        if log_scale:
            arr_plot = arr[arr > 0]
        plt.hist(
            arr_plot,
            bins=bins,
            color=color,
            edgecolor=edgecolor,
            alpha=alpha_val,
            histtype="stepfilled",
            label=label,
        )
    if log_scale:
        plt.xscale("log")
    if xlabel:
        plt.xlabel(xlabel, fontweight="bold")
    plt.ylabel(ylabel, fontweight="bold")
    plt.legend(loc=legend_loc)
    if title:
        plt.title(title, fontweight="bold")
    plt.tight_layout()
    if filename:
        plt.savefig(filename, bbox_inches="tight")
    plt.show()


# Plot Hellinger Distance
plot_metric_histograms(
    metric_name="avg_hellinger",
    pred_length="512",
    horizon_name="pred_with_context",  # between gt_with_context and pred_with_context
    model_labels=["Panda", "Chronos 20M SFT", "Chronos 20M"],
    model_keys=["panda", "chronos_sft", "chronos_zs"],
    bins=25,
    log_scale=False,
    # xlabel="Average Hellinger",
    ylabel="Count",
    title="Avg Hellinger Distance ($L_{\\mathrm{pred}} = 512$)",
    filename=os.path.join(
        "../figures",
        "avg_hellinger_distribution_pred_with_context_512.pdf",
    ),
    alpha_val=0.6,
    legend_loc="upper right",
)

# Plot KL Divergence
plot_metric_histograms(
    metric_name="kld",
    pred_length="128",
    horizon_name="full",  # between full_traj and predicrtions
    model_labels=["Panda", "Chronos 20M SFT", "Chronos 20M"],
    model_keys=["panda", "chronos_sft", "chronos_zs"],
    bins=20,
    log_scale=True,
    # xlabel="Average KL Divergence",
    ylabel="Count",
    title="KL Divergence ($L_{\\mathrm{pred}} = 128$)",
    filename=os.path.join(
        "../figures",
        "kld_distribution_pred_with_context_128_log.pdf",
    ),
    alpha_val=0.6,
    legend_loc="upper left",
)