In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob
import json
import os
from typing import Any

import dysts.flows as flows
import numpy as np
from dysts.systems import DynSys
from tqdm import tqdm

from dystformer.coupling_maps import RandomAdditiveCouplingMap
from dystformer.skew_system import SkewProduct
from dystformer.utils import load_trajectory_from_arrow, plot_trajs_multivariate

In [None]:
rseed = 99
rng = np.random.default_rng(rseed)

In [None]:
WORK_DIR = os.environ.get("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

### Utils

In [None]:
def init_skew_system_from_params(
    driver_name: str,
    response_name: str,
    param_dict: dict[str, Any],
    **kwargs,
) -> DynSys:
    """
    Initialize a skew-product dynamical system from saved parameters.
    Assumes RandomAdditiveCouplingMap.
    """
    system_name = f"{driver_name}_{response_name}"
    required_keys = [
        "driver_params",
        "response_params",
        "coupling_map",
    ]
    for key in required_keys:
        if key not in param_dict:
            raise ValueError(f"Key {key} not found in param_dict for {system_name}")

    driver = getattr(flows, driver_name)(parameters=param_dict["driver_params"])
    response = getattr(flows, response_name)(parameters=param_dict["response_params"])

    coupling_map = RandomAdditiveCouplingMap._deserialize(param_dict["coupling_map"])

    sys = SkewProduct(
        driver=driver, response=response, coupling_map=coupling_map, **kwargs
    )

    return sys

### Test system

In [None]:
output_dir = "../outputs"
output_path = os.path.join(output_dir, "filtered_params_dict.json")

In [None]:
reloaded_params_dicts = json.load(open(output_path))

In [None]:
test_system_name = "Rucklidge_PanXuZhou"

In [None]:
reloaded_test_params = reloaded_params_dicts[test_system_name][0]

In [None]:
reloaded_test_params

In [None]:
is_skew = "_" in test_system_name
if is_skew:
    driver_name, response_name = test_system_name.split("_")
    sys = init_skew_system_from_params(driver_name, response_name, reloaded_test_params)

In [None]:
# set initial condition
sys.ic = np.array(reloaded_test_params["ic"])
print(sys.ic)

if not sys.has_jacobian():
    print(f"Jacobian not implemented for {test_system_name}")

num_timesteps = 4311
num_periods = 40

ts, traj = sys.make_trajectory(
    num_timesteps,
    pts_per_period=num_timesteps // num_periods,
    return_times=True,
    atol=1e-10,
    rtol=1e-8,
)

In [None]:
transient_frac = 0.05
transient_length = int(transient_frac * num_timesteps)
trajectory = traj[None, transient_length:, :]
print(trajectory.shape)
trajectory_to_plot = trajectory.transpose(0, 2, 1)
driver_coords = trajectory_to_plot[:, : sys.driver_dim, :]
response_coords = trajectory_to_plot[:, sys.driver_dim :, :]
for name, coords in [
    ("driver", driver_coords),
    ("response", response_coords),
]:
    plot_trajs_multivariate(
        coords,
        save_dir=None,
        plot_name=f"reconstructed_{test_system_name}_{name}",
        standardize=True,
        plot_projections=False,
        show_plot=True,
    )

### Remake trajectory with new IC

In [None]:
traj.shape

In [None]:
print(f"Old initial condition: \n {sys.ic}")

In [None]:
n_timesteps = traj.shape[0]
n_dims = traj.shape[1]
new_ic_idx = rng.integers(0, n_timesteps, size=n_dims)
new_ic = np.array([traj[new_ic_idx[i], i] for i in range(n_dims)])
print(f"New initial condition: \n {new_ic}")

In [None]:
sys.ic = new_ic

if not sys.has_jacobian():
    print(f"Jacobian not implemented for {test_system_name}")

num_timesteps = 4311
num_periods = 40

ts, traj = sys.make_trajectory(
    num_timesteps,
    pts_per_period=num_timesteps // num_periods,
    return_times=True,
    atol=1e-10,
    rtol=1e-8,
)

In [None]:
trajectory = traj[None, transient_length:, :]
print(trajectory.shape)
trajectory_to_plot = trajectory.transpose(0, 2, 1)
driver_coords = trajectory_to_plot[:, : sys.driver_dim, :]
response_coords = trajectory_to_plot[:, sys.driver_dim :, :]
for name, coords in [
    ("driver", driver_coords),
    ("response", response_coords),
]:
    plot_trajs_multivariate(
        coords,
        save_dir=None,
        plot_name=f"reconstructed_{test_system_name}_{name}",
        standardize=True,
        plot_projections=False,
        show_plot=True,
    )

### QA: check if corresponding saved arrow file exists

In [None]:
def check_matching_sample_idx(
    saved_params_dict: dict, system_arrow_dirs_lst: list[str], verbose: bool = False
) -> bool:
    params_sample_idx_lst = []
    for entry in saved_params_dict:
        sample_idx = entry["sample_idx"]
        params_sample_idx_lst.append(sample_idx)
    params_sample_idx_lst.sort()

    arrow_sample_idx_paths_lst = []
    arrow_sample_idx_lst = []
    for system_arrow_dir in system_arrow_dirs_lst:
        arrow_file_paths = glob.glob(os.path.join(system_arrow_dir, "*.arrow"))
        if len(arrow_file_paths) > 0:
            arrow_sample_idx_paths_lst.extend(arrow_file_paths)

            sample_idx_lst = [
                int(os.path.basename(arrow_file_path).split("_T-4096.arrow")[0])
                for arrow_file_path in arrow_file_paths
            ]
            arrow_sample_idx_lst.extend(sample_idx_lst)

    arrow_sample_idx_lst.sort()

    if verbose:
        print(arrow_sample_idx_lst)
        print(params_sample_idx_lst)
        print(arrow_sample_idx_paths_lst)
    if arrow_sample_idx_lst != params_sample_idx_lst:
        return False
    return True

In [None]:
split_names_lst = ["copy/final_skew40/train", "copy/final_skew40/train_z5_z10"]
saved_arrow_dirs_lst = [
    os.path.join(DATA_DIR, split_name) for split_name in split_names_lst
]
print(saved_arrow_dirs_lst)

system_arrow_dirs_lst = [
    os.path.join(saved_arrow_dir, test_system_name)
    for saved_arrow_dir in saved_arrow_dirs_lst
]
print(system_arrow_dirs_lst)

In [None]:
check_matching_sample_idx(
    reloaded_params_dicts[test_system_name], system_arrow_dirs_lst, verbose=True
)

In [None]:
for system_name, param_dict in tqdm(
    reloaded_params_dicts.items(), desc="Checking sample idx of all systems..."
):
    curr_system_arrow_dirs_lst = [
        os.path.join(saved_arrow_dir, system_name)
        for saved_arrow_dir in saved_arrow_dirs_lst
    ]
    res = check_matching_sample_idx(param_dict, curr_system_arrow_dirs_lst)
    if not res:
        print(f"Mismatch found for {system_name}")
        break

### Make traj from saved params, with new ic sampled from saved arrow file (NOTE: can't use this method because we only save the response coordinates to arrow file)

In [None]:
test_sample_idx = reloaded_test_params["sample_idx"]
print(test_sample_idx)

saved_traj = None
for system_arrow_dir in system_arrow_dirs_lst:
    cand_path = os.path.join(system_arrow_dir, f"{test_sample_idx}_T-4096.arrow")
    print(cand_path)
    if os.path.exists(cand_path):
        saved_traj, _ = load_trajectory_from_arrow(cand_path)
        break

if saved_traj is None:
    raise ValueError("No saved traj found")
print(saved_traj.shape)

In [None]:
# Sample a new initial condition from the saved trajectory
# We'll randomly select a point from the trajectory to use as our new initial condition
n_timepoints = saved_traj.shape[1]  # type: ignore
n_dims_response = saved_traj.shape[0]  # type: ignore
random_idx = rng.integers(0, n_timepoints, size=n_dims_response)  # type: ignore
# print(random_idx)
new_ic = np.array([saved_traj[i, idx] for i, idx in enumerate(random_idx)])

old_ic = reloaded_test_params["ic"][n_dims_response:]
print(f"Old initial condition: {old_ic}")
print(f"Sampled new initial condition from index {random_idx} of the trajectory")
print(f"New initial condition shape: {new_ic.shape}")
print(f"New initial condition: {new_ic}")