# Spectral Separatrix: Coupled Ring Attractor Analysis

**Self-contained Colab notebook for reproducing all computations in the spectral separatrix paper.**

This notebook implements the spectral analysis of coupled ring attractor networks,
demonstrating how cross-inhibition strength $J_\times$ governs a pitchfork bifurcation
from coexistence to winner-take-all dynamics.

**Key results:**
- $J_\times^* \approx 0.3485$ (pitchfork bifurcation)
- $J_\times^{\mathrm{exist}} \approx 0.358$ (coexistence existence limit)
- $\Delta J \approx 0.01$ (razor-thin instability window)
- Two Goldstone modes protected by rotational symmetry
- Critical mode is uniform/DC (total activity competition, not spatial pattern)

**Paper:** *Spectral Separatrix of the Coupled Ring Attractor* (February 2026)

**Citation:** [placeholder]

---

In [None]:
# ============================================================================
# Cell 2: Imports and Model Constants
# ============================================================================

import numpy as np
from scipy.optimize import fsolve
from scipy.special import i0
from scipy.interpolate import RBFInterpolator
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patheffects as pe
from matplotlib.patches import FancyArrowPatch, Arc, Circle, FancyBboxPatch
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.gridspec import GridSpec
from matplotlib.colors import LinearSegmentedColormap, Normalize
from matplotlib.lines import Line2D
from matplotlib import ticker
import warnings
warnings.filterwarnings('ignore')

# ── Model constants ──────────────────────────────────────────────────────────
N = 48                # Neurons per ring
J_0, J_1 = 1.0, 6.0  # Within-network connectivity parameters
KAPPA = 2.0           # Von Mises concentration
INPUT_GAIN = 5.0      # Stimulus input gain
R_MAX = 1.0           # Sigmoid maximum rate
BETA = 5.0            # Sigmoid steepness
H0 = 0.5              # Sigmoid threshold
DT = 0.1              # Integration time step
TAU = 10.0            # Membrane time constant
GOLDSTONE_THRESH = 1e-3  # Threshold for identifying Goldstone modes

# ── Colors ───────────────────────────────────────────────────────────────────
COLORS = {
    'bulk': '#95a5a6',
    'critical': '#c0392b',
    'stable': '#2d5a7b',
    'unstable': '#e67e22',
    'dominance': '#8e44ad',
    'network_A': '#2d5a7b',
    'network_B': '#c0392b',
    'cliff': '#f39c12',
    'background': '#faf8f5',
}
COLOR_A = '#2d5a7b'
COLOR_B = '#c0392b'
COLOR_STABLE = '#2d5a7b'
COLOR_UNSTABLE = '#c0392b'
COLOR_BG = '#faf8f5'

# ── Matplotlib style ─────────────────────────────────────────────────────────
plt.rcParams.update({
    'figure.facecolor': COLOR_BG,
    'axes.facecolor': COLOR_BG,
    'font.family': 'serif',
    'font.size': 11,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.linewidth': 0.8,
    'axes.edgecolor': '#444444',
    'text.color': '#2a2a2a',
    'axes.labelcolor': '#2a2a2a',
    'xtick.color': '#555555',
    'ytick.color': '#555555',
})

print("Imports and constants loaded successfully.")
print(f"Model: N={N} neurons/ring, J_0={J_0}, J_1={J_1}, kappa={KAPPA}")
print(f"Sigmoid: r_max={R_MAX}, beta={BETA}, h0={H0}")
print(f"Dynamics: dt={DT}, tau={TAU}")

In [None]:
# ============================================================================
# Cell 3: Core Utility Functions
# ============================================================================
# Source: spectral_portrait_ring_attractor.py
# These are the fundamental building blocks of the coupled ring attractor model.

def sigmoid(h, r_max=1.0, beta=5.0, h0=0.5):
    """Sigmoid transfer function."""
    return r_max / (1.0 + np.exp(-beta * (h - h0)))


def sigmoid_derivative(h, r_max=1.0, beta=5.0, h0=0.5):
    """Derivative: sigma'(h) = beta * sigma(h) * (1 - sigma(h)/r_max)."""
    s = sigmoid(h, r_max, beta, h0)
    return beta * s * (1.0 - s / r_max)


def build_within_weights(N, J_0, J_1):
    """Cosine connectivity: W_ij = (-J_0 + J_1 cos(phi_i - phi_j)) / N."""
    preferred = np.linspace(-np.pi, np.pi, N, endpoint=False)
    dphi = preferred[:, np.newaxis] - preferred[np.newaxis, :]
    W = (-J_0 + J_1 * np.cos(dphi)) / N
    return W, preferred


def tuning_curve(theta, preferred, kappa):
    """Von Mises tuning curve."""
    return np.exp(kappa * np.cos(theta - preferred)) / (2 * np.pi * i0(kappa))


# Quick verification
W_test, pref_test = build_within_weights(N, J_0, J_1)
print(f"Weight matrix shape: {W_test.shape}")
print(f"Weight range: [{W_test.min():.4f}, {W_test.max():.4f}]")
print(f"Sigmoid at h=0.5: {sigmoid(0.5):.4f} (should be 0.5)")
print(f"Sigmoid derivative at h=0.5: {sigmoid_derivative(0.5):.4f}")
print("Core utility functions loaded.")

In [None]:
# ============================================================================
# Cell 4: Core Analysis Functions
# ============================================================================
# Source: spectral_separatrix_goldstone.py
# Fixed-point finding, Jacobian computation, eigenvalue classification.

def residual(x, W, cue_A, J_cross):
    """Residual for the coupled ring attractor fixed-point equation."""
    r_A, r_B = x[:N], x[N:]
    h_A = W @ r_A + cue_A - J_cross * np.mean(r_B)
    h_B = W @ r_B - J_cross * np.mean(r_A)
    return np.concatenate([-r_A + sigmoid(h_A), -r_B + sigmoid(h_B)])


def jacobian_analytical(x, W, cue_A, J_cross):
    """Analytical Jacobian of the coupled system (2N x 2N)."""
    r_A, r_B = x[:N], x[N:]
    h_A = W @ r_A + cue_A - J_cross * np.mean(r_B)
    h_B = W @ r_B - J_cross * np.mean(r_A)
    D_A = np.diag(sigmoid_derivative(h_A))
    D_B = np.diag(sigmoid_derivative(h_B))
    cm = np.full((N, N), -J_cross / N)
    J = np.zeros((2*N, 2*N))
    J[:N, :N] = -np.eye(N) + D_A @ W
    J[:N, N:] = D_A @ cm
    J[N:, :N] = D_B @ cm
    J[N:, N:] = -np.eye(N) + D_B @ W
    return J


def find_coexistence_fp(W, preferred, J_cross, r_A_init=None, r_B_init=None):
    """Find coexistence fixed point at cue=0 by simulation + Newton refinement."""
    if r_A_init is not None:
        r_A, r_B = r_A_init.copy(), r_B_init.copy()
    else:
        theta1, theta2 = np.pi/4, -np.pi/4
        drive_A = INPUT_GAIN * tuning_curve(theta1, preferred, KAPPA)
        drive_B = INPUT_GAIN * tuning_curve(theta2, preferred, KAPPA)
        r_A = sigmoid(W @ (drive_A * 0.3) + drive_A)
        r_B = sigmoid(W @ (drive_B * 0.3) + drive_B)
        for _ in range(500):
            h_A = W @ r_A + drive_A
            h_B = W @ r_B + drive_B
            r_A = np.maximum(0, r_A + (-r_A + sigmoid(h_A)) * DT/TAU)
            r_B = np.maximum(0, r_B + (-r_B + sigmoid(h_B)) * DT/TAU)

    for _ in range(100000):
        h_A = W @ r_A - J_cross * np.mean(r_B)
        h_B = W @ r_B - J_cross * np.mean(r_A)
        r_A = np.maximum(0, r_A + (-r_A + sigmoid(h_A)) * DT/TAU)
        r_B = np.maximum(0, r_B + (-r_B + sigmoid(h_B)) * DT/TAU)

    cue_0 = np.zeros(N)
    x0 = np.concatenate([r_A, r_B])
    sol, _, ier, _ = fsolve(residual, x0, args=(W, cue_0, J_cross),
                             fprime=lambda x, W, c, j: jacobian_analytical(x, W, c, j),
                             full_output=True, maxfev=10000)
    res = np.max(np.abs(residual(sol, W, cue_0, J_cross)))
    return sol[:N], sol[N:], res


