In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from collections import Counter, defaultdict

import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np

from dystformer.utils import (
    load_trajectory_from_arrow,
    plot_trajs_multivariate,
)

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
split_name = "improved/final_skew40/train"
system_name = "RikitakeDynamo_LuChen"

# split_name = "final_base40/train"
# system_name = "Tsucs2"

# split_name = "big_base80_run1/train"
# system_name = "Aizawa"

# split_name = "final_base20/train"
# system_name = "Lorenz"

In [None]:
filepath = os.path.join(DATA_DIR, split_name, system_name, "1_T-4096.arrow")
dyst_coords, _ = load_trajectory_from_arrow(filepath)

In [None]:
dyst_coords.shape

In [None]:
plot_trajs_multivariate(
    np.expand_dims(dyst_coords, axis=0),
    plot_name=f"{system_name}",
    show_plot=True,
)

In [None]:
dim = dyst_coords.shape[0]
for i in range(dim):
    plt.figure(figsize=(5, 2))
    plt.plot(dyst_coords[i], "b-", color="tab:blue")
    plt.title(f"Dimension {i}")
    plt.show()

In [None]:
# Create a sample trajectory.
# Here we generate a simple (regular) trajectory as an example.
t = np.linspace(0, 10 * np.pi, 4096)

test_system_periodic = np.array(
    [
        np.sin(t),  # x-coordinate
        np.sin(2 * t),  # y-coordinate
        np.sin(3 * t),  # z-coordinate
    ]
)
test_system_fourier = np.zeros((3, 4096))
for i in range(3):  # For each dimension
    for j in range(10):  # For each mode
        freq = np.random.rand() * 2 * np.pi  # Random frequency
        phase = np.random.rand() * 2 * np.pi  # Random phase
        test_system_fourier[i] += np.sin(freq * t + phase)

test_system_noise = np.random.randn(3, 4097).cumsum(axis=1)

### Power Spectrum

In [None]:
from scipy.fft import rfft

from dystformer.attractor import check_power_spectrum

In [None]:
check_power_spectrum(dyst_coords)

In [None]:
def plot_power_spectrum(traj: np.ndarray):
    power = np.abs(rfft(traj, axis=1)) ** 2  # type: ignore
    print(power.shape)
    d, n_freqs = power.shape
    _, axes = plt.subplots(d, 1, figsize=(10, 2 * d), sharex=True)
    x = np.arange(n_freqs)

    for i in range(d):
        axes[i].plot(x, power[i], "b-", color="tab:blue")
        axes[i].set_yscale("log")
        axes[i].set_ylabel(f"Dim {i + 1}")
        axes[i].grid(True)

    # Set common x-axis label
    axes[-1].set_xlabel("Frequency")

    # Add a title to the figure
    plt.suptitle("Power Spectrum")

    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)

    # Show the plot
    plt.show()


plot_power_spectrum(dyst_coords)

In [None]:
# plot_power_spectrum(test_system_fourier)

### Grassberger-Procaccia Dimension

In [None]:
from dysts.analysis import gp_dim

In [None]:
dyst_coords.shape

In [None]:
gp = gp_dim(dyst_coords.T)

In [None]:
print(gp)

In [None]:
from scipy.spatial.distance import pdist


