In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path
from typing import Any

import dysts.flows as flows
import numpy as np

from dystformer.utils import make_ensemble_from_arrow_dir

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import TABLEAU_COLORS

from dystformer.utils import apply_custom_style, safe_standardize

COLORS = list(TABLEAU_COLORS.values())

# Apply matplotlib style from config
apply_custom_style("../config/plotting.yaml")

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
np.arange(0, 10, 1)

In [None]:
def plot_grid_trajs_multivariate(
    ensemble: dict[str, np.ndarray],
    save_path: str | None = None,
    standardize: bool = False,
    dims_3d: list[int] = [0, 1, 2],
    sample_indices: list[int] | np.ndarray | None = None,
    n_rows_cols: tuple[int, int] | None = None,
    subplot_size: tuple[int, int] = (3, 3),
    row_col_padding: tuple[float, float] = (0.2, 0.2),
    plot_kwargs: dict[str, Any] = {},
    scatter_kwargs: dict[str, Any] = {},
    title_kwargs: dict[str, Any] = {},
    color: str = "black",
    show_start_end: bool = True,
    show_titles: bool = True,
) -> None:
    n_systems = len(ensemble)
    if n_rows_cols is None:
        n_rows = int(np.ceil(np.sqrt(n_systems)))
        n_cols = int(np.ceil(n_systems / n_rows))
    else:
        n_rows, n_cols = n_rows_cols

    row_padding, column_padding = row_col_padding
    figsize = (
        n_cols * subplot_size[0] * (1 + column_padding),
        n_rows * subplot_size[1] * (1 + row_padding),
    )
    fig = plt.figure(figsize=figsize)

    if sample_indices is None:
        sample_indices = np.zeros(len(ensemble), dtype=int)

    for i, (system_name, trajectories) in enumerate(ensemble.items()):
        assert trajectories.shape[1] >= len(dims_3d), (
            f"Data has {trajectories.shape[1]} dimensions, but {len(dims_3d)} dimensions were requested for plotting"
        )

        if standardize:
            trajectories = safe_standardize(trajectories)

        ax = fig.add_subplot(n_rows, n_cols, i + 1, projection="3d")

        sample_idx = sample_indices[i]
        xyz = trajectories[sample_idx, dims_3d, :]
        ax.plot(*xyz, **plot_kwargs, color=color)

        if show_start_end:
            ic_pt = xyz[:, 0]
            ax.scatter(*ic_pt, marker="*", **scatter_kwargs, color=color)

            end_pt = xyz[:, -1]
            ax.scatter(*end_pt, marker="x", **scatter_kwargs, color=color)

        if show_titles:
            system_name_title = system_name.replace("_", " + ")
            ax.set_title(f"{system_name_title}", **title_kwargs)
        fig.patch.set_facecolor("white")  # Set the figure's face color to white
        ax.set_facecolor("white")  # Set the axes' face color to white
        # Hide tick marks
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])  # type: ignore
        # Hide axes
        ax.set_axis_off()
        # Ensure 3D grid is off
        ax.grid(False)
        ax.grid(False)

    plt.subplots_adjust(hspace=row_padding, wspace=column_padding)

    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        print("Plotting grid of 3D trajectories and saving to ", save_path)
        plt.savefig(save_path)
    plt.show()
    plt.close()

In [None]:
def make_response_ensemble(ensemble: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """
    Make an ensemble of just the response coordinates for each system in ensemble.
    Use case: when saving the concatenated driver and response coordinates of skew system.
    """
    driver_dims = {
        sys: getattr(flows, sys.split("_")[0])().dimension for sys in ensemble.keys()
    }
    print("got driver dims")
    response_ensemble = {
        sys: ensemble[sys][:, driver_dims[sys] :, :] for sys in ensemble.keys()
    }
    return response_ensemble

In [None]:
rseed = 443
split_name = "improved/final_skew40/train"
n_systems_plot = 10
n_samples_plot = 1

rng = np.random.default_rng(rseed)
split_dir = os.path.join(DATA_DIR, split_name)

In [None]:
selected_dyst_names = [
    "LorenzCoupled_HyperPang",
    "Bouali2_StickSlipOscillator",
    "ThomasLabyrinth_TurchinHanski",
    "HyperXu_LorenzCoupled",
    "BlinkingRotlet_YuWang2",
    # "Laser_IsothermalChemical",
    "DoubleGyre_SanUmSrisuchinwong",
    "SwingingAtwood_Finance",
    "DequanLi_HenonHeiles",
    # "CellularNeuralNetwork_SprottC",
    "CellCycle_HyperWang",
    # "Thomas_WindmiReduced",
    # "CoevolvingPredatorPrey_HyperPang",
    # "SprottM_Torus",
    # "PehlivanWei_NuclearQuadrupole",
    # "RayleighBenard_IsothermalChemical",
    "RikitakeDynamo_Aizawa",
    # # "ItikBanksTumor_Lorenz",
    # "Colpitts_ForcedVanDerPol",
    # # "ForcedBrusselator_Lorenz96",
    # # "Lorenz96_Chua",
    # "Halvorsen_SprottM",
]
n_selected_systems = len(selected_dyst_names)
print("number of selected systems: ", len(selected_dyst_names))

In [None]:
if n_selected_systems < n_systems_plot:
    all_dyst_names_lst = [
        folder.name for folder in Path(split_dir).iterdir() if folder.is_dir()
    ]
    # Filter out selected systems from all_dyst_names_lst
    filtered_dyst_names = [
        name for name in all_dyst_names_lst if name not in selected_dyst_names
    ]
    # Randomly sample the required number of systems
    dyst_names_lst = rng.choice(
        filtered_dyst_names, size=(n_systems_plot - n_selected_systems), replace=False
    ).tolist()
else:
    dyst_names_lst = selected_dyst_names

dyst_names_lst.extend(selected_dyst_names)
print(f"dyst names: {dyst_names_lst}")
print("number of systems: ", len(dyst_names_lst))

In [None]:
plot_name_suffix = "_".join(split_name.split("/"))
plot_save_dir = "dataset_figs"

In [None]:
ensemble = make_ensemble_from_arrow_dir(
    DATA_DIR, split_name, dyst_names_lst=dyst_names_lst
)

In [None]:
n_systems = len(ensemble)
# assert n_systems == n_systems_plot
default_name = f"{n_systems}_systems"

plot_name = f"{default_name}_{plot_name_suffix}" if plot_name_suffix else default_name
save_path = os.path.join(plot_save_dir, f"{plot_name}.pdf")

In [None]:
sample_indices = [0] * n_systems

In [None]:
len(ensemble)

In [None]:
plot_grid_trajs_multivariate(
    ensemble,
    save_path=save_path,
    standardize=True,
    sample_indices=sample_indices,
    n_rows_cols=(2, 5),
    subplot_size=(4, 4),
    row_col_padding=(0.0, 0.00),
    plot_kwargs={"linewidth": 0.3, "alpha": 0.8},
    title_kwargs={"fontweight": "bold"},
    color="black",
    show_start_end=False,
    show_titles=False,
)