In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from panda.utils.plot_utils import apply_custom_style

apply_custom_style("../../config/plotting.yaml")

In [None]:
fig_save_dir = os.path.join("../../figures", "eval_metrics")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
DEFAULT_COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")
eval_results_dir = os.path.join(WORK_DIR, "eval_results_distributional")
data_split = "test_zeroshot"
run_name = "fdiv"

In [None]:
use_chronos_deterministic = False
chronos_dirname = "chronos" if use_chronos_deterministic else "chronos_nondeterministic"

print(f"Using {chronos_dirname} for chronos metrics")


def get_sorted_metric_fnames(save_dir):
    def extract_window(fname):
        m = re.search(r"window-(\d+)", fname)
        return int(m.group(1)) if m else float("inf")

    return sorted(
        [f for f in os.listdir(save_dir) if f.endswith(".json") and "distributional_metrics" in f], key=extract_window
    )


run_suffix = run_name if run_name else ""

metrics_save_dirs = {
    "Panda": os.path.join(eval_results_dir, "panda", "panda-21M", data_split, run_suffix),
    "Chronos 20M SFT": os.path.join(eval_results_dir, chronos_dirname, "chronos_t5_mini_ft-0", data_split, run_suffix),
    "Chronos 20M": os.path.join(eval_results_dir, chronos_dirname, "chronos_mini_zeroshot", data_split, run_suffix),
    "Chronos 200M": os.path.join(eval_results_dir, chronos_dirname, "chronos_base_zeroshot", data_split, run_suffix),
    "Dynamix": os.path.join(eval_results_dir, "dynamix", "dynamix", data_split, run_suffix),
}
model_run_names = list(metrics_save_dirs.keys())

metrics_fnames = {}
for model_name, save_dir in metrics_save_dirs.items():
    print(f"Loading {model_name} metrics from: {save_dir}")
    found_fnames = get_sorted_metric_fnames(save_dir)
    metrics_fnames[model_name] = found_fnames
    print(f"Found {len(found_fnames)} {model_name} metrics files: {found_fnames}")

In [None]:
# Example metrics file
metrics_fpath = os.path.join(metrics_save_dirs["Panda"], metrics_fnames["Panda"][0])
with open(metrics_fpath, "rb") as f:
    metrics = json.load(f)

# Convert string keys to integers
metrics = {int(k): v for k, v in metrics.items()}

print(metrics.keys())

In [None]:
metrics[1024][0]

In [None]:
def accumulate_metrics(metrics_fnames, metrics_save_dir):
    """Accumulate distributional metrics across multiple files."""
    HORIZONS = ["prediction_horizon", "full_trajectory"]
    METRICS = ["avg_hellinger_distance", "kl_divergence"]

    # Initialize accumulators
    accum = {metric: {horizon: defaultdict(lambda: defaultdict(list)) for horizon in HORIZONS} for metric in METRICS}
    prediction_time_accum = defaultdict(list)

    # Accumulate values from all files
    for fname in metrics_fnames:
        with open(os.path.join(metrics_save_dir, fname), "rb") as f:
            metrics = json.load(f)
        metrics = {int(k) if isinstance(k, str) else k: v for k, v in metrics.items()}

        print(f"Processing {fname}: {len(metrics)} prediction interval(s)")

        for pred_interval, data in metrics.items():
            for system_name, system_entry in tqdm(data, desc=f"Interval {pred_interval}"):
                # Process each horizon
                for horizon in HORIZONS:
                    if horizon in system_entry:
                        for metric in METRICS:
                            accum[metric][horizon][pred_interval][system_name].append(system_entry[horizon][metric])

                # Accumulate prediction time
                if "prediction_time" in system_entry:
                    prediction_time_accum[system_name].append(system_entry["prediction_time"])

    # Compute means, filtering None values
    def compute_means(data_accum):
        result = {horizon: defaultdict(dict) for horizon in HORIZONS}
        for horizon in HORIZONS:
            for pred_interval, systems in data_accum[horizon].items():
                for system_name, values in systems.items():
                    filtered = [v for v in values if v is not None]
                    result[horizon][pred_interval][system_name] = float(np.mean(filtered)) if filtered else None
        return result

    # Compute prediction time means
    prediction_time = {
        system: float(np.mean([t for t in times if t is not None])) if times else None
        for system, times in prediction_time_accum.items()
    }

    return {
        "avg_hellinger": compute_means(accum["avg_hellinger_distance"]),
        "kld": compute_means(accum["kl_divergence"]),
        "prediction_time": prediction_time,
    }


