# Finkelstein Fontolan RNN Loading & Dynamics

This tutorial demonstrates how to load the pretrained recurrent neural network (RNN) from `configs/finkelstein_fontolan.py`. And how to pick the training distribution for the Koopman eigenfunction... which defines the domain of interest.


In [None]:
# Install the package when running in Colab
# !pip install torchdiffeq
# !pip install compose
# !pip install --no-deps git+https://github.com/KabirDabholkar/separatrix_locator.git

In [None]:
from pathlib import Path
import importlib
from typing import Optional

import numpy as np
import torch
import scipy.io as sio

from separatrix_locator.utils.finkelstein_fontolan_RNN import init_network, extract_opposite_attractors_from_model
from separatrix_locator.utils.finkelstein_fontolan_task import initialize_task
from separatrix_locator.dynamics.rnn import get_autonomous_dynamics_from_model, discrete_to_continuous
from separatrix_locator.utils.odeint_utils import run_odeint_to_final


In [None]:
def load_params_dict(params_mat_path: Path) -> dict:
    params_mat = sio.loadmat(str(params_mat_path))
    params_struct = params_mat["params"][0, 0]

    N = int(params_struct["N"][0, 0])
    M_full = np.asarray(params_struct["M"], dtype=np.float32)
    h_full = np.asarray(params_struct["h"].flatten(), dtype=np.float32)
    ramp_train = np.asarray(params_struct["ramp_train"], dtype=np.float32)

    params = {
        "N": N,
        "dt": float(params_struct["dt"][0, 0]),
        "tau": float(params_struct["tau"][0, 0]),
        "f0": float(params_struct["f0"][0, 0]),
        "beta0": float(params_struct["beta0"][0, 0]),
        "theta0": float(params_struct["theta0"][0, 0]),
        "M": M_full,
        "h": h_full,
        "eff_dt": float(params_struct["eff_dt"][0, 0]),
        "sigma_noise_cd": 100.0 / N,
        "des_out_left": np.asarray(params_struct["des_out_left"], dtype=np.float32),
        "des_out_right": np.asarray(params_struct["des_out_right"], dtype=np.float32),
        "ramp_train": ramp_train,
    }
    return params

package = importlib.import_module("separatrix_locator")
PROJECT_ROOT = Path(package.__file__).resolve().parents[2]
PARAMS_ROOT = PROJECT_ROOT / "rnn_params" / "finkelstein_fontolan"
INPUT_DATA_DIR = PARAMS_ROOT / "input_data"

params_dict = load_params_dict(INPUT_DATA_DIR / "params_data_wramp.mat")
# Helper utilities assume CPU tensors; keep everything on CPU for compatibility.
device = torch.device("cpu")
model = init_network(params_dict, device=device)
dim = int(params_dict["N"])

print(f"Loaded Finkelstein Fontolan RNN with hidden dimension {dim} on {device}.")


In [None]:
dataset = initialize_task(str(INPUT_DATA_DIR) + "/", N_trials_cd=10)

discrete_dynamics = get_autonomous_dynamics_from_model(
    model,
    device=device,
    rnn_submodule_name=None,
    kwargs={"deterministic": True, "batch_first": False},
    output_id=1,
)
continuous_dynamics = discrete_to_continuous(discrete_dynamics, delta_t=1.0)

speed_factor = 60.0 ### make it so that bistability is observed at O(1) time.

def dynamics_function(x: torch.Tensor, external_input: Optional[torch.Tensor] = None) -> torch.Tensor:
    if external_input is None:
        external_input = torch.zeros(x.shape[-2] if x.dim() > 1 else 1, 3, device=x.device)
    return continuous_dynamics(x, external_input) * speed_factor


input_range = (0.9, 0.92)
with torch.no_grad():
    attractors = extract_opposite_attractors_from_model(model, dataset, input_range=input_range)
    static_external_input = torch.tensor([0.0, 0.0, input_range[0]], dtype=torch.float32, device=device)
    trajectory = run_odeint_to_final(
        lambda state, u: dynamics_function(state, u),
        torch.tensor(attractors, dtype=torch.float32, device=device),
        T=30,
        inputs=static_external_input,
        steps=10,
        return_last_only=False,
        no_grad=True,
    )
    attractors = trajectory[-1]
print(f"Integrated {trajectory.shape[0]} time steps for {trajectory.shape[1]} trajectories.")


In [None]:
from separatrix_locator.core.separatrix_point import find_separatrix_point_along_line
from separatrix_locator.distributions.gaussian import MultivariateGaussian, MultivariateGaussianList

with torch.no_grad():
    point_on_separatrix = find_separatrix_point_along_line(
        dynamics_function,
        static_external_input,
        (torch.as_tensor(attractors[0]), torch.as_tensor(attractors[1])),
        num_points=10,
        num_iterations=4,
        final_time=10,
    ).cpu()

