In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from panda.utils.plot_utils import apply_custom_style

apply_custom_style("../../config/plotting.yaml")

In [None]:
fig_save_dir = os.path.join("../../figures", "eval_metrics")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
DEFAULT_COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")
eval_results_dir = os.path.join(WORK_DIR, "eval_results_distributional")
data_split = "test_zeroshot"
run_name = None

In [None]:
use_chronos_deterministic = False
chronos_dirname = "chronos" if use_chronos_deterministic else "chronos_nondeterministic"

print(f"Using {chronos_dirname} for chronos metrics")


def get_sorted_metric_fnames(save_dir):
    def extract_window(fname):
        m = re.search(r"window-(\d+)", fname)
        return int(m.group(1)) if m else float("inf")

    return sorted(
        [f for f in os.listdir(save_dir) if f.endswith(".json") and "distributional_metrics" in f], key=extract_window
    )


run_suffix = run_name if run_name else ""

metrics_save_dirs = {
    "Panda": os.path.join(eval_results_dir, "panda", "panda-21M", data_split, run_suffix),
    "Chronos 20M SFT": os.path.join(eval_results_dir, chronos_dirname, "chronos_t5_mini_ft-0", data_split, run_suffix),
    "Chronos 20M": os.path.join(eval_results_dir, chronos_dirname, "chronos_mini_zeroshot", data_split, run_suffix),
    "Dynamix": os.path.join(eval_results_dir, "dynamix", "dynamix", data_split, run_suffix),
}

metrics_fnames = {}
for model_name, save_dir in metrics_save_dirs.items():
    print(f"Loading {model_name} metrics from: {save_dir}")
    metrics_fnames[model_name] = get_sorted_metric_fnames(save_dir)
    print(f"Found {len(metrics_fnames)} {model_name} metrics files: {metrics_fnames}")

In [None]:
# Example metrics file
metrics_fpath = os.path.join(metrics_save_dirs["Panda"], metrics_fnames["Panda"][0])
with open(metrics_fpath, "rb") as f:
    metrics = json.load(f)

# Convert string keys to integers
metrics = {int(k): v for k, v in metrics.items()}

print(metrics.keys())

In [None]:
metrics[1024][0]

