# Training the Regulatory Network

Train a neural network to act as the regulatory function $f(\bar{s})$ using **evolutionary optimization**. Goal: Maximize utility $U = S_{pat} - S_{rep}$ to discover patterns with high information content and low variability.

**Expected outcome**: Network converges to tanh-like lateral inhibition → alternating on-off patterns.

## Setup & Imports

In [None]:
import sys
from importlib import reload
from pathlib import Path

sys.path.insert(0, '../src')

import jax
import jax.numpy as jnp
from jax import random, jit
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from typing import Tuple, Dict

# Import our modules
from neural_network import RegulatoryNetwork, init_params, get_regulatory_function
from dynamics import run_multiple_replicates, apply_threshold
from utility_function import compute_soft_utility, compute_hard_utility

# Plotting
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 5)
plt.rcParams['figure.dpi'] = 100

# Create figures directory
Path('../figures').mkdir(exist_ok=True)

print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"JAX devices: {jax.devices()}")
print("✓ Imports successful")

## Simulation & Training Parameters

In [None]:
# System parameters
N_CELLS = 7                # Number of cells in 1D system
N_REPLICATES = 50          # Replicates per fitness evaluation (more = more reliable)
DT = 0.01                  # Time step
T = 20.0                   # Simulation time
N_STEPS = int(T / DT)      # Number of integration steps
NOISE_STRENGTH = 0.1       # Stochastic noise level

# Neural network architecture
HIDDEN_DIMS = (8, 8)       # Two hidden layers with 8 neurons each

# Evolution parameters
POPULATION_SIZE = 20       # Number of individuals per generation
N_GENERATIONS = 100        # Total training iterations
MUTATION_STD = 0.1         # Standard deviation for parameter mutations
ELITE_FRACTION = 0.2       # Keep top 20% as parents

# Utility function parameters
SOFT_BANDWIDTH = 0.1       # KDE bandwidth for soft utility

# Random seed
SEED = 42
key = random.PRNGKey(SEED)

print(f"System: {N_CELLS} cells, {N_REPLICATES} replicates")
print(f"Dynamics: T={T}s, dt={DT}, noise={NOISE_STRENGTH}")
print(f"Network: hidden_dims={HIDDEN_DIMS}")
print(f"Evolution: pop={POPULATION_SIZE}, gen={N_GENERATIONS}, σ_mut={MUTATION_STD}")

## Evolution Strategy Implementation

Simple **(μ+λ)-ES** with Gaussian mutations:
1. Evaluate fitness (utility) for each individual
2. Select elite parents (top performers)
3. Generate offspring via mutation
4. Repeat

In [None]:
def flatten_params(params: Dict) -> jnp.ndarray:
    """Flatten nested parameter dict to 1D array for evolution."""
    flat = []
    for layer_params in jax.tree_util.tree_leaves(params):
        flat.append(layer_params.flatten())
    return jnp.concatenate(flat)

def unflatten_params(flat: jnp.ndarray, template: Dict) -> Dict:
    """Reconstruct parameter dict from flat array using template structure."""
    # Get shapes from template
    shapes = [x.shape for x in jax.tree_util.tree_leaves(template)]
    sizes = [np.prod(s) for s in shapes]
    
    # Split flat array
    idx = 0
    arrays = []
    for size, shape in zip(sizes, shapes):
        arrays.append(flat[idx:idx+size].reshape(shape))
        idx += size
    
    # Reconstruct tree structure
    return jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(template), arrays)

# Test parameter flattening
model = RegulatoryNetwork(hidden_dims=HIDDEN_DIMS)
key, subkey = random.split(key)
test_params = init_params(model, subkey, (1,))
flat = flatten_params(test_params)
reconstructed = unflatten_params(flat, test_params)

print(f"✓ Parameter flattening working")
print(f"  Total parameters: {len(flat)}")
print(f"  Structure: {jax.tree_util.tree_map(lambda x: x.shape, test_params)}")