attractor_a = torch.as_tensor(attractors[0]).cpu()
attractor_b = torch.as_tensor(attractors[1]).cpu()
attractor_vector = attractor_b - attractor_a
attractor_distance = torch.linalg.norm(attractor_vector).item()
unit_vector = attractor_vector / (torch.linalg.norm(attractor_vector) + 1e-8)
identity = torch.eye(dim)

convergence_threshold = 0.05 * attractor_distance

print(f"Point on separatrix shape: {tuple(point_on_separatrix.shape)}")
print(f"Attractor distance: {attractor_distance:.3f}")
print(f"Convergence threshold: {convergence_threshold:.3f}")



In [None]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple

from sklearn.decomposition import PCA

rng = torch.Generator(device="cpu").manual_seed(0)

def integrate_samples(samples: torch.Tensor, return_full: bool = False) -> torch.Tensor:
    result = run_odeint_to_final(
        lambda x, u: dynamics_function(x, u),
        samples.to(device),
        T=duration,
        inputs=static_external_input,
        steps=50,
        return_last_only=not return_full,
        no_grad=True,
    )
    return result.cpu()


def min_distance_over_time(states: torch.Tensor) -> torch.Tensor:
    dist_a = torch.linalg.norm(states - attractor_a, dim=-1)
    dist_b = torch.linalg.norm(states - attractor_b, dim=-1)
    return torch.minimum(dist_a, dist_b)


def plot_pca_trajectories(trajectories: torch.Tensor, title: str, max_trajs: int = 10) -> None:
    time_steps, batch_size, _ = trajectories.shape
    limit = min(batch_size, max_trajs)
    traj_subset = trajectories[:, :limit].reshape(time_steps * limit, -1).numpy()
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(traj_subset)
    reduced = reduced.reshape(time_steps, limit, 2)

    plt.figure(figsize=(6, 4))
    for i in range(limit):
        plt.plot(reduced[:, i, 0], reduced[:, i, 1], marker="o", markersize=3, alpha=0.7)
        plt.scatter(reduced[0, i, 0], reduced[0, i, 1], c="g", marker="*", s=80)
        plt.scatter(reduced[-1, i, 0], reduced[-1, i, 1], c="r", marker="X", s=80)
    plt.title(title)
    plt.xlabel("PC 1")
    plt.ylabel("PC 2")
    plt.tight_layout()
    plt.show()