metrics_by_modelname = {}
for model_name in metrics_save_dirs.keys():
    print(f"Accumulating {model_name} metrics...")
    metrics = accumulate_metrics(metrics_fnames[model_name], metrics_save_dirs[model_name])
    metrics_by_modelname[model_name] = metrics
    print(f"Accumulated {model_name} metrics")

In [None]:
metrics = {
    k: {m: metrics_by_modelname[m][k] for m in metrics_save_dirs.keys()}
    for k in ["avg_hellinger", "kld", "prediction_time"]
}

In [None]:
metrics["avg_hellinger"]["Panda"]["prediction_horizon"].keys()

In [None]:
values = list(metrics["avg_hellinger"]["Panda"]["prediction_horizon"][1024].values())
num_nones = sum(v is None for v in values)
num_nans = sum(np.isnan(v) for v in values if v is not None)
print(f"Number of None values: {num_nones}")
print(f"Number of NaN values: {num_nans}")

In [None]:
pred_length = 1024
horizon_name = "prediction_horizon"

show_chronos_zs = False
show_chronos_sft = True


def filter_nans(values):
    arr = [float(v) for v in values if v is not None and not (isinstance(v, float) and np.isnan(v))]
    return np.array(arr, dtype=float)


avg_hellinger = {}
for model_key in model_run_names:
    avg_hellinger[model_key] = filter_nans(metrics["avg_hellinger"][model_key][horizon_name][pred_length].values())

# colors = DEFAULT_COLORS
colors = DEFAULT_COLORS[:4] + ["#FFB5B8"]
print(colors)

num_bins = 50
plt.figure(figsize=(4, 4))
all_hellinger = np.concatenate(list(avg_hellinger.values()))
print(f"min hellinger: {all_hellinger.min()}, max hellinger: {all_hellinger.max()}")
bins = np.histogram_bin_edges(all_hellinger, bins=num_bins)
alpha_val = 0.7

for i, (label, vals) in enumerate(avg_hellinger.items()):
    if not show_chronos_zs and label in ["Chronos 200M", "Chronos 20M"]:
        continue
    if not show_chronos_sft and label == "Chronos 20M SFT":
        continue
    zorder = 1
    if label == "Dynamix":
        zorder = 5
    elif label == "Panda":
        zorder = 10
    plt.hist(
        vals,
        bins=bins,
        color=colors[i],
        edgecolor=colors[i],
        alpha=alpha_val,
        zorder=zorder,
        histtype="stepfilled",
        label=label,
    )

plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper left", frameon=True)
plt.title(f"Avg Hellinger Distance ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.savefig(
    os.path.join(fig_save_dir, f"avg_hellinger_distribution_{horizon_name}_{pred_length}.pdf"),
    bbox_inches="tight",
)
plt.show()

In [None]:
pred_length = 1024
horizon_name = "prediction_horizon"

show_chronos_zs = False


# Extract and filter positive KL divergence values
def pos_vals(vals):
    return [x for x in vals if x is not None and x > 0]


kld_dict = {}
for model_key in model_run_names:
    kld_dict[model_key] = pos_vals(metrics["kld"][model_key][horizon_name][pred_length].values())

all_kld_pos = np.concatenate(list(kld_dict.values()))
num_bins = 50
if len(all_kld_pos) > 0:
    bins = np.linspace(all_kld_pos.min(), all_kld_pos.max(), num_bins)
else:
    bins = num_bins
    print("No positive values found")

