In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import re
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
from sklearn.metrics import r2_score
from tqdm import tqdm

from panda.utils import (
    apply_custom_style,
)

In [None]:
apply_custom_style("../config/plotting.yaml")

In [None]:
DEFAULT_COLORS = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
def get_sorted_metric_fnames(save_dir):
    fnames = [
        f
        for f in os.listdir(save_dir)
        if f.endswith(".json") and "distributional_metrics" in f
    ]

    def extract_window(fname):
        m = re.search(r"window-(\d+)", fname)
        return int(m.group(1)) if m else float("inf")

    return sorted(fnames, key=extract_window)


panda_metrics_save_dir = (
    f"{WORK_DIR}/eval_results/panda/pft_chattn_emb_w_poly-0/test_zeroshot/metrics_run1"
)
# NOTE: we also have for chronos_nondeterministic, replace "chronos" with "chronos_nondeterministic" in the paths below
chronos_sft_metrics_save_dir = (
    f"{WORK_DIR}/eval_results/chronos/chronos_t5_mini_ft-0/test_zeroshot/metrics_run1"
)
chronos_zs_metrics_save_dir = (
    f"{WORK_DIR}/eval_results/chronos/chronos_mini_zeroshot/test_zeroshot/metrics_run1"
)

panda_metrics_fnames = get_sorted_metric_fnames(panda_metrics_save_dir)
chronos_sft_metrics_fnames = get_sorted_metric_fnames(chronos_sft_metrics_save_dir)
chronos_zs_metrics_fnames = get_sorted_metric_fnames(chronos_zs_metrics_save_dir)

print(f"Found {len(panda_metrics_fnames)} panda metrics files: {panda_metrics_fnames}")
print(
    f"Found {len(chronos_sft_metrics_fnames)} chronos sft metrics files: {chronos_sft_metrics_fnames}"
)
print(
    f"Found {len(chronos_zs_metrics_fnames)} chronos zs metrics files: {chronos_zs_metrics_fnames}"
)

In [None]:
# For accumulating values across all files, for both panda and chronos_sft metrics


def filter_none(values):
    """Remove None values from a list."""
    return [v for v in values if v is not None]