def classify_convergence(final_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    dist_a = torch.linalg.norm(final_states - attractor_a, dim=1)
    dist_b = torch.linalg.norm(final_states - attractor_b, dim=1)
    min_dist = torch.minimum(dist_a, dist_b)
    converged_mask = min_dist < convergence_threshold
    return converged_mask, min_dist


def summarize_convergence(mask: torch.Tensor) -> Dict[str, int]:
    total = int(mask.shape[0])
    converged = int(mask.sum().item())
    return {
        "converged": converged,
        "spurious": total - converged,
        "total": total,
    }



# Step 1: isotropic Gaussians of increasing scale
This is naive choice. Let's see what happens:

In [None]:
num_samples = 50
duration = 30
isotropic_scales = torch.tensor([2.5e-2, 5e-2, 1e-1, 5e-1, 1.0], dtype=torch.float32)
base_sigma = attractor_distance
isotropic_results: List[Dict[str, object]] = []

for scale in isotropic_scales:
    covariance = (scale * base_sigma) ** 2 * identity
    proposal = torch.distributions.MultivariateNormal(point_on_separatrix, covariance_matrix=covariance)
    samples = proposal.sample((num_samples,))
    trajectories = integrate_samples(samples, return_full=True)
    final_states = trajectories[-1]
    min_distances_time = min_distance_over_time(trajectories)
    converged_mask, min_distances_final = classify_convergence(final_states)
    summary = summarize_convergence(converged_mask)
    summary.update({
        "scale": float(scale),
        "min_distance_time": min_distances_time,
        "min_distance_final": min_distances_final,
        "trajectories": trajectories,
    })
    isotropic_results.append(summary)

In [None]:
rows = len(isotropic_results)
time_axis = torch.linspace(0.0, duration, isotropic_results[0]["min_distance_time"].shape[0]).numpy()
fig, axes = plt.subplots(rows, 2, figsize=(6, 1.8 * rows))
if rows == 1:
    axes = np.expand_dims(axes, axis=0)

for row, result in enumerate(isotropic_results):
    min_distances_time = result["min_distance_time"].numpy()
    trajectories = result["trajectories"]

    # Calculate convergence using only the final time point for each trajectory
    final_min_distances = min_distances_time[-1, :]  # last time point for each trajectory
    converged_mask = final_min_distances < convergence_threshold

    # Assign colours: blue for converged, orange for not converged
    line_colours = ["C0" if conv else "C1" for conv in converged_mask]

    # Plot each trajectory's min_distances_time as a separate line, colored by convergence
    for idx in range(min_distances_time.shape[1]):
        axes[row, 0].plot(time_axis, min_distances_time[:, idx], alpha=0.6, color=line_colours[idx])
    axes[row, 0].axhline(convergence_threshold, color="r", linestyle="--")
    axes[row, 0].annotate(f"σ={result['scale']:.2g}", xy=(0.05, 0.93), xycoords='axes fraction',
                          fontsize=10, ha='left', va='top', bbox=dict(boxstyle="round", fc="w", alpha=0.7))
    # Remove individual ylabel; we'll add one common label later
    # axes[row, 0].set_ylabel("min{dist to A,dist to B}")
    # axes[row, 0].set_title(f"σ={result['scale']:.2g} min distance")
    if row == rows - 1:
        axes[row, 0].set_xlabel("time")

    time_steps, batch_size, _ = trajectories.shape
    limit = min(batch_size, 10)
    traj_subset = trajectories[:, :limit].reshape(time_steps * limit, -1).numpy()
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(traj_subset).reshape(time_steps, limit, 2)
    # Colour PCA lines by same converged_mask
    for idx in range(limit):
        color = "C0" if converged_mask[idx] else "C1"
        axes[row, 1].plot(reduced[:, idx, 0], reduced[:, idx, 1], marker="o", markersize=3, alpha=0.6, color=color)
        axes[row, 1].scatter(reduced[0, idx, 0], reduced[0, idx, 1], c="g", marker="*", s=70)
        axes[row, 1].scatter(reduced[-1, idx, 0], reduced[-1, idx, 1], c="r", marker="X", s=70)
    axes[row, 1].annotate(f"σ={result['scale']:.2g}", xy=(0.05, 0.93), xycoords='axes fraction',
                          fontsize=10, ha='left', va='top', bbox=dict(boxstyle="round", fc="w", alpha=0.7))



axes[0, 0].set_title("min{dist to A, dist to B}", fontsize=12)
fig.tight_layout()
plt.show()


In [None]:
labels = [f"{result['scale']:.2g}" for result in isotropic_results]
converged_counts = [result["converged"] for result in isotropic_results]
spurious_counts = [result["spurious"] for result in isotropic_results]

x = np.arange(len(labels))
plt.figure(figsize=(8, 4))
plt.bar(x, converged_counts, label="converged to A or B")
plt.bar(x, spurious_counts, bottom=converged_counts, label="went elsewhere")
plt.xticks(x, labels)
plt.xlabel("Isotropic scale (relative σ)")
plt.ylabel("Number of trajectories")
plt.title("identifying the domain of interest")
plt.legend()
plt.tight_layout()
plt.show()

print("Isotropic experiment summary:")
for result in isotropic_results:
    print(
        f"scale={result['scale']:.2g}: converged={result['converged']}, spurious={result['spurious']} (total={result['total']})"
    )


States converge to other 'spurious' attractors beyond the two of interest.

# Step 2: Anisotropic gaussian, oblong along attractors, narrower along other axes.


Let's pick the distribution more carefully to include the two attractors but make it narrow along the other dimensions. We express this in the covariance:

 $$\Sigma = \sigma_{\perp}^2 I + (\sigma_{\parallel}^2 - \sigma_{\perp}^2)(\mathbf{u}\mathbf{u}^\top),$$
 
 where $\mathbf{u}$ is the unit vector along $(A - B)$ and $\sigma_{\parallel}$, $\sigma_{\perp}$ are the widths of the distribution along $\mathbf{u}$ and perpendicular to it respectively.

In [None]:
scale_factors = torch.tensor([5e-2, 1e-1, 5e-1, 1.0], dtype=torch.float32)

normalised_dist = ((attractor_a - attractor_b)**2).mean().sqrt()

anisotropic_sigma_a_multiplier = 3.0
anisotropic_sigma_b = 3.0
normalised_dist = ((attractor_a - attractor_b)**2).mean().sqrt()

sigma_parallel = anisotropic_sigma_a_multiplier * normalised_dist
sigma_perp = anisotropic_sigma_b
base_covariance = (sigma_perp ** 2) * identity
base_covariance += ((sigma_parallel ** 2) - (sigma_perp ** 2)) * torch.outer(unit_vector, unit_vector)

In [None]:
anisotropic_results: List[Dict[str, object]] = []
scaled_gaussians: List[MultivariateGaussian] = []

for scale in scale_factors:
    covariance = base_covariance * scale
    proposal = torch.distributions.MultivariateNormal(point_on_separatrix, covariance_matrix=covariance)
    samples = proposal.sample((num_samples,))
    trajectories = integrate_samples(samples, return_full=True)
    final_states = trajectories[-1]
    min_distances_time = min_distance_over_time(trajectories)
    converged_mask, min_distances_final = classify_convergence(final_states)
    summary = summarize_convergence(converged_mask)
    summary.update({
        "scale": float(scale),
        "min_distance_time": min_distances_time,
        "min_distance_final": min_distances_final,
        "trajectories": trajectories,
    })
    anisotropic_results.append(summary)

    stable_samples = final_states[converged_mask]
    if stable_samples.shape[0] < 2:
        stable_samples = final_states
    mean = stable_samples.mean(dim=0)
    centered = stable_samples - mean
    covariance_est = (centered.T @ centered) / max(stable_samples.shape[0] - 1, 1)
    covariance_est += 1e-6 * identity
    scaled_gaussians.append(MultivariateGaussian(dim=dim, mean=mean, covariance_matrix=covariance_est))

ic_distribution_fit = MultivariateGaussianList(
    scaled_gaussians,
    name="FF_FIM_forward_empirical",
    scales=scale_factors.tolist(),
)



In [None]:
rows = len(anisotropic_results)
time_axis = torch.linspace(0.0, duration, anisotropic_results[0]["min_distance_time"].shape[0]).numpy()
fig, axes = plt.subplots(rows, 2, figsize=(6, 1.8 * rows))
if rows == 1:
    axes = np.expand_dims(axes, axis=0)

for row, result in enumerate(anisotropic_results):
    min_distances_time = result["min_distance_time"].numpy()
    trajectories = result["trajectories"]

    # Calculate convergence using only the final time point for each trajectory
    final_min_distances = min_distances_time[-1, :]  # last time point for each trajectory
    converged_mask = final_min_distances < convergence_threshold

    # Assign colours: blue for converged, orange for not converged
    line_colours = ["C0" if conv else "C1" for conv in converged_mask]

    # Plot each trajectory's min_distances_time as a separate line, colored by convergence
    for idx in range(min_distances_time.shape[1]):
        axes[row, 0].plot(time_axis, min_distances_time[:, idx], alpha=0.6, color=line_colours[idx])
    axes[row, 0].axhline(convergence_threshold, color="r", linestyle="--")
    axes[row, 0].annotate(f"σ={result['scale']:.2g}", xy=(0.05, 0.93), xycoords='axes fraction',
                          fontsize=10, ha='left', va='top', bbox=dict(boxstyle="round", fc="w", alpha=0.7))
    if row == rows - 1:
        axes[row, 0].set_xlabel("time")

    time_steps, batch_size, _ = trajectories.shape
    limit = min(batch_size, 10)
    traj_subset = trajectories[:, :limit].reshape(time_steps * limit, -1).numpy()
    pca = PCA(n_components=2)
    reduced = pca.fit_transform(traj_subset).reshape(time_steps, limit, 2)
    # Colour PCA lines by same converged_mask
    for idx in range(limit):
        color = "C0" if converged_mask[idx] else "C1"
        axes[row, 1].plot(reduced[:, idx, 0], reduced[:, idx, 1], marker="o", markersize=3, alpha=0.6, color=color)
        axes[row, 1].scatter(reduced[0, idx, 0], reduced[0, idx, 1], c="g", marker="*", s=70)
        axes[row, 1].scatter(reduced[-1, idx, 0], reduced[-1, idx, 1], c="r", marker="X", s=70)
    axes[row, 1].annotate(f"σ={result['scale']:.2g}", xy=(0.05, 0.93), xycoords='axes fraction',
                          fontsize=10, ha='left', va='top', bbox=dict(boxstyle="round", fc="w", alpha=0.7))

axes[0, 0].set_title("min{dist to A, dist to B}", fontsize=12)
fig.tight_layout()
plt.show()


In [None]:
labels = [f"{result['scale']:.2g}" for result in anisotropic_results]
converged_counts = [result["converged"] for result in anisotropic_results]
spurious_counts = [result["spurious"] for result in anisotropic_results]

x = np.arange(len(labels))
plt.figure(figsize=(8, 4))
plt.bar(x, converged_counts, label="converged to A or B")
plt.bar(x, spurious_counts, bottom=converged_counts, label="went elsewhere")
plt.xticks(x, labels)
plt.xlabel("Anisotropic scale factor")
plt.ylabel("Number of trajectories")
plt.title("identifying the domain of interest (anisotropic design)")
plt.legend()
plt.tight_layout()
plt.show()

print(ic_distribution_fit)
print("Anisotropic experiment summary:")
for result in anisotropic_results:
    print(
        f"scale={result['scale']:.2g}: converged={result['converged']}, spurious={result['spurious']} (total={result['total']})"
    )


We managed to find distributions that are _mostly_ in the two-basin domain of interest.