# Tutorial 04: Training an RNN on the Flip-Flop Task

In this tutorial we will walk through a two-part workflow using the utilities provided in `separatrix_locator.dynamics`:

1. Train a recurrent neural network on the classic flip-flop task and inspect its behaviour.
2. Treat the trained RNN as a dynamical system, extract its autonomous dynamics, and run the separatrix locator on that learned vector field.


In [None]:
# Install the package when running in Colab
# !pip install git+https://github.com/KabirDabholkar/separatrix_locator.git


## Part 1 · Train the Flip-Flop RNN

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.decomposition import PCA

from separatrix_locator.distributions import MultivariateGaussian, multiscaler

from separatrix_locator.dynamics import (
    FlipFlopDataset,
    FlipFlopSweepDataset,
    TrainingConfig,
    train_flipflop_rnn,
)

In [None]:
# Configuration for the flip-flop dataset and training loop
n_bits = 1
n_time = 50
n_trials = 32
repeats = 5
random_seed = 2

save_dir = Path("../rnn_params/tutorial_outputs/flipflop_rnn")
save_dir.mkdir(parents=True, exist_ok=True)

training_config = TrainingConfig(
    epochs=1000,
    log_interval=50,
    device="cuda" if torch.cuda.is_available() else "cpu",
    save_dir=save_dir,
    save_checkpoint=True,
    save_loss_plot=True,
)

# Instantiate dataset callables
train_dataset = FlipFlopDataset(
    n_trials=n_trials,
    repeats=repeats,
    n_time=n_time,
    n_bits=n_bits,
    p=0.2,
    random_seed=random_seed,
)

analysis_dataset = FlipFlopSweepDataset(
    n_trials=16,
    repeats=repeats,
    n_time=n_time,
    n_bits=n_bits,
    p=0.2,
    random_seed=random_seed,
)

In [None]:
%%time

result = train_flipflop_rnn(
    dataset=train_dataset,
    input_size=n_bits,
    output_size=n_bits,
    hidden_size=64,
    training_config=training_config,
)

print(f"Training finished with {len(result.loss_history)} iterations.")


In [None]:
plt.figure()
plt.plot(result.loss_history)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.show()


In [None]:
inputs_train, outputs_train = result.evaluate(train_dataset)

fig, ax = plt.subplots(figsize=(10, 4))

# Plot all input bits with solid lines
for i in range(1):
    ax.plot(inputs_train[:, i, 0], label=f"Train input bit {i+1}", linestyle='-', alpha=0.7)

# Plot all output bits with dashed lines
for i in range(1):
    ax.plot(outputs_train[:, i, 0], label=f"Train output bit {i+1}", linestyle='-', alpha=0.9)

ax.set_xlabel("Time step")
ax.set_ylabel("Value")
ax.set_title("Flip-Flop Task: Inputs and Network Outputs")
# ax.legend(loc="best")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib

inputs, outputs = result.evaluate(analysis_dataset)

fig, ax = plt.subplots(figsize=(10, 4))

n_bits = inputs.shape[1]
cmap = matplotlib.colormaps["Blues"]

for i in range(n_bits):
    color = cmap((i + 1) / n_bits)
    ax.plot(inputs[:, i, 0], label=f"Input bit {i+1}", linestyle='-', alpha=0.7, color=color)
    ax.plot(outputs[:, i, 0], label=f"Output bit {i+1}", linestyle='--', alpha=0.9, color=color)

ax.set_xlabel("Time step")
ax.set_ylabel("Value")
# ax.legend()
ax.set_title("Flip-Flop Task: Inputs and Network Outputs")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib

hidden, hidden_pca, explained = result.compute_hidden_pca(analysis_dataset, n_components=2)

print(f"Explained variance ratio: PC1={explained[0]:.3f}, PC2={explained[1]:.3f}")

plt.figure(figsize=(8, 6))
n_trials = hidden_pca.shape[1]
cmap = matplotlib.colormaps["Blues"]
for trial_idx in range(n_trials):
    color = cmap((trial_idx + 1) / max(8, n_trials))
    # Plot trajectory
    plt.plot(
        hidden_pca[:, trial_idx, 0],
        hidden_pca[:, trial_idx, 1],
        alpha=0.7,
        label=f"Trial {trial_idx + 1}",
        color=color,
    )
    # Mark the final point with a scatter plot
    plt.scatter(
        hidden_pca[-1, trial_idx, 0],
        hidden_pca[-1, trial_idx, 1],
        color=color,
        edgecolor='black',
        s=60,
        zorder=5
    )

plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("Hidden State Trajectories (PCA)")
# plt.legend(loc="best")
plt.grid(alpha=0.3)
plt.show()


## Part 2 · Extract RNN Dynamics and Locate the Separatrix


In this section we:

- extract the autonomous hidden-state dynamics from the trained GRU
- build a distribution over visited hidden states
- train a small Koopman-eigenfunction ensemble with `SeparatrixLocator`
- run gradient descent on the learned eigenfunctions to approximate separatrix points


In [None]:
from separatrix_locator.dynamics.rnn import (
    discrete_to_continuous,
    get_autonomous_dynamics_from_model,
    hidden_distribution_with_spectral_norm,
)
from separatrix_locator.core import SeparatrixLocator
from separatrix_locator.core.models import ResNet
from separatrix_locator.utils.odeint_utils import run_odeint_to_final

In [None]:
# Gather hidden-state trajectories on the analysis sweep and fit a Gaussian distribution
result.model.eval()
with torch.no_grad():
    inputs_eval, _ = analysis_dataset()
    inputs_eval_tensor = torch.from_numpy(inputs_eval).to(result.device, dtype=torch.float32)
    _, hidden_eval = result.model(inputs_eval_tensor, return_hidden=True)

hidden_eval_cpu = hidden_eval.detach().cpu()
hidden_eval_flat = hidden_eval_cpu.reshape(-1, hidden_eval_cpu.shape[-1])
hidden_distribution = hidden_distribution_with_spectral_norm(hidden_eval_flat)

hidden_dim = hidden_eval_cpu.shape[-1]
print(f"Hidden-state dimension: {hidden_dim}")


In [None]:
# Build a differentiable vector field from the trained recurrent weights
result.model.to(result.device)
autonomous_dynamics = get_autonomous_dynamics_from_model(result.model, device=str(result.device))
continuous_vector_field = discrete_to_continuous(autonomous_dynamics, delta_t=1.0)

