In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import time

import dysts.flows as flows
import matplotlib.pyplot as plt
import numpy as np
import torch
from dysts.analysis import max_lyapunov_exponent_rosenstein  # type: ignore

from panda.chronos.pipeline import ChronosPipeline
from panda.patchtst.pipeline import PatchTSTPipeline
from panda.utils import (
    apply_custom_style,
    get_system_filepaths,
    load_trajectory_from_arrow,
    make_clean_projection,
    plot_3d_and_univariate,
    safe_standardize,
)

In [None]:
apply_custom_style("../config/plotting.yaml")

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
split_name = "improved/final_base40"
# system_name = "Coullet_AnishchenkoAstakhov"
# system_name = "ForcedBrusselator_Rossler"
system_name = "LorenzStenflo"

In [None]:
if "_" in system_name:
    driver_name, response_name = system_name.split("_")
else:
    driver_name, response_name = None, None
print(driver_name)
print(response_name)

In [None]:
split_dir = os.path.join(DATA_DIR, split_name)
print(split_dir)
files_lst = get_system_filepaths(system_name, split_dir, split="test_zeroshot")

print(files_lst)

sample_idx = 0
filepath = files_lst[sample_idx]

In [None]:
transient_time = 0
dyst_coords, _ = load_trajectory_from_arrow(filepath)
dyst_coords = dyst_coords[:, transient_time:]

In [None]:
standardize = False
if standardize:
    dyst_coords = safe_standardize(dyst_coords)

In [None]:
dyst_coords.shape

In [None]:
plot_3d_and_univariate(
    dyst_coords[:, :None],
    figsize=(6, 8),
    plot_kwargs={"linewidth": 0.5, "alpha": 0.8},
    custom_colors=["tab:blue"],
    # plot_title=f"{driver_name} + {response_name}",
    title_kwargs={"fontsize": 14, "fontweight": "bold"},
)

### Lyapunov Spectrum

In [None]:
from panda.utils import lyap_wolf

In [None]:
sys = getattr(flows, "Lorenz")()
print(f"dt = {sys.dt}")
print(f"1 / dt = {1 / sys.dt}")

In [None]:
# Faster version, check with cell below (longer, more accurate)
# Compute Lyapunov exponents
start_time = time.time()
lambdas = lyap_wolf(f=sys.rhs, x0=sys.ic, dt=0.01, jac=None, n_steps=10_000)
end_time = time.time()
print(f"Time taken: {end_time - start_time:.2f} seconds")
print("Lyapunov spectrum:", lambdas)

In [None]:
n_steps = int(100 // sys.dt)  # just a heuristic
print(f"Using {n_steps} steps...")

has_jac = sys.has_jacobian and sys._jac.__dict__ != {}

if has_jac:
    print("Jacobian is available")
    print(f"Jacobian: {sys.jac(sys.ic, 0.0)}")
else:
    print("Jacobian is not available")

# Compute Lyapunov exponents
lambdas = lyap_wolf(
    f=sys.rhs,
    x0=sys.ic,
    dt=sys.dt,
    jac=sys.jac if has_jac else None,
    n_steps=n_steps,
)

print("Lyapunov spectrum:", lambdas)

In [None]:
# 2 + lambdas[0] / abs(lambdas[2])

### Rosenstein Max Lyapunov Exponent

In [None]:
# dyst_coords.shape

In [None]:
# # max_lyapunov_exponent_rosenstein(dyst_coords.T, trajectory_len=102)
# max_lyapunov_exponent_rosenstein(dyst_coords.T)

In [None]:
# from nolds.measures import lyap_r

In [None]:
# lyap_r(
#     dyst_coords[0],
#     emb_dim=10,
#     lag=None,
#     min_tsep=None,
#     tau=1,
#     min_neighbors=20,
#     trajectory_len=20,
#     debug_plot=True,
# )

# Forecasts

In [None]:
run_name = "pft_chattn_emb_w_poly-0"
# run_name = "panda_nh12_dmodel768_mixedp-4"
# run_name = "panda_nh10_dmodel640-1"
# # run_name = "pft_polyfeats_repro-0"

pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:0",
)
panda_kwargs = {
    "limit_prediction_length": False,
    "sliding_context": True,
    "is_chronos": False,
}

In [None]:
chronos_kwargs = {
    "is_chronos": True,
    "limit_prediction_length": False,
    "num_samples": 10,
    "deterministic": False,
}

chronos_sft_run_name = "chronos_t5_mini_ft-0"

chronos_sft = ChronosPipeline.from_pretrained(
    f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{chronos_sft_run_name}/checkpoint-final",
    device_map="cuda:3",
    torch_dtype=torch.float32,
)

