In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from panda.utils.plot_utils import apply_custom_style

apply_custom_style("../../config/plotting.yaml")

In [None]:
fig_save_dir = os.path.join("../../figures", "eval_metrics")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
DEFAULT_COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
use_chronos_deterministic = False
if use_chronos_deterministic:
    chronos_dirname = "chronos"
else:
    chronos_dirname = "chronos_nondeterministic"


def get_sorted_metric_fnames(save_dir):
    fnames = [f for f in os.listdir(save_dir) if f.endswith(".json") and "distributional_metrics"]

    def extract_window(fname):
        m = re.search(r"window-(\d+)", fname)
        return int(m.group(1)) if m else float("inf")

    return sorted(fnames, key=extract_window)


panda_metrics_save_dir = f"{WORK_DIR}/eval_results_old/panda/pft_chattn_emb_w_poly-0/test_zeroshot/metrics_run1"
# NOTE: we also have for chronos_nondeterministic, replace "chronos" with "chronos_nondeterministic" in the paths below
chronos_sft_metrics_save_dir = (
    f"{WORK_DIR}/eval_results_old/{chronos_dirname}/chronos_t5_mini_ft-0/test_zeroshot/metrics_run1"
)
chronos_zs_metrics_save_dir = (
    f"{WORK_DIR}/eval_results_old/{chronos_dirname}/chronos_mini_zeroshot/test_zeroshot/metrics_run1"
)

panda_metrics_fnames = get_sorted_metric_fnames(panda_metrics_save_dir)
chronos_sft_metrics_fnames = get_sorted_metric_fnames(chronos_sft_metrics_save_dir)
chronos_zs_metrics_fnames = get_sorted_metric_fnames(chronos_zs_metrics_save_dir)

print(f"Found {len(panda_metrics_fnames)} panda metrics files: {panda_metrics_fnames}")
print(f"Found {len(chronos_sft_metrics_fnames)} chronos sft metrics files: {chronos_sft_metrics_fnames}")
print(f"Found {len(chronos_zs_metrics_fnames)} chronos zs metrics files: {chronos_zs_metrics_fnames}")

In [None]:
metrics_fpath = os.path.join(chronos_sft_metrics_save_dir, chronos_zs_metrics_fnames[0])
with open(metrics_fpath, "rb") as f:
    metrics = json.load(f)

print(metrics.keys())

In [None]:
# For accumulating values across all files, for both panda and chronos_sft metrics


def filter_none(values):
    """Remove None values from a list."""
    return [v for v in values if v is not None]


