In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from collections import Counter, defaultdict

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np

from panda.utils import (
    load_trajectory_from_arrow,
    plot_trajs_multivariate,
)

In [None]:
# apply_custom_style("../config/plotting.yaml")

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
split_name = "improved/final_skew40/train"
system_name = "Thomas_Sakarya"

In [None]:
subdir = os.path.join(DATA_DIR, split_name, system_name)
files_lst = os.listdir(subdir)
files_lst.sort()
filepath = os.path.join(subdir, files_lst[-1])

In [None]:
transient_time = 512
dyst_coords, _ = load_trajectory_from_arrow(filepath)
dyst_coords = dyst_coords[:, transient_time:]

In [None]:
dyst_coords.shape

In [None]:
plot_trajs_multivariate(
    np.expand_dims(dyst_coords, axis=0),
    plot_name=f"{system_name}",
    show_plot=True,
)

In [None]:
dim = dyst_coords.shape[0]
for i in range(dim):
    plt.figure(figsize=(5, 2))
    plt.plot(dyst_coords[i], "b-", color="tab:blue")
    plt.title(f"Dimension {i}")
    plt.show()

In [None]:
# Create a sample trajectory.
# Here we generate a simple (regular) trajectory as an example.
t = np.linspace(0, 10 * np.pi, 4096)

test_system_periodic = np.array(
    [
        np.sin(t),  # x-coordinate
        np.sin(2 * t),  # y-coordinate
        np.sin(3 * t),  # z-coordinate
    ]
)
test_system_fourier = np.zeros((3, 4096))
for i in range(3):  # For each dimension
    for j in range(10):  # For each mode
        freq = np.random.rand() * 2 * np.pi  # Random frequency
        phase = np.random.rand() * 2 * np.pi  # Random phase
        test_system_fourier[i] += np.sin(freq * t + phase)

test_system_noise = np.random.randn(3, 4097).cumsum(axis=1)

### Power Spectrum

In [None]:
from scipy.fft import rfft

from panda.attractor import check_power_spectrum

In [None]:
check_power_spectrum(dyst_coords)

In [None]:
def plot_power_spectrum(traj: np.ndarray):
    power = np.abs(rfft(traj, axis=1)) ** 2  # type: ignore
    print(power.shape)
    d, n_freqs = power.shape
    _, axes = plt.subplots(d, 1, figsize=(10, 2 * d), sharex=True)
    x = np.arange(n_freqs)

    for i in range(d):
        axes[i].plot(x, power[i], "b-", color="tab:blue")
        axes[i].set_yscale("log")
        axes[i].set_ylabel(f"Dim {i + 1}")
        axes[i].grid(True)

    # Set common x-axis label
    axes[-1].set_xlabel("Frequency")

    # Add a title to the figure
    plt.suptitle("Power Spectrum")

    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)

    # Show the plot
    plt.show()


plot_power_spectrum(dyst_coords)

In [None]:
# plot_power_spectrum(test_system_fourier)

### Grassberger-Procaccia Dimension

In [None]:
from dysts.analysis import gp_dim

In [None]:
dyst_coords.shape

In [None]:
gp = gp_dim(dyst_coords.T)

In [None]:
print(gp)

In [None]:
# from panda.utils import compute_gp_dimension
# compute_gp_dimension(test_system_periodic.T)
# compute_gp_dimension(test_system_noise.T)
# compute_gp_dimension(test_system_fourier.T)

### Additional Checks

#### Limit Cycle Test

In [None]:
from scipy.spatial.distance import cdist

In [None]:
# NOTE: could also just import from attractor.py but I wanted to expose some things for plotting


