In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os

import dysts.flows as flows  # type: ignore
import numpy as np

from panda.utils import (
    apply_custom_style,
    check_dict_match,
    init_skew_system_from_params,
    load_dyst_samples,
    plot_grid_trajs_multivariate,
    test_system_jacobian,
)

In [None]:
# Apply matplotlib style from config
apply_custom_style("../config/plotting.yaml")

In [None]:
rseed = 99
rng = np.random.default_rng(rseed)

In [None]:
WORK_DIR = os.environ.get("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

### Set test system

In [None]:
# dataset_name = "final_skew40_tanh_random"
dataset_name = "improved/final_skew40"
split_name = "test_zeroshot"
split_dir = os.path.join(DATA_DIR, dataset_name, split_name)

In [None]:
params_json_path = os.path.join(
    DATA_DIR, f"{dataset_name}/parameters/{split_name}", "filtered_params_dict.json"
)

In [None]:
reloaded_params_dicts = json.load(open(params_json_path))

In [None]:
system_idx = 2
test_system_name = list(reloaded_params_dicts.keys())[system_idx]
print(test_system_name)

In [None]:
sample_idx = 0

### Plot trajectory from saved arrow file

In [None]:
test_dyst_traj = load_dyst_samples(
    test_system_name,
    data_dir=split_dir,
    one_dim_target=False,
    num_samples=(1 + sample_idx),
)[sample_idx]  # type: ignore

print(test_dyst_traj.shape)

In [None]:
print(test_dyst_traj.shape)

In [None]:
plot_grid_trajs_multivariate(
    {test_system_name: test_dyst_traj[None, :, :]},
    save_path=None,
    subplot_size=(4, 4),
    show_axes=True,
    show_titles=True,
    plot_projections=True,
    projections_alpha=0.08,
    plot_kwargs={"linewidth": 0.3},
)

### Load saved params

In [None]:
reloaded_test_params = reloaded_params_dicts[test_system_name][sample_idx]

In [None]:
reloaded_test_params

In [None]:
reloaded_test_params.keys()

In [None]:
reloaded_test_params["coupling_map"]

In [None]:
is_skew = "_" in test_system_name
if is_skew:
    driver_name, response_name = test_system_name.split("_")
    sys = init_skew_system_from_params(
        driver_name,
        response_name,
        reloaded_test_params,
    )

In [None]:
sys.__dict__

In [None]:
print(f"driver period: {sys.driver.period}")
print(f"response period: {sys.response.period}")
print(f"driver dt: {sys.driver.dt}")
print(f"response dt: {sys.response.dt}")

In [None]:
sys.params

In [None]:
sys.unbounded_indices

In [None]:
sys.params

In [None]:
serialized_coupling_map_params = sys.coupling_map._serialize()
reloaded_coupling_map_params = reloaded_test_params["coupling_map"]


check_dict_match(reloaded_coupling_map_params, serialized_coupling_map_params)

In [None]:
# set initial condition
sys.ic = np.array(reloaded_test_params["ic"])
print("initial condition: ", sys.ic)

if not sys.has_jacobian():
    print(f"Jacobian not implemented for {test_system_name}")
else:
    print(f"Jacobian implemented for {test_system_name}")
    print(f"Jacobian shape: {sys.jac(sys.ic, 0).shape}")
    print(f"Jacobian: {sys.jac(sys.ic, 0)}")

In [None]:
test_system_jacobian(
    sys,
    num_timesteps=5120,
    num_periods=50,
    transient=200,
    n_points_sample=10,
    eps=1e-8,
)

In [None]:
num_timesteps = 5120
num_periods = 50

ts, traj = sys.make_trajectory(
    num_timesteps,
    pts_per_period=num_timesteps // num_periods,
    return_times=True,
    atol=1e-10,
    rtol=1e-8,
)

In [None]:
transient_frac = 0.2
transient_length = int(transient_frac * num_timesteps)
trajectory = traj[None, transient_length:, :]
print(trajectory.shape)
trajectory_to_plot = trajectory.transpose(0, 2, 1)
driver_coords = trajectory_to_plot[:, : sys.driver_dim, :]
response_coords = trajectory_to_plot[:, sys.driver_dim :, :]
for name, coords in [
    ("driver", driver_coords),
    ("response", response_coords),
]:
    plot_grid_trajs_multivariate(
        # {f"{test_system_name} {name}": coords},
        {name: coords},
        save_path=None,
        subplot_size=(4, 4),
        show_axes=True,
        show_titles=True,
        plot_projections=True,
        projections_alpha=0.08,
        plot_kwargs={"linewidth": 0.3},
    )

### Check that saved trajectory is close to re-computed trajectory

In [None]:
saved_traj = test_dyst_traj
reloaded_traj = response_coords.squeeze()

assert saved_traj.shape == reloaded_traj.shape, (
    f"saved_traj.shape: {saved_traj.shape}, reloaded_traj.shape: {reloaded_traj.shape}"
)

mse = np.mean((saved_traj - reloaded_traj) ** 2)
rmse = np.sqrt(mse)
print(f"MSE: {mse}")
print(f"RMSE: {rmse}")

### Integrate Pure Response

Response without the Driver

In [None]:
sys.driver_dim

In [None]:
response_ic = reloaded_test_params["ic"][sys.driver_dim :]
print(f"response_ic: {response_ic}")

In [None]:
pure_response_sys = getattr(flows, response_name)(
    parameters=reloaded_test_params["response_params"]
)
print(pure_response_sys.__dict__.keys())
print(pure_response_sys.__dict__["param_list"])
print(pure_response_sys.__dict__["params"])

In [None]:
pure_response_sys.__dict__

In [None]:
# set initial condition
pure_response_sys.ic = np.array(response_ic)
print(pure_response_sys.ic)

if not pure_response_sys.has_jacobian():
    print(f"Jacobian not implemented for {response_name}")
else:
    print(f"Jacobian implemented for {response_name}")
    # print(f"Jacobian shape: {pure_response_sys.jac(pure_response_sys.ic, 0).shape}")
    print(f"Jacobian: {pure_response_sys.jac(pure_response_sys.ic, 0)}")

In [None]:
# Check Jacobian of reconstructed pure response
test_system_jacobian(
    pure_response_sys,
    num_timesteps=5120,
    num_periods=50,
    transient=200,
    n_points_sample=10,
    eps=1e-8,
)

In [None]:
num_timesteps = 5120
num_periods = 50

pure_response_ts, pure_response_traj = pure_response_sys.make_trajectory(
    num_timesteps,
    pts_per_period=num_timesteps // num_periods,
    return_times=True,
    atol=1e-10,
    rtol=1e-8,
)

In [None]:
pure_response_traj.shape

In [None]:
plot_grid_trajs_multivariate(
    {response_name: pure_response_traj.T[None, :, :]},
    save_path=None,
    subplot_size=(4, 4),
    show_axes=True,
    show_titles=True,
    plot_projections=True,
    projections_alpha=0.08,
    plot_kwargs={"linewidth": 0.3},
)

### Remake Skew Trajectory with new IC

In [None]:
traj.shape

In [None]:
print(f"Old initial condition: \n {sys.ic}")

In [None]:
n_timesteps = traj.shape[0]
n_dims = traj.shape[1]
new_ic_idx = rng.integers(0, n_timesteps, size=n_dims)
new_ic = np.array([traj[new_ic_idx[i], i] for i in range(n_dims)])
print(f"New initial condition: \n {new_ic}")

In [None]:
sys.ic = new_ic

if not sys.has_jacobian():
    print(f"Jacobian not implemented for {test_system_name}")

num_timesteps = 5120
num_periods = 50

ts, traj = sys.make_trajectory(
    num_timesteps,
    pts_per_period=num_timesteps // num_periods,
    return_times=True,
    atol=1e-10,
    rtol=1e-8,
)

In [None]:
trajectory = traj[None, transient_length:, :]
print(trajectory.shape)
trajectory_to_plot = trajectory.transpose(0, 2, 1)
driver_coords = trajectory_to_plot[:, : sys.driver_dim, :]
response_coords = trajectory_to_plot[:, sys.driver_dim :, :]
for name, coords in [
    ("driver", driver_coords),
    ("response", response_coords),
]:
    plot_grid_trajs_multivariate(
        {name: coords},
        save_path=None,
        subplot_size=(4, 4),
        plot_kwargs={"linewidth": 0.3},
        show_axes=True,
        show_titles=True,
        plot_projections=True,
        projections_alpha=0.08,
    )