def classify_eigenvalues(evals, evecs, preferred):
    """
    Separate eigenvalues into Goldstone modes (|lambda| < threshold)
    and genuine modes. For each eigenvector, compute its character
    (dominance vs drift vs uniform).
    """
    cos_p = np.cos(preferred - np.pi/4)
    sin_p = np.sin(preferred - np.pi/4)

    # Projection directions
    d_dom = np.concatenate([cos_p, -cos_p])
    d_dom /= np.linalg.norm(d_dom)

    d_drift_same = np.concatenate([sin_p, sin_p])
    d_drift_same /= np.linalg.norm(d_drift_same)

    d_drift_opp = np.concatenate([sin_p, -sin_p])
    d_drift_opp /= np.linalg.norm(d_drift_opp)

    d_uni = np.concatenate([np.ones(N), -np.ones(N)])
    d_uni /= np.linalg.norm(d_uni)

    d_gold_A = np.concatenate([sin_p, np.zeros(N)])
    d_gold_A /= np.linalg.norm(d_gold_A)

    d_gold_B = np.concatenate([np.zeros(N), sin_p])
    d_gold_B /= np.linalg.norm(d_gold_B)

    results = []
    for i in range(len(evals)):
        ev = evals[i].real
        vec = evecs[:, i].real
        vec_n = vec / (np.linalg.norm(vec) + 1e-30)

        is_goldstone = abs(ev) < GOLDSTONE_THRESH

        proj = {
            'dominance': abs(np.dot(vec_n, d_dom)),
            'drift_same': abs(np.dot(vec_n, d_drift_same)),
            'drift_opp': abs(np.dot(vec_n, d_drift_opp)),
            'uniform': abs(np.dot(vec_n, d_uni)),
            'gold_A': abs(np.dot(vec_n, d_gold_A)),
            'gold_B': abs(np.dot(vec_n, d_gold_B)),
        }
        best_char = max(proj, key=proj.get)

        results.append({
            'eigenvalue': ev,
            'eigenvector': vec_n,
            'is_goldstone': is_goldstone,
            'projections': proj,
            'character': best_char,
        })

    return results


def find_wta_fp(W, preferred, J_cross, dominant='A'):
    """
    Find a WTA fixed point by initializing one network much stronger
    than the other.
    """
    theta1 = np.pi / 4
    drive = INPUT_GAIN * tuning_curve(theta1, preferred, KAPPA)

    if dominant == 'A':
        r_A = sigmoid(W @ (drive * 0.5) + drive)
        r_B = np.ones(N) * 0.01
    else:
        r_A = np.ones(N) * 0.01
        r_B = sigmoid(W @ (drive * 0.5) + drive)

    for _ in range(200000):
        h_A = W @ r_A - J_cross * np.mean(r_B)
        h_B = W @ r_B - J_cross * np.mean(r_A)
        r_A = np.maximum(0, r_A + (-r_A + sigmoid(h_A)) * DT / TAU)
        r_B = np.maximum(0, r_B + (-r_B + sigmoid(h_B)) * DT / TAU)

    cue_0 = np.zeros(N)
    x0 = np.concatenate([r_A, r_B])
    sol, _, ier, _ = fsolve(residual, x0, args=(W, cue_0, J_cross),
                             fprime=lambda x, W, c, j: jacobian_analytical(x, W, c, j),
                             full_output=True, maxfev=10000)
    r_A_sol = sol[:N]
    r_B_sol = sol[N:]
    res = np.max(np.abs(residual(sol, W, cue_0, J_cross)))
    return r_A_sol, r_B_sol, res


print("Core analysis functions loaded.")
print("  - residual(), jacobian_analytical()")
print("  - find_coexistence_fp(), find_wta_fp()")
print("  - classify_eigenvalues()")

## Analysis 1: Goldstone-Separated Eigenvalue Scan

Scan across $J_\times$ values, computing the full 96D Jacobian eigenspectrum at each point.
Goldstone modes (neutral rotations with $|\lambda| < 10^{-3}$) are separated from genuine
dynamical modes. The first non-Goldstone eigenvalue $\lambda_\mathrm{dom}$ crosses zero
at the pitchfork bifurcation $J_\times^* \approx 0.3485$.

This is the core computation of the paper.

In [None]:
%%time
# ============================================================================
# Cell 6: Goldstone-Separated Eigenvalue Scan (THE computation)
# ============================================================================
# Source: spectral_separatrix_goldstone.py — run_analysis() + plotting

def run_analysis():
    """Scan J_cross, separating Goldstone modes from genuine instability."""
    W, preferred = build_within_weights(N, J_0, J_1)

    # Coarse scan + fine scan near critical region
    jc_values = np.sort(np.unique(np.concatenate([
        np.linspace(0.0, 0.30, 10),
        np.linspace(0.30, 0.36, 30),
    ])))

    print("=" * 70)
    print("GOLDSTONE-SEPARATED EIGENVALUE ANALYSIS")
    print("=" * 70)

    all_results = []
    r_A_prev, r_B_prev = None, None

    for jc in jc_values:
        if r_A_prev is not None and jc > 0:
            r_A, r_B, res = find_coexistence_fp(W, preferred, jc, r_A_prev, r_B_prev)
        else:
            r_A, r_B, res = find_coexistence_fp(W, preferred, jc)

        if res > 1e-4 or np.max(r_A) < 0.3 or np.max(r_B) < 0.3:
            print(f"  J_cross={jc:.4f}: FAILED or no bumps")
            all_results.append(None)
            continue

        # Full Jacobian
        cue_0 = np.zeros(N)
        x = np.concatenate([r_A, r_B])
        J = jacobian_analytical(x, W, cue_0, jc)
        evals, evecs = np.linalg.eig(J)
        idx = np.argsort(-evals.real)
        evals, evecs = evals[idx], evecs[:, idx]

        classified = classify_eigenvalues(evals, evecs, preferred)

        # Separate Goldstone from genuine
        goldstone = [c for c in classified if c['is_goldstone']]
        genuine = [c for c in classified if not c['is_goldstone']]
        genuine.sort(key=lambda c: -c['eigenvalue'])

        n_gold = len(goldstone)
        n_genuine_pos = sum(1 for c in genuine if c['eigenvalue'] > 1e-6)
        lam_dom = genuine[0]['eigenvalue'] if genuine else np.nan
        dom_char = genuine[0]['character'] if genuine else 'N/A'
        dom_proj = genuine[0]['projections'] if genuine else {}

        result = {
            'J_cross': jc,
            'max_rA': np.max(r_A), 'max_rB': np.max(r_B),
            'D': np.mean(r_A) - np.mean(r_B),
            'n_goldstone': n_gold,
            'n_genuine_positive': n_genuine_pos,
            'lambda_dom': lam_dom,
            'dom_character': dom_char,
            'dom_projections': dom_proj,
            'goldstone_evals': [c['eigenvalue'] for c in goldstone[:4]],
            'top_genuine_evals': [c['eigenvalue'] for c in genuine[:5]],
            'top_genuine_chars': [c['character'] for c in genuine[:5]],
            'all_evals': evals,
        }
        all_results.append(result)
        r_A_prev, r_B_prev = r_A.copy(), r_B.copy()

        marker = "UNSTABLE" if n_genuine_pos > 0 else "stable"
        gold_str = f"({n_gold} Goldstone)"
        print(f"  J_cross={jc:.4f}: lam_dom={lam_dom:+.6f} [{dom_char}], "
              f"genuine_pos={n_genuine_pos}, {gold_str}  [{marker}]")

    return all_results, W, preferred, jc_values