def check_not_limit_cycle(
    traj: np.ndarray,
    tolerance: float = 1e-3,
    min_prop_recurrences: float = 0.0,
    min_counts_per_rtime: int = 100,
    min_block_length: int = 1,
    min_recurrence_time: int = 1,
    enforce_endpoint_recurrence: bool = False,
    return_computed_quantities: bool = False,
) -> bool | tuple[bool, dict]:
    """
    limit cycle test from attractor.py, exposed here for plotting purposes
    Returns: True if the trajectory is not a limit cycle, False otherwise
        If False and also return_computed_quantities is True, returns a tuple (False, computed_quantities)
    """
    n = traj.shape[1]

    # Step 1: Calculate the pairwise distance matrix, shape should be (N, N)
    dist_matrix = cdist(traj.T, traj.T, metric="euclidean").astype(np.float16)
    dist_matrix = np.triu(dist_matrix, k=1)

    # Step 2: Get recurrence times from thresholding distance matrix
    recurrence_indices = np.asarray(
        (dist_matrix < tolerance) & (dist_matrix > 0)
    ).nonzero()

    n_recurrences = len(recurrence_indices[0])
    if n_recurrences == 0:
        return True

    if enforce_endpoint_recurrence:
        # check if an eps neighborhood around either n-1 or 0 is in either of the recurrence indices
        eps = 0
        if not any(
            (n - 1) - max(indices) <= eps or min(indices) - 0 <= eps
            for indices in recurrence_indices
        ):
            return True

    # get recurrence times
    recurrence_times = np.abs(recurrence_indices[0] - recurrence_indices[1])
    recurrence_times = recurrence_times[recurrence_times >= min_recurrence_time]

    # Heuristic 1: Check if there are enough recurrences to consider a limit cycle
    n_recurrences = len(recurrence_times)
    if n_recurrences < int(min_prop_recurrences * n):
        return True

    # Heuristic 2: Check if there are enough valid recurrence times
    rtimes_counts = Counter(recurrence_times)
    n_valid_rtimes = sum(
        1 for count in rtimes_counts.values() if count >= min_counts_per_rtime
    )
    if n_valid_rtimes < 1:
        return True

    # Heuristic 3: Check if the valid recurrence times are formed of blocks of consecutive timepoints
    if min_block_length > 1:
        rtimes_dict = defaultdict(list)
        block_length = 1
        prev_rtime = None
        prev_t1 = None
        prev_t2 = None
        rtimes_is_valid = False
        num_blocks = 0
        # assuming recurrence_indices[0] is sorted
        for t1, t2 in zip(*recurrence_indices):
            rtime = abs(t2 - t1)
            if rtime < min_recurrence_time:
                continue
            if (
                rtime == prev_rtime
                and abs(t1 - prev_t1) == 1
                and abs(t2 - prev_t2) == 1
            ):
                block_length += 1
            else:
                if block_length > min_block_length:
                    rtimes_dict[prev_rtime].append(block_length)
                    num_blocks += 1
                block_length = 1
            prev_t1, prev_t2, prev_rtime = t1, t2, rtime
            if block_length > min_block_length * 2:
                rtimes_is_valid = True
                break
            if num_blocks >= 2:  # if valid, save computation and break
                rtimes_is_valid = True
                break
        if not rtimes_is_valid:
            return True

    computed_quantities = {
        "dist_matrix": dist_matrix,
        "recurrence_indices": recurrence_indices,
        "recurrence_times": recurrence_times,
    }
    if return_computed_quantities:
        return False, computed_quantities
    return False

In [None]:
# is_not_limit_cycle_result = check_not_limit_cycle(
#     dyst_coords,
#     tolerance=1e-3,
#     min_prop_recurrences=0.1,
#     min_counts_per_rtime=200,
#     min_block_length=50,
#     enforce_endpoint_recurrence=True,
#     return_computed_quantities=True,
# )
# print(is_not_limit_cycle_result)

In [None]:
def plot_recurrence_times(
    traj: np.ndarray,
    dist_matrix: np.ndarray,
    recurrence_times: np.ndarray,
    recurrence_indices: np.ndarray,
):
    dyst_name = system_name.split("_")[0]
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 18))

    ax1.hist(recurrence_times, bins=100, edgecolor="black")
    ax1.set_xlabel("Recurrence Time")
    ax1.set_ylabel("Frequency")
    ax1.set_title("Recurrence Times")
    ax1.grid(True)

    xyz = traj[:3, :]
    xyz1 = xyz[:, : int(n / 2)]
    xyz2 = xyz[:, int(n / 2) :]
    ic_point = traj[:3, 0]
    final_point = traj[:3, -1]
    ax2 = fig.add_subplot(312, projection="3d")
    ax2.plot(*xyz1, alpha=0.5, linewidth=1, color="tab:blue")
    ax2.plot(*xyz2, alpha=0.5, linewidth=1, color="tab:orange")
    ax2.scatter(*ic_point, marker="*", s=100, alpha=0.5, color="tab:blue")
    ax2.scatter(*final_point, marker="x", s=100, alpha=0.5, color="tab:orange")
    ax2.set_xlabel("X")
    ax2.set_ylabel("Y")
    ax2.set_zlabel("Z")  # type: ignore
    ax2.set_title(dyst_name)

    ax3 = fig.add_subplot(313)
    X, Y = np.meshgrid(np.arange(dist_matrix.shape[0]), np.arange(dist_matrix.shape[1]))
    pcolormesh = ax3.pcolormesh(
        X,
        Y,
        dist_matrix,
        cmap="viridis_r",
        shading="auto",
        norm=colors.LogNorm(),
    )
    plt.colorbar(pcolormesh, ax=ax3)
    ax3.scatter(
        recurrence_indices[0],
        recurrence_indices[1],
        color="black",
        s=20,
        alpha=0.5,
    )
    ax3.set_title("Recurrence Distance Matrix")
    ax3.set_xlabel("Time")
    ax3.set_ylabel("Time")
    ax3.set_aspect("equal")
    plt.tight_layout()
    plt.show()