def accumulate_metrics(metrics_fnames, metrics_save_dir):
    avg_hellinger_accum = {
        "pred_horizon": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    kld_accum = {
        "pred_horizon": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    prediction_time_accum = defaultdict(list)

    for fname in metrics_fnames:
        with open(os.path.join(metrics_save_dir, fname), "rb") as f:
            metrics = json.load(f)
        n_pred_intervals = len(metrics)
        print(f"number of prediction intervals in {fname}: {n_pred_intervals}")
        for pred_interval in metrics:
            print(pred_interval)
            data = metrics[pred_interval]
            for system_name, system_entry in tqdm(data, desc=f"Processing {pred_interval}"):
                avg_hellinger_accum["pred_horizon"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["avg_hellinger_distance"]
                )
                avg_hellinger_accum["full"][pred_interval][system_name].append(
                    system_entry["full_trajectory"]["avg_hellinger_distance"]
                )
                kld_accum["pred_horizon"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["kl_divergence"]
                )
                kld_accum["full"][pred_interval][system_name].append(system_entry["full_trajectory"]["kl_divergence"])
                pred_time = system_entry["prediction_time"]
                prediction_time_accum[system_name].append(pred_time)

    # Now, take the mean across all files for each metric, skipping None values
    avg_hellinger = {
        "pred_horizon": defaultdict(dict),
        "full": defaultdict(dict),
    }
    kld = {
        "pred_horizon": defaultdict(dict),
        "full": defaultdict(dict),
    }
    prediction_time = {}

    for key in ["pred_horizon", "full"]:
        for pred_interval in avg_hellinger_accum[key]:
            for system_name, values in avg_hellinger_accum[key][pred_interval].items():
                filtered = filter_none(values)
                avg_hellinger[key][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None
        for pred_interval in kld_accum[key]:
            for system_name, values in kld_accum[key][pred_interval].items():
                filtered = filter_none(values)
                kld[key][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None

    for system_name, times in prediction_time_accum.items():
        times_arr = np.array(filter_none(times))
        prediction_time[system_name] = np.mean(times_arr) if len(times_arr) > 0 else None

    return {
        "avg_hellinger": avg_hellinger,
        "kld": kld,
        "prediction_time": prediction_time,
    }


# Accumulate metrics for both panda and chronos_sft
print("Accumulating panda metrics...")
panda_metrics = accumulate_metrics(panda_metrics_fnames, panda_metrics_save_dir)
print("Accumulating chronos_sft metrics...")
chronos_sft_metrics = accumulate_metrics(chronos_sft_metrics_fnames, chronos_sft_metrics_save_dir)
chronos_zs_metrics = accumulate_metrics(chronos_zs_metrics_fnames, chronos_zs_metrics_save_dir)

In [None]:
models = ["panda", "chronos_sft", "chronos_zs"]
metrics = {k: {m: eval(f"{m}_metrics")[k] for m in models} for k in ["avg_hellinger", "kld", "prediction_time"]}

In [None]:
metrics["avg_hellinger"]["panda"].keys()

In [None]:
pred_length = "512"
horizon_name = "pred_horizon"

show_chronos_zs = False


def filter_nans(values):
    arr = [float(v) for v in values if v is not None and not (isinstance(v, float) and np.isnan(v))]
    return np.array(arr, dtype=float)


avg_hellinger = {
    "Panda": filter_nans(metrics["avg_hellinger"]["panda"][horizon_name][pred_length].values()),
    "Chronos 20M SFT": filter_nans(metrics["avg_hellinger"]["chronos_sft"][horizon_name][pred_length].values()),
    "Chronos 20M": filter_nans(metrics["avg_hellinger"]["chronos_zs"][horizon_name][pred_length].values()),
}

num_bins = 50
plt.figure(figsize=(4, 4))
all_hellinger = np.concatenate(list(avg_hellinger.values()))
print(f"min hellinger: {all_hellinger.min()}, max hellinger: {all_hellinger.max()}")
bins = np.histogram_bin_edges(all_hellinger, bins=num_bins)
alpha_val = 0.6

for i, (label, vals) in enumerate(avg_hellinger.items()):
    if not show_chronos_zs and label == "Chronos 20M":
        continue
    color = DEFAULT_COLORS[i] if i < len(DEFAULT_COLORS) else f"tab:{['blue', 'orange', 'green'][i % 3]}"
    plt.hist(
        vals,
        bins=bins,
        color=color,
        edgecolor=color,
        alpha=alpha_val,
        zorder=10 - i,
        histtype="stepfilled",
        label=label,
    )

plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper right")
plt.title(f"Avg Hellinger Distance ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.savefig(
    os.path.join(fig_save_dir, f"avg_hellinger_distribution_{horizon_name}_{pred_length}.pdf"),
    bbox_inches="tight",
)
plt.show()

In [None]:
metrics["kld"]["panda"].keys()

In [None]:
pred_length = "512"
horizon_name = "full"

show_chronos_zs = False


# Extract and filter positive KL divergence values
def pos_vals(vals):
    return [x for x in vals if x > 0]


kld_dict = {
    "Panda": pos_vals(metrics["kld"]["panda"][horizon_name][pred_length].values()),
    "Chronos 20M SFT": pos_vals(metrics["kld"]["chronos_sft"][horizon_name][pred_length].values()),
    "Chronos 20M": pos_vals(metrics["kld"]["chronos_zs"][horizon_name][pred_length].values()),
}

all_kld_pos = np.concatenate(list(kld_dict.values()))
num_bins = 50
if len(all_kld_pos) > 0:
    bins = np.linspace(all_kld_pos.min(), all_kld_pos.max(), num_bins)
else:
    bins = num_bins
    print("No positive values found")

plt.figure(figsize=(4, 4))
alpha_val = 0.6
for i, (label, vals) in enumerate(kld_dict.items()):
    if not show_chronos_zs and label == "Chronos 20M":
        continue
    color = DEFAULT_COLORS[i] if i < len(DEFAULT_COLORS) else f"tab:{['blue', 'orange', 'green'][i % 3]}"
    plt.hist(
        vals,
        bins=bins,
        color=color,
        edgecolor=color,
        alpha=alpha_val,
        histtype="stepfilled",
        label=label,
        zorder=10 - i,
    )
plt.yscale("log")
plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper right")
plt.title(f"KL Divergence ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.savefig(os.path.join(fig_save_dir, f"kld_distribution_{horizon_name}_{pred_length}_log.pdf"), bbox_inches="tight")
plt.show()

In [None]:
pred_length = "512"
horizon_name = "full"

full_kld_panda = list(metrics["kld"]["panda"][horizon_name][pred_length].values())
full_kld_chronos_sft = list(metrics["kld"]["chronos_sft"][horizon_name][pred_length].values())
full_kld_chronos_zs = list(metrics["kld"]["chronos_zs"][horizon_name][pred_length].values())

# Plot difference between Chronos SFT and Panda KL divergences
kld_diff = np.array(full_kld_chronos_sft) - np.array(full_kld_panda)
# kld_diff = np.array(full_kld_panda) - np.array(full_kld_chronos_sft)

plt.figure(figsize=(4, 4))
plt.hist(kld_diff, bins=30, color="gray", edgecolor="black", alpha=0.7, histtype="stepfilled")
plt.axvline(0, color="k", linestyle="dotted", linewidth=1.5)
plt.xlabel("$D_{{KL}}$ (Chronos SFT - Panda)", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.title(f"Difference in $D_{{KL}}$ ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.yscale("log")
plt.show()

In [None]:
mean_kld_diff = np.mean(np.array(full_kld_chronos_sft) - np.array(full_kld_panda))
std_kld_diff = np.std(np.array(full_kld_chronos_sft) - np.array(full_kld_panda))
print(f"Mean KL diff: {mean_kld_diff:.4f}, Std KL diff: {std_kld_diff:.4f}")

In [None]:
mean_kld_diff = np.mean(np.array(full_kld_chronos_zs) - np.array(full_kld_chronos_sft))
std_kld_diff = np.std(np.array(full_kld_chronos_zs) - np.array(full_kld_chronos_sft))
print(f"Mean KL diff: {mean_kld_diff:.4f}, Std KL diff: {std_kld_diff:.4f}")

In [None]:
mean_kld_diff = np.mean(np.array(full_kld_chronos_zs) - np.array(full_kld_panda))
std_kld_diff = np.std(np.array(full_kld_chronos_zs) - np.array(full_kld_panda))
print(f"Mean KL diff: {mean_kld_diff:.4f}, Std KL diff: {std_kld_diff:.4f}")

In [None]:
pred_lengths = ["128", "256", "512"]
horizon_name = "full"
pairs = [
    ("Chronos SFT - Panda", "chronos_sft", "panda"),
    ("Chronos ZS - Chronos SFT", "chronos_zs", "chronos_sft"),
    ("Chronos ZS - Panda", "chronos_zs", "panda"),
]
for pred_length in pred_lengths:
    klds = {key: np.array(list(metrics["kld"][key][horizon_name][pred_length].values())) for _, key, _ in pairs}
    klds["panda"] = np.array(list(metrics["kld"]["panda"][horizon_name][pred_length].values()))
    for label, key1, key2 in pairs:
        diff = klds[key1] - klds[key2]
        print(f"Mean KL diff ({label}): {diff.mean():.4f}, Std KL diff: {diff.std():.4f}")

In [None]:
pred_lengths = ["128", "256", "512"]
horizon_name = "pred_horizon"
for pred_length in pred_lengths:
    hells = {
        key: np.array(list(metrics["avg_hellinger"][key][horizon_name][pred_length].values())) for _, key, _ in pairs
    }
    hells["panda"] = np.array(list(metrics["avg_hellinger"]["panda"][horizon_name][pred_length].values()))
    for label, key1, key2 in pairs:
        diff = hells[key1] - hells[key2]
        print(f"Mean avg_hellinger diff ({label}): {diff.mean():.2f}, Std avg_hellinger diff: {diff.std():.2f}")

### Summarize

In [None]:
horizon_name = "pred_horizon"
print(f"horizon_name: {horizon_name}")
print("KL Divergence")
for model_key in ["panda", "chronos_sft", "chronos_zs"]:
    print(f"Model: {model_key}")
    for pred_length in pred_lengths:
        kld_values = np.array(list(metrics["kld"][model_key][horizon_name][pred_length].values()))
        kld_values_no_nan = kld_values[~np.isnan(kld_values)]
        mean_kld = kld_values_no_nan.mean()
        std_kld = kld_values_no_nan.std()
        print(f"  Prediction length {pred_length}: mean kld = {mean_kld:.4f}, std kld = {std_kld:.4f}")

In [None]:
horizon_name = "full"
print(f"horizon_name: {horizon_name}")
print("KL Divergence")
for model_key in ["panda", "chronos_sft", "chronos_zs"]:
    print(f"Model: {model_key}")
    for pred_length in pred_lengths:
        kld_values = np.array(list(metrics["kld"][model_key][horizon_name][pred_length].values()))
        kld_values_no_nan = kld_values[~np.isnan(kld_values)]
        mean_kld = kld_values_no_nan.mean()
        std_kld = kld_values_no_nan.std()
        print(f"  Prediction length {pred_length}: mean kld = {mean_kld:.4f}, std kld = {std_kld:.4f}")