def rnn_vector_field(x: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        return continuous_vector_field(x)

# Sanity check on a batch of hidden states
sample_states = hidden_distribution.sample((8,)).to(result.device)
vector_field_sample = rnn_vector_field(sample_states)
print(f"Vector field output shape: {vector_field_sample.shape}")

In [None]:
# Sample initial conditions from the hidden state distribution
num_samples = 256
init_states = hidden_distribution.sample((num_samples,)).to(result.device)

# Integrate each initial condition trajectory with the learned RNN vector field
T = 500.0
steps = 100
with torch.no_grad():
    # Run trajectories: shape (steps, num_samples, hidden_dim)
    traj = run_odeint_to_final(
        func=continuous_vector_field,
        y0=init_states,
        T=T,
        steps=steps,
        return_last_only=False,
        no_grad=True,
    )  # shape (steps, num_samples, hidden_dim)

# Reshape for PCA: concatenate all timepoints and samples to (steps*num_samples, hidden_dim)
traj_flat = traj.cpu().numpy().reshape(-1, hidden_dim)

# Fit PCA on the entire trajectory
pca = PCA(n_components=2)
traj_pca = pca.fit_transform(traj_flat)

# For coloring, we use time steps (repeat for each sample)
time_steps = np.arange(steps)
plot_time = np.repeat(time_steps, num_samples)

# Plot trajectories in PCA space (faint lines for each trajectory)
fig, ax = plt.subplots(figsize=(8, 6))
# Store all endpoints for scatter
endpoints_pca = []
for i in range(num_samples):
    seg = traj[:, i, :].cpu().numpy()
    seg_pca = pca.transform(seg)
    ax.plot(seg_pca[:, 0], seg_pca[:, 1], alpha=0.4)
    endpoints_pca.append(seg_pca[-1])

endpoints_pca = np.stack(endpoints_pca)
ax.scatter(endpoints_pca[:, 0], endpoints_pca[:, 1], color='red', label='End points', s=24, marker='o', zorder=3)
ax.set_title("RNN hidden state trajectories in PCA space")
ax.set_xlabel("PC 1")
ax.set_ylabel("PC 2")
ax.legend()
plt.show()


### Extracting attractors and a separatrix point

In [None]:
from sklearn.cluster import KMeans

# Get endpoints for each sample trajectory (shape: num_samples x hidden_dim)
traj_endpoints = traj[-1].cpu().numpy()

# Run k-means (k=2) on the endpoints to find two attractors in hidden state space
kmeans = KMeans(n_clusters=2, random_state=42)
kmeans.fit(traj_endpoints)
attractors = kmeans.cluster_centers_
# print("Identified attractors (in hidden_dim):\n", attractors)
print(attractors.shape)

### Finding a point on the separatrix by interpolating between attractors

In [None]:
from separatrix_locator.core.separatrix_point import find_separatrix_point_along_line
point_on_separatrix = find_separatrix_point_along_line(
    dynamics_function=rnn_vector_field,
    external_input=None,
    attractors=torch.from_numpy(attractors).type(torch.float32),
    num_points=20,
    num_iterations=2,
    final_time=500.0,
)

### Training distributions: centered at separatrix point, with multiple scales

In [None]:
distribution = MultivariateGaussian(
    dim = hidden_dim,
    mean=point_on_separatrix, 
    covariance_matrix=torch.eye(hidden_dim) * 2.0
)
multiscaled_distribution = multiscaler(distribution, [0.1, 1.0, 3.0])

### Training Koopman Eigenfunction

In [None]:
# Configure Separatrix Locator
num_models = 1
locator_models = [ResNet(input_dim=hidden_dim, hidden_size=500, output_dim=1, num_layers=8, input_scale_factor=1.0) for _ in range(num_models)]
for model in locator_models:
    model.to(result.device)

locator = SeparatrixLocator(
    num_models=num_models,
    dynamics_dim=hidden_dim,
    models=locator_models,
    lr=1e-3,
    epochs=500,
    use_multiprocessing=False,
    verbose=True,
    device=str(result.device),
)
locator.to(str(result.device))

# Train Koopman eigenfunction models on the learned dynamics
results = locator.fit(
    func=rnn_vector_field,
    distribution=multiscaled_distribution,
    dist_requires_dim=False,
    batch_size=256,
    eigenvalue=1.0,
    print_every_num_epochs=50,
)

### Validating Separatrix Locator

In [None]:
from separatrix_locator.plotting.hermite import find_separatrix_along_curve_using_ODE
from separatrix_locator.utils.interpolation import generate_curves_between_points

num_points = 100

alphas = np.linspace(0, 1, num_points)

curve_points = generate_curves_between_points(
    x=attractors[0],
    y=attractors[1],
    num_curves=100,
    num_points=num_points,
    rand_scale=3.0,
)

change_points_alpha, labels_bt = find_separatrix_along_curve_using_ODE(
    dynamics_function=rnn_vector_field,
    attractors=torch.from_numpy(attractors),
    alphas=alphas,
    curve_points=curve_points,
    integration_time=500.0,
    attractor_epsilon=0.02,
    kmeans_random_state=42,
    clustering_method="attractor_eps",
)


In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

num_curves, num_points, dim = curve_points.shape

# Flatten curves for PCA (shape: num_curves * num_points, dim)
curve_points_flat = curve_points.reshape(-1, curve_points.shape[-1])

# Perform PCA to 2D
pca = PCA(n_components=2)
curve_points_pca = pca.fit_transform(curve_points_flat)
curve_points_pca_bt = curve_points_pca.reshape(num_curves, num_points, 2)

# Prepare color labels (shape: num_curves, num_points)
labels_bt = labels_bt

plt.figure(figsize=(7, 7))
for i in range(num_curves):
    # Pick one color for the label of each curve, or color by label at each point
    # Here we color *segments* by the basin label at each segment
    for j in range(num_points-1):
        color = plt.cm.coolwarm(labels_bt[i,j] / (labels_bt.max() if labels_bt.max()>0 else 1))
        plt.plot(
            curve_points_pca_bt[i, j:j+2, 0],
            curve_points_pca_bt[i, j:j+2, 1],
            color=color,
            alpha=0.8,
            linewidth=1,
        )
    # Mark the change point with an 'x'
    # Find the change point index along this curve
    change_alpha = change_points_alpha[i]
    # Each curve goes from alpha=0 to alpha=1 uniformly spaced, so map change_alpha to nearest point
    change_idx = np.argmin(np.abs(np.linspace(0, 1, num_points) - change_alpha))
    plt.plot(
        curve_points_pca_bt[i, change_idx, 0],
        curve_points_pca_bt[i, change_idx, 1],
        marker='x',
        markersize=8,
        markeredgewidth=2,
        color='black',
        linestyle='None',
        label="Change point" if i == 0 else None,  # Label only once for legend
    )

plt.title("Hermite curves (PCA) colored by basin labels\nChange points marked with x's")
plt.xlabel("PC1")
plt.ylabel("PC2")

plt.show()


In [None]:
KEFvals = locator.predict(torch.from_numpy(curve_points).type(torch.float32))
l=plt.plot(KEFvals[:,:,0].T)

In [None]:
min_idx = torch.argmin(torch.abs(KEFvals)[...,0], dim=1)
zero_point_alphas = alphas[min_idx]

In [None]:
plt.scatter(change_points_alpha,zero_point_alphas)

#### do the validation with a single method call

In [None]:
r2score = locator.validate_with_curves(
    dynamics_function=rnn_vector_field,
    attractors=torch.from_numpy(attractors).type(torch.float32),
    num_curves=100,
    num_points=100,
    rand_scale=3.0,
    alpha_lims=(0.0, 1.0),
    integration_time=500.0,
    attractor_epsilon = 0.02,
    kmeans_random_state = 42,
    clustering_method = "attractor_eps",
    kef_component = 0,
    plot_pca = True,
    plot_kef = True,
    plot_scatter = True,
)
print("Curves R2 score:",r2score)

In [None]:
# Prepare the trained models for separatrix search and run gradient descent
_ = locator.prepare_models_for_gradient_descent(
    distribution=hidden_distribution,
    dist_needs_dim=False,
)

gd_trajectories, separatrix_points = locator.find_separatrix(
    distribution=hidden_distribution,
    dist_needs_dim=False,
    batch_size=96,
    num_steps=200,
    threshold=5e-2,
)

if separatrix_points:
    print(f"Collected {len(separatrix_points)} sets of candidate separatrix points.")
    for idx, pts in enumerate(separatrix_points[:2]):
        print(f"Model {idx} candidate count: {pts.shape[0]}")
else:
    print("No separatrix candidates found — consider adjusting threshold or training settings.")


In [None]:
# Visualise separatrix candidates in the leading PCA plane of the hidden states
if separatrix_points and separatrix_points[0].numel() > 0:
    from sklearn.decomposition import PCA

    pca = PCA(n_components=2)
    hidden_pca_flat = pca.fit_transform(hidden_eval_flat.numpy())

    sep_proj = [pca.transform(pts.cpu().numpy()) for pts in separatrix_points if pts.numel() > 0]

    plt.figure(figsize=(6, 6))
    plt.scatter(
        hidden_pca_flat[:, 0],
        hidden_pca_flat[:, 1],
        alpha=0.1,
        label="Hidden states",
    )
    for idx, proj in enumerate(sep_proj):
        plt.scatter(
            proj[:, 0],
            proj[:, 1],
            s=40,
            label=f"Separatrix candidates (model {idx})",
        )
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.legend()
    plt.title("Separatrix locator candidates in hidden-state PCA space")
    plt.show()
else:
    print("No separatrix candidates to visualise.")

## Next steps

- Experiment with different hidden sizes or cell types by passing `cell_type="LSTM"` to `train_flipflop_rnn`.
- Adjust the dataset parameters to make the task harder (e.g., more bits or longer sequences).
- Swap in alternative hidden-state distributions (e.g., PCA-reduced Gaussians) before running `SeparatrixLocator`.
- Tune the Koopman ensemble size, training epochs, or gradient-descent thresholds to sharpen the separatrix estimate.
- Use the trained checkpoint saved in `tutorial_outputs/flipflop_rnn` to initialise other analyses in the separatrix locator pipeline.
