In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

import numpy as np
from dysts.analysis import corr_gpdim
from dysts.metrics import estimate_kl_divergence, smape, spearman

from dystformer.utils import (
    load_trajectory_from_arrow,
    plot_trajs_multivariate,
)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")
METRICS_DIR = os.path.join(WORK_DIR, "eval_results")
CHECKPOINT_DIR = os.path.join(WORK_DIR, "checkpoints")
EVAL_DIR = os.path.join(DATA_DIR, "eval")

In [None]:
run_name = "pft_chattn_emb_w_poly-0"
chronos_ft_run_name = "chronos_bolt_mini-12"

In [None]:
# pft_model = PatchTSTPipeline.from_pretrained(
#     mode="predict",
#     pretrain_path=os.path.join(CHECKPOINT_DIR, run_name, "checkpoint-final"),
#     device_map="cuda:2",
# )

In [None]:
# chronos_ft_model = ChronosPipeline.from_pretrained(
#     os.path.join(CHECKPOINT_DIR, chronos_ft_run_name, "checkpoint-final"),
#     device_map="cuda:2",
#     torch_dtype=torch.float32,
# )

In [None]:
split_name = "improved/final_skew40/test_zeroshot"
system_name = (
    "HyperWang_Duffing"  # "VallisElNino_Hopfield"  # "ForcedFitzHughNagumo_Hopfield"
)

In [None]:
pred_dir_dict = {
    "chronos_ft": os.path.join(
        EVAL_DIR,
        "chronos",
        chronos_ft_run_name,
        split_name.split("improved/")[1],
        "forecasts",
        system_name,
    ),
    "pft": os.path.join(
        EVAL_DIR,
        "patchtst",
        run_name,
        split_name.split("improved/")[1],
        "forecasts",
        system_name,
    ),
}
gt_dir_dict = {
    "chronos_ft": os.path.join(
        Path(pred_dir_dict["chronos_ft"]).resolve().parent.parent, "labels", system_name
    ),
    "pft": os.path.join(
        Path(pred_dir_dict["pft"]).resolve().parent.parent, "labels", system_name
    ),
}

In [None]:
os.listdir(pred_dir_dict["pft"])

In [None]:
os.listdir(pred_dir_dict["pft"])

In [None]:
os.listdir(gt_dir_dict["pft"])

In [None]:
dyst_dir = os.path.join(DATA_DIR, split_name, system_name)

In [None]:
os.listdir(dyst_dir)

In [None]:
dyst_dir_sample_idx_vals = [
    int(s.split("_T-4096.arrow")[0]) for s in os.listdir(dyst_dir)
]
# dyst_dir_sample_idx_vals = sorted(dyst_dir_sample_idx_vals)
print(dyst_dir_sample_idx_vals)

In [None]:
sample_idx = 4

dyst_filepath = os.path.join(
    dyst_dir, f"{dyst_dir_sample_idx_vals[sample_idx]}_T-4096.arrow"
)
dyst_coords, _ = load_trajectory_from_arrow(dyst_filepath)


context_length = 512
pred_length = 128

pred_coords_dict = {}
gt_coords_dict = {}

for model_name in pred_dir_dict.keys():
    pred_filepath = os.path.join(
        pred_dir_dict[model_name], f"{sample_idx}_T-1024.arrow"
    )
    pred_coords_with_context, _ = load_trajectory_from_arrow(pred_filepath)

    gt_filepath = os.path.join(gt_dir_dict[model_name], f"{sample_idx}_T-1024.arrow")
    gt_coords_with_context, _ = load_trajectory_from_arrow(gt_filepath)

    pred_coords = pred_coords_with_context[
        :, context_length : context_length + pred_length
    ]
    gt_coords = gt_coords_with_context[:, context_length : context_length + pred_length]
    context_coords = gt_coords_with_context[:, :context_length]
    assert np.allclose(context_coords, pred_coords_with_context[:, :context_length])

    pred_coords_dict[model_name] = pred_coords
    gt_coords_dict[model_name] = gt_coords

In [None]:
print(f"Dyst coords shape: {dyst_coords.shape}")
print(f"Preds coords shape: {pred_coords.shape}")
print(f"GT coords shape: {gt_coords.shape}")

In [None]:
plot_trajs_multivariate(
    np.expand_dims(dyst_coords, axis=0),
    plot_name=f"{system_name} Full Trajectory",
    # standardize=True,
    show_plot=True,
)

# plot_trajs_multivariate(
#     np.expand_dims(context_coords, axis=0),
#     plot_name=f"{system_name} Context",
#     # standardize=True,
#     show_plot=True,
# )

In [None]:
plot_trajs_multivariate(
    np.expand_dims(pred_coords_dict["pft"], axis=0),
    plot_name=f"{system_name} Predictions (Our Model)",
    # standardize=True,
    show_plot=True,
)

plot_trajs_multivariate(
    np.expand_dims(gt_coords_dict["pft"], axis=0),
    plot_name=f"{system_name} Ground Truth",
    # standardize=True,
    show_plot=True,
)

In [None]:
plot_trajs_multivariate(
    np.expand_dims(pred_coords_dict["chronos_ft"], axis=0),
    plot_name=f"{system_name} Predictions (Chronos FT)",
    # standardize=True,
    show_plot=True,
)

plot_trajs_multivariate(
    np.expand_dims(gt_coords_dict["chronos_ft"], axis=0),
    plot_name=f"{system_name} Ground Truth",
    # standardize=True,
    show_plot=True,
)

In [None]:
dyst_coords.shape

In [None]:
# corr_gpdim_val = corr_gpdim(dyst_coords.T, dyst_coords.T, standardize=False)
# print(corr_gpdim_val)

In [None]:
pred_coords_dict["pft"].shape

In [None]:
corr_gpdim_val = corr_gpdim(
    pred_coords_dict["pft"].T, gt_coords_dict["pft"].T, standardize=False
)
print(f"Our model corr_gpdim_val: {corr_gpdim_val}")
corr_gpdim_val = corr_gpdim(
    pred_coords_dict["chronos_ft"].T, gt_coords_dict["chronos_ft"].T, standardize=False
)
print(f"Chronos FT corr_gpdim_val: {corr_gpdim_val}")

In [None]:
# gpdim_val = compute_gp_dimension(dyst_coords)
# print(gpdim_val)

In [None]:
corr_gpdim_val = smape(pred_coords_dict["pft"].T, gt_coords_dict["pft"].T)
print(f"Our model smape: {corr_gpdim_val}")
corr_gpdim_val = smape(pred_coords_dict["chronos_ft"].T, gt_coords_dict["chronos_ft"].T)
print(f"Chronos FT smape: {corr_gpdim_val}")

In [None]:
spearman_val = spearman(pred_coords_dict["pft"].T, gt_coords_dict["pft"].T)
print(f"Our model spearman: {spearman_val}")
spearman_val = spearman(
    pred_coords_dict["chronos_ft"].T, gt_coords_dict["chronos_ft"].T
)
print(f"Chronos FT spearman: {spearman_val}")

In [None]:
kl_divergence_val = estimate_kl_divergence(
    pred_coords_dict["pft"].T, gt_coords_dict["pft"].T
)
print(f"Our model kl_divergence: {kl_divergence_val}")
kl_divergence_val = estimate_kl_divergence(
    pred_coords_dict["chronos_ft"].T, gt_coords_dict["chronos_ft"].T
)
print(f"Chronos FT kl_divergence: {kl_divergence_val}")