In [None]:
def get_model_prediction(
    model,
    context: np.ndarray,
    groundtruth: np.ndarray,
    prediction_length: int,
    is_chronos: bool = False,
    title: str | None = None,
    save_path: str | None = None,
    show_plot: bool = True,
    figsize: tuple[int, int] = (6, 8),
    color: str = "red",
    verbose: bool = True,
    **kwargs,
) -> tuple[np.ndarray, float]:
    context_tensor = (
        torch.from_numpy(context.T if not is_chronos else context).float()
        # .to(model.device)
    )
    if not is_chronos:
        context_tensor = context_tensor[None, ...]

    start_time = time.time()
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    elapsed_time = time.time() - start_time

    if is_chronos:
        if not kwargs.get("deterministic", False):
            pred = np.median(pred, axis=1)
        pred = pred.T

    if verbose:
        print(f"context tensor shape: {context_tensor.shape}")
        print(f"context tensor device: {context_tensor.device}")
        print(f"pred shape: {pred.shape}")
        print(f"Prediction time: {elapsed_time:.4f} seconds")

    if show_plot:
        total_length = context.shape[1] + prediction_length
        context_ts = np.arange(context.shape[1]) / total_length
        pred_ts = np.arange(context.shape[1], total_length) / total_length

        # Add the last time point of context to the beginning of groundtruth
        # This ensures continuity between context and groundtruth in the plot
        if context.shape[1] > 0 and groundtruth.shape[1] > 0:
            last_context_point = context[:, -1][
                :, np.newaxis
            ]  # Get last point and reshape to column vector
            groundtruth = np.hstack(
                (last_context_point, groundtruth)
            )  # Prepend to groundtruth

            # Prepend last context point to prediction timeline and data for continuity
            pred_ts = np.concatenate(([context_ts[-1]], pred_ts))
            if pred.shape[0] + 1 == len(pred_ts):
                pred = np.vstack((context[:, -1], pred))

        # Create figure with gridspec layout
        fig = plt.figure(figsize=figsize)

        # Create main grid with padding for colorbar
        outer_grid = fig.add_gridspec(2, 1, height_ratios=[0.65, 0.35], hspace=-0.1)

        # Create sub-grid for the plots
        gs = outer_grid[1].subgridspec(
            3, 1, height_ratios=[0.2] * 3, wspace=0, hspace=0
        )
        ax_3d = fig.add_subplot(outer_grid[0], projection="3d")

        ax_3d.plot(*context[:3], alpha=0.5, color="black", label="Context")
        ax_3d.plot(*groundtruth[:3], linestyle="-", color="black", label="Groundtruth")
        ax_3d.plot(*pred.T[:3], color=color, label="Prediction")
        # make_arrow_axes(ax_3d)
        make_clean_projection(ax_3d)

        if title is not None:
            title_name = title.replace("_", " ")
            ax_3d.set_title(title_name, fontweight="bold")

        axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
        for i, ax in enumerate(axes_1d):
            ax.plot(context_ts, context[i], alpha=0.5, color="black")
            ax.plot(pred_ts, groundtruth[i], linestyle="-", color="black")
            ax.plot(pred_ts, pred[:, i], color=color)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_aspect("auto")

        if save_path is not None:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            print(f"saving fig to: {save_path}")
            plt.savefig(save_path, bbox_inches="tight")
    pred = pred.T
    return pred, elapsed_time

In [None]:
context_length = 512
prediction_length = 512

start_time = 512
end_time = start_time + context_length

In [None]:
context = dyst_coords[:, start_time:end_time]
groundtruth = dyst_coords[:, end_time : end_time + prediction_length]
print(f"context shape: {context.shape}")
print(f"groundtruth shape: {groundtruth.shape}")

In [None]:
save_path = os.path.join(
    "../figures/panda/",
    f"{system_name}_sample{sample_idx}_context{start_time}-{end_time}_pred{prediction_length}_run1.pdf",
)

panda_pred, panda_elapsed_time = get_model_prediction(
    pft_model,
    context,
    groundtruth,
    prediction_length,
    verbose=True,
    show_plot=True,
    save_path=save_path,
    **panda_kwargs,
)

In [None]:
save_path = os.path.join(
    "../figures/chronos_sft",
    f"{system_name}_sample{sample_idx}_context{start_time}-{end_time}_pred{prediction_length}_nsamples-{chronos_kwargs['num_samples']}_run2.pdf",
)

chronos_sft_pred, chronos_sft_elapsed_time = get_model_prediction(
    chronos_sft,
    context,
    groundtruth,
    prediction_length,
    verbose=True,
    show_plot=True,
    save_path=save_path,
    color="tab:blue",
    **chronos_kwargs,
)

In [None]:
chronos_sft_pred.shape

In [None]:
dyst_coords.shape

In [None]:
max_lyapunov_exponent_rosenstein(groundtruth.T)

In [None]:
panda_pred.shape

In [None]:
max_lyapunov_exponent_rosenstein(panda_pred.T)

In [None]:
max_lyapunov_exponent_rosenstein(chronos_sft_pred.T)

### Distributional Metrics

In [None]:
from dysts.metrics import (  # type: ignore
    average_hellinger_distance,
    estimate_kl_divergence,
)

In [None]:
dyst_coords.shape

In [None]:
estimate_kl_divergence(dyst_coords.T, panda_pred.T, n_samples=10_000)

In [None]:
estimate_kl_divergence(dyst_coords.T, chronos_sft_pred.T, n_samples=10_000)

In [None]:
# average_hellinger_distance(groundtruth.T, panda_pred.T)
average_hellinger_distance(groundtruth.T, panda_pred.T)

In [None]:
average_hellinger_distance(groundtruth.T, chronos_sft_pred.T)

In [None]:
average_hellinger_distance(dyst_coords.T, panda_pred.T)

In [None]:
average_hellinger_distance(dyst_coords.T, chronos_sft_pred.T)

In [None]:
panda_pred.T[:128].shape

In [None]:
average_hellinger_distance(dyst_coords.T, panda_pred.T[:256])

In [None]:
estimate_kl_divergence(dyst_coords.T, chronos_sft_pred.T[:256], n_samples=10_000)