def plot_goldstone_analysis(results):
    """Plot the 9-panel Goldstone-separated analysis figure."""
    valid = [r for r in results if r is not None]
    jc_v = [r['J_cross'] for r in valid]

    fig = plt.figure(figsize=(22, 14))
    gs = GridSpec(3, 3, figure=fig, hspace=0.45, wspace=0.35)

    # (0,0) lambda_dom vs J_cross (the money plot)
    ax = fig.add_subplot(gs[0, 0])
    lam_dom = [r['lambda_dom'] for r in valid]
    colors = ['#e74c3c' if l > 1e-4 else '#2196F3' for l in lam_dom]
    ax.scatter(jc_v, lam_dom, c=colors, s=40, zorder=3, edgecolors='black', linewidths=0.5)
    ax.plot(jc_v, lam_dom, '-', color='gray', lw=0.8, alpha=0.5)
    ax.axhline(y=0, color='black', ls='--', lw=1.5, alpha=0.4)
    ax.set_xlabel('J_cross')
    ax.set_ylabel('lambda_dom (first non-Goldstone)')
    ax.set_title('Dominance Eigenvalue\n(Goldstone modes removed)', fontweight='bold')

    # Find crossing
    for i in range(len(valid) - 1):
        l1, l2 = valid[i]['lambda_dom'], valid[i+1]['lambda_dom']
        if l1 < -1e-4 and l2 > 1e-4:
            jc1, jc2 = valid[i]['J_cross'], valid[i+1]['J_cross']
            jc_crit = jc1 + (0 - l1) * (jc2 - jc1) / (l2 - l1)
            ax.axvline(x=jc_crit, color='#e74c3c', ls=':', lw=2, alpha=0.8)
            ax.annotate(f'J* = {jc_crit:.4f}', xy=(jc_crit, 0),
                       xytext=(jc_crit - 0.08, max(lam_dom) * 0.7),
                       fontsize=11, fontweight='bold', color='#e74c3c',
                       arrowprops=dict(arrowstyle='->', color='#e74c3c', lw=1.5))

    # (0,1) Number of Goldstone modes
    ax = fig.add_subplot(gs[0, 1])
    ax.plot(jc_v, [r['n_goldstone'] for r in valid], 'o-', color='#9C27B0',
            lw=2, ms=5)
    ax.set_xlabel('J_cross')
    ax.set_ylabel('# Goldstone modes (|lambda| < 1e-3)')
    ax.set_title('Goldstone Mode Count', fontweight='bold')
    ax.set_yticks(range(max(r['n_goldstone'] for r in valid) + 2))

    # (0,2) Top 5 genuine eigenvalues
    ax = fig.add_subplot(gs[0, 2])
    for k in range(min(5, min(len(r['top_genuine_evals']) for r in valid))):
        vals = [r['top_genuine_evals'][k] for r in valid]
        ax.plot(jc_v, vals, 'o-', ms=2, lw=1.5, label=f'lambda_{k+1}(genuine)')
    ax.axhline(y=0, color='black', ls='--', alpha=0.4)
    ax.set_xlabel('J_cross')
    ax.set_ylabel('Re(lambda)')
    ax.set_title('Top 5 Non-Goldstone Eigenvalues', fontweight='bold')
    ax.legend(fontsize=7)

    # (1,0) Eigenvector character
    ax = fig.add_subplot(gs[1, 0])
    proj_keys = ['dominance', 'drift_same', 'drift_opp', 'uniform']
    proj_colors = {'dominance': '#e74c3c', 'drift_same': '#2196F3',
                   'drift_opp': '#9C27B0', 'uniform': '#4CAF50'}
    for key in proj_keys:
        vals = [r['dom_projections'].get(key, 0) for r in valid]
        ax.plot(jc_v, vals, 'o-', color=proj_colors[key], lw=1.5, ms=3, label=key)
    ax.set_xlabel('J_cross')
    ax.set_ylabel('|Projection|')
    ax.set_title('Dominant Mode Character', fontweight='bold')
    ax.legend(fontsize=7)
    ax.set_ylim(-0.05, 1.05)

    # (1,1) Goldstone eigenvalue magnitudes
    ax = fig.add_subplot(gs[1, 1])
    for k in range(4):
        vals, jcs = [], []
        for r in valid:
            if len(r['goldstone_evals']) > k:
                vals.append(abs(r['goldstone_evals'][k]))
                jcs.append(r['J_cross'])
        if vals:
            ax.semilogy(jcs, vals, 'o-', ms=3, lw=1, label=f'|lambda_gold_{k+1}|')
    ax.set_xlabel('J_cross')
    ax.set_ylabel('|lambda| (log scale)')
    ax.set_title('Goldstone Eigenvalue Magnitudes', fontweight='bold')
    ax.legend(fontsize=7)

    # (1,2) Bump heights
    ax = fig.add_subplot(gs[1, 2])
    ax.plot(jc_v, [r['max_rA'] for r in valid], 'o-', color=COLORS['network_A'],
            lw=2, ms=3, label='max(r_A)')
    ax.plot(jc_v, [r['max_rB'] for r in valid], 'o-', color=COLORS['network_B'],
            lw=2, ms=3, label='max(r_B)')
    ax.set_xlabel('J_cross')
    ax.set_ylabel('Peak firing rate')
    ax.set_title('Bump Heights (cue=0)', fontweight='bold')
    ax.legend(fontsize=8)

    # (2,0-2) Full eigenvalue spectrum at 3 key J_cross values
    show_jc = [0.1, 0.3, 0.35]
    for col, jc_target in enumerate(show_jc):
        ax = fig.add_subplot(gs[2, col])
        matches = [r for r in valid if abs(r['J_cross'] - jc_target) < 0.015]
        if matches:
            m = min(matches, key=lambda r: abs(r['J_cross'] - jc_target))
            evals_sorted = np.sort(m['all_evals'].real)[::-1]
            colors_ev = ['#e74c3c' if e > 1e-3 else '#9C27B0' if abs(e) < 1e-3
                         else '#2196F3' for e in evals_sorted]
            ax.bar(range(len(evals_sorted)), evals_sorted, color=colors_ev, width=1.0)
            ax.axhline(y=0, color='black', ls='--', lw=1, alpha=0.5)
            ax.set_xlabel('Eigenvalue index')
            ax.set_ylabel('Re(lambda)')
            ax.set_title(f'Full Spectrum at J_cross={m["J_cross"]:.3f}\n'
                        f'Gold={m["n_goldstone"]}, Unstable={m["n_genuine_positive"]}',
                        fontsize=9, fontweight='bold')
            ax.set_xlim(-1, min(30, len(evals_sorted)))

    fig.suptitle('Goldstone-Separated Spectral Analysis\n'
                 'Distinguishing neutral rotation modes from genuine instability',
                 fontsize=14, fontweight='bold', y=1.01)
    plt.tight_layout()
    plt.show()

    # Print summary
    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    for r in valid:
        status = "UNSTABLE" if r['n_genuine_positive'] > 0 else "stable"
        print(f"  J={r['J_cross']:.4f}: lam_dom={r['lambda_dom']:+.6f} "
              f"[{r['dom_character']:15s}] gold={r['n_goldstone']} "
              f"gen_pos={r['n_genuine_positive']}  {status}")

    # Precise critical point
    for i in range(len(valid) - 1):
        l1, l2 = valid[i]['lambda_dom'], valid[i+1]['lambda_dom']
        if l1 < -1e-4 and l2 > 1e-4:
            jc1, jc2 = valid[i]['J_cross'], valid[i+1]['J_cross']
            jc_crit = jc1 + (0 - l1) * (jc2 - jc1) / (l2 - l1)
            print(f"\n  *** CRITICAL J_cross* = {jc_crit:.4f} ***")
            break


# Run the analysis
goldstone_results, W_gold, preferred_gold, jc_values_gold = run_analysis()
plot_goldstone_analysis(goldstone_results)

## Analysis 2: Heterogeneity Test (GLM 5 Prediction)

GLM 5 predicted that adding heterogeneity (non-uniform connectivity noise) to the weight
matrices should **widen** the instability window from the razor-thin $\Delta J \approx 0.01$
to a broader regime. This section tests that prediction by:

1. Building noisy weight matrices at various amplitudes $\sigma$
2. Measuring where coexistence becomes unstable ($\lambda_\mathrm{dom}$ crosses zero)
3. Measuring where coexistence ceases to exist (bumps collapse)
4. Tracking critical slowing down near $J_\times^*$

If GLM 5 is right: the window widens with noise amplitude.
If wrong: the window narrows or stays the same.

In [None]:
%%time
# ============================================================================
# Cell 8: Heterogeneity and Critical Slowing Down Experiments
# ============================================================================
# Source: heterogeneity_test.py

def build_noisy_weights(N, J_0, J_1, noise_sigma):
    """Build within-network weights with additive Gaussian noise."""
    W_clean, preferred = build_within_weights(N, J_0, J_1)
    noise = np.random.randn(N, N) * noise_sigma / np.sqrt(N)
    # Keep it symmetric (biological plausibility)
    noise = (noise + noise.T) / 2
    W = W_clean + noise
    return W, preferred, W_clean


def residual_het(x, W, J_cross):
    """Residual for heterogeneity test (no cue input)."""
    r_A, r_B = x[:N], x[N:]
    h_A = W @ r_A - J_cross * np.mean(r_B)
    h_B = W @ r_B - J_cross * np.mean(r_A)
    return np.concatenate([-r_A + sigmoid(h_A), -r_B + sigmoid(h_B)])


def jacobian_het(x, W, J_cross):
    """Jacobian for heterogeneity test."""
    r_A, r_B = x[:N], x[N:]
    h_A = W @ r_A - J_cross * np.mean(r_B)
    h_B = W @ r_B - J_cross * np.mean(r_A)
    D_A = np.diag(sigmoid_derivative(h_A))
    D_B = np.diag(sigmoid_derivative(h_B))
    cm = np.full((N, N), -J_cross / N)
    J = np.zeros((2*N, 2*N))
    J[:N, :N] = -np.eye(N) + D_A @ W
    J[:N, N:] = D_A @ cm
    J[N:, :N] = D_B @ cm
    J[N:, N:] = -np.eye(N) + D_B @ W
    return J


