In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path
from typing import Any

import dysts.flows as flows
import numpy as np

from dystformer.utils import make_ensemble_from_arrow_dir

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import TABLEAU_COLORS

from dystformer.utils import apply_custom_style

COLORS = list(TABLEAU_COLORS.values())

# Apply matplotlib style from config
apply_custom_style("../config/plotting.yaml")

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
np.arange(0, 10, 1)

In [None]:
def plot_grid_trajs_multivariate(
    ensemble: dict[str, np.ndarray],
    save_path: str | None = None,
    dims_3d: list[int] = [0, 1, 2],
    sample_indices: list[int] | np.ndarray | None = None,
    n_rows_cols: tuple[int, int] | None = None,
    subplot_size: tuple[int, int] = (3, 3),
    row_col_padding: tuple[float, float] = (0.0, 0.0),
    plot_kwargs: dict[str, Any] = {},
    title_kwargs: dict[str, Any] = {},
    custom_colors: list[str] = [],
    show_titles: bool = True,
    show_axes: bool = False,
    plot_projections: bool = False,
) -> None:
    n_systems = len(ensemble)
    if n_rows_cols is None:
        n_rows = int(np.ceil(np.sqrt(n_systems)))
        n_cols = int(np.ceil(n_systems / n_rows))
    else:
        n_rows, n_cols = n_rows_cols

    row_padding, column_padding = row_col_padding
    # Reduce spacing by using smaller padding multipliers
    figsize = (
        n_cols * subplot_size[0] * (1 + column_padding),
        n_rows * subplot_size[1] * (1 + row_padding),
    )
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(wspace=column_padding, hspace=row_padding)

    if sample_indices is None:
        sample_indices = np.zeros(len(ensemble), dtype=int)
    # Keep track of the last used color index to avoid consecutive same colors
    last_color_idx = -1

    for i, (system_name, trajectories) in enumerate(ensemble.items()):
        assert trajectories.shape[1] >= len(dims_3d), (
            f"Data has {trajectories.shape[1]} dimensions, but {len(dims_3d)} dimensions were requested for plotting"
        )

        ax = fig.add_subplot(n_rows, n_cols, i + 1, projection="3d")

        sample_idx = sample_indices[i]
        xyz = trajectories[sample_idx, dims_3d, :]

        # Select a color that's different from the last one used
        if len(custom_colors) > 0:
            if len(custom_colors) > 1:
                # Get a new color index that's different from the last one
                available_indices = [
                    j for j in range(len(custom_colors)) if j != last_color_idx
                ]
                color_idx = np.random.choice(available_indices)
                last_color_idx = color_idx
            else:
                # If only one color is available, use it
                color_idx = 0
        else:
            color_idx = 0
        ax.plot(
            *xyz,
            **plot_kwargs,
            color=custom_colors[color_idx] if len(custom_colors) > 0 else None,
            zorder=10,
        )

        if show_titles:
            system_name_title = system_name.replace("_", " + ")
            ax.set_title(f"{system_name_title}", **title_kwargs)
        fig.patch.set_facecolor("white")  # Set the figure's face color to white
        ax.set_facecolor("white")  # Set the axes' face color to white
        # Hide tick marks
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])  # type: ignore

        if not show_axes:
            ax.set_axis_off()
        ax.grid(False)

        if plot_projections:
            x_min, x_max = ax.get_xlim3d()  # type: ignore
            y_min, y_max = ax.get_ylim3d()  # type: ignore
            z_min, z_max = ax.get_zlim3d()  # type: ignore
            palpha = 0.1  # 0.15

            proj_color = "black"
            proj_linewidth = 0.3

            # XY plane projection (bottom)
            ax.plot(
                xyz[0],
                xyz[1],
                z_min,
                alpha=palpha,
                linewidth=proj_linewidth,
                color=proj_color,
            )

            # XZ plane projection (back)
            ax.plot(
                xyz[0],
                y_max,
                xyz[2],
                alpha=palpha,
                linewidth=proj_linewidth,
                color=proj_color,
            )

            # YZ plane projection (right)
            ax.plot(
                x_min,
                xyz[1],
                xyz[2],
                alpha=palpha,
                linewidth=proj_linewidth,
                color=proj_color,
            )

    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
    plt.show()
    plt.close()

