# Tutorial 3: RNN Flip-Flop Dynamics

This notebook extends the earlier tutorials by replacing analytic dynamics with a recurrent neural network that has learned the 2D flip-flop system. We will load the pretrained GRU, convert its discrete hidden-state updates into a continuous vector field, and train Koopman eigenfunctions with the Separatrix Locator.

## Learning Objectives
- Load a pretrained GRU-based flip-flop dynamics model
- Convert discrete-time RNN updates into continuous vector fields
- Train and visualise Koopman eigenfunctions for RNN-driven dynamics
- Compare standard and squashed Koopman objectives on the same system



In [None]:
# Install the package when running in Colab
# !pip install torchdiffeq
# !pip install --no-deps git+https://github.com/KabirDabholkar/separatrix_locator.git

In [None]:
# Imports
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torchdiffeq import odeint

from separatrix_locator import SeparatrixLocator
from separatrix_locator.core.separatrix_point import find_separatrix_point_along_line
from separatrix_locator.distributions import MultivariateGaussian, multiscaler
from separatrix_locator.dynamics.rnn import (
    GRU_RNN,
    discrete_to_continuous,
    get_autonomous_dynamics_from_model,
)
from separatrix_locator.utils import get_estimate_attractor_func


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")



## 1. Load the pretrained flip-flop RNN

The GRU was trained on one-bit flip-flop input/output pairs. We now repurpose its hidden-state dynamics as an autonomous system by feeding zero inputs and treating the hidden state as the continuous system state.


In [None]:
dim = 2

rnn_params = {
    "ob_size": 1,
    "act_size": 1,
    "num_h": dim,
    "tau": 10.0,
    "speed_factor": 6.0,
}

checkpoint_path = Path("../rnn_params/1bitflipflop2D/RNNmodel.torch")

rnn_model = GRU_RNN(
    num_h=rnn_params["num_h"],
    ob_size=rnn_params["ob_size"],
    act_size=rnn_params["act_size"],
)

if checkpoint_path.exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        checkpoint = checkpoint["model_state_dict"]
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]
    rnn_model.load_state_dict(checkpoint)
else:
    print("Warning: checkpoint not found; using randomly initialised weights.")

rnn_model.to(device)
for param in rnn_model.parameters():
    param.requires_grad_(False)
# cuDNN requires training mode for backward through RNN even without parameter updates
rnn_model.train()

discrete_dynamics = get_autonomous_dynamics_from_model(rnn_model, device=device)
continuous_dynamics = discrete_to_continuous(discrete_dynamics, delta_t=1.0)


def dynamics_func(x: torch.Tensor) -> torch.Tensor:
    original_shape = x.shape
    input_device = x.device
    x_flat = x.reshape(-1, dim).to(device)
    dx_flat = continuous_dynamics(x_flat) * rnn_params["speed_factor"]
    return dx_flat.reshape(original_shape).to(input_device)


base_distribution = MultivariateGaussian(dim=dim)
distribution_scales = [0.01, 0.1, 0.5, 1.0, 2.0, 4.0]
distribution = multiscaler(base_distribution, distribution_scales)

print(f"Loaded GRU parameters from {checkpoint_path.relative_to(Path('..')) if checkpoint_path.exists() else 'N/A'}")


### Estimate attracting equilibria and a separatrix seed

We reuse the helper that integrates random initial conditions and clusters their endpoints to obtain candidate attractors. A point near the separatrix can then be found by bisection between the two clusters.


In [None]:
estimate_attractors = get_estimate_attractor_func(dynamics_func)

with torch.no_grad():
    attractor_a, attractor_b = estimate_attractors(dim, num_inits=256, T=40.0)

stable_points = torch.stack((attractor_a, attractor_b))
separatrix_guess = find_separatrix_point_along_line(
    dynamics_function=dynamics_func,
    external_input=None,
    attractors=(attractor_a, attractor_b),
    num_points=20,
    num_iterations=4,
    final_time=30.0,
).unsqueeze(0)