In [None]:
def accumulate_metrics(metrics_fnames, metrics_save_dir):
    """Accumulate distributional metrics across multiple files."""
    HORIZONS = ["prediction_horizon", "full_trajectory"]
    METRICS = ["avg_hellinger_distance", "kl_divergence"]

    # Initialize accumulators
    accum = {metric: {horizon: defaultdict(lambda: defaultdict(list)) for horizon in HORIZONS} for metric in METRICS}
    prediction_time_accum = defaultdict(list)

    # Accumulate values from all files
    for fname in metrics_fnames:
        with open(os.path.join(metrics_save_dir, fname), "rb") as f:
            metrics = json.load(f)
        metrics = {int(k) if isinstance(k, str) else k: v for k, v in metrics.items()}

        print(f"Processing {fname}: {len(metrics)} prediction interval(s)")

        for pred_interval, data in metrics.items():
            for system_name, system_entry in tqdm(data, desc=f"Interval {pred_interval}"):
                # Process each horizon
                for horizon in HORIZONS:
                    if horizon in system_entry:
                        for metric in METRICS:
                            accum[metric][horizon][pred_interval][system_name].append(system_entry[horizon][metric])

                # Accumulate prediction time
                if "prediction_time" in system_entry:
                    prediction_time_accum[system_name].append(system_entry["prediction_time"])

    # Compute means, filtering None values
    def compute_means(data_accum):
        result = {horizon: defaultdict(dict) for horizon in HORIZONS}
        for horizon in HORIZONS:
            for pred_interval, systems in data_accum[horizon].items():
                for system_name, values in systems.items():
                    filtered = [v for v in values if v is not None]
                    result[horizon][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None
        return result

    # Compute prediction time means
    prediction_time = {
        system: float(np.mean([t for t in times if t is not None])) if times else None
        for system, times in prediction_time_accum.items()
    }

    return {
        "avg_hellinger": compute_means(accum["avg_hellinger_distance"]),
        "kld": compute_means(accum["kl_divergence"]),
        "prediction_time": prediction_time,
    }


metrics_by_modelname = {}
for model_name in metrics_save_dirs.keys():
    print(f"Accumulating {model_name} metrics...")
    metrics = accumulate_metrics(metrics_fnames[model_name], metrics_save_dirs[model_name])
    metrics_by_modelname[model_name] = metrics
    print(f"Accumulated {model_name} metrics")

In [None]:
metrics = {
    k: {m: eval(f"{m}_metrics")[k] for m in metrics_save_dirs.keys()}
    for k in ["avg_hellinger", "kld", "prediction_time"]
}


In [None]:
# first_system = list(metrics["prediction_time"]["panda"].keys())[0]
# metrics["prediction_time"]["panda"].pop(first_system)
# metrics["prediction_time"]["chronos_sft"].pop(first_system)
# metrics["prediction_time"]["chronos_zs"].pop(first_system)
# print(metrics["prediction_time"]["panda"])
# print(metrics["prediction_time"]["chronos_sft"])
# print(metrics["prediction_time"]["chronos_zs"])

In [None]:
# # Print prediction time mean and std for both panda and chronos_sft

# for model_name in ["panda", "chronos_sft", "chronos_zs"]:
#     prediction_times = list(metrics["prediction_time"][model_name].values())
#     prediction_time_mean = np.mean(prediction_times)
#     prediction_time_std = np.std(prediction_times)
#     print(f"{model_name} prediction time mean:", prediction_time_mean)
#     print(f"{model_name} prediction time std:", prediction_time_std)

In [None]:
metrics["avg_hellinger"]["panda"]["prediction_horizon"].keys()

In [None]:
values = list(metrics["avg_hellinger"]["panda"]["prediction_horizon"][1024].values())
num_nones = sum(v is None for v in values)
num_nans = sum(np.isnan(v) for v in values if v is not None)
print(f"Number of None values: {num_nones}")
print(f"Number of NaN values: {num_nans}")

In [None]:
pred_length = 1024
horizon_name = "prediction_horizon"

show_chronos_zs = False


def filter_nans(values):
    arr = [float(v) for v in values if v is not None and not (isinstance(v, float) and np.isnan(v))]
    return np.array(arr, dtype=float)


avg_hellinger = {
    "Panda": filter_nans(metrics["avg_hellinger"]["panda"][horizon_name][pred_length].values()),
    "Chronos 20M SFT": filter_nans(metrics["avg_hellinger"]["chronos_sft"][horizon_name][pred_length].values()),
    "Chronos 20M": filter_nans(metrics["avg_hellinger"]["chronos_zs"][horizon_name][pred_length].values()),
    "Dynamix": filter_nans(metrics["avg_hellinger"]["dynamix"][horizon_name][pred_length].values()),
}

# colors = DEFAULT_COLORS
colors = DEFAULT_COLORS[:3] + ["olive"]
print(colors)

num_bins = 50
plt.figure(figsize=(4, 4))
all_hellinger = np.concatenate(list(avg_hellinger.values()))
print(f"min hellinger: {all_hellinger.min()}, max hellinger: {all_hellinger.max()}")
bins = np.histogram_bin_edges(all_hellinger, bins=num_bins)
alpha_val = 0.6

for i, (label, vals) in enumerate(avg_hellinger.items()):
    if not show_chronos_zs and label == "Chronos 20M":
        continue
    plt.hist(
        vals,
        bins=bins,
        color=colors[i],
        edgecolor=colors[i],
        alpha=alpha_val,
        zorder=10 - i,
        histtype="stepfilled",
        label=label,
    )

plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper right")
plt.title(f"Avg Hellinger Distance ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.savefig(
    os.path.join(fig_save_dir, f"avg_hellinger_distribution_{horizon_name}_{pred_length}.pdf"),
    bbox_inches="tight",
)
plt.show()

In [None]:
metrics["kld"]["panda"].keys()

In [None]:
pred_length = 1024
horizon_name = "full_trajectory"

show_chronos_zs = False


# Extract and filter positive KL divergence values
def pos_vals(vals):
    return [x for x in vals if x is not None and x > 0]


kld_dict = {
    "Panda": pos_vals(metrics["kld"]["panda"][horizon_name][pred_length].values()),
    "Chronos 20M SFT": pos_vals(metrics["kld"]["chronos_sft"][horizon_name][pred_length].values()),
    "Chronos 20M": pos_vals(metrics["kld"]["chronos_zs"][horizon_name][pred_length].values()),
    "Dynamix": pos_vals(metrics["kld"]["dynamix"][horizon_name][pred_length].values()),
}

all_kld_pos = np.concatenate(list(kld_dict.values()))
num_bins = 50
if len(all_kld_pos) > 0:
    bins = np.linspace(all_kld_pos.min(), all_kld_pos.max(), num_bins)
else:
    bins = num_bins
    print("No positive values found")

plt.figure(figsize=(4, 4))
alpha_val = 0.6
for i, (label, vals) in enumerate(kld_dict.items()):
    if not show_chronos_zs and label == "Chronos 20M":
        continue
    color = DEFAULT_COLORS[i] if i < len(DEFAULT_COLORS) else f"tab:{['blue', 'orange', 'green'][i % 3]}"
    plt.hist(
        vals,
        bins=bins,
        color=color,
        edgecolor=color,
        alpha=alpha_val,
        histtype="stepfilled",
        label=label,
        zorder=10 - i,
    )
plt.yscale("log")
plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper right")
plt.title(f"KL Divergence ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.savefig(os.path.join(fig_save_dir, f"kld_distribution_{horizon_name}_{pred_length}_log.pdf"), bbox_inches="tight")
plt.show()

In [None]:
pred_length = 1024
horizon_name = "full_trajectory"

full_kld_panda = list(metrics["kld"]["panda"][horizon_name][pred_length].values())
full_kld_chronos_sft = list(metrics["kld"]["chronos_sft"][horizon_name][pred_length].values())
full_kld_chronos_zs = list(metrics["kld"]["chronos_zs"][horizon_name][pred_length].values())

# Plot difference between Chronos SFT and Panda KL divergences
# Filter out None values (keep only pairs where both values are not None)
kld_pairs = [(c, p) for c, p in zip(full_kld_chronos_sft, full_kld_panda) if c is not None and p is not None]
num_skipped_pairs = len(full_kld_chronos_sft) - len(kld_pairs)
print(f"Skipped {num_skipped_pairs} pairs due to None values")
kld_diff = np.array([c - p for c, p in kld_pairs])

plt.figure(figsize=(4, 4))
plt.hist(kld_diff, bins=30, color="gray", edgecolor="black", alpha=0.7, histtype="stepfilled")
plt.axvline(0, color="k", linestyle="dotted", linewidth=1.5)
plt.xlabel("$D_{{KL}}$ (Chronos SFT - Panda)", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.title(f"Difference in $D_{{KL}}$ ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.yscale("log")
plt.show()

In [None]:
pred_lengths = [1024]
horizon_name = "full_trajectory"
pairs = [
    ("Chronos SFT - Panda", "chronos_sft", "panda"),
    ("Chronos ZS - Chronos SFT", "chronos_zs", "chronos_sft"),
    ("Chronos ZS - Panda", "chronos_zs", "panda"),
    ("Dynamix - Panda", "dynamix", "panda"),
]
for pred_length in pred_lengths:
    print(f"Prediction length: {pred_length}")
    # Get all keys and ensure they exist
    all_keys = set(["panda", "chronos_sft", "chronos_zs", "dynamix"])
    klds = {}
    for key in all_keys:
        klds[key] = np.array(list(metrics["kld"][key][horizon_name][pred_length].values()))

    for label, key1, key2 in pairs:
        # Filter out None values when computing difference
        valid_pairs = [(v1, v2) for v1, v2 in zip(klds[key1], klds[key2]) if v1 is not None and v2 is not None]
        if valid_pairs:
            diff = np.array([v1 - v2 for v1, v2 in valid_pairs])
            print(f"Mean KL diff ({label}): {diff.mean():.4f}, Std KL diff: {diff.std():.4f}")
        else:
            print(f"No valid data for {label}")

In [None]:
horizon_name = "prediction_horizon"
for pred_length in pred_lengths:
    print(f"Prediction length: {pred_length}")
    # Get all keys
    all_keys = set(["panda", "chronos_sft", "chronos_zs", "dynamix"])
    hells = {}
    for key in all_keys:
        hells[key] = np.array(list(metrics["avg_hellinger"][key][horizon_name][pred_length].values()))

    for label, key1, key2 in pairs:
        # Filter out None and NaN values when computing difference
        valid_pairs = [
            (v1, v2)
            for v1, v2 in zip(hells[key1], hells[key2])
            if v1 is not None and v2 is not None and not np.isnan(v1) and not np.isnan(v2)
        ]
        if valid_pairs:
            diff = np.array([v1 - v2 for v1, v2 in valid_pairs])
            print(f"Mean avg_hellinger diff ({label}): {diff.mean():.4f}, Std avg_hellinger diff: {diff.std():.4f}")
        else:
            print(f"No valid data for {label}")

In [None]:
print("Avg Hellinger")
for model_key in ["panda", "chronos_sft", "chronos_zs", "dynamix"]:
    print(f"Model: {model_key}")
    for pred_length in pred_lengths:
        hell_values = list(metrics["avg_hellinger"][model_key][horizon_name][pred_length].values())
        # Filter out None and NaN values
        hell_values_filtered = [v for v in hell_values if v is not None and not np.isnan(v)]
        if hell_values_filtered:
            mean_hell = np.mean(hell_values_filtered)
            std_hell = np.std(hell_values_filtered)
            print(
                f"  Prediction length {pred_length}: mean avg_hellinger = {mean_hell:.4f}, std avg_hellinger = {std_hell:.4f}"
            )
        else:
            print(f"  Prediction length {pred_length}: No valid data")

In [None]:
print("KL Divergence")
for model_key in ["panda", "chronos_sft", "chronos_zs", "dynamix"]:
    print(f"Model: {model_key}")
    for pred_length in pred_lengths:
        hell_values = list(metrics["kld"][model_key][horizon_name][pred_length].values())
        # Filter out None and NaN values
        hell_values_filtered = [v for v in hell_values if v is not None and not np.isnan(v)]
        if hell_values_filtered:
            mean_hell = np.mean(hell_values_filtered)
            std_hell = np.std(hell_values_filtered)
            print(f"  Prediction length {pred_length}: mean kld = {mean_hell:.4f}, std kld = {std_hell:.4f}")
        else:
            print(f"  Prediction length {pred_length}: No valid data")