def accumulate_metrics(metrics_fnames, metrics_save_dir):
    avg_hellinger_accum = {
        "pred_horizon": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    kld_accum = {
        "pred_horizon": defaultdict(lambda: defaultdict(list)),
        "full": defaultdict(lambda: defaultdict(list)),
    }
    max_lyap_r_gt_accum = defaultdict(lambda: defaultdict(list))
    max_lyap_r_pred_accum = defaultdict(lambda: defaultdict(list))
    max_lyap_r_gtcontext_accum = defaultdict(lambda: defaultdict(list))
    max_lyap_r_predcontext_accum = defaultdict(lambda: defaultdict(list))
    prediction_time_accum = defaultdict(list)

    for fname in metrics_fnames:
        with open(os.path.join(metrics_save_dir, fname), "rb") as f:
            metrics = json.load(f)
        n_pred_intervals = len(metrics)
        print(f"number of prediction intervals in {fname}: {n_pred_intervals}")
        for pred_interval in metrics:
            print(pred_interval)
            data = metrics[pred_interval]
            for system_name, system_entry in tqdm(
                data, desc=f"Processing {pred_interval}"
            ):
                avg_hellinger_accum["pred_horizon"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["avg_hellinger_distance"]
                )
                avg_hellinger_accum["full"][pred_interval][system_name].append(
                    system_entry["full_trajectory"]["avg_hellinger_distance"]
                )
                kld_accum["pred_horizon"][pred_interval][system_name].append(
                    system_entry["prediction_horizon"]["kl_divergence"]
                )
                kld_accum["full"][pred_interval][system_name].append(
                    system_entry["full_trajectory"]["kl_divergence"]
                )
                max_lyap_r_gt_accum[pred_interval][system_name].append(
                    system_entry["max_lyap_rosenstein"]["max_lyap_gt"]
                )
                max_lyap_r_pred_accum[pred_interval][system_name].append(
                    system_entry["max_lyap_rosenstein"]["max_lyap_pred"]
                )
                max_lyap_r_gtcontext_accum[pred_interval][system_name].append(
                    system_entry["max_lyap_rosenstein"]["max_lyap_gt_with_context"]
                )
                max_lyap_r_predcontext_accum[pred_interval][system_name].append(
                    system_entry["max_lyap_rosenstein"]["max_lyap_pred_with_context"]
                )
                pred_time = system_entry["prediction_time"]
                prediction_time_accum[system_name].append(pred_time)

    # Now, take the mean across all files for each metric, skipping None values
    avg_hellinger = {
        "pred_horizon": defaultdict(dict),
        "full": defaultdict(dict),
    }
    kld = {
        "pred_horizon": defaultdict(dict),
        "full": defaultdict(dict),
    }
    max_lyap_r_gt = defaultdict(dict)
    max_lyap_r_pred = defaultdict(dict)
    max_lyap_r_gtcontext = defaultdict(dict)
    max_lyap_r_predcontext = defaultdict(dict)
    prediction_time = {}

    for key in ["pred_horizon", "full"]:
        for pred_interval in avg_hellinger_accum[key]:
            for system_name, values in avg_hellinger_accum[key][pred_interval].items():
                filtered = filter_none(values)
                avg_hellinger[key][pred_interval][system_name] = (
                    float(np.mean(filtered)) if filtered else None
                )
        for pred_interval in kld_accum[key]:
            for system_name, values in kld_accum[key][pred_interval].items():
                filtered = filter_none(values)
                kld[key][pred_interval][system_name] = (
                    float(np.mean(filtered)) if filtered else None
                )

    for pred_interval in max_lyap_r_gt_accum:
        for system_name, values in max_lyap_r_gt_accum[pred_interval].items():
            filtered = filter_none(values)
            max_lyap_r_gt[pred_interval][system_name] = (
                float(np.mean(filtered)) if filtered else None
            )
    for pred_interval in max_lyap_r_pred_accum:
        for system_name, values in max_lyap_r_pred_accum[pred_interval].items():
            filtered = filter_none(values)
            max_lyap_r_pred[pred_interval][system_name] = (
                float(np.mean(filtered)) if filtered else None
            )
    for pred_interval in max_lyap_r_gtcontext_accum:
        for system_name, values in max_lyap_r_gtcontext_accum[pred_interval].items():
            filtered = filter_none(values)
            max_lyap_r_gtcontext[pred_interval][system_name] = (
                float(np.mean(filtered)) if filtered else None
            )
    for pred_interval in max_lyap_r_predcontext_accum:
        for system_name, values in max_lyap_r_predcontext_accum[pred_interval].items():
            filtered = filter_none(values)
            max_lyap_r_predcontext[pred_interval][system_name] = (
                float(np.mean(filtered)) if filtered else None
            )

    for system_name, times in prediction_time_accum.items():
        times_arr = np.array(filter_none(times))
        prediction_time[system_name] = (
            np.mean(times_arr) if len(times_arr) > 0 else None
        )

    return {
        "avg_hellinger": avg_hellinger,
        "kld": kld,
        "max_lyap_r_gt": max_lyap_r_gt,
        "max_lyap_r_pred": max_lyap_r_pred,
        "max_lyap_r_gtcontext": max_lyap_r_gtcontext,
        "max_lyap_r_predcontext": max_lyap_r_predcontext,
        "prediction_time": prediction_time,
    }


# Accumulate metrics for both panda and chronos_sft
print("Accumulating panda metrics...")
panda_metrics = accumulate_metrics(panda_metrics_fnames, panda_metrics_save_dir)
print("Accumulating chronos_sft metrics...")
chronos_sft_metrics = accumulate_metrics(
    chronos_sft_metrics_fnames, chronos_sft_metrics_save_dir
)
chronos_zs_metrics = accumulate_metrics(
    chronos_zs_metrics_fnames, chronos_zs_metrics_save_dir
)

In [None]:
metrics = {
    "avg_hellinger": {
        "panda": panda_metrics["avg_hellinger"],
        "chronos_sft": chronos_sft_metrics["avg_hellinger"],
        "chronos_zs": chronos_zs_metrics["avg_hellinger"],
    },
    "kld": {
        "panda": panda_metrics["kld"],
        "chronos_sft": chronos_sft_metrics["kld"],
        "chronos_zs": chronos_zs_metrics["kld"],
    },
    "max_lyap_r_gt": {
        "panda": panda_metrics["max_lyap_r_gt"],
        "chronos_sft": chronos_sft_metrics["max_lyap_r_gt"],
        "chronos_zs": chronos_zs_metrics["max_lyap_r_gt"],
    },
    "max_lyap_r_pred": {
        "panda": panda_metrics["max_lyap_r_pred"],
        "chronos_sft": chronos_sft_metrics["max_lyap_r_pred"],
        "chronos_zs": chronos_zs_metrics["max_lyap_r_pred"],
    },
    "max_lyap_r_gtcontext": {
        "panda": panda_metrics["max_lyap_r_gtcontext"],
        "chronos_sft": chronos_sft_metrics["max_lyap_r_gtcontext"],
        "chronos_zs": chronos_zs_metrics["max_lyap_r_gtcontext"],
    },
    "max_lyap_r_predcontext": {
        "panda": panda_metrics["max_lyap_r_predcontext"],
        "chronos_sft": chronos_sft_metrics["max_lyap_r_predcontext"],
        "chronos_zs": chronos_zs_metrics["max_lyap_r_predcontext"],
    },
    "prediction_time": {
        "panda": panda_metrics["prediction_time"],
        "chronos_sft": chronos_sft_metrics["prediction_time"],
        "chronos_zs": chronos_zs_metrics["prediction_time"],
    },
}

In [None]:
# NOTE: the reason for this is for some reason the first prediction takes lot more time, prob spinning up overhead
first_system = list(metrics["prediction_time"]["panda"].keys())[0]
metrics["prediction_time"]["panda"].pop(first_system)
metrics["prediction_time"]["chronos_sft"].pop(first_system)
metrics["prediction_time"]["chronos_zs"].pop(first_system)
print(metrics["prediction_time"]["panda"])
print(metrics["prediction_time"]["chronos_sft"])
print(metrics["prediction_time"]["chronos_zs"])

In [None]:
# Print prediction time mean and std for both panda and chronos_sft

for model_name in ["panda", "chronos_sft", "chronos_zs"]:
    prediction_times = list(metrics["prediction_time"][model_name].values())
    prediction_time_mean = np.mean(prediction_times)
    prediction_time_std = np.std(prediction_times)
    print(f"{model_name} prediction time mean:", prediction_time_mean)
    print(f"{model_name} prediction time std:", prediction_time_std)

In [None]:
# Choose the prediction interval (pred_length) of 512
pred_length = "128"
model_type = "chronos_sft"
# Get the dictionaries for gtcontext and predcontext at pred_length 512 for model_type
gtcontext_dict = metrics["max_lyap_r_gtcontext"][model_type].get(pred_length, {})
predcontext_dict = metrics["max_lyap_r_predcontext"][model_type].get(pred_length, {})

# Find the intersection of system names present in both
system_names = set(gtcontext_dict.keys()) & set(predcontext_dict.keys())

# Prepare x and y data for scatter plot
x = [gtcontext_dict[sys] for sys in system_names]
y = [predcontext_dict[sys] for sys in system_names]

# Compute R^2 score
if len(x) > 0 and len(y) > 0:
    r2 = r2_score(x, y)
else:
    r2 = float("nan")

plt.figure(figsize=(4, 4))
plt.scatter(x, y, color="black", s=5, alpha=0.1, label=None)
plt.xlabel("Context + Ground Truth", fontweight="bold")
plt.ylabel("Context + Prediction", fontweight="bold")
plt.title(
    rf"$\lambda_{{\max}}$ ($L_{{\mathrm{{pred}}}} = {pred_length}$)", fontweight="bold"
)

# Prepare handles and labels for legend
handles = []
labels = []

# Plot y=x line in red dashed, but do NOT add to legend yet
y_eq_x_min = min(x + y)
y_eq_x_max = max(x + y)
(h1,) = plt.plot(
    [y_eq_x_min, y_eq_x_max], [y_eq_x_min, y_eq_x_max], "r--", label=r"$y=x$"
)

# Plot line of best fit as solid red line and prepare equation+R2 for legend
eqn_r2_label = None
if len(x) > 1 and len(y) > 1:
    # Fit line: y = m*x + b
    m, b = np.polyfit(x, y, 1)
    x_fit = np.array([y_eq_x_min, y_eq_x_max])
    y_fit = m * x_fit + b
    # Format equation for label (to be shown in legend with R^2)
    if abs(b) < 1e-10:
        eqn_str = rf"$y = {m:.2f}x$"
    else:
        sign = "+" if b >= 0 else "-"
        eqn_str = rf"$y = {m:.2f}x {sign} {abs(b):.2f}$"
    if not (r2 != r2):  # check for nan
        eqn_r2_label = rf"{eqn_str}  $(R^2 = {r2:.3f})$"
    else:
        eqn_r2_label = eqn_str
    (h2,) = plt.plot(
        x_fit, y_fit, color="red", linestyle="-", linewidth=1.5, label=eqn_r2_label
    )
    # Add best fit line first, then y=x line, to put y=x below in legend
    handles.append(h2)
    labels.append(eqn_r2_label)
    handles.append(h1)
    labels.append(r"$y=x$")
else:
    # If no best fit, just add y=x
    handles.append(h1)
    labels.append(r"$y=x$")

# Show legend in lower right, showing both lines and their labels
ax = plt.gca()
ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
ax.ticklabel_format(style="sci", axis="both", scilimits=(0, 0))

if len(handles) > 0:
    ax.legend(
        handles=handles, labels=labels, loc="lower right", fontsize=8, frameon=True
    )

plt.tight_layout()
plt.savefig(
    os.path.join(
        "../figures",
        f"max_lyap_r_gtcontext_predcontext_{pred_length}_{model_type}.pdf",
    ),
    bbox_inches="tight",
)

plt.show()

In [None]:
pred_length = "512"
horizon_name = "pred_horizon"
avg_full_hellinger_panda = list(
    metrics["avg_hellinger"]["panda"][horizon_name][pred_length].values()
)
avg_full_hellinger_chronos_sft = list(
    metrics["avg_hellinger"]["chronos_sft"][horizon_name][pred_length].values()
)
avg_full_hellinger_chronos_zs = list(
    metrics["avg_hellinger"]["chronos_zs"][horizon_name][pred_length].values()
)

plt.figure(figsize=(4, 4))
# Compute common bins for all histograms
all_hellinger = (
    avg_full_hellinger_panda
    + avg_full_hellinger_chronos_sft
    + avg_full_hellinger_chronos_zs
)
bins = np.histogram_bin_edges(all_hellinger, bins=25)

alpha_val = 0.6
plt.hist(
    avg_full_hellinger_panda,
    bins=bins,
    color=DEFAULT_COLORS[0],
    edgecolor=DEFAULT_COLORS[0],
    alpha=alpha_val,
    histtype="stepfilled",
    label="Panda",
)
plt.hist(
    avg_full_hellinger_chronos_sft,
    bins=bins,
    color=DEFAULT_COLORS[1],
    edgecolor=DEFAULT_COLORS[1],
    alpha=alpha_val,
    histtype="stepfilled",
    label="Chronos 20M SFT",
)
plt.hist(
    avg_full_hellinger_chronos_zs,
    bins=bins,
    color=DEFAULT_COLORS[2] if len(DEFAULT_COLORS) > 2 else "tab:green",
    edgecolor=DEFAULT_COLORS[2] if len(DEFAULT_COLORS) > 2 else "tab:green",
    alpha=alpha_val,
    histtype="stepfilled",
    label="Chronos 20M",
)
# plt.xlabel("Average Hellinger", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper right")
plt.title(
    f"Avg Hellinger Distance ($L_{{\mathrm{{pred}}}} = {pred_length}$)",
    fontweight="bold",
)
plt.tight_layout()
plt.savefig(
    os.path.join(
        "../figures",
        f"avg_hellinger_distribution_{horizon_name}_{pred_length}.pdf",
    ),
    bbox_inches="tight",
)
plt.show()

In [None]:
# Compute the difference between Chronos SFT and Panda Hellinger distances
hellinger_diff = np.array(avg_full_hellinger_chronos_sft) - np.array(
    avg_full_hellinger_panda
)

plt.figure(figsize=(4, 4))
plt.hist(
    hellinger_diff,
    bins=30,
    color="gray",
    edgecolor="black",
    alpha=0.7,
    histtype="stepfilled",
)
plt.axvline(
    0, color="k", linestyle="dotted", linewidth=1.5
)  # Dotted vertical line at zero
plt.xlabel("Avg Hellinger (Chronos SFT - Panda)", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.title(
    f"Difference in Avg Hellinger ($L_{{\mathrm{{pred}}}} = {pred_length}$)",
    fontweight="bold",
)
plt.tight_layout()
# plt.savefig(
#     os.path.join(
#         "../figures",
#         f"avg_hellinger_diff_distribution_{horizon_name}_{pred_length}.pdf",
#     ),
#     bbox_inches="tight",
# )
plt.show()

In [None]:
pred_length = "512"
horizon_name = "full"
full_kld_panda = list(metrics["kld"]["panda"][horizon_name][pred_length].values())
full_kld_chronos_sft = list(
    metrics["kld"]["chronos_sft"][horizon_name][pred_length].values()
)
full_kld_chronos_zs = list(
    metrics["kld"]["chronos_zs"][horizon_name][pred_length].values()
)

plt.figure(figsize=(4, 4))
# Compute common bins for all histograms, on log scale
all_kld = full_kld_panda + full_kld_chronos_sft + full_kld_chronos_zs

# Remove zeros and negative values for log scale
all_kld_pos = [x for x in all_kld if x > 0]
full_kld_panda_pos = [x for x in full_kld_panda if x > 0]
full_kld_chronos_sft_pos = [x for x in full_kld_chronos_sft if x > 0]
full_kld_chronos_zs_pos = [x for x in full_kld_chronos_zs if x > 0]

# Use log-spaced bins
if len(all_kld_pos) > 0:
    min_kld = min(all_kld_pos)
    max_kld = max(all_kld_pos)
    bins = np.logspace(np.log10(min_kld), np.log10(max_kld), 20)
else:
    bins = 25  # fallback
    print("No positive values found")

alpha_val = 0.6
plt.hist(
    full_kld_panda_pos,
    bins=bins,
    color=DEFAULT_COLORS[0],
    edgecolor=DEFAULT_COLORS[0],
    alpha=alpha_val,
    histtype="stepfilled",
    label="Panda",
)
plt.hist(
    full_kld_chronos_sft_pos,
    bins=bins,
    color=DEFAULT_COLORS[1],
    edgecolor=DEFAULT_COLORS[1],
    alpha=alpha_val,
    histtype="stepfilled",
    label="Chronos 20M SFT",
)
plt.hist(
    full_kld_chronos_zs_pos,
    bins=bins,
    color=DEFAULT_COLORS[2] if len(DEFAULT_COLORS) > 2 else "tab:green",
    edgecolor=DEFAULT_COLORS[2] if len(DEFAULT_COLORS) > 2 else "tab:green",
    alpha=alpha_val,
    histtype="stepfilled",
    label="Chronos 20M",
)
plt.xscale("log")
# plt.xlabel("Average KL Divergence", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.legend(loc="upper left")
plt.title(
    f"KL Divergence ($L_{{\mathrm{{pred}}}} = {pred_length}$)",
    fontweight="bold",
)
plt.tight_layout()
plt.savefig(
    os.path.join(
        "../figures",
        f"kld_distribution_{horizon_name}_{pred_length}_log.pdf",
    ),
    bbox_inches="tight",
)
plt.show()

In [None]:
# Compute the difference between Chronos SFT and Panda KL divergences
kld_diff = np.array(full_kld_chronos_sft) - np.array(full_kld_panda)

plt.figure(figsize=(4, 4))
plt.hist(
    kld_diff,
    bins=30,
    color="gray",
    edgecolor="black",
    alpha=0.7,
    histtype="stepfilled",
)
plt.axvline(
    0, color="k", linestyle="dotted", linewidth=1.5
)  # Dotted vertical line at zero
plt.xlabel("$D_{{KL}}$ (Chronos SFT - Panda)", fontweight="bold")
plt.ylabel("Count", fontweight="bold")
plt.title(
    f"Difference in $D_{{KL}}$ ($L_{{\mathrm{{pred}}}} = {pred_length}$)",
    fontweight="bold",
)
plt.tight_layout()
# plt.savefig(
#     os.path.join(
#         "../figures",
#         f"avg_kld_diff_distribution_{horizon_name}_{pred_length}.pdf",
#     ),
#     bbox_inches="tight",
# )
plt.show()

In [None]:
mean_kld_diff = np.mean(np.array(full_kld_chronos_sft) - np.array(full_kld_panda))
std_kld_diff = np.std(np.array(full_kld_chronos_sft) - np.array(full_kld_panda))
print(f"Mean KL diff: {mean_kld_diff:.4f}, Std KL diff: {std_kld_diff:.4f}")

In [None]:
mean_kld_diff = np.mean(np.array(full_kld_chronos_zs) - np.array(full_kld_chronos_sft))
std_kld_diff = np.std(np.array(full_kld_chronos_zs) - np.array(full_kld_chronos_sft))
print(f"Mean KL diff: {mean_kld_diff:.4f}, Std KL diff: {std_kld_diff:.4f}")

In [None]:
mean_kld_diff = np.mean(np.array(full_kld_chronos_zs) - np.array(full_kld_panda))
std_kld_diff = np.std(np.array(full_kld_chronos_zs) - np.array(full_kld_panda))
print(f"Mean KL diff: {mean_kld_diff:.4f}, Std KL diff: {std_kld_diff:.4f}")

In [None]:
pred_lengths = ["128", "256", "512"]
horizon_name = "full"
pairs = [
    ("Chronos SFT - Panda", "chronos_sft", "panda"),
    ("Chronos ZS - Chronos SFT", "chronos_zs", "chronos_sft"),
    ("Chronos ZS - Panda", "chronos_zs", "panda"),
]
for pred_length in pred_lengths:
    klds = {
        key: np.array(list(metrics["kld"][key][horizon_name][pred_length].values()))
        for _, key, _ in pairs
    }
    klds["panda"] = np.array(
        list(metrics["kld"]["panda"][horizon_name][pred_length].values())
    )
    for label, key1, key2 in pairs:
        diff = klds[key1] - klds[key2]
        print(
            f"Mean KL diff ({label}): {diff.mean():.4f}, Std KL diff: {diff.std():.4f}"
        )

In [None]:
horizon_name = "pred_horizon"
for pred_length in pred_lengths:
    hells = {
        key: np.array(
            list(metrics["avg_hellinger"][key][horizon_name][pred_length].values())
        )
        for _, key, _ in pairs
    }
    hells["panda"] = np.array(
        list(metrics["avg_hellinger"]["panda"][horizon_name][pred_length].values())
    )
    for label, key1, key2 in pairs:
        diff = hells[key1] - hells[key2]
        print(
            f"Mean avg_hellinger diff ({label}): {diff.mean():.2f}, Std avg_hellinger diff: {diff.std():.2f}"
        )