In [None]:
@jit
def compute_fitness(params: Dict, model: RegulatoryNetwork, eval_key: jax.random.PRNGKey) -> float:
    """Evaluate utility (fitness) for given network parameters.
    
    Higher utility = better fitness.
    Uses soft utility for differentiable patterns (though ES doesn't need gradients).
    """
    # Get regulatory function from params
    f = get_regulatory_function(model, params)
    
    # Run simulations
    final_states = run_multiple_replicates(
        f=f,
        n_cells=N_CELLS,
        n_replicates=N_REPLICATES,
        n_steps=N_STEPS,
        dt=DT,
        noise_strength=NOISE_STRENGTH,
        key=eval_key
    )
    
    # Apply threshold to get patterns (using STE for technical consistency)
    patterns = apply_threshold(final_states)
    
    # Compute soft utility (allows gradients to flow, even though ES doesn't use them)
    utility, s_pat, s_rep = compute_soft_utility(patterns, bandwidth=SOFT_BANDWIDTH)
    
    return utility

# Test fitness computation
key, subkey = random.split(key)
test_fitness = compute_fitness(test_params, model, subkey)
print(f"✓ Fitness function working")
print(f"  Test fitness (random params): {test_fitness:.4f}")

In [None]:
def evaluate_population(population: jnp.ndarray, model: RegulatoryNetwork, 
                       template: Dict, eval_key: jax.random.PRNGKey) -> jnp.ndarray:
    """Evaluate fitness for entire population."""
    fitnesses = []
    keys = random.split(eval_key, len(population))
    
    for individual, key_i in zip(population, keys):
        params = unflatten_params(individual, template)
        fitness = compute_fitness(params, model, key_i)
        fitnesses.append(fitness)
    
    return jnp.array(fitnesses)

def select_parents(population: jnp.ndarray, fitnesses: jnp.ndarray, 
                  n_parents: int) -> jnp.ndarray:
    """Select top performers as parents."""
    parent_indices = jnp.argsort(fitnesses)[-n_parents:]  # Top n_parents
    return population[parent_indices]

def mutate_population(parents: jnp.ndarray, population_size: int, 
                     mutation_std: float, mut_key: jax.random.PRNGKey) -> jnp.ndarray:
    """Generate offspring via Gaussian mutation of parents."""
    n_parents = len(parents)
    n_children = population_size - n_parents
    
    # Random parent selection for each child
    parent_indices = random.randint(mut_key, (n_children,), 0, n_parents)
    
    # Generate mutations
    noise_key = random.split(mut_key)[1]
    noise = random.normal(noise_key, (n_children, parents.shape[1])) * mutation_std
    
    children = parents[parent_indices] + noise
    
    # Combine parents and children
    return jnp.concatenate([parents, children], axis=0)

print("✓ Evolution functions defined")

## Training Loop

Evolve population over multiple generations, tracking best fitness.

In [None]:
# Initialize population (random parameters)
print("Initializing population...")
param_template = init_params(model, random.PRNGKey(0), (1,))
n_params = len(flatten_params(param_template))
n_parents = max(2, int(POPULATION_SIZE * ELITE_FRACTION))

key, subkey = random.split(key)
population_keys = random.split(subkey, POPULATION_SIZE)
population = jnp.array([flatten_params(init_params(model, k, (1,))) 
                       for k in population_keys])

print(f"Population: {POPULATION_SIZE} individuals, {n_params} params each")
print(f"Parents per generation: {n_parents}\n")

# Training history
best_fitness_history = []
mean_fitness_history = []
best_params = None
best_fitness = -jnp.inf

print("Starting training...\n")
print("Gen | Best Fit | Mean Fit | Std Fit")
print("----+----------+----------+---------")