print("Stable equilibria:")
print(stable_points)
print("Separatrix seed:")
print(separatrix_guess)



## 2. Explore the learned vector field

We can evaluate the autonomous dynamics directly to inspect the vector field and sample trajectories from random initial conditions drawn from a multi-scale Gaussian prior.


In [None]:
lin = torch.linspace(-2.5, 2.5, 41)
X_vals = lin.numpy()
Y_vals = lin.numpy()
X, Y = np.meshgrid(X_vals, Y_vals)

grid_points = torch.tensor(np.stack([X.ravel(), Y.ravel()], axis=-1), dtype=torch.float32)
with torch.no_grad():
    field = dynamics_func(grid_points)
U = field[:, 0].reshape(X.shape).cpu().numpy()
V = field[:, 1].reshape(Y.shape).cpu().numpy()

samples = base_distribution.sample(sample_shape=(2000,))

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].streamplot(X, Y, U, V, color='tab:blue', linewidth=1.0, density=1.2)
axes[0].scatter(stable_points[:, 0], stable_points[:, 1], c='tab:green', s=40, label='Attractors')
axes[0].scatter(separatrix_guess[:, 0], separatrix_guess[:, 1], c='tab:red', s=60, label='Separatrix seed')
axes[0].set_aspect('equal')
axes[0].set_title('Vector field')
axes[0].set_xlabel('$h_1$')
axes[0].set_ylabel('$h_2$')
axes[0].legend()

axes[1].scatter(samples[:, 0].numpy(), samples[:, 1].numpy(), s=6, alpha=0.25)
axes[1].set_title('Base Gaussian samples')
axes[1].set_xlabel('$h_1$')
axes[1].set_ylabel('$h_2$')
axes[1].set_aspect('equal')

plt.tight_layout()
plt.show()



### Sample trajectories

Integrating a few random initial conditions illustrates how the learned dynamics settle into different attractors depending on which side of the separatrix they start.


In [None]:
time_points = torch.linspace(0.0, 30.0, 301, device=device)
initial_conditions = base_distribution.sample(sample_shape=(8,)).to(device)

trajectories = []
with torch.no_grad():
    for init in initial_conditions:
        traj = odeint(lambda t, x: dynamics_func(x), init, time_points)
        trajectories.append(traj.cpu())

trajectories = torch.stack(trajectories)  # (n_traj, time, dim)
colors = plt.cm.tab10(np.linspace(0, 1, trajectories.shape[0]))

plt.figure(figsize=(5, 4))
for idx, traj in enumerate(trajectories):
    plt.plot(traj[:, 0], traj[:, 1], color=colors[idx])
    plt.scatter(traj[0, 0], traj[0, 1], color=colors[idx], s=20)

plt.scatter(stable_points[:, 0], stable_points[:, 1], c='black', marker='x', s=60, label='Attractors')
plt.scatter(separatrix_guess[:, 0], separatrix_guess[:, 1], c='yellow', s=80, edgecolor='k', label='Separatrix seed')
plt.xlabel('$h_1$')
plt.ylabel('$h_2$')
plt.title('Sample trajectories')
plt.legend()
plt.gca().set_aspect('equal')
plt.tight_layout()
plt.show()



## 3. Train Koopman eigenfunctions

We now train neural Koopman eigenfunctions using the Separatrix Locator. First we optimise a standard linear eigenvalue objective, then repeat with the "squashed" cubic right-hand side to sharpen the classification boundary.


In [None]:
def make_model(hidden=256):
    return nn.Sequential(
        nn.Linear(dim, hidden),
        nn.Tanh(),
        nn.Linear(hidden, hidden),
        nn.Tanh(),
        nn.Linear(hidden, 1),
    )


standard_model = make_model().to(device)

