In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
from collections import defaultdict
from pathlib import Path

import dysts.flows as flows
import numpy as np
from tqdm import tqdm

from panda.utils.data_utils import (
    get_system_filepaths,
    load_trajectory_from_arrow,
)
from panda.utils.plot_utils import plot_trajs_multivariate

### Utils and Data Path Setup

In [None]:
WORK_DIR = os.environ.get("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
main_split_name = "improved/final_skew40"

In [None]:
# split_names_lst = [f"{main_split_name}/train", f"{main_split_name}/train_z5_z10"]
split_names_lst = [
    f"{main_split_name}/test_zeroshot",
    f"{main_split_name}/test_zeroshot_z5_z10",
]

In [None]:
split_paths_lst = [os.path.join(DATA_DIR, split_name) for split_name in split_names_lst]
print(f"split paths: {split_paths_lst}")

In [None]:
subdirs = []
for split_path in split_paths_lst:
    subdirs.extend([d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))])
print(f"Found {len(subdirs)} subdirectories in {split_path}")

In [None]:
unique_subdirs = list(set(subdirs))
print(f"Found {len(unique_subdirs)} unique subdirectories in {split_path}")

In [None]:
subdir_samples_dict = defaultdict(list)
for subdir in unique_subdirs:
    for split_path in split_paths_lst:
        if subdir in os.listdir(split_path):
            subdir_samples_dict[subdir].extend(
                [int(filename.split("_T-4096.arrow")[0]) for filename in os.listdir(os.path.join(split_path, subdir))]
            )
            subdir_samples_dict[subdir].sort()

In [None]:
print(len(subdir_samples_dict.keys()))

### Make Filtered Params Dict

In [None]:
parameters_json_path_train = os.path.join(DATA_DIR, f"{main_split_name}/parameters/train/successes.json")
parameters_json_path_test = os.path.join(DATA_DIR, f"{main_split_name}/parameters/test/successes.json")

In [None]:
saved_params_dict_train = json.load(open(parameters_json_path_train))

In [None]:
saved_params_dict_test = json.load(open(parameters_json_path_test))

In [None]:
print(f"Found {len(saved_params_dict_train.keys())} systems with successful param perts in train")
print(f"Found {len(saved_params_dict_test.keys())} systems with successful param perts in test")
print(
    f"... for a total of {len(saved_params_dict_train.keys()) + len(saved_params_dict_test.keys())} systems with successful param perts"
)

In [None]:
list(saved_params_dict_train.keys())[:10]

In [None]:
filtered_subdir_samples_dict = {}
filtered_params_dict = {}
total_systems = 0

zs_counter = 0
for i, (system_name, samples_lst) in tqdm(enumerate(subdir_samples_dict.items()), desc="Checking all subdirs..."):
    # print(f"System: {system_name} has {len(samples_lst)} samples")
    system_param_dict = {}

    # system_name must be in either saved_params_dict_train XOR saved_params_dict_test
    if system_name in saved_params_dict_train:
        system_param_dict = saved_params_dict_train[system_name]
    elif system_name in saved_params_dict_test:
        system_param_dict = saved_params_dict_test[system_name]
    else:
        zs_counter += 1
        continue

    # print(f"system_param_dict for {system_name}: {system_param_dict}")

    samples_lst_in_system_param_dict = [d["sample_idx"] for d in system_param_dict]
    filtered_samples_lst = list(set(samples_lst) & set(samples_lst_in_system_param_dict))

    if len(filtered_samples_lst) == 0:
        continue

    total_systems += len(filtered_samples_lst)
    # Get the dicts corresponding to the filtered sample indices
    filtered_params_dict[system_name] = [
        param_dict for param_dict in system_param_dict if param_dict["sample_idx"] in filtered_samples_lst
    ]
    filtered_subdir_samples_dict[system_name] = filtered_samples_lst

print("zs_counter: ", zs_counter)

In [None]:
len(filtered_subdir_samples_dict)

In [None]:
total_systems

### Save Filtered Params Dict

In [None]:
# Create output directory if it doesn't exist
output_dir = "../../outputs"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Save the filtered parameters dictionary to a JSON file
# output_path = os.path.join(output_dir, "filtered_params_dict.json")
output_path = os.path.join(output_dir, "filtered_params_dict_test_zeroshot.json")

In [None]:
output_path

NOTE: uncomment the cell below to save the filtered params dict