for generation in range(N_GENERATIONS):
    # Evaluate fitness
    key, eval_key, mut_key = random.split(key, 3)
    fitnesses = evaluate_population(population, model, param_template, eval_key)
    
    # Track best
    gen_best_fitness = float(jnp.max(fitnesses))
    gen_mean_fitness = float(jnp.mean(fitnesses))
    gen_std_fitness = float(jnp.std(fitnesses))
    
    best_fitness_history.append(gen_best_fitness)
    mean_fitness_history.append(gen_mean_fitness)
    
    if gen_best_fitness > best_fitness:
        best_fitness = gen_best_fitness
        best_idx = jnp.argmax(fitnesses)
        best_params = population[best_idx]
    
    # Print progress
    if generation % 10 == 0 or generation == N_GENERATIONS - 1:
        print(f"{generation:3d} | {gen_best_fitness:+.4f} | {gen_mean_fitness:+.4f} | {gen_std_fitness:.4f}")
    
    # Selection and mutation
    if generation < N_GENERATIONS - 1:  # Don't mutate after last generation
        parents = select_parents(population, fitnesses, n_parents)
        population = mutate_population(parents, POPULATION_SIZE, MUTATION_STD, mut_key)

print(f"\n✓ Training complete!")
print(f"  Best fitness: {best_fitness:.4f}")

## Training Results

Visualize learning progress and best network function.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training curve
ax = axes[0]
ax.plot(best_fitness_history, label='Best fitness', linewidth=2, color='darkblue')
ax.plot(mean_fitness_history, label='Mean fitness', linewidth=1.5, 
        color='steelblue', alpha=0.7, linestyle='--')
ax.set_xlabel('Generation')
ax.set_ylabel('Utility (fitness)')
ax.set_title('Evolution Progress')
ax.legend()
ax.grid(True, alpha=0.3)

# Best network function vs tanh
ax = axes[1]
s_bar_range = jnp.linspace(0, 1, 200)

# Trained network
best_params_dict = unflatten_params(best_params, param_template)
f_trained = get_regulatory_function(model, best_params_dict)
f_values = f_trained(s_bar_range)

# Target: tanh-like lateral inhibition
tanh_strength = 5.0
f_target = jnp.tanh(-tanh_strength * (s_bar_range - 0.5))

ax.plot(s_bar_range, f_values, label='Trained NN', linewidth=2.5, color='darkred')
ax.plot(s_bar_range, f_target, label=f'Target (tanh, α={tanh_strength})', 
        linewidth=2, color='green', linestyle='--', alpha=0.7)
ax.axhline(0, color='k', linewidth=0.5, alpha=0.3)
ax.axvline(0.5, color='k', linewidth=0.5, alpha=0.3)
ax.set_xlabel('Neighbor average $\\bar{s}$')
ax.set_ylabel('$ds/dt = f(\\bar{s})$')
ax.set_title('Learned Regulatory Function')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim(0, 1)