standard_locator = SeparatrixLocator(
    models=[standard_model],
    dynamics_dim=dim,
    device=device,
    verbose=True,
    epochs=1500,
    lr=1e-3,
)



In [None]:
standard_locator.fit(
    func=dynamics_func,
    distribution=distribution,
    batch_size=2048,
    balance_loss_lambda=1e-2,
    eigenvalue=1.0,
    RHS_function="lambda phi: phi",
)



### Visualise the standard eigenfunction


In [None]:
grid_lin = torch.linspace(-2.5, 2.5, 161)
X_grid, Y_grid = torch.meshgrid(grid_lin, grid_lin, indexing="ij")
coords = torch.stack((X_grid, Y_grid), dim=-1).reshape(-1, dim).to(device)

with torch.no_grad():
    values_standard = standard_model(coords).cpu().numpy().reshape(X_grid.shape)

X_np = X_grid.cpu().numpy()
Y_np = Y_grid.cpu().numpy()

fig, ax = plt.subplots(figsize=(5.5, 4.5))
contourf = ax.contourf(X_np, Y_np, values_standard, levels=31, cmap='coolwarm')
plt.colorbar(contourf, ax=ax, label='Koopman eigenfunction')
ax.contour(
    X_np,
    Y_np,
    values_standard,
    levels=[0.0],
    colors=['yellow'],
    linewidths=2.5,
    linestyles='--',
)
ax.streamplot(
    X_vals,
    Y_vals,
    U,
    V,
    color='k',
    density=1.0,
    linewidth=0.7,
)
ax.scatter(stable_points[:, 0], stable_points[:, 1], c='tab:green', s=40, label='Attractors')
ax.scatter(separatrix_guess[:, 0], separatrix_guess[:, 1], c='tab:red', s=60, label='Separatrix seed')
ax.set_aspect('equal')
ax.set_xlabel('$h_1$')
ax.set_ylabel('$h_2$')
ax.set_title('Standard Koopman eigenfunction level sets')
ax.legend()
plt.tight_layout()
plt.show()



### Squashed objective

The squashed objective replaces the linear right-hand side with $\phi - \phi^3$, encouraging eigenfunction values to stay near $\pm 1$ away from the separatrix.


In [None]:
squashed_model = make_model().to(device)

squashed_locator = SeparatrixLocator(
    models=[squashed_model],
    dynamics_dim=dim,
    device=device,
    verbose=True,
    epochs=1000,
    lr=1e-3,
)

squashed_locator.fit(
    func=dynamics_func,
    distribution=distribution,
    batch_size=2048,
    balance_loss_lambda=5e-2,
    eigenvalue=1.0,
    RHS_function="lambda phi: phi - phi**3",
)



In [None]:
with torch.no_grad():
    values_squashed = squashed_model(coords).cpu().numpy().reshape(X_grid.shape)

fig, ax = plt.subplots(figsize=(5.5, 4.5))
contourf = ax.contourf(X_np, Y_np, values_squashed, levels=31, cmap='coolwarm')
plt.colorbar(contourf, ax=ax, label='Squashed eigenfunction')
ax.contour(
    X_np,
    Y_np,
    values_squashed,
    levels=[0.0],
    colors=['yellow'],
    linewidths=2.5,
    linestyles='--',
)
ax.streamplot(
    X_vals,
    Y_vals,
    U,
    V,
    color='k',
    density=1.0,
    linewidth=0.7,
)
ax.scatter(stable_points[:, 0], stable_points[:, 1], c='tab:green', s=40, label='Attractors')
ax.scatter(separatrix_guess[:, 0], separatrix_guess[:, 1], c='tab:red', s=60, label='Separatrix seed')
ax.set_aspect('equal')
ax.set_xlabel('$h_1$')
ax.set_ylabel('$h_2$')
ax.set_title('Squashed Koopman eigenfunction level sets')
ax.legend()
plt.tight_layout()
plt.show()