def find_coexistence_het(W, J_cross, preferred):
    """Find coexistence FP at cue=0 by simulation + Newton."""
    theta1, theta2 = np.pi/4, -np.pi/4
    drive_A = INPUT_GAIN * tuning_curve(theta1, preferred, KAPPA)
    drive_B = INPUT_GAIN * tuning_curve(theta2, preferred, KAPPA)

    r_A = sigmoid(W @ (drive_A * 0.3) + drive_A)
    r_B = sigmoid(W @ (drive_B * 0.3) + drive_B)
    for _ in range(500):
        h_A = W @ r_A + drive_A
        h_B = W @ r_B + drive_B
        r_A = np.maximum(0, r_A + (-r_A + sigmoid(h_A)) * DT/TAU)
        r_B = np.maximum(0, r_B + (-r_B + sigmoid(h_B)) * DT/TAU)

    for _ in range(50000):
        h_A = W @ r_A - J_cross * np.mean(r_B)
        h_B = W @ r_B - J_cross * np.mean(r_A)
        r_A = np.maximum(0, r_A + (-r_A + sigmoid(h_A)) * DT/TAU)
        r_B = np.maximum(0, r_B + (-r_B + sigmoid(h_B)) * DT/TAU)

    x0 = np.concatenate([r_A, r_B])
    sol, _, ier, _ = fsolve(residual_het, x0, args=(W, J_cross),
                             fprime=lambda x, W, j: jacobian_het(x, W, j),
                             full_output=True, maxfev=5000)
    res = np.max(np.abs(residual_het(sol, W, J_cross)))
    return sol[:N], sol[N:], res


def analyze_at_jcross(W, J_cross, preferred):
    """Return (lambda_dom, n_goldstone, has_bumps, convergence_time)."""
    r_A, r_B, res = find_coexistence_het(W, J_cross, preferred)

    has_bumps = np.max(r_A) > 0.3 and np.max(r_B) > 0.3
    if res > 1e-4 or not has_bumps:
        return np.nan, 0, False, np.nan

    x = np.concatenate([r_A, r_B])
    J = jacobian_het(x, W, J_cross)
    evals = np.linalg.eigvals(J)
    evals_real = np.sort(evals.real)[::-1]

    # Separate Goldstone from genuine
    n_gold = np.sum(np.abs(evals_real) < GOLDSTONE_THRESH)
    genuine = evals_real[np.abs(evals_real) >= GOLDSTONE_THRESH]

    if len(genuine) == 0:
        return 0.0, n_gold, True, np.nan

    lam_dom = genuine[0]

    # Convergence time: simulate perturbation and measure decay
    perturb = np.random.randn(2*N) * 0.001
    x_pert = x + perturb
    r_A_p, r_B_p = x_pert[:N], x_pert[N:]
    diffs = []
    for t in range(5000):
        h_A = W @ r_A_p - J_cross * np.mean(r_B_p)
        h_B = W @ r_B_p - J_cross * np.mean(r_A_p)
        r_A_p = np.maximum(0, r_A_p + (-r_A_p + sigmoid(h_A)) * DT/TAU)
        r_B_p = np.maximum(0, r_B_p + (-r_B_p + sigmoid(h_B)) * DT/TAU)
        diff = np.sqrt(np.mean((r_A_p - r_A)**2) + np.mean((r_B_p - r_B)**2))
        diffs.append(diff)
        if diff < 1e-8:
            break

    diffs = np.array(diffs)
    initial = diffs[0] if len(diffs) > 0 else 1.0
    target = initial / np.e
    crossings = np.where(diffs < target)[0]
    conv_time = crossings[0] * DT if len(crossings) > 0 else 5000 * DT

    return lam_dom, n_gold, True, conv_time


# ── Experiment 1: Heterogeneity ──────────────────────────────────────────────

def heterogeneity_experiment():
    print("=" * 70)
    print("EXPERIMENT 1: HETEROGENEITY TEST")
    print("=" * 70)

    noise_levels = [0.0, 0.05, 0.1, 0.2, 0.3, 0.5]
    jc_values = np.linspace(0.25, 0.38, 30)
    n_trials = 5

    all_data = {}

    for sigma in noise_levels:
        print(f"\n  Noise sigma = {sigma:.2f}")
        trial_results = []

        for trial in range(n_trials):
            np.random.seed(42 + trial)
            W, preferred, _ = build_noisy_weights(N, J_0, J_1, sigma)
            results = []

            for jc in jc_values:
                lam, ng, has_bumps, ct = analyze_at_jcross(W, jc, preferred)
                results.append({
                    'J_cross': jc, 'lambda_dom': lam,
                    'n_goldstone': ng, 'has_bumps': has_bumps,
                    'conv_time': ct
                })

            trial_results.append(results)

            valid = [r for r in results if r['has_bumps'] and not np.isnan(r['lambda_dom'])]
            if valid:
                jc_exist_max = max(r['J_cross'] for r in valid)
                unstable = [r for r in valid if r['lambda_dom'] > 0]
                jc_unstable_min = min(r['J_cross'] for r in unstable) if unstable else None

                if jc_unstable_min:
                    print(f"    Trial {trial}: unstable from {jc_unstable_min:.4f}, "
                          f"exists to {jc_exist_max:.4f}, "
                          f"window = {jc_exist_max - jc_unstable_min:.4f}")
                else:
                    print(f"    Trial {trial}: all stable, exists to {jc_exist_max:.4f}")

        all_data[sigma] = trial_results

    return all_data, jc_values


# ── Experiment 2: Critical Slowing Down ──────────────────────────────────────

def critical_slowing_experiment():
    print("\n" + "=" * 70)
    print("EXPERIMENT 2: CRITICAL SLOWING DOWN")
    print("=" * 70)

    W, preferred = build_within_weights(N, J_0, J_1)
    jc_values = np.linspace(0.20, 0.36, 40)

    results = []
    for jc in jc_values:
        lam, ng, has_bumps, ct = analyze_at_jcross(W, jc, preferred)
        results.append({
            'J_cross': jc, 'lambda_dom': lam,
            'conv_time': ct, 'has_bumps': has_bumps
        })
        if has_bumps and not np.isnan(lam):
            print(f"  J_cross={jc:.4f}: lambda_dom={lam:+.6f}, tau_conv={ct:.1f}")

    return results


# ── Plotting ─────────────────────────────────────────────────────────────────