def compute_gp_dimension(points, num_r=50, scaling_range_idx=None):
    """
    Computes and plots the Grassberger-Procaccia correlation integral and estimates the GP dimension.

    Parameters:
        points (ndarray): Array of points with shape (T, dim).
        num_r (int): Number of r values to consider in the logarithmically spaced range.
        scaling_range_idx (tuple or slice): Indices to select the scaling region for the linear fit.
                                           If None, a default middle 50% region is used.

    Returns:
        D2 (float): Estimated correlation (GP) dimension.
    """
    # Define the range of r values based on the pairwise distances
    # We use the min and max distance from the data to set our range.
    distances = pdist(points, metric="euclidean")
    N_pairs = len(distances)
    print(distances.shape)

    r_min = np.min(distances)
    r_max = np.max(distances)
    print(f"r_min: {r_min}, r_max: {r_max}")
    r_min = max(r_min, 1e-10)
    # Generate logarithmically spaced r values
    r_vals = np.logspace(np.log10(r_min), np.log10(r_max), num_r)

    # Compute the correlation integral C(r)
    # Compute all pairwise distances (for N points, there are N*(N-1)/2 distances)
    # For each r value, count the number of pairs with distance < r
    C = np.array([np.sum(distances < r) / N_pairs for r in r_vals])

    # Compute logarithms
    log_r = np.log10(r_vals)
    log_C = np.log10(C)

    # Select a scaling region for the linear fit
    # Here, if not provided, we choose the middle 50% of the r-range.
    if scaling_range_idx is None:
        scaling_range = slice(num_r // 4, 3 * num_r // 4)
    else:
        scaling_range = scaling_range_idx

    # Fit a line (using least squares) over the selected scaling region
    try:
        slope, intercept = np.polyfit(log_r[scaling_range], log_C[scaling_range], 1)
        D2 = slope  # The slope corresponds to the correlation dimension D2
    except Exception as e:
        print(f"Error fitting line: {e}")
        D2 = None
        slope = None
        intercept = None

    print("Estimated correlation (GP) dimension D2: {:.3f}".format(D2))

    # Plot the fitted line over the scaling region
    plt.figure(figsize=(5, 5))
    plt.plot(
        log_r,
        log_C,
        "o-",
        markersize=3,
        label=r"$\log C(r)$ vs $\log r$",
        color="tab:blue",
    )
    if D2 is not None:
        plt.plot(
            log_r[scaling_range],
            slope * log_r[scaling_range] + intercept,
            "r",
            linewidth=2,
            label=rf"Fit: slope $D_2 \approx {slope:.3f}$",
            color="tab:red",
        )
    plt.xlabel(r"$\log(r)$", fontsize=14)
    plt.ylabel(r"$\log C(r)$", fontsize=14)
    plt.title("Correlation Integral (Grassberger-Procaccia)", fontsize=16)
    plt.legend()
    plt.grid(True)
    plt.show()

    return D2

In [None]:
# Compute the GP dimension and plot the results.
compute_gp_dimension(dyst_coords.T)

In [None]:
# compute_gp_dimension(test_system_periodic.T)

In [None]:
# compute_gp_dimension(test_system_noise.T)

In [None]:
# compute_gp_dimension(test_system_fourier.T)

### Time Delay with MI

In [None]:
def mutual_information(x: np.ndarray, y: np.ndarray, bins: int = 64) -> float:
    """
    Compute the mutual information between two 1D arrays x and y
        - Uses a 2D histogram with the given number of bins.
        - sum_{i,j} p_xy[i,j] * log(p_xy[i,j] / (p_x[i]*p_y[j]))
    """
    pxy, _, _ = np.histogram2d(x, y, bins=bins, density=True)
    px = np.sum(pxy, axis=1)
    py = np.sum(pxy, axis=0)

    mi = 0.0
    for i in range(pxy.shape[0]):
        for j in range(pxy.shape[1]):
            if pxy[i, j] > 0:
                mi += pxy[i, j] * (np.log(pxy[i, j]) - np.log(px[i]) - np.log(py[j]))
    return mi

In [None]:
def optimal_delay(
    x: np.ndarray,
    max_delay: int = 50,
    bins: int = 64,
    conv_window_size: int = 3,
    first_k_minima_to_consider: int = 3,
    plot: bool = False,
) -> int:
    """
    Computes the mutual information I(tau) = I( x(t), x(t+tau) ) for tau in {1, 2, ..., max_delay}
    Returns the time lag tau corresponding to the first prominent local minimum.

    Parameters:
        x: 1D array of shape (T,)
        max_delay: maximum time lag to consider
        bins: number of bins for the histogram
        conv_window_size: size of the convolution window for smoothing the MI curve
        first_k_minima_to_consider: number of minima to consider for determining first prominent minimum
    """
    mi_values = []
    assert x.ndim == 1, "x must be a 1D array"
    T = len(x)
    for tau in range(1, max_delay + 1):
        # Use only overlapping segments
        mi_tau = mutual_information(x[: T - tau], x[tau:], bins=bins)
        mi_values.append(mi_tau)
    mi_values = np.array(mi_values)

    # Find a prominent local minimum
    # 1. smooth the MI curve to reduce noise
    smoothed_mi = np.convolve(
        mi_values, np.ones(conv_window_size) / conv_window_size, mode="valid"
    )

    # 2. Calculate the prominence of each minimum
    minima_indices = []
    prominences = []

    # 3. Find all local minima in the smoothed curve
    for i in range(1, len(smoothed_mi) - 1):
        if smoothed_mi[i] < smoothed_mi[i - 1] and smoothed_mi[i] < smoothed_mi[i + 1]:
            minima_indices.append(i)

            # Calculate prominence (height difference to nearby values on smoothed MI curve)
            left_max = np.max(smoothed_mi[: i + 1])
            right_max = np.max(smoothed_mi[i:])
            lower_max = min(left_max, right_max)
            prominence = lower_max - smoothed_mi[i]
            prominences.append(prominence)
            if len(prominences) >= first_k_minima_to_consider:
                break

    # print(prominences)

    # If no minima found, return the global minimum
    if len(minima_indices) == 0:
        first_min = np.argmin(mi_values) + 1
    else:
        # Find the most prominent minimum among the first first_k_minima_to_consider
        num_to_consider = min(first_k_minima_to_consider, len(minima_indices))
        best_idx = np.argmax(prominences[:num_to_consider])
        # Adjust index to account for smoothing window and 1-based tau
        first_min = minima_indices[best_idx] + (conv_window_size // 2) + 1

    first_min = int(first_min)

    if plot:
        taus = np.arange(1, max_delay + 1)
        plt.figure(figsize=(6, 4))
        plt.plot(taus, mi_values, marker="o", alpha=0.6, label="Original MI")

        # Plot smoothed curve
        smoothed_taus = (
            taus[(conv_window_size // 2) : -(conv_window_size // 2)]
            if conv_window_size > 1
            else taus
        )
        if len(smoothed_taus) == len(smoothed_mi):
            plt.plot(smoothed_taus, smoothed_mi, "r-", linewidth=2, label="Smoothed MI")

        # Mark all detected minima
        for i, idx in enumerate(minima_indices):
            adjusted_idx = idx + (conv_window_size // 2) + 1
            if i < num_to_consider:
                plt.plot(
                    adjusted_idx,
                    mi_values[adjusted_idx - 1],
                    "go",
                    markersize=8,
                    # label=f"Minimum τ={adjusted_idx}" if i == 0 else None,
                )
            else:
                plt.plot(adjusted_idx, mi_values[adjusted_idx - 1], "yo", markersize=6)

        plt.axvline(
            first_min,
            color="red",
            linestyle="--",
            label=f"Selected min at τ={first_min}",
        )
        plt.xlabel("Delay τ")
        plt.ylabel("Mutual Information")
        plt.title("Mutual Information vs. Delay")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.show()

    return first_min


def optimal_sampling_interval(
    trajectory, max_delay=100, bins=64, observable="x", k=1, plot=False
):
    """
    Given a trajectory of shape (dim, T), selects an observable and computes the optimal
    sampling interval (delay) using the first minimum of mutual information.

    Parameters:
      trajectory: NumPy array of shape (dim, T)
      max_delay: maximum delay (in timesteps) to consider
      bins: number of bins for histogram estimation
      observable: which observable to use; options: 'x', 'y', 'z', or 'norm'
      k: number of minima to consider for determining first prominent minimum
      plot: if True, plot mutual information vs. delay

    Returns:
      tau_opt: the optimal delay (number of timesteps) as determined by the first minimum.
    """
    if observable == "x":
        x = trajectory[0, :]
    elif observable == "y":
        if trajectory.shape[0] < 2:
            raise ValueError("Trajectory does not have a second coordinate for 'y'.")
        x = trajectory[1, :]
    elif observable == "z":
        if trajectory.shape[0] < 3:
            raise ValueError("Trajectory does not have a third coordinate for 'z'.")
        x = trajectory[2, :]
    elif observable == "norm":
        x = np.linalg.norm(trajectory, axis=0)
    else:
        raise ValueError("Unknown observable. Choose 'x', 'y', 'z', or 'norm'.")

    tau_opt = optimal_delay(
        x, max_delay=max_delay, bins=bins, first_k_minima_to_consider=k, plot=plot
    )
    return tau_opt

In [None]:
# Determine the optimal delay using the first coordinate
tau_opt = optimal_sampling_interval(
    dyst_coords, max_delay=100, bins=64, observable="x", k=1, plot=True
)
print("Optimal sampling interval (delay τ):", tau_opt)

### Zero-One Test

In [None]:
def compute_translation_variables(phi, c):
    """
    Given a scalar observable phi and constant c,
    compute the translation variables p(n) and q(n)
    as cumulative sums.
    """
    T = len(phi)
    n = np.arange(1, T + 1)
    # Compute p(n) and q(n)
    p = np.cumsum(phi * np.cos(c * n))
    q = np.cumsum(phi * np.sin(c * n))
    return p, q


def compute_mean_square_displacement(p, q, max_shift_ratio=0.1):
    """
    Computes the mean square displacement (MSD) over a range of shifts.

    Parameters:
      p, q: translation variables (1D arrays)
      max_shift_ratio: maximum fraction of the length to use for shifts.

    Returns:
      n_vals: array of shift indices
      MSD: array of mean square displacements corresponding to n_vals.
    """
    T = len(p)
    max_shift = int(max_shift_ratio * T)
    n_vals = np.arange(1, max_shift + 1)
    MSD = np.empty_like(n_vals, dtype=float)

    # For each time shift n, compute the mean squared difference
    for idx, n in enumerate(n_vals):
        diff_p = p[n:] - p[:-n]
        diff_q = q[n:] - q[:-n]
        MSD[idx] = np.mean(diff_p**2 + diff_q**2)
    return n_vals, MSD


def compute_K_statistic(n_vals, MSD):
    """
    Computes the correlation coefficient between the shift indices and the MSD.
    A value near 1 indicates linear growth (chaos), while near 0 indicates bounded behavior.
    """
    # np.corrcoef returns a 2x2 correlation matrix.
    corr_matrix = np.corrcoef(n_vals, MSD)
    K = corr_matrix[0, 1]
    return K


def zero_one_test(phi, c=None, threshold=0.5, plot=False):
    """
    Performs the 0–1 test for chaos on a scalar observable.

    Parameters:
      phi: 1D NumPy array representing the observable (length T).
      c: constant in (0,pi); if None, a random value in (pi/5, 4*pi/5) is chosen to avoid resonances.
      threshold: threshold on |K| to decide if the system is chaotic.
      plot: if True, plots the (p, q) trajectory and MSD vs. shift index.

    Returns:
      K: the computed correlation coefficient.
      is_chaotic: boolean, True if |K| > threshold.
    """
    if c is None:
        # Choosing c in (pi/5, 4*pi/5) can help avoid resonances.
        c = np.random.uniform(np.pi / 10, 1 * np.pi / 5)

    p, q = compute_translation_variables(phi, c)
    n_vals, MSD = compute_mean_square_displacement(p, q)
    K = compute_K_statistic(n_vals, MSD)

    if plot:
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        axs[0].plot(p, q, lw=1)
        axs[0].set_title("Translation Variables (p, q)")
        axs[0].set_xlabel("p")
        axs[0].set_ylabel("q")
        axs[1].plot(n_vals, MSD, "o-", lw=1)
        axs[1].set_title("Mean Square Displacement (MSD)")
        axs[1].set_xlabel("n (shift)")
        axs[1].set_ylabel("MSD")
        plt.tight_layout()
        plt.show()

    is_chaotic = np.abs(K) > threshold
    return K, is_chaotic


def test_0_1_for_chaos(traj, observable="x", c=None, threshold=0.5, plot=False):
    """
    Applies the 0–1 test for chaos to a multidimensional trajectory.

    Parameters:
      traj: NumPy array of shape (dim, T).
      observable: which observable to use; options:
                  'x', 'y', 'z' (for individual coordinates) or 'norm' (Euclidean norm of the state vector).
      threshold: threshold on the K statistic to decide if chaotic.
      plot: if True, produces diagnostic plots.

    Returns:
      K: computed K-statistic.
      is_chaotic: boolean, True if the test indicates chaos.
    """
    dim, T = traj.shape

    if observable == "norm":
        # Use the Euclidean norm of the state as the observable.
        phi = np.linalg.norm(traj, axis=0)
    elif observable == "x":
        phi = traj[0, :]
    elif observable == "y":
        if dim < 2:
            raise ValueError("Trajectory does not have a y dimension.")
        phi = traj[1, :]
    elif observable == "z":
        if dim < 3:
            raise ValueError("Trajectory does not have a z dimension.")
        phi = traj[2, :]
    else:
        raise ValueError("Invalid observable. Choose 'x', 'y', 'z', or 'norm'.")

    K, is_chaotic = zero_one_test(phi, c=c, threshold=threshold, plot=plot)
    return K, is_chaotic

In [None]:
# # Test using the Euclidean norm as observable.
# K, chaotic = test_0_1_for_chaos(
#     dyst_coords[:, ::tau_opt], observable="x", c=default_c_val, threshold=0.5, plot=True
# )
# print("K-statistic:", K)
# print("Chaotic:", chaotic)

In [None]:
# Test using the Euclidean norm as observable.
def run_zero_one_sweep(
    traj,
    c_vals: np.ndarray,
    observable="x",
    subsample_interval: int = 1,
    threshold: float = 0.5,
) -> tuple[float, np.ndarray]:
    K_vals = []
    if subsample_interval > 1:
        traj = traj[:, ::subsample_interval]
    for c_val in c_vals:
        K, _ = test_0_1_for_chaos(
            traj,
            observable=observable,
            c=c_val,
            threshold=threshold,
            plot=False,
        )
        K_vals.append(K)

    K_vals = np.array(K_vals)
    chaos_score = float(np.sum(K_vals > threshold) / len(K_vals))
    return chaos_score, K_vals

In [None]:
c_vals = np.random.uniform(np.pi / 5, 4 * np.pi / 5, 100)

In [None]:
transient_prop = 0.5
len_transient = int(transient_prop * dyst_coords.shape[1])
print(f"Transient length: {len_transient}")
chaos_score, K_vals = run_zero_one_sweep(
    dyst_coords[:, len_transient:],
    c_vals=c_vals,
    observable="z",
    subsample_interval=tau_opt,
    threshold=0.25,
)
print(f"Dyst Test chaos score: {chaos_score}")
print("median K", np.median(K_vals))
print("mean K", np.mean(K_vals))

In [None]:
plt.figure(figsize=(5, 5))
plt.scatter(c_vals, K_vals)
plt.xlabel("c")
plt.ylabel("K")
plt.title("K-statistic vs c")
plt.tight_layout()
plt.show()

In [None]:
# # Test using the Euclidean norm as observable.
# K, chaotic = test_0_1_for_chaos(
#     test_system_fourier[:, ::tau_opt],
#     observable="x",
#     c=default_c_val,
#     threshold=0.5,
#     plot=True,
# )
# print("K-statistic:", K)
# print("Chaotic:", chaotic)

In [None]:
# chaos_score, K_vals = run_zero_one_sweep(
#     test_system_fourier,
#     c_vals=c_vals,
#     observable="x",
#     subsample_interval=tau_opt,
#     threshold=0.5,
# )
# print(f"Fourier Test chaos score: {chaos_score}")
# print("median K", np.median(K_vals))
# print("mean K", np.mean(K_vals))

In [None]:
# # Test using the Euclidean norm as observable.
# K, chaotic = test_0_1_for_chaos(
#     test_system_periodic[:, ::tau_opt],
#     observable="x",
#     c=default_c_val,
#     threshold=0.5,
#     plot=True,
# )
# print("K-statistic:", K)
# print("Chaotic:", chaotic)

In [None]:
# chaos_score, K_vals = run_zero_one_sweep(
#     test_system_periodic,
#     c_vals=c_vals,
#     observable="x",
#     subsample_interval=tau_opt,
#     threshold=0.5,
# )
# print(f"Periodic Test chaos score: {chaos_score}")
# print("median K", np.median(K_vals))
# print("mean K", np.mean(K_vals))

In [None]:
# # Test using the Euclidean norm as observable.
# K, chaotic = test_0_1_for_chaos(
#     test_system_noise[:, ::tau_opt],
#     observable="x",
#     c=default_c_val,
#     threshold=0.5,
#     plot=True,
# )
# print("K-statistic:", K)
# print("Chaotic:", chaotic)

In [None]:
# chaos_score, K_vals = run_zero_one_sweep(
#     test_system_noise,
#     c_vals=c_vals,
#     observable="x",
#     subsample_interval=tau_opt,
#     threshold=0.5,
# )
# print(f"Noise Test chaos score: {chaos_score}")
# print("median K", np.median(K_vals))
# print("mean K", np.mean(K_vals))

In [None]:
def z1test(x, c_min=np.pi / 5, c_max=4 * np.pi / 5):
    """
    Gottwald-Melbourne 0-1 test for chaos.

    Parameters:
        x (array-like): Input time series data.

    Returns:
        kmedian (float): Test statistic near 0 for non-chaotic data and near 1 for chaotic data.

    Notes:
        - The function generates 100 random c-values in the interval [pi/5, 4pi/5].
        - It computes cumulative sums p and q and then calculates a mean-square displacement M.
        - Finally, the Pearson correlation between t and M is computed for each c-value,
          and the median of these correlations is returned.

    Translated from MATLAB code provided by Paul Matthews https://www.mathworks.com/matlabcentral/fileexchange/25050-0-1-test-for-chaos
        based on the method proposed by Gottwald and Melbourne
    """
    # Ensure x is a 1D numpy array
    x = np.asarray(x).flatten()
    N = len(x)
    j = np.arange(1, N + 1)  # equivalent to MATLAB's [1:N]

    # t runs from 1 to round(N/10)
    t_max = int(round(N / 10))
    t = np.arange(1, t_max + 1)
    M = np.zeros(t_max)

    # 100 random c values in [pi/5, 4pi/5]
    c = c_min + np.random.rand(100) * (c_max - c_min)
    kcorr = np.zeros(100)

    # Loop over each random c value
    for its in range(100):
        c_val = c[its]
        # Compute cumulative sums p and q
        p = np.cumsum(x * np.cos(j * c_val))
        q = np.cumsum(x * np.sin(j * c_val))

        # Compute M(n) for n = 1, ..., round(N/10)
        for n in range(1, t_max + 1):
            # p[n:] corresponds to MATLAB p(n+1:N) and p[:-n] to p(1:N-n)
            diff_p = p[n:] - p[:-n]
            diff_q = q[n:] - q[:-n]
            term = np.mean(diff_p**2 + diff_q**2)
            correction = (
                (np.mean(x) ** 2) * (1 - np.cos(n * c_val)) / (1 - np.cos(c_val))
            )
            M[n - 1] = term - correction

        # Compute the Pearson correlation coefficient between t and M
        corr_matrix = np.corrcoef(t, M)
        kcorr[its] = corr_matrix[0, 1]

    # Two crude attempts to check for oversampling:
    condition1 = (np.max(x) - np.min(x)) / np.mean(np.abs(np.diff(x))) > 10
    median_lower = np.median(kcorr[c < np.mean(c)])
    median_upper = np.median(kcorr[c > np.mean(c)])
    condition2 = (median_lower - median_upper) > 0.5
    if condition1 or condition2:
        print("Warning: data is probably oversampled.")
        print("Use coarser sampling or reduce the maximum value of c.")

    return np.median(kcorr)

In [None]:
# z1test(dyst_coords[0, ::tau_opt], c_min=np.pi / 5, c_max=4 * np.pi / 5)

In [None]:
# z1test(test_system_fourier[0, ::tau_opt], c_min=np.pi / 5, c_max=4 * np.pi / 5)

In [None]:
# z1test(test_system_periodic[0, ::tau_opt], c_min=np.pi / 5, c_max=4 * np.pi / 5)

In [None]:
# z1test(test_system_noise[0, ::tau_opt], c_min=np.pi / 5, c_max=4 * np.pi / 5)

### Additional Checks

#### Limit Cycle Test

In [None]:
from scipy.spatial.distance import cdist

In [None]:
def check_not_limit_cycle(
    traj: np.ndarray,
    tolerance: float = 1e-3,
    min_prop_recurrences: float = 0.0,
    min_counts_per_rtime: int = 100,
    min_block_length: int = 1,
    min_recurrence_time: int = 1,
    enforce_endpoint_recurrence: bool = False,
    return_computed_quantities: bool = False,
) -> bool | tuple[bool, dict]:
    """
    limit cycle test from attractor.py, exposed here for plotting purposes
    Returns: True if the trajectory is not a limit cycle, False otherwise
        If False and also return_computed_quantities is True, returns a tuple (False, computed_quantities)
    """
    n = traj.shape[1]

    # Step 1: Calculate the pairwise distance matrix, shape should be (N, N)
    dist_matrix = cdist(traj.T, traj.T, metric="euclidean").astype(np.float16)
    dist_matrix = np.triu(dist_matrix, k=1)

    # Step 2: Get recurrence times from thresholding distance matrix
    recurrence_indices = np.asarray(
        (dist_matrix < tolerance) & (dist_matrix > 0)
    ).nonzero()

    n_recurrences = len(recurrence_indices[0])
    if n_recurrences == 0:
        return True

    if enforce_endpoint_recurrence:
        # check if an eps neighborhood around either n-1 or 0 is in either of the recurrence indices
        eps = 0
        if not any(
            (n - 1) - max(indices) <= eps or min(indices) - 0 <= eps
            for indices in recurrence_indices
        ):
            return True

    # get recurrence times
    recurrence_times = np.abs(recurrence_indices[0] - recurrence_indices[1])
    recurrence_times = recurrence_times[recurrence_times >= min_recurrence_time]

    # Heuristic 1: Check if there are enough recurrences to consider a limit cycle
    n_recurrences = len(recurrence_times)
    if n_recurrences < int(min_prop_recurrences * n):
        return True

    # Heuristic 2: Check if there are enough valid recurrence times
    rtimes_counts = Counter(recurrence_times)
    n_valid_rtimes = sum(
        1 for count in rtimes_counts.values() if count >= min_counts_per_rtime
    )
    if n_valid_rtimes < 1:
        return True

    # Heuristic 3: Check if the valid recurrence times are formed of blocks of consecutive timepoints
    if min_block_length > 1:
        rtimes_dict = defaultdict(list)
        block_length = 1
        prev_rtime = None
        prev_t1 = None
        prev_t2 = None
        rtimes_is_valid = False
        num_blocks = 0
        # assuming recurrence_indices[0] is sorted
        for t1, t2 in zip(*recurrence_indices):
            rtime = abs(t2 - t1)
            if rtime < min_recurrence_time:
                continue
            if (
                rtime == prev_rtime
                and abs(t1 - prev_t1) == 1
                and abs(t2 - prev_t2) == 1
            ):
                block_length += 1
            else:
                if block_length > min_block_length:
                    rtimes_dict[prev_rtime].append(block_length)
                    num_blocks += 1
                block_length = 1
            prev_t1, prev_t2, prev_rtime = t1, t2, rtime
            if block_length > min_block_length * 2:
                rtimes_is_valid = True
                break
            if num_blocks >= 2:  # if valid, save computation and break
                rtimes_is_valid = True
                break
        if not rtimes_is_valid:
            return True

    computed_quantities = {
        "dist_matrix": dist_matrix,
        "recurrence_indices": recurrence_indices,
        "recurrence_times": recurrence_times,
    }
    if return_computed_quantities:
        return False, computed_quantities
    return False

In [None]:
# is_not_limit_cycle_result = check_not_limit_cycle(
#     dyst_coords,
#     tolerance=1e-3,
#     min_prop_recurrences=0.1,
#     min_counts_per_rtime=200,
#     min_block_length=50,
#     enforce_endpoint_recurrence=True,
#     return_computed_quantities=True,
# )
# print(is_not_limit_cycle_result)

In [None]:
def plot_recurrence_times(
    traj: np.ndarray,
    dist_matrix: np.ndarray,
    recurrence_times: np.ndarray,
    recurrence_indices: np.ndarray,
):
    dyst_name = system_name.split("_")[0]
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 18))

    ax1.hist(recurrence_times, bins=100, edgecolor="black")
    ax1.set_xlabel("Recurrence Time")
    ax1.set_ylabel("Frequency")
    ax1.set_title("Recurrence Times")
    ax1.grid(True)

    xyz = traj[:3, :]
    xyz1 = xyz[:, : int(n / 2)]
    xyz2 = xyz[:, int(n / 2) :]
    ic_point = traj[:3, 0]
    final_point = traj[:3, -1]
    ax2 = fig.add_subplot(312, projection="3d")
    ax2.plot(*xyz1, alpha=0.5, linewidth=1, color="tab:blue")
    ax2.plot(*xyz2, alpha=0.5, linewidth=1, color="tab:orange")
    ax2.scatter(*ic_point, marker="*", s=100, alpha=0.5, color="tab:blue")
    ax2.scatter(*final_point, marker="x", s=100, alpha=0.5, color="tab:orange")
    ax2.set_xlabel("X")
    ax2.set_ylabel("Y")
    ax2.set_zlabel("Z")  # type: ignore
    ax2.set_title(dyst_name)

    ax3 = fig.add_subplot(313)
    X, Y = np.meshgrid(np.arange(dist_matrix.shape[0]), np.arange(dist_matrix.shape[1]))
    pcolormesh = ax3.pcolormesh(
        X,
        Y,
        dist_matrix,
        cmap="viridis_r",
        shading="auto",
        norm=colors.LogNorm(),
    )
    plt.colorbar(pcolormesh, ax=ax3)
    ax3.scatter(
        recurrence_indices[0],
        recurrence_indices[1],
        color="black",
        s=20,
        alpha=0.5,
    )
    ax3.set_title("Recurrence Distance Matrix")
    ax3.set_xlabel("Time")
    ax3.set_ylabel("Time")
    ax3.set_aspect("equal")
    plt.tight_layout()
    plt.show()