In [None]:
# Convert numpy arrays to lists for JSON serialization
def convert_numpy_to_list(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_numpy_to_list(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_to_list(item) for item in obj]
    else:
        return obj


# Create a serializable version of the dictionary
serializable_params_dict = {}
for system_name, param_dicts in filtered_params_dict.items():
    serializable_params_dict[system_name] = [convert_numpy_to_list(param_dict) for param_dict in param_dicts]

# Save to JSON file
with open(output_path, "w") as f:
    json.dump(serializable_params_dict, f, indent=2)

print(f"Saved filtered parameters to {output_path}")

### Test Param Dict

In [None]:
reloaded_params_dicts = json.load(open(output_path))

In [None]:
len(reloaded_params_dicts.keys())

In [None]:
tot_systems_reloaded = sum([len(v) for v in reloaded_params_dicts.values()])
print(f"tot_systems_reloaded: {tot_systems_reloaded}")

In [None]:
# test_system_name = "LorenzBounded_YuWang2"
test_system_name = "HastingsPowell_LuChenCheng"

In [None]:
reloaded_test_params = reloaded_params_dicts[test_system_name][0]

In [None]:
reloaded_test_params

In [None]:
test_param_dicts = filtered_params_dict[test_system_name]

In [None]:
test_params = test_param_dicts[0]

In [None]:
test_params

In [None]:
for key_name in test_params.keys():
    if key_name == "ic":
        assert np.allclose(reloaded_test_params["ic"], test_params["ic"])
    elif key_name == "coupling_map":
        continue
    else:
        assert reloaded_test_params[key_name] == test_params[key_name]
    print(key_name)

### Check Trajectories. NOTE: More testing done in `test_reloaded_params.ipynb`

In [None]:
from panda.utils import init_skew_system_from_params

In [None]:
is_skew = "_" in test_system_name
if is_skew:
    driver_name, response_name = test_system_name.split("_")
    sys = init_skew_system_from_params(driver_name, response_name, reloaded_test_params)

In [None]:
sys.name

In [None]:
# set initial condition
sys.ic = np.array(reloaded_test_params["ic"])
print(sys.ic)

if not sys.has_jacobian():
    print(f"Jacobian not implemented for {test_system_name}")

num_timesteps = 4311
num_periods = 40

ts, traj = sys.make_trajectory(
    num_timesteps,
    pts_per_period=num_timesteps // num_periods,
    return_times=True,
    atol=1e-10,
    rtol=1e-8,
)

In [None]:
transient_frac = 0.05
transient_length = int(transient_frac * num_timesteps)
trajectory = traj[None, transient_length:, :]
print(trajectory.shape)
trajectory_to_plot = trajectory.transpose(0, 2, 1)
driver_coords = trajectory_to_plot[:, : sys.driver_dim, :]
response_coords = trajectory_to_plot[:, sys.driver_dim :, :]
for name, coords in [
    ("driver", driver_coords),
    ("response", response_coords),
]:
    plot_trajs_multivariate(
        coords,
        save_dir="tests/figs",
        plot_name=f"reconstructed_{test_system_name}_{name}",
        standardize=True,
        plot_projections=False,
        show_plot=True,
    )

In [None]:
sample_idx = 0

filepaths = get_system_filepaths(test_system_name, DATA_DIR, split_names_lst[0])[sample_idx : sample_idx + 1]
print(f"{test_system_name} filepaths: ", filepaths)

In [None]:
def accumulate_coords(
    filepaths: list[Path], one_dim_target: bool = False, num_samples: int | None = None
) -> np.ndarray:
    dyst_coords_samples = []
    for filepath in filepaths:
        if num_samples is not None and len(dyst_coords_samples) >= num_samples:
            break
        dyst_coords, _ = load_trajectory_from_arrow(filepath, one_dim_target)
        dyst_coords_samples.append(dyst_coords)

    dyst_coords_samples = np.array(dyst_coords_samples)  # type: ignore
    return dyst_coords_samples

In [None]:
dyst_coords_samples = accumulate_coords(filepaths, one_dim_target=False)
coords_dim = dyst_coords_samples.shape[1]
print(f"{test_system_name} coords_dim: ", coords_dim)

In [None]:
# plot the trajectories
plot_name = f"{test_system_name}"

is_skew = "_" in test_system_name
if is_skew and coords_dim >= 6:  # hacky check
    driver_name, _ = test_system_name.split("_")
    driver_dim = getattr(flows, driver_name)().dimension
    driver_coords = dyst_coords_samples[:, :driver_dim, :]
    response_coords = dyst_coords_samples[:, driver_dim:, :]
    for name, coords in [
        ("driver", driver_coords),
        ("response", response_coords),
    ]:
        plot_trajs_multivariate(
            coords,
            save_dir="tests/figs",
            plot_name=f"{plot_name}_{name}",
            samples_subset=None,
            standardize=True,
            plot_projections=False,
            show_plot=True,
        )
else:
    plot_trajs_multivariate(
        dyst_coords_samples,
        save_dir="tests/figs",
        plot_name=plot_name,
        samples_subset=None,
        standardize=True,
        plot_projections=False,
        show_plot=True,
    )