In [None]:
def make_response_ensemble(ensemble: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """
    Make an ensemble of just the response coordinates for each system in ensemble.
    Use case: when saving the concatenated driver and response coordinates of skew system.
    """
    driver_dims = {
        sys: getattr(flows, sys.split("_")[0])().dimension for sys in ensemble.keys()
    }
    print("got driver dims")
    response_ensemble = {
        sys: ensemble[sys][:, driver_dims[sys] :, :] for sys in ensemble.keys()
    }
    return response_ensemble

In [None]:
rseed = 396
split_name = "improved/final_skew40/train"

rng = np.random.default_rng(rseed)
split_dir = os.path.join(DATA_DIR, split_name)

In [None]:
selected_dyst_names = [
    "Lorenz96_Aizawa",
    "Bouali2_StickSlipOscillator",
    "ThomasLabyrinth_TurchinHanski",
    "HyperXu_LorenzCoupled",
    "BlinkingRotlet_YuWang2",
    "SprottP_NewtonLiepnik",
    "Laser_IsothermalChemical",
    "LorenzCoupled_HyperPang",
    "SprottL_HindmarshRose",
    "GuckenheimerHolmes_ZhouChen",
    "DoubleGyre_SanUmSrisuchinwong",
    "SwingingAtwood_Finance",
    "DequanLi_HenonHeiles",
    "CellularNeuralNetwork_SprottC",
    "SprottN_SprottC",
    "CellCycle_HyperWang",
    "Thomas_WindmiReduced",
    "CoevolvingPredatorPrey_HyperPang",
    "SprottM_Torus",
    "PehlivanWei_NuclearQuadrupole",
    "RayleighBenard_IsothermalChemical",
    "RikitakeDynamo_Aizawa",
    "ItikBanksTumor_Lorenz",
    "Colpitts_ForcedVanDerPol",
    "ForcedBrusselator_Lorenz96",
    "Lorenz96_Chua",
    "Halvorsen_SprottM",
    "Bouali2_CellularNeuralNetwork",
    "InteriorSquirmer_Halvorsen",
    "ForcedBrusselator_RabinovichFabrikant",
]
n_selected_systems = len(selected_dyst_names)
print("number of selected systems: ", len(selected_dyst_names))

In [None]:
n_systems_plot = 30
n_samples_plot = 1

In [None]:
if n_selected_systems < n_systems_plot:
    all_dyst_names_lst = [
        folder.name for folder in Path(split_dir).iterdir() if folder.is_dir()
    ]
    # Filter out selected systems from all_dyst_names_lst
    filtered_dyst_names = [
        name for name in all_dyst_names_lst if name not in selected_dyst_names
    ]
    # Randomly sample the required number of systems
    dyst_names_lst = rng.choice(
        filtered_dyst_names, size=(n_systems_plot - n_selected_systems), replace=False
    ).tolist()
    dyst_names_lst.extend(selected_dyst_names)
else:
    dyst_names_lst = selected_dyst_names

print(f"dyst names: {dyst_names_lst}")
print("number of systems: ", len(dyst_names_lst))

In [None]:
# ForcedFitzHughNagumo_VallisElNino
# SprottMore_CellularNeuralNetwork
# GuckenheimerHolmes_ZhouChen
# Chen_Colpitts
# SprottN_SprottC

In [None]:
plot_name_suffix = "_".join(split_name.split("/"))
plot_save_dir = "dataset_figs"

In [None]:
len(dyst_names_lst)

In [None]:
ensemble = make_ensemble_from_arrow_dir(
    DATA_DIR, split_name, dyst_names_lst=dyst_names_lst
)

In [None]:
n_systems = len(ensemble)
print(f"n_systems: {n_systems}")
# assert n_systems == n_systems_plot
default_name = f"{n_systems}_systems"

plot_name = f"{default_name}_{plot_name_suffix}" if plot_name_suffix else default_name
save_path = os.path.join(plot_save_dir, f"{plot_name}.pdf")

In [None]:
sample_indices = [0] * n_systems
sample_indices[5] = 8

In [None]:
len(ensemble)

In [None]:
# Define 6 dark, visually appealing colors from around the color wheel
# These are rich, saturated colors that work well for visualizations
custom_colors = [
    "steelblue",
    # "darkslateblue",
    "teal",
    "forestgreen",
    # "seagreen",
    "firebrick",
    "darkorange",
    # "darkmagenta",
    "mediumvioletred",
    "indigo",
]

In [None]:
plot_grid_trajs_multivariate(
    ensemble,
    save_path=save_path,
    sample_indices=sample_indices,
    n_rows_cols=(6, 5),
    subplot_size=(4, 4),
    row_col_padding=(0.0, 0.0),
    plot_kwargs={"linewidth": 0.3, "alpha": 0.8},
    title_kwargs={"fontweight": "bold"},
    custom_colors=custom_colors,
    show_titles=False,
    show_axes=True,
    plot_projections=True,
)

In [None]:
def plot_traj(
    x: np.ndarray,
    plot_projections: bool = True,
    save_path: str | None = None,
) -> None:
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection="3d")

    x_3d = x[:3, :]
    ax.plot(*x_3d, linewidth=0.5, zorder=10, color="black", alpha=0.8)  # X,Y,Z

    ax.ticklabel_format(style="sci", scilimits=(0, 0), axis="both")

    if plot_projections:
        x_min, x_max = ax.get_xlim3d()  # type: ignore
        y_min, y_max = ax.get_ylim3d()  # type: ignore
        z_min, z_max = ax.get_zlim3d()  # type: ignore
        palpha = 0.1  # 0.15

        proj_color = None
        proj_linewidth = 0.3

        # XY plane projection (bottom)
        ax.plot(
            x_3d[0],
            x_3d[1],
            z_min,
            alpha=palpha,
            linewidth=proj_linewidth,
            color=proj_color,
        )

        # XZ plane projection (back)
        ax.plot(
            x_3d[0],
            y_max,
            x_3d[2],
            alpha=palpha,
            linewidth=proj_linewidth,
            color=proj_color,
        )

        # YZ plane projection (right)
        ax.plot(
            x_min,
            x_3d[1],
            x_3d[2],
            alpha=palpha,
            linewidth=proj_linewidth,
            color=proj_color,
        )

    # Make clean projection
    ax.grid(False)
    ax.set_facecolor("white")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])  # type: ignore

    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(
            save_path,
            bbox_inches="tight",
        )
    plt.show()

In [None]:
save_dir = "dataset_figs"

In [None]:
# Bouali2_CellularNeuralNetwork
# InteriorSquirmer_Halvorsen
# ForcedBrusselator_RabinovichFabrikant

In [None]:
for dyst_name in ensemble.keys():
    print(f"Plotting {dyst_name}")
    plot_traj(
        ensemble[dyst_name][0],
        plot_projections=True,
        # save_path=os.path.join(save_dir, f"{dyst_name}.pdf"),
    )