plt.figure(figsize=(4, 4))
alpha_val = 0.6
for i, (label, vals) in enumerate(kld_dict.items()):
    if not show_chronos_zs and label in ["Chronos 200M", "Chronos 20M"]:
        continue
    plt.hist(
        vals,
        bins=bins,
        color=colors[i],
        edgecolor=colors[i],
        alpha=alpha_val,
        histtype="stepfilled",
        label=label,
        zorder=10 - i,
    )
plt.yscale("log")
plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper right")
plt.title(f"KL Divergence ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.savefig(os.path.join(fig_save_dir, f"kld_distribution_{horizon_name}_{pred_length}_log.pdf"), bbox_inches="tight")
plt.show()

In [None]:
pred_length = 1024
horizon_name = "prediction_horizon"

# Extract KL divergences for each model
full_kld_dict = {
    model_key: pos_vals(metrics["kld"][model_key][horizon_name][pred_length].values()) for model_key in model_run_names
}

# Compute difference between Chronos SFT and Panda
kld_diff = np.array(
    [c - p for c, p in zip(full_kld_dict["Chronos 20M SFT"], full_kld_dict["Panda"]) if c is not None and p is not None]
)

plt.figure(figsize=(4, 4))
plt.hist(kld_diff, bins=30, color="gray", edgecolor="black", alpha=0.7, histtype="stepfilled")
plt.axvline(0, color="k", linestyle="dotted", linewidth=1.5)
plt.xlabel("$D_{{KL}}$ (Chronos SFT - Panda)", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.title(f"Difference in $D_{{KL}}$ ($L_{{\\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold")
plt.tight_layout()
plt.yscale("log")
plt.show()

In [None]:
pred_lengths = [128, 256, 512, 1024]
horizon_name = "prediction_horizon"
print(f"horizon_name: {horizon_name}")
# Determine which Chronos base model is available

pairs = [
    ("Chronos 20M SFT - Panda", "Chronos 20M SFT", "Panda"),
    ("Chronos 20M - Chronos 20M SFT", "Chronos 20M", "Chronos 20M SFT"),
    ("Chronos 200M - Chronos 20M SFT", "Chronos 200M", "Chronos 20M SFT"),
    ("Dynamix - Panda", "Dynamix", "Panda"),
]
for pred_length in pred_lengths:
    print(f"Prediction length: {pred_length}")
    # Get all keys and ensure they exist
    all_keys = set(model_run_names)
    klds = {}
    for key in all_keys:
        klds[key] = np.array(list(metrics["kld"][key][horizon_name][pred_length].values()))

    for label, key1, key2 in pairs:
        # Filter out None values when computing difference
        valid_pairs = [(v1, v2) for v1, v2 in zip(klds[key1], klds[key2]) if v1 is not None and v2 is not None]
        if valid_pairs:
            diff = np.array([v1 - v2 for v1, v2 in valid_pairs])
            print(f"Mean KL diff ({label}): {diff.mean():.4f}, Std KL diff: {diff.std():.4f}")
        else:
            print(f"No valid data for {label}")

In [None]:
pred_lengths = [128, 256, 512, 1024]
horizon_name = "prediction_horizon"
print(f"horizon_name: {horizon_name}")
# Determine which Chronos base model is available

pairs = [
    ("Chronos 20M SFT - Panda", "Chronos 20M SFT", "Panda"),
    ("Chronos 20M - Chronos 20M SFT", "Chronos 20M", "Chronos 20M SFT"),
    ("Chronos 200M - Chronos 20M SFT", "Chronos 200M", "Chronos 20M SFT"),
    ("Dynamix - Panda", "Dynamix", "Panda"),
]
for pred_length in pred_lengths:
    print(f"Prediction length: {pred_length}")
    # Get all keys and ensure they exist
    all_keys = set(model_run_names)
    hellingers = {}
    for key in all_keys:
        hellingers[key] = np.array(list(metrics["avg_hellinger"][key][horizon_name][pred_length].values()))

    for label, key1, key2 in pairs:
        # Filter out None values when computing difference
        valid_pairs = [
            (v1, v2) for v1, v2 in zip(hellingers[key1], hellingers[key2]) if v1 is not None and v2 is not None
        ]
        if valid_pairs:
            diff = np.array([v1 - v2 for v1, v2 in valid_pairs])
            print(f"Mean Hellinger diff ({label}): {diff.mean():.4f}, Std Hellinger diff: {diff.std():.4f}")
        else:
            print(f"No valid data for {label}")

In [None]:
pred_lengths = [1024]
horizon_names = ["prediction_horizon", "full_trajectory"]
for horizon_name in horizon_names:
    print("-" * 100)
    print("Avg Hellinger")
    print("-" * 100)
    print(f"horizon_name: {horizon_name}")
    for model_key in model_run_names:
        print(f"Model: {model_key}")
        for pred_length in pred_lengths:
            hell_values = list(metrics["avg_hellinger"][model_key][horizon_name][pred_length].values())
            # Filter out None and NaN values
            hell_values_filtered = [v for v in hell_values if v is not None and not np.isnan(v)]
            if hell_values_filtered:
                mean_hell = np.mean(hell_values_filtered)
                std_hell = np.std(hell_values_filtered)
                print(
                    f"  Prediction length {pred_length}: mean avg_hellinger = {mean_hell:.4f}, std avg_hellinger = {std_hell:.4f}"
                )
            else:
                print(f"  Prediction length {pred_length}: No valid data")

In [None]:
pred_lengths = [1024]
horizon_names = ["prediction_horizon", "full_trajectory"]
for horizon_name in horizon_names:
    print("-" * 100)
    print("KL Divergence")
    print("-" * 100)
    print(f"horizon_name: {horizon_name}")
    for model_key in model_run_names:
        print(f"Model: {model_key}")
        for pred_length in pred_lengths:
            kld = list(metrics["kld"][model_key][horizon_name][pred_length].values())
            # Filter out None and NaN values
            kld_values_filtered = [v for v in kld if v is not None and not np.isnan(v)]
            if kld_values_filtered:
                mean_kld = np.mean(kld_values_filtered)
                std_kld = np.std(kld_values_filtered)
                print(f"  Prediction length {pred_length}: mean kld = {mean_kld:.4f}, std kld = {std_kld:.4f}")
            else:
                print(f"  Prediction length {pred_length}: No valid data")

In [None]:
pred_lengths = [128, 256, 512, 1024]
horizon_names = ["prediction_horizon", "full_trajectory"]
for horizon_name in horizon_names:
    print("-" * 100)
    print("KL Divergence")
    print("-" * 100)
    print(f"horizon_name: {horizon_name}")
    for model_key in model_run_names:
        print(f"Model: {model_key}")
        for pred_length in pred_lengths:
            kld = list(metrics["kld"][model_key][horizon_name][pred_length].values())
            # Filter out None and NaN values
            kld_values_filtered = [v for v in kld if v is not None and not np.isnan(v)]
            if kld_values_filtered:
                mean_kld = np.mean(kld_values_filtered)
                std_kld = np.std(kld_values_filtered)
                print(f"  Prediction length {pred_length}: mean kld = {mean_kld:.4f}, std kld = {std_kld:.4f}")
            else:
                print(f"  Prediction length {pred_length}: No valid data")

In [None]:
pred_lengths = [128, 256, 512, 1024]
horizon_names = ["prediction_horizon", "full_trajectory"]
for horizon_name in horizon_names:
    print("-" * 100)
    print("Avg Hellinger")
    print("-" * 100)
    print(f"horizon_name: {horizon_name}")
    for model_key in model_run_names:
        print(f"Model: {model_key}")
        for pred_length in pred_lengths:
            hell_values = list(metrics["avg_hellinger"][model_key][horizon_name][pred_length].values())
            # Filter out None and NaN values
            hell_values_filtered = [v for v in hell_values if v is not None and not np.isnan(v)]
            if hell_values_filtered:
                mean_hell = np.mean(hell_values_filtered)
                std_hell = np.std(hell_values_filtered)
                print(
                    f"  Prediction length {pred_length}: mean avg_hellinger = {mean_hell:.4f}, std avg_hellinger = {std_hell:.4f}"
                )
            else:
                print(f"  Prediction length {pred_length}: No valid data")