def plot_experiments(het_data, jc_values, csd_results):
    fig = plt.figure(figsize=(22, 14))
    gs = GridSpec(2, 3, figure=fig, hspace=0.40, wspace=0.35)

    # (0,0) lambda_dom vs J_cross for different noise levels
    ax = fig.add_subplot(gs[0, 0])
    cmap = plt.cm.viridis
    noise_levels = sorted(het_data.keys())
    for i, sigma in enumerate(noise_levels):
        trials = het_data[sigma]
        all_lams = []
        for trial in trials:
            lams = []
            for r in trial:
                if r['has_bumps'] and not np.isnan(r['lambda_dom']):
                    lams.append(r['lambda_dom'])
                else:
                    lams.append(np.nan)
            all_lams.append(lams)
        mean_lams = np.nanmean(all_lams, axis=0)
        c = cmap(i / max(1, len(noise_levels) - 1))
        ax.plot(jc_values, mean_lams, 'o-', color=c, ms=3, lw=1.5,
                label=f'sigma={sigma:.2f}')
    ax.axhline(y=0, color='black', ls='--', lw=1.5, alpha=0.4)
    ax.set_xlabel('J_cross')
    ax.set_ylabel('lambda_dom (mean over trials)')
    ax.set_title('Dominance Eigenvalue vs Noise', fontweight='bold')
    ax.legend(fontsize=7, ncol=2)

    # (0,1) Instability onset J_cross vs noise level
    ax = fig.add_subplot(gs[0, 1])
    onset_jc = []
    collapse_jc = []
    window_widths = []
    for sigma in noise_levels:
        trials = het_data[sigma]
        onsets = []
        collapses = []
        for trial in trials:
            valid = [r for r in trial if r['has_bumps'] and not np.isnan(r['lambda_dom'])]
            if not valid:
                continue
            collapse = max(r['J_cross'] for r in valid)
            collapses.append(collapse)
            unstable = [r for r in valid if r['lambda_dom'] > 0]
            if unstable:
                onsets.append(min(r['J_cross'] for r in unstable))
        onset_jc.append((np.mean(onsets) if onsets else np.nan,
                         np.std(onsets) if len(onsets) > 1 else 0))
        collapse_jc.append((np.mean(collapses) if collapses else np.nan,
                            np.std(collapses) if len(collapses) > 1 else 0))
        if onsets and collapses:
            widths = [c - o for o, c in zip(onsets, collapses) if not np.isnan(o)]
            window_widths.append((np.mean(widths) if widths else np.nan,
                                  np.std(widths) if len(widths) > 1 else 0))
        else:
            window_widths.append((np.nan, 0))

    onset_means = [o[0] for o in onset_jc]
    onset_stds = [o[1] for o in onset_jc]
    collapse_means = [c[0] for c in collapse_jc]
    collapse_stds = [c[1] for c in collapse_jc]

    ax.errorbar(noise_levels, onset_means, yerr=onset_stds,
                fmt='o-', color=COLORS['critical'], lw=2, ms=6, capsize=3,
                label='Instability onset')
    ax.errorbar(noise_levels, collapse_means, yerr=collapse_stds,
                fmt='s-', color=COLORS['stable'], lw=2, ms=6, capsize=3,
                label='Coexistence collapse')
    ax.fill_between(noise_levels, onset_means, collapse_means,
                    alpha=0.2, color=COLORS['dominance'])
    ax.set_xlabel('Noise sigma')
    ax.set_ylabel('J_cross')
    ax.set_title('Instability Window vs Heterogeneity', fontweight='bold')
    ax.legend(fontsize=8)

    # (0,2) Window width vs noise
    ax = fig.add_subplot(gs[0, 2])
    widths_means = [w[0] for w in window_widths]
    widths_stds = [w[1] for w in window_widths]
    ax.errorbar(noise_levels, widths_means, yerr=widths_stds,
                fmt='o-', color=COLORS['dominance'], lw=2, ms=8, capsize=4)
    ax.set_xlabel('Noise sigma')
    ax.set_ylabel('Delta_J_cross (window width)')
    ax.set_title('GLM 5 Prediction: Wider Window?', fontweight='bold')
    ax.axhline(y=widths_means[0] if not np.isnan(widths_means[0]) else 0,
               color='gray', ls=':', lw=1, alpha=0.5, label='sigma=0 baseline')
    ax.legend(fontsize=8)

    # (1,0) Critical slowing down: lambda_dom
    ax = fig.add_subplot(gs[1, 0])
    csd_valid = [r for r in csd_results if r['has_bumps'] and not np.isnan(r['lambda_dom'])]
    jc_csd = [r['J_cross'] for r in csd_valid]
    lam_csd = [r['lambda_dom'] for r in csd_valid]
    colors_csd = ['#e74c3c' if l > 0 else '#2196F3' for l in lam_csd]
    ax.scatter(jc_csd, lam_csd, c=colors_csd, s=30, zorder=3)
    ax.plot(jc_csd, lam_csd, '-', color='gray', lw=0.8, alpha=0.5)
    ax.axhline(y=0, color='black', ls='--', lw=1)
    ax.set_xlabel('J_cross')
    ax.set_ylabel('lambda_dom')
    ax.set_title('lambda_dom Near Critical Point (clean)', fontweight='bold')

    # (1,1) Critical slowing down: convergence time
    ax = fig.add_subplot(gs[1, 1])
    ct_valid = [r for r in csd_valid if not np.isnan(r['conv_time']) and r['lambda_dom'] < 0]
    if ct_valid:
        ax.plot([r['J_cross'] for r in ct_valid],
                [r['conv_time'] for r in ct_valid],
                'o-', color=COLORS['dominance'], lw=2, ms=5)
        ax.set_xlabel('J_cross')
        ax.set_ylabel('Convergence time tau')
        ax.set_title('Critical Slowing Down', fontweight='bold')

    # (1,2) tau vs 1/|lambda_dom|
    ax = fig.add_subplot(gs[1, 2])
    if ct_valid:
        inv_lam = [1.0 / abs(r['lambda_dom']) for r in ct_valid if abs(r['lambda_dom']) > 1e-5]
        ct_vals = [r['conv_time'] for r in ct_valid if abs(r['lambda_dom']) > 1e-5]
        ax.scatter(inv_lam, ct_vals, c=COLORS['dominance'], s=30)
        if len(inv_lam) > 2:
            coeffs = np.polyfit(inv_lam, ct_vals, 1)
            x_fit = np.linspace(min(inv_lam), max(inv_lam), 100)
            ax.plot(x_fit, np.polyval(coeffs, x_fit), '--', color='gray', lw=1.5,
                    label=f'slope={coeffs[0]:.2f}')
            ax.legend(fontsize=8)
        ax.set_xlabel('1/|lambda_dom|')
        ax.set_ylabel('Convergence time tau')
        ax.set_title('tau vs 1/|lambda_dom| (Linear = CSD)', fontweight='bold')

    fig.suptitle('Heterogeneity and Critical Slowing Down\n'
                 'Testing GLM 5 predictions about the separatrix',
                 fontsize=14, fontweight='bold', y=1.01)
    plt.tight_layout()
    plt.show()

    # Summary
    print(f"\n{'='*70}")
    print("GLM 5 PREDICTION TEST SUMMARY")
    print(f"{'='*70}")
    for sigma in noise_levels:
        trials = het_data[sigma]
        widths = []
        for trial in trials:
            valid = [r for r in trial if r['has_bumps'] and not np.isnan(r['lambda_dom'])]
            if not valid:
                continue
            collapse = max(r['J_cross'] for r in valid)
            unstable = [r for r in valid if r['lambda_dom'] > 0]
            if unstable:
                onset = min(r['J_cross'] for r in unstable)
                widths.append(collapse - onset)
        if widths:
            print(f"  sigma={sigma:.2f}: window = {np.mean(widths):.4f} +/- {np.std(widths):.4f}")
        else:
            print(f"  sigma={sigma:.2f}: no instability window found")
    print(f"{'='*70}")


# Run both experiments
het_data, jc_het = heterogeneity_experiment()
csd_results = critical_slowing_experiment()
plot_experiments(het_data, jc_het, csd_results)

## Paper Figures

Publication-quality figures for the spectral separatrix paper:
- **Fig 5**: Coexistence existence boundary (bump heights vs $J_\times$)
- **Fig 6**: Pitchfork bifurcation diagram ($D$ vs $J_\times$)
- **Fig 7**: Eigenvector comparison (Goldstone vs critical dominance mode)

In [None]:
%%time
# ============================================================================
# Cell 10: Paper Figures 5, 6, 7
# ============================================================================
# Source: generate_paper_figures.py

# Critical parameter values
J_CROSS_STAR = 0.3485
J_CROSS_EXIST = 0.358


# ── Fig 5: Coexistence Existence Boundary ────────────────────────────────────

def fig5_existence():
    """
    Fig 5: Coexistence Existence Boundary
    max(r_A) and max(r_B) vs J_cross, showing the sharp collapse
    at J_cross ~ 0.36.
    """
    print("  Generating Fig 5: Coexistence Existence Boundary...")
    W, preferred = build_within_weights(N, J_0, J_1)

    jc_values = np.sort(np.unique(np.concatenate([
        np.linspace(0.01, 0.30, 15),
        np.linspace(0.30, 0.37, 40),
        np.linspace(0.37, 0.40, 10),
    ])))

    max_rA_list, max_rB_list = [], []
    mean_rA_list, mean_rB_list = [], []
    jc_valid = []
    r_A_prev, r_B_prev = None, None

    for jc in jc_values:
        try:
            if r_A_prev is not None:
                r_A, r_B, res = find_coexistence_fp(W, preferred, jc,
                                                     r_A_prev, r_B_prev)
            else:
                r_A, r_B, res = find_coexistence_fp(W, preferred, jc)

            if res > 1e-4:
                r_A, r_B, res = find_coexistence_fp(W, preferred, jc)

            max_rA_list.append(np.max(r_A))
            max_rB_list.append(np.max(r_B))
            mean_rA_list.append(np.mean(r_A))
            mean_rB_list.append(np.mean(r_B))
            jc_valid.append(jc)

            if res < 1e-4 and np.max(r_A) > 0.3 and np.max(r_B) > 0.3:
                r_A_prev, r_B_prev = r_A.copy(), r_B.copy()

            print(f"    J_cross={jc:.4f}: max(r_A)={np.max(r_A):.3f}, "
                  f"max(r_B)={np.max(r_B):.3f}, res={res:.2e}")
        except Exception as e:
            print(f"    J_cross={jc:.4f}: FAILED ({e})")
            continue

    jc_valid = np.array(jc_valid)
    max_rA = np.array(max_rA_list)
    max_rB = np.array(max_rB_list)

    fig, ax = plt.subplots(1, 1, figsize=(5.5, 4.0))

    ax.plot(jc_valid, max_rA, '-o', color=COLOR_A, lw=2.0, ms=3,
            label=r'$\max(r^A)$', zorder=3)
    ax.plot(jc_valid, max_rB, '-s', color=COLOR_B, lw=2.0, ms=3,
            label=r'$\max(r^B)$', zorder=3)

    ax.axvline(x=J_CROSS_STAR, color='#8e44ad', ls='--', lw=1.5, alpha=0.7)
    ax.text(J_CROSS_STAR - 0.003, 0.05,
            r'$J_\times^* \approx 0.349$' + '\n(pitchfork)',
            ha='right', va='bottom', fontsize=8.5, color='#8e44ad',
            fontweight='bold')

    ax.axvline(x=J_CROSS_EXIST, color='#e67e22', ls='--', lw=1.5, alpha=0.7)
    ax.text(J_CROSS_EXIST + 0.003, 0.05,
            r'$J_\times^{exist} \approx 0.358$' + '\n(existence)',
            ha='left', va='bottom', fontsize=8.5, color='#e67e22',
            fontweight='bold')

    ax.axvspan(J_CROSS_STAR, J_CROSS_EXIST, alpha=0.08, color='#e67e22',
               label='Coexistence saddle')

    ax.set_xlabel(r'Cross-inhibition strength $J_\times$', fontsize=11)
    ax.set_ylabel('Peak firing rate', fontsize=11)
    ax.set_title('Coexistence Existence Boundary', fontweight='bold', fontsize=12)
    ax.legend(fontsize=9, loc='center left')
    ax.set_xlim(0, 0.40)
    ax.set_ylim(-0.05, 1.05)

    plt.tight_layout()
    plt.show()
    print("    Fig 5 complete.")