plt.tight_layout()
plt.savefig('../figures/training_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"✓ Results plotted")

## Generated Patterns

Test trained network: simulate multiple replicates and examine resulting patterns.

In [None]:
# Generate patterns with trained network
key, test_key = random.split(key)
test_replicates = 30

final_states = run_multiple_replicates(
    f=f_trained,
    n_cells=N_CELLS,
    n_replicates=test_replicates,
    n_steps=N_STEPS,
    dt=DT,
    noise_strength=NOISE_STRENGTH,
    key=test_key
)

patterns = apply_threshold(final_states)

# Compute final utility
u_soft, s_pat_soft, s_rep_soft = compute_soft_utility(patterns, bandwidth=SOFT_BANDWIDTH)
u_hard, s_pat_hard, s_rep_hard = compute_hard_utility(patterns)

print("Trained Network Performance:")
print(f"  Soft utility: U = {u_soft:.4f} (S_pat={s_pat_soft:.4f}, S_rep={s_rep_soft:.4f})")
print(f"  Hard utility: U = {u_hard:.4f} (S_pat={s_pat_hard:.4f}, S_rep={s_rep_hard:.4f})")
print(f"\nPattern statistics:")
print(f"  Mean fate 1 ratio: {jnp.mean(patterns):.3f}")
print(f"  Pattern diversity: {len(jnp.unique(patterns, axis=0))} unique / {test_replicates} total")

In [None]:
# Visualize patterns
fig, ax = plt.subplots(figsize=(10, 8))

# Plot as heatmap
im = ax.imshow(patterns, cmap='RdBu_r', aspect='auto', interpolation='nearest')
ax.set_xlabel('Cell index')
ax.set_ylabel('Replicate')
ax.set_title(f'Trained Network Patterns (U={u_hard:.3f})')

# Colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Cell fate')
cbar.set_ticks([0, 1])

plt.tight_layout()
plt.savefig('../figures/trained_patterns.png', dpi=150, bbox_inches='tight')
plt.show()

# Print first few patterns
print("\nFirst 10 patterns:")
for i in range(min(10, test_replicates)):
    pattern_str = ''.join(str(int(x)) for x in patterns[i])
    print(f"  {i:2d}: {pattern_str}")

## Comparison: Random vs Trained

Compare trained network against random initialization.

In [None]:
# Random network for comparison
key, random_key = random.split(key)
random_params = init_params(model, random_key, (1,))
f_random = get_regulatory_function(model, random_params)

# Generate patterns with random network
key, test_key = random.split(key)
random_states = run_multiple_replicates(
    f=f_random,
    n_cells=N_CELLS,
    n_replicates=test_replicates,
    n_steps=N_STEPS,
    dt=DT,
    noise_strength=NOISE_STRENGTH,
    key=test_key
)
random_patterns = apply_threshold(random_states)
u_random, s_pat_random, s_rep_random = compute_hard_utility(random_patterns)

# Comparison plot
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Random function
ax = axes[0]
f_random_vals = f_random(s_bar_range)
ax.plot(s_bar_range, f_random_vals, label='Random NN', linewidth=2, color='gray')
ax.plot(s_bar_range, f_target, label='Target (tanh)', linewidth=1.5, 
        color='green', linestyle='--', alpha=0.7)
ax.axhline(0, color='k', linewidth=0.5, alpha=0.3)
ax.set_xlabel('$\\bar{s}$')
ax.set_ylabel('$f(\\bar{s})$')
ax.set_title('Random Network')
ax.legend()
ax.grid(True, alpha=0.3)

# Trained function
ax = axes[1]
ax.plot(s_bar_range, f_values, label='Trained NN', linewidth=2, color='darkred')
ax.plot(s_bar_range, f_target, label='Target (tanh)', linewidth=1.5, 
        color='green', linestyle='--', alpha=0.7)
ax.axhline(0, color='k', linewidth=0.5, alpha=0.3)
ax.set_xlabel('$\\bar{s}$')
ax.set_ylabel('$f(\\bar{s})$')
ax.set_title('Trained Network')
ax.legend()
ax.grid(True, alpha=0.3)

# Utility comparison
ax = axes[2]
categories = ['Random', 'Trained']
utilities = [u_random, u_hard]
colors = ['gray', 'darkred']

bars = ax.bar(categories, utilities, color=colors, alpha=0.7, edgecolor='black', linewidth=1.5)
ax.set_ylabel('Utility')
ax.set_title('Performance Comparison')
ax.grid(True, alpha=0.3, axis='y')

# Add values on bars
for bar, util in zip(bars, utilities):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{util:.3f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('../figures/random_vs_trained.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPerformance gain: {(u_hard - u_random):.4f} ({100*(u_hard-u_random)/abs(u_random):.1f}%)")

## Save Trained Parameters

Save best network for later use.

In [None]:
import pickle

save_dict = {
    'params': best_params_dict,
    'fitness': best_fitness,
    'hidden_dims': HIDDEN_DIMS,
    'training_config': {
        'n_cells': N_CELLS,
        'n_replicates': N_REPLICATES,
        'population_size': POPULATION_SIZE,
        'n_generations': N_GENERATIONS,
        'mutation_std': MUTATION_STD
    },
    'fitness_history': {
        'best': best_fitness_history,
        'mean': mean_fitness_history
    }
}

save_path = '../figures/trained_network.pkl'
with open(save_path, 'wb') as f:
    pickle.dump(save_dict, f)

print(f"✓ Trained parameters saved to {save_path}")

## Summary

**Training approach**: Simple evolution strategy (μ+λ)-ES with Gaussian mutations

**Key findings**:
- Trained network should converge toward tanh-like function
- Higher utility indicates better pattern formation (more information, less noise)
- Expect emergence of alternating on-off patterns (lateral inhibition)

**Next steps**: Notebook 05 for detailed analysis and visualization