# ── Fig 6: Pitchfork Bifurcation Diagram ─────────────────────────────────────

def fig6_pitchfork():
    """
    Fig 6: Pitchfork Bifurcation Diagram
    D = mean(r_A) - mean(r_B) vs J_cross.
    """
    print("  Generating Fig 6: Pitchfork Bifurcation Diagram...")
    W, preferred = build_within_weights(N, J_0, J_1)

    # Coexistence branch
    jc_coex = np.sort(np.unique(np.concatenate([
        np.linspace(0.01, 0.30, 12),
        np.linspace(0.30, 0.36, 30),
    ])))

    D_coex = []
    jc_coex_valid = []
    r_A_prev, r_B_prev = None, None

    for jc in jc_coex:
        try:
            if r_A_prev is not None:
                r_A, r_B, res = find_coexistence_fp(W, preferred, jc,
                                                     r_A_prev, r_B_prev)
            else:
                r_A, r_B, res = find_coexistence_fp(W, preferred, jc)

            if res < 1e-4 and np.max(r_A) > 0.2 and np.max(r_B) > 0.2:
                D_coex.append(np.mean(r_A) - np.mean(r_B))
                jc_coex_valid.append(jc)
                r_A_prev, r_B_prev = r_A.copy(), r_B.copy()
                print(f"    Coex J_cross={jc:.4f}: D={D_coex[-1]:.6f}, res={res:.2e}")
        except Exception as e:
            print(f"    Coex J_cross={jc:.4f}: exception ({e})")

    jc_coex_valid = np.array(jc_coex_valid)
    D_coex = np.array(D_coex)

    # WTA branches (only exist above J_cross*)
    jc_wta = np.sort(np.unique(np.concatenate([
        np.linspace(J_CROSS_STAR + 0.001, 0.36, 20),
        np.linspace(0.36, 0.50, 15),
    ])))

    D_wta_A, D_wta_B = [], []
    jc_wta_A_valid, jc_wta_B_valid = [], []

    for jc in jc_wta:
        try:
            # A-dominant
            r_A, r_B, res = find_wta_fp(W, preferred, jc, dominant='A')
            D = np.mean(r_A) - np.mean(r_B)
            if res < 1e-3 and D > 0.01:
                D_wta_A.append(D)
                jc_wta_A_valid.append(jc)
                print(f"    WTA-A J_cross={jc:.4f}: D={D:.4f}, res={res:.2e}")

            # B-dominant
            r_A, r_B, res = find_wta_fp(W, preferred, jc, dominant='B')
            D = np.mean(r_A) - np.mean(r_B)
            if res < 1e-3 and D < -0.01:
                D_wta_B.append(D)
                jc_wta_B_valid.append(jc)
                print(f"    WTA-B J_cross={jc:.4f}: D={D:.4f}, res={res:.2e}")
        except Exception as e:
            print(f"    WTA J_cross={jc:.4f}: exception ({e})")

    jc_wta_A_valid = np.array(jc_wta_A_valid)
    jc_wta_B_valid = np.array(jc_wta_B_valid)
    D_wta_A = np.array(D_wta_A)
    D_wta_B = np.array(D_wta_B)

    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(5.5, 4.0))

    mask_stable = jc_coex_valid < J_CROSS_STAR
    mask_unstable = jc_coex_valid >= J_CROSS_STAR

    if np.any(mask_stable):
        ax.plot(jc_coex_valid[mask_stable], D_coex[mask_stable],
                '-', color=COLOR_STABLE, lw=2.5, label='Coexistence (stable)',
                zorder=3)
    if np.any(mask_unstable):
        ax.plot(jc_coex_valid[mask_unstable], D_coex[mask_unstable],
                '--', color=COLOR_UNSTABLE, lw=2.0, label='Coexistence (saddle)',
                zorder=3)

    if len(D_wta_A) > 0:
        ax.plot(jc_wta_A_valid, D_wta_A, '-', color=COLOR_STABLE, lw=2.5,
                label='WTA (stable)', zorder=3)
    if len(D_wta_B) > 0:
        ax.plot(jc_wta_B_valid, D_wta_B, '-', color=COLOR_STABLE, lw=2.5,
                zorder=3)

    ax.plot(J_CROSS_STAR, 0, 'o', color='#8e44ad', ms=10, zorder=10,
            markeredgecolor='white', markeredgewidth=1.5)
    ax.annotate(r'$J_\times^* \approx 0.349$',
                xy=(J_CROSS_STAR, 0),
                xytext=(J_CROSS_STAR - 0.06, 0.12),
                fontsize=9.5, fontweight='bold', color='#8e44ad',
                arrowprops=dict(arrowstyle='->', color='#8e44ad', lw=1.5))

    ax.axhline(y=0, color='gray', ls=':', lw=0.8, alpha=0.5)
    ax.axvline(x=J_CROSS_EXIST, color='#e67e22', ls='--', lw=1.2, alpha=0.5)
    ax.text(J_CROSS_EXIST + 0.005, -0.22,
            r'$J_\times^{exist}$', fontsize=9, color='#e67e22',
            fontweight='bold')

    ax.set_xlabel(r'Cross-inhibition strength $J_\times$', fontsize=11)
    ax.set_ylabel(r'Dominance $D = \bar{r}^A - \bar{r}^B$', fontsize=11)
    ax.set_title('Pitchfork Bifurcation: Coexistence to WTA', fontweight='bold',
                 fontsize=12)
    ax.legend(fontsize=8.5, loc='upper left')
    ax.set_xlim(0, 0.50)

    plt.tight_layout()
    plt.show()
    print("    Fig 6 complete.")


# ── Fig 7: Eigenvector Comparison ────────────────────────────────────────────

def fig7_eigenvectors():
    """
    Fig 7: Eigenvector Comparison
    Left: A Goldstone eigenvector (sine-shaped, one network only)
    Right: The critical dominance eigenvector (spatially uniform/DC, anti-phase)
    """
    print("  Generating Fig 7: Eigenvector Comparison...")
    W, preferred = build_within_weights(N, J_0, J_1)

    jc_target = 0.34

    r_A, r_B, res = find_coexistence_fp(W, preferred, jc_target)
    print(f"    J_cross={jc_target}: res={res:.2e}, "
          f"max(r_A)={np.max(r_A):.3f}, max(r_B)={np.max(r_B):.3f}")

    if res > 1e-4 or np.max(r_A) < 0.3 or np.max(r_B) < 0.3:
        print("    WARNING: Fixed point may not be reliable, trying J_cross=0.33...")
        jc_target = 0.33
        r_A, r_B, res = find_coexistence_fp(W, preferred, jc_target)
        print(f"    J_cross={jc_target}: res={res:.2e}")

    cue_0 = np.zeros(N)
    x = np.concatenate([r_A, r_B])
    J = jacobian_analytical(x, W, cue_0, jc_target)
    evals, evecs = np.linalg.eig(J)

    idx = np.argsort(-evals.real)
    evals = evals[idx]
    evecs = evecs[:, idx]

    classified = classify_eigenvalues(evals, evecs, preferred)

    goldstone_modes = [c for c in classified if c['is_goldstone']]
    genuine_modes = [c for c in classified if not c['is_goldstone']]
    genuine_modes.sort(key=lambda c: -c['eigenvalue'])

    if not goldstone_modes:
        print("    WARNING: No Goldstone modes found!")
        return

    goldstone_mode = max(goldstone_modes,
                         key=lambda c: max(c['projections']['gold_A'],
                                           c['projections']['gold_B']))
    critical_mode = genuine_modes[0]

    print(f"    Goldstone: lambda={goldstone_mode['eigenvalue']:.2e}, "
          f"char={goldstone_mode['character']}")
    print(f"    Critical:  lambda={critical_mode['eigenvalue']:.6f}, "
          f"char={critical_mode['character']}")

    fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
    neuron_idx = np.arange(N)

    # LEFT: Goldstone eigenvector
    ax = axes[0]
    vec = goldstone_mode['eigenvector']
    v_A = vec[:N]
    v_B = vec[N:]
    norm = max(np.max(np.abs(v_A)), np.max(np.abs(v_B)))
    if norm > 0:
        v_A = v_A / norm
        v_B = v_B / norm

    ax.plot(neuron_idx, v_A, '-', color=COLOR_A, lw=2.0, label='Network A')
    ax.plot(neuron_idx, v_B, '-', color=COLOR_B, lw=2.0, label='Network B')
    ax.axhline(y=0, color='gray', ls=':', lw=0.8, alpha=0.5)
    ax.set_xlabel('Neuron index $i$', fontsize=10)
    ax.set_ylabel('Eigenvector component', fontsize=10)
    ax.set_title(
        f'Goldstone Mode\n'
        r'$\lambda \approx$' + f' {goldstone_mode["eigenvalue"]:.1e}  '
        f'(rotation)',
        fontsize=10, fontweight='bold')
    ax.legend(fontsize=8, loc='upper right')
    ax.set_xlim(0, N - 1)

    # RIGHT: Critical dominance eigenvector
    ax = axes[1]
    vec = critical_mode['eigenvector']
    v_A = vec[:N]
    v_B = vec[N:]
    norm = max(np.max(np.abs(v_A)), np.max(np.abs(v_B)))
    if norm > 0:
        v_A = v_A / norm
        v_B = v_B / norm

    ax.plot(neuron_idx, v_A, '-', color=COLOR_A, lw=2.0, label='Network A')
    ax.plot(neuron_idx, v_B, '-', color=COLOR_B, lw=2.0, label='Network B')
    ax.axhline(y=0, color='gray', ls=':', lw=0.8, alpha=0.5)
    ax.set_xlabel('Neuron index $i$', fontsize=10)
    ax.set_ylabel('Eigenvector component', fontsize=10)
    ax.set_title(
        f'Critical Dominance Mode\n'
        r'$\lambda_{dom} = $' + f'{critical_mode["eigenvalue"]:.4f}  '
        f'(DC/uniform)',
        fontsize=10, fontweight='bold')
    ax.legend(fontsize=8, loc='upper right')
    ax.set_xlim(0, N - 1)

    ax.text(0.5, 0.02, 'Anti-phase: A up, B down\n(spatially uniform)',
            transform=ax.transAxes, ha='center', va='bottom',
            fontsize=8, color='#8e44ad', style='italic',
            bbox=dict(boxstyle='round,pad=0.3', facecolor=COLOR_BG,
                      edgecolor='#8e44ad', alpha=0.8))

    fig.suptitle(r'Eigenvector Comparison at $J_\times = $' + f'{jc_target}',
                 fontsize=12, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()
    print("    Fig 7 complete.")


# ── Generate all three figures ───────────────────────────────────────────────
print("=" * 70)
print("GENERATING PAPER FIGURES 5, 6, 7")
print("=" * 70)

print("\n[1/3] Fig 5: Coexistence Existence Boundary")
fig5_existence()

print("\n[2/3] Fig 6: Pitchfork Bifurcation Diagram")
fig6_pitchfork()

print("\n[3/3] Fig 7: Eigenvector Comparison")
fig7_eigenvectors()

print("\nAll paper figures generated.")

## Phase Diagram (Fig 8)

Publication-quality phase diagram in $(J_\times, \mathrm{drive\;strength})$ space showing
the swap error landscape from 128,000 stochastic simulation trials.

**Note:** This cell requires pre-computed data files (`agent_1.json` through `agent_4.json`)
from the 4-agent parameter sweep. If the data files are not available, the cell will print
instructions for regenerating them using the stochastic simulation. The simulation is
computationally expensive (128,000 trials across a 2D parameter grid) and is best run
as a batch job rather than interactively in this notebook.

In [None]:
%%time
# ============================================================================
# Cell 12: Phase Diagram (Fig 8) — Requires Pre-Computed Data
# ============================================================================
# Source: generate_fig8_phase_diagram.py
#
# This cell expects agent_1.json through agent_4.json in a local 'results/'
# directory. If running on Colab, upload these files first or skip this cell.

import json
import glob
import os

# ── Data Loading ─────────────────────────────────────────────────────────────

def load_all_results(results_dir='results'):
    """Load and merge all agent results into arrays."""
    all_j, all_d, all_swap, all_correct = [], [], [], []
    for path in sorted(glob.glob(os.path.join(results_dir, "agent_*.json"))):
        with open(path) as f:
            data = json.load(f)
        for r in data["results"]:
            if r["status"] == "ok":
                all_j.append(r["J_cross"])
                all_d.append(r["input_gain"])
                all_swap.append(r["swap_rate"])
                all_correct.append(r["correct_rate"])
    return (np.array(all_j), np.array(all_d),
            np.array(all_swap), np.array(all_correct))


# ── Colormap ─────────────────────────────────────────────────────────────────

def make_colormap():
    """Custom diverging colormap tuned for swap error rates."""
    anchors = [
        (0.00, (0.07, 0.16, 0.28)),
        (0.04, (0.13, 0.32, 0.50)),
        (0.10, (0.50, 0.65, 0.72)),
        (0.14, (0.85, 0.84, 0.78)),
        (0.24, (0.92, 0.78, 0.40)),
        (0.38, (0.85, 0.48, 0.20)),
        (0.55, (0.70, 0.22, 0.15)),
        (1.00, (0.32, 0.07, 0.05)),
    ]
    positions = [a[0] for a in anchors]
    colors_rgb = [a[1] for a in anchors]
    cdict = {'red': [], 'green': [], 'blue': []}
    for pos, (r, g, b) in zip(positions, colors_rgb):
        cdict['red'].append((pos, r, r))
        cdict['green'].append((pos, g, g))
        cdict['blue'].append((pos, b, b))
    return LinearSegmentedColormap('phase_swap', cdict, N=512)


# ── Interpolation ────────────────────────────────────────────────────────────

def interpolate_swap(j_vals, d_vals, swap_vals, n_grid=300):
    """RBF interpolation of swap rate onto a regular grid in LOG(J_cross) space."""
    log_j = np.log10(j_vals)
    log_j_min, log_j_max = np.log10(0.18), np.log10(8.5)
    d_min, d_max = 0.8, 8.3

    j_norm = (log_j - log_j_min) / (log_j_max - log_j_min)
    d_norm = (d_vals - d_min) / (d_max - d_min)

    log_j_grid = np.linspace(log_j_min, log_j_max, n_grid)
    d_grid = np.linspace(d_min, d_max, n_grid)
    LOG_J_mesh, D_mesh = np.meshgrid(log_j_grid, d_grid)
    J_mesh = 10**LOG_J_mesh

    j_mesh_norm = (LOG_J_mesh - log_j_min) / (log_j_max - log_j_min)
    d_mesh_norm = (D_mesh - d_min) / (d_max - d_min)

    coords = np.column_stack([j_norm, d_norm])
    query = np.column_stack([j_mesh_norm.ravel(), d_mesh_norm.ravel()])
    rbf = RBFInterpolator(coords, swap_vals, kernel='thin_plate_spline',
                          smoothing=0.8)
    swap_grid = rbf(query).reshape(J_mesh.shape)
    swap_grid = np.clip(swap_grid, 0, 55)

    return J_mesh, D_mesh, swap_grid


# ── Figure Generation ────────────────────────────────────────────────────────

def generate_phase_diagram(j_vals, d_vals, swap_vals):
    """Generate the publication-quality phase diagram."""
    n_points = len(j_vals)

    print(f"Loaded {n_points} data points")
    print(f"  J_cross range: [{j_vals.min():.3f}, {j_vals.max():.3f}]")
    print(f"  Drive range:   [{d_vals.min():.3f}, {d_vals.max():.3f}]")
    print(f"  Swap range:    [{swap_vals.min():.1f}%, {swap_vals.max():.1f}%]")

    # Valley stats
    vmask = (j_vals > 1.0) & (j_vals < 2.0) & (d_vals < 3.5)
    if vmask.any():
        vs = swap_vals[vmask]
        vj = j_vals[vmask]
        vd = d_vals[vmask]
        mi = np.argmin(vs)
        print(f"  Valley minimum: {vs[mi]:.1f}% at J_cross={vj[mi]:.3f}, drive={vd[mi]:.2f}")

    # Interpolation
    J_mesh, D_mesh, swap_grid = interpolate_swap(j_vals, d_vals, swap_vals)

    # Use phase-diagram-specific style overrides
    with plt.rc_context({
        'axes.spines.top': True,
        'axes.spines.right': True,
    }):
        fig, ax = plt.subplots(figsize=(7, 5))

        cmap = make_colormap()
        vmax = 52

        im = ax.pcolormesh(
            J_mesh, D_mesh, swap_grid,
            cmap=cmap, norm=Normalize(vmin=0, vmax=vmax),
            shading='gouraud', rasterized=True,
        )

        ax.set_xscale('log')

        # Contour lines
        contour_levels = [5, 10, 20, 30, 40]
        cs = ax.contour(
            J_mesh, D_mesh, swap_grid,
            levels=contour_levels,
            colors='white',
            linewidths=[1.6, 0.6, 0.6, 0.6, 0.6],
            linestyles=['solid', 'dashed', 'dashed', 'dashed', 'dashed'],
            alpha=0.6,
        )
        fmt_dict = {5: '5%', 10: '10%', 20: '20%', 30: '30%', 40: '40%'}
        ax.clabel(cs, levels=[5, 20, 40], fontsize=7.5, fmt=fmt_dict,
                  colors='white', inline=True, inline_spacing=6)

        # Data points (subtle)
        ax.scatter(j_vals, d_vals, s=8, c='white', alpha=0.25,
                   edgecolors='none', zorder=5, marker='o')

        # Colorbar
        cbar = fig.colorbar(im, ax=ax, shrink=0.85, aspect=28, pad=0.015)
        cbar.set_label('Swap error rate (%)', fontsize=11, labelpad=10)
        cbar.ax.tick_params(labelsize=9)
        cbar.set_ticks([0, 5, 10, 20, 30, 40, 50])

        # Critical vertical lines
        ax.axvline(x=0.3485, color='#00e5ff', ls='--', lw=1.5, alpha=0.9, zorder=10)
        ax.axvline(x=0.358, color='#76ff03', ls=':', lw=1.4, alpha=0.85, zorder=10)

        # Text styling
        stroke_dk = [pe.withStroke(linewidth=3.0, foreground='#111111')]
        stroke_md = [pe.withStroke(linewidth=2.2, foreground='#1a1a1a')]

        # Region labels
        ax.text(0.215, 4.5, 'Coexistence',
                fontsize=11.5, color='#b0d8e8', fontweight='bold', fontstyle='italic',
                ha='center', va='center', rotation=90,
                path_effects=stroke_dk, zorder=20)

        ax.text(0.36, 6.0, 'WTA\nonset',
                fontsize=9, color='#ffe0b2', fontweight='bold', fontstyle='italic',
                ha='center', va='center',
                path_effects=stroke_dk, zorder=20)

        ax.text(1.35, 2.2, 'The Valley',
                fontsize=10.5, color='#b0d8e8', fontweight='bold', fontstyle='italic',
                ha='center', va='center',
                path_effects=stroke_dk, zorder=20)

        ax.text(4.5, 5.0, 'Overpowering',
                fontsize=11, color='#ffcdd2', fontweight='bold', fontstyle='italic',
                ha='center', va='center',
                path_effects=stroke_dk, zorder=20)

        # Alleman et al. annotations
        ax.annotate(
            'Selection failure\n(Alleman et al., 2024)',
            xy=(0.45, 3.0), xytext=(1.2, 7.2),
            fontsize=8.5, color='#e0e0e0', fontstyle='italic', ha='center',
            arrowprops=dict(arrowstyle='->', color='#cccccc', lw=1.0,
                            connectionstyle='arc3,rad=0.15'),
            path_effects=stroke_md, zorder=20,
        )

        ax.annotate(
            'Representation\nfailure',
            xy=(5.5, 5.0), xytext=(5.8, 7.0),
            fontsize=8.5, color='#e0e0e0', fontstyle='italic', ha='center',
            arrowprops=dict(arrowstyle='->', color='#cccccc', lw=1.0,
                            connectionstyle='arc3,rad=-0.1'),
            path_effects=stroke_md, zorder=20,
        )

        # Legend
        legend_elements = [
            Line2D([0], [0], color='#00e5ff', ls='--', lw=1.5,
                   label=r'$J_{\mathrm{cross}}^{\,*} \approx 0.349$  (pitchfork)'),
            Line2D([0], [0], color='#76ff03', ls=':', lw=1.4,
                   label=r'$J_{\mathrm{cross}}^{\,\mathrm{exist}} \approx 0.358$  (existence limit)'),
        ]
        legend = ax.legend(
            handles=legend_elements,
            loc='lower right', fontsize=8,
            framealpha=0.92, facecolor=COLOR_BG,
            edgecolor='#888888', borderpad=0.5,
            handlelength=2.5,
        )
        legend.get_frame().set_linewidth(0.5)

        # Delta-J bracket
        y_bkt = 1.2
        ax.plot([0.3485, 0.3485], [y_bkt - 0.15, y_bkt + 0.15],
                color='#ffeb3b', lw=1.3, zorder=25)
        ax.plot([0.358, 0.358], [y_bkt - 0.15, y_bkt + 0.15],
                color='#ffeb3b', lw=1.3, zorder=25)
        ax.plot([0.3485, 0.358], [y_bkt, y_bkt],
                color='#ffeb3b', lw=1.3, zorder=25)
        ax.text(
            0.353, y_bkt - 0.25, r'$\Delta J \!\approx\! 0.01$',
            fontsize=7, color='#ffeb3b', ha='center', va='top',
            path_effects=stroke_dk, zorder=25,
        )

        # Axes
        ax.set_xlabel(r'Cross-inhibition strength  $J_{\mathrm{cross}}$', fontsize=12)
        ax.set_ylabel('Drive strength  (input gain)', fontsize=12)
        ax.set_xlim(0.18, 8.5)
        ax.set_ylim(0.85, 8.2)

        ax.set_xticks([0.2, 0.3, 0.5, 1, 2, 3, 5, 8])
        ax.get_xaxis().set_major_formatter(ticker.ScalarFormatter())
        ax.get_xaxis().set_minor_formatter(ticker.NullFormatter())
        ax.xaxis.set_minor_locator(ticker.NullLocator())
        ax.yaxis.set_minor_locator(ticker.AutoMinorLocator(2))

        # Title
        ax.set_title(
            'Phase diagram of swap errors in coupled ring attractors',
            fontsize=12.5, fontweight='bold', pad=18,
        )
        ax.text(
            0.5, 1.005,
            f'$N\\!=\\!48$ neurons/ring  |  500 trials/point  |  '
            f'{n_points} parameter points  |  128,000 total trials',
            fontsize=7.5, color='#777777', ha='center', va='bottom',
            transform=ax.transAxes,
        )

        plt.tight_layout()
        plt.show()
    print("Phase diagram generated.")


# ── Try to load data and generate ────────────────────────────────────────────

try:
    # Check multiple possible locations for the data files
    results_dir = None
    for candidate in ['results', 'data', '.']:
        if os.path.exists(os.path.join(candidate, 'agent_1.json')):
            results_dir = candidate
            break

    if results_dir is None:
        raise FileNotFoundError("No agent_*.json files found")

    j_vals, d_vals, swap_vals, correct_vals = load_all_results(results_dir)
    if len(j_vals) == 0:
        raise ValueError("No valid data points loaded")

    generate_phase_diagram(j_vals, d_vals, swap_vals)

except (FileNotFoundError, ValueError) as e:
    print(f"Phase diagram data not found: {e}")
    print()
    print("To generate the phase diagram, you need the pre-computed simulation data.")
    print("The data consists of 4 JSON files from parallel stochastic simulations:")
    print()
    print("  agent_1.json: J=[0.2, 0.5]  x drive=[1.0, 3.5]  (8x8 grid, 500 trials each)")
    print("  agent_2.json: J=[0.2, 0.5]  x drive=[3.5, 8.0]  (8x8 grid, 500 trials each)")
    print("  agent_3.json: J=[0.5, 3.0]  x drive=[1.0, 3.5]  (8x8 grid, 500 trials each)")
    print("  agent_4.json: J=[3.0, 8.0]  x drive=[3.5, 8.0]  (8x8 grid, 500 trials each)")
    print()
    print("Total: 256 parameter points x 500 trials = 128,000 stochastic simulations")
    print()
    print("If you have these files, upload them to a 'results/' directory in Colab")
    print("and re-run this cell.")
    print()
    print("To regenerate from scratch, you would need to run the full stochastic")
    print("simulation sweep, which takes approximately 2-4 hours on a single CPU.")