In [1]:
import sys
import os
import time

import jax # type: ignore
import jax.numpy as jnp # type: ignore
from jax import random, jit # type: ignore
import jax.random as jar # type: ignore

# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')
from typing import Callable, Dict, Tuple, Optional
import jax # type: ignore
import jax.numpy as jnp # type: ignore
import jax.random as jar # type: ignore
from jax.random import PRNGKey # type: ignore
from jax import Array # type: ignore
import flax.struct # type: ignore
import functools

from malthusjax.core.genome import AbstractGenome
from malthusjax.core.fitness import AbstractFitnessEvaluator
from malthusjax.operators.selection import AbstractSelectionOperator
from malthusjax.operators.crossover import AbstractCrossover
from malthusjax.operators.mutation import AbstractMutation
from malthusjax.engine.base import AbstractEngine, AbstractState
from malthusjax.engine.ResearchEngine import ResearchEngine, ResearchState, FullIntermediateState


In [2]:
# Import MalthusJAX components
from malthusjax.core.genome.binary import BinaryGenome, BinaryGenomeConfig
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.operators.selection.tournament import TournamentSelection
from malthusjax.operators.crossover.binary import UniformCrossover
from malthusjax.operators.mutation.binary import BitFlipMutation

print("‚úÖ MalthusJAX components imported successfully")

# Setup GA configuration
pop_size = 100
elitism = 5
num_offspring = pop_size - elitism

# Define components
genome_config = BinaryGenome(array_shape=(100,), p=0.2)
fitness_evaluator = BinarySumFitnessEvaluator()
selector = TournamentSelection(number_of_choices=num_offspring, tournament_size=3)
crossover = UniformCrossover(crossover_rate=0.8, n_outputs=1)
mutator = BitFlipMutation(mutation_rate=0.05)

components = {
    'genome_representation': genome_config,
    'fitness_evaluator': fitness_evaluator,
    'selection_operator': selector,
    'crossover_operator': crossover,
    'mutation_operator': mutator,
    'elitism': elitism
}

print(f"üß¨ GA Configuration:")
print(f"  Problem: Maximize sum of {genome_config.array_shape[0]} bits")
print(f"  Population size: {pop_size}")
print(f"  Elitism: {elitism}")
print(f"  Selection: Tournament (size={selector.tournament_size})")
print(f"  Crossover: Uniform (rate={crossover.crossover_rate})")
print(f"  Mutation: Bit flip (rate={mutator.mutation_rate})")

‚úÖ MalthusJAX components imported successfully
üß¨ GA Configuration:
  Problem: Maximize sum of 100 bits
  Population size: 100
  Elitism: 5
  Selection: Tournament (size=3)
  Crossover: Uniform (rate=0.8)
  Mutation: Bit flip (rate=0.05)


In [3]:
research_engine = ResearchEngine( **components)

research_final_state, research_intermediates = research_engine.run(
        key=jar.PRNGKey(10),
        num_generations=200,
        pop_size=pop_size
    )

print(type(research_intermediates)) # <class 'malthusjax.engine.ResearchEngine.FullIntermediateState'>
print(type(research_final_state)) # ResearchState

<class 'malthusjax.engine.ResearchEngine.FullIntermediateState'>
<class 'malthusjax.engine.ResearchEngine.ResearchState'>


In [4]:
def _research_step_fn(state: ResearchState,
                        _: None, # unused scan input
                        fitness_fn: Callable,
                        selection_fn: Callable,
                        crossover_fn: Callable,
                        mutation_fn: Callable,
                        pop_size: int,
                        elitism: int) -> Tuple[ResearchState, FullIntermediateState]:
    """
    JIT-compatible research GA step with complete pipeline capture.
    
    This is the productionized version of ga_step_fn_full_pipeline
    from the L3_callbacks_Scratchpad.ipynb notebook.
    """
    
    key, sel_key1, sel_key2, cross_key, mut_key = jar.split(state.key, 5)
    
    # --- Phase 0: Elitism (Static Shape) ---
    sorted_indices = jnp.argsort(-state.fitness)  
    elite_indices = sorted_indices[:elitism]
    elite_individuals = state.population[elite_indices]
    
    # --- Phase 1: Selection (Fixed Number of Offspring) ---
    num_offspring = pop_size - elitism  # Static calculation
    selected_indices_1 = selection_fn(sel_key1, state.fitness)[:num_offspring]
    selected_indices_2 = selection_fn(sel_key2, state.fitness)[:num_offspring]
    
    # Calculate selection pressure from stored indices
    selection_counts_1 = jnp.bincount(selected_indices_1, length=pop_size)
    selection_counts_2 = jnp.bincount(selected_indices_2, length=pop_size)
    selection_counts = selection_counts_1 + selection_counts_2
    selection_pressure = jnp.var(selection_counts)
    
    # --- Phase 2: Crossover (Static Shapes) ---
    parent_1 = state.population[selected_indices_1]  
    parent_2 = state.population[selected_indices_2]  
    
    # Fixed number of crossover operations
    crossover_keys = jar.split(cross_key, num_offspring)
    # get dimension of parents:
    #print(f"parents 1 shape{parent_1.shape}")
    #print(f"parents 2 shape{parent_2.shape}")
    #print(f"keys shape{crossover_keys.shape}")

    
    crossover_fn_batched = jax.vmap(crossover_fn, in_axes=(0, 0, 0))
    offspring_raw = crossover_fn_batched(crossover_keys, parent_1, parent_2)
    
    # Handle crossover output shape (ensure it's (num_offspring, genome_shape))
    if len(offspring_raw.shape) > 2:  # If crossover returns (n, 1, genome_shape)
        offspring_raw = offspring_raw[:, 0, :]  # Take first offspring from each pair
    
    # Calculate crossover success rate
    parent_avg = (parent_1.astype(jnp.float32) + parent_2.astype(jnp.float32)) / 2
    offspring_float = offspring_raw.astype(jnp.float32)
    crossover_success_rate = jnp.mean(jnp.abs(offspring_float - parent_avg))
    
    # --- Phase 3: Mutation (Fixed Shapes) ---
    mutation_keys = jar.split(mut_key, num_offspring)  # Static number
    mutation_fn_batched = jax.vmap(mutation_fn, in_axes=(0, 0))
    offspring_final = mutation_fn_batched(mutation_keys, offspring_raw)
    
    # Calculate mutation impact
    mutation_impact = jnp.mean(jnp.not_equal(offspring_final, offspring_raw).astype(jnp.float32))
    
    # --- Phase 4: Population Assembly (Static Shape) ---
    # Create new population with known shape: (pop_size, genome_shape)
    new_population = jnp.zeros_like(state.population)  # Start with zeros
    new_population = new_population.at[:elitism].set(elite_individuals)  # Set elites
    new_population = new_population.at[elitism:elitism+num_offspring].set(offspring_final)  # Set offspring
    
    new_fitness = jax.vmap(fitness_fn)(new_population)
    
    # --- Build Complete Intermediate State ---
    full_intermediate = FullIntermediateState(
        selected_indices_1=selected_indices_1,
        selected_indices_2=selected_indices_2, 
        selection_pressure=selection_pressure,
        offspring_raw=offspring_raw,
        crossover_success_rate=crossover_success_rate,
        offspring_final=offspring_final,
        mutation_impact=mutation_impact,
        elite_indices=elite_indices,
        new_population=new_population,
        best_fitness=state.fitness[elite_indices[0]]
    )
    
    # --- Update State ---
    updated_metrics = state.metrics.update_selection_pressure(selection_pressure)
    new_state = ResearchState(
        population=new_population,
        fitness=new_fitness,
        best_genome=elite_individuals[0],
        best_fitness=state.fitness[elite_indices[0]],
        key=key,
        generation=state.generation + 1,
        metrics=updated_metrics
    )
    
    return new_state, full_intermediate

In [5]:
new_state, full_intermediate = _research_step_fn(
    state  = research_final_state,
    _ = None,
    fitness_fn=fitness_evaluator.get_pure_fitness_function(),
    selection_fn=selector.get_pure_function(),
    crossover_fn=crossover.get_pure_function(),
    mutation_fn=mutator.get_pure_function(),
    pop_size=pop_size,
    elitism=elitism 
)

print(f"New state generation: {new_state.generation}")
print(f"New best fitness: {new_state.best_fitness}")

New state generation: 201
New best fitness: 98


In [None]:
# Enhanced step function with fitness-based convergence detection
def _research_step_fn_with_convergence(state: ResearchState,
                                       convergence_input: Dict, # fitness metrics for convergence detection
                                       fitness_fn: Callable,
                                       selection_fn: Callable,
                                       crossover_fn: Callable,
                                       mutation_fn: Callable,
                                       pop_size: int,
                                       elitism: int) -> Tuple[ResearchState, Dict]:
    """
    Enhanced research GA step with convergence detection capabilities.
    
    Args:
        convergence_input: Dict containing:
            - 'prev_best_fitness': Previous generation's best fitness
            - 'fitness_delta_ema': Exponential moving average of fitness deltas
            - 'stagnation_counter': Consecutive generations with minimal improvement
    """
    
    key, sel_key1, sel_key2, cross_key, mut_key = jar.split(state.key, 5)
    
    # Extract convergence tracking info
    prev_best_fitness = convergence_input['prev_best_fitness']
    fitness_delta_ema = convergence_input['fitness_delta_ema']
    stagnation_counter = convergence_input['stagnation_counter']
    
    # --- Phase 0: Elitism (Static Shape) ---
    sorted_indices = jnp.argsort(-state.fitness)  
    elite_indices = sorted_indices[:elitism]
    elite_individuals = state.population[elite_indices]
    current_best_fitness = state.fitness[elite_indices[0]]
    
    # --- Convergence Analysis ---
    fitness_delta = current_best_fitness - prev_best_fitness
    
    # Update EMA of fitness delta (Œ± = 0.3 for smoothing)
    alpha = 0.3
    new_fitness_delta_ema = alpha * fitness_delta + (1 - alpha) * fitness_delta_ema
    
    # Update stagnation counter
    stagnation_threshold = 0.001  # Minimal improvement threshold
    new_stagnation_counter = jnp.where(
        jnp.abs(fitness_delta) < stagnation_threshold,
        stagnation_counter + 1,
        0
    )
    
    # Convergence warning signals
    is_stagnant = new_stagnation_counter >= 5  # 5+ generations of minimal improvement
    is_declining = new_fitness_delta_ema < -0.001  # Fitness getting worse
    convergence_risk = jnp.logical_or(is_stagnant, is_declining)
    
    # --- Standard GA Operations ---
    num_offspring = pop_size - elitism
    selected_indices_1 = selection_fn(sel_key1, state.fitness)[:num_offspring]
    selected_indices_2 = selection_fn(sel_key2, state.fitness)[:num_offspring]
    
    selection_counts_1 = jnp.bincount(selected_indices_1, length=pop_size)
    selection_counts_2 = jnp.bincount(selected_indices_2, length=pop_size)
    selection_counts = selection_counts_1 + selection_counts_2
    selection_pressure = jnp.var(selection_counts)
    
    parent_1 = state.population[selected_indices_1]  
    parent_2 = state.population[selected_indices_2]  
    
    crossover_keys = jar.split(cross_key, num_offspring)
    crossover_fn_batched = jax.vmap(crossover_fn, in_axes=(0, 0, 0))
    offspring_raw = crossover_fn_batched(crossover_keys, parent_1, parent_2)
    
    if len(offspring_raw.shape) > 2:
        offspring_raw = offspring_raw[:, 0, :]
    
    crossover_success_rate = jnp.mean(jnp.abs(
        offspring_raw.astype(jnp.float32) - 
        (parent_1.astype(jnp.float32) + parent_2.astype(jnp.float32)) / 2
    ))
    
    mutation_keys = jar.split(mut_key, num_offspring)
    mutation_fn_batched = jax.vmap(mutation_fn, in_axes=(0, 0))
    offspring_final = mutation_fn_batched(mutation_keys, offspring_raw)
    
    mutation_impact = jnp.mean(jnp.not_equal(offspring_final, offspring_raw).astype(jnp.float32))
    
    # Population assembly
    new_population = jnp.zeros_like(state.population)
    new_population = new_population.at[:elitism].set(elite_individuals)
    new_population = new_population.at[elitism:elitism+num_offspring].set(offspring_final)
    
    new_fitness = jax.vmap(fitness_fn)(new_population)
    
    # --- Enhanced Output with Convergence Metrics ---
    convergence_output = {
        'prev_best_fitness': current_best_fitness,
        'fitness_delta_ema': new_fitness_delta_ema,
        'stagnation_counter': new_stagnation_counter,
        'fitness_delta': fitness_delta,
        'is_stagnant': is_stagnant,
        'is_declining': is_declining,
        'convergence_risk': convergence_risk,
        'selection_pressure': selection_pressure,
        'crossover_success_rate': crossover_success_rate,
        'mutation_impact': mutation_impact,
        'best_fitness': current_best_fitness
    }
    
    # Update state
    updated_metrics = state.metrics.update_selection_pressure(selection_pressure)
    new_state = ResearchState(
        population=new_population,
        fitness=new_fitness,
        best_genome=elite_individuals[0],
        best_fitness=current_best_fitness,
        key=key,
        generation=state.generation + 1,
        metrics=updated_metrics
    )
    
    return new_state, convergence_output

In [None]:
# Demo: Evolution with real-time convergence detection
def run_evolution_with_convergence_tracking(initial_state: ResearchState, 
                                           num_generations: int,
                                           fitness_fn, selection_fn, crossover_fn, mutation_fn,
                                           pop_size: int, elitism: int):
    """
    Run evolution with real-time convergence tracking using scan.
    """
    
    # Create the step function with all static arguments baked in
    step_fn = functools.partial(
        _research_step_fn_with_convergence,
        fitness_fn=fitness_fn,
        selection_fn=selection_fn,
        crossover_fn=crossover_fn,
        mutation_fn=mutation_fn,
        pop_size=pop_size,
        elitism=elitism
    )
    
    # Initialize convergence tracking data for scan
    initial_convergence_data = jnp.array([
        {
            'prev_best_fitness': initial_state.best_fitness,
            'fitness_delta_ema': 0.0,
            'stagnation_counter': 0
        }
    ] * num_generations)  # Repeat for each generation
    
    # Use scan with convergence data as input sequence
    final_state, all_outputs = jax.lax.scan(
        step_fn, 
        initial_state, 
        initial_convergence_data
    )
    
    return final_state, all_outputs

# Test the enhanced evolution
print("üöÄ Running evolution with convergence tracking...")

enhanced_final_state, convergence_outputs = run_evolution_with_convergence_tracking(
    initial_state=research_final_state,
    num_generations=50,  # Shorter run to see convergence patterns
    fitness_fn=fitness_evaluator.get_pure_fitness_function(),
    selection_fn=selector.get_pure_function(),
    crossover_fn=crossover.get_pure_function(),
    mutation_fn=mutator.get_pure_function(),
    pop_size=pop_size,
    elitism=elitism
)

print(f"Final generation: {enhanced_final_state.generation}")
print(f"Final best fitness: {enhanced_final_state.best_fitness}")

In [None]:
# Analyze the convergence patterns
import matplotlib.pyplot as plt

def plot_convergence_tracking(convergence_outputs):
    """Visualize real-time convergence detection metrics."""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    generations = jnp.arange(len(convergence_outputs['best_fitness']))
    
    # Plot 1: Best fitness over time
    axes[0,0].plot(generations, convergence_outputs['best_fitness'], 'b-', linewidth=2)
    axes[0,0].set_title('Best Fitness Evolution')
    axes[0,0].set_xlabel('Generation')
    axes[0,0].set_ylabel('Best Fitness')
    axes[0,0].grid(True)
    
    # Plot 2: Fitness delta (improvement per generation)
    axes[0,1].plot(generations, convergence_outputs['fitness_delta'], 'g-', alpha=0.7, label='Raw Delta')
    axes[0,1].plot(generations, convergence_outputs['fitness_delta_ema'], 'r-', linewidth=2, label='EMA Delta')
    axes[0,1].axhline(y=0, color='black', linestyle='--', alpha=0.5)
    axes[0,1].set_title('Fitness Improvement Rate')
    axes[0,1].set_xlabel('Generation')
    axes[0,1].set_ylabel('Fitness Delta')
    axes[0,1].legend()
    axes[0,1].grid(True)
    
    # Plot 3: Stagnation counter
    axes[0,2].plot(generations, convergence_outputs['stagnation_counter'], 'orange', linewidth=2)
    axes[0,2].axhline(y=5, color='red', linestyle='--', alpha=0.7, label='Stagnation Threshold')
    axes[0,2].set_title('Stagnation Counter')
    axes[0,2].set_xlabel('Generation')
    axes[0,2].set_ylabel('Consecutive Stagnant Generations')
    axes[0,2].legend()
    axes[0,2].grid(True)
    
    # Plot 4: Convergence risk indicators
    risk_generations = generations[convergence_outputs['convergence_risk']]
    axes[1,0].scatter(risk_generations, 
                     convergence_outputs['best_fitness'][convergence_outputs['convergence_risk']], 
                     color='red', s=50, alpha=0.7, label='High Convergence Risk')
    axes[1,0].plot(generations, convergence_outputs['best_fitness'], 'b-', alpha=0.5)
    axes[1,0].set_title('Convergence Risk Detection')
    axes[1,0].set_xlabel('Generation')
    axes[1,0].set_ylabel('Best Fitness')
    axes[1,0].legend()
    axes[1,0].grid(True)
    
    # Plot 5: Selection pressure
    axes[1,1].plot(generations, convergence_outputs['selection_pressure'], 'purple', linewidth=2)
    axes[1,1].set_title('Selection Pressure')
    axes[1,1].set_xlabel('Generation')
    axes[1,1].set_ylabel('Selection Pressure (Variance)')
    axes[1,1].grid(True)
    
    # Plot 6: Operator effectiveness
    axes[1,2].plot(generations, convergence_outputs['crossover_success_rate'], 'cyan', 
                   linewidth=2, label='Crossover Success')
    axes[1,2].plot(generations, convergence_outputs['mutation_impact'], 'magenta', 
                   linewidth=2, label='Mutation Impact')
    axes[1,2].set_title('Operator Effectiveness')
    axes[1,2].set_xlabel('Generation')
    axes[1,2].set_ylabel('Rate/Impact')
    axes[1,2].legend()
    axes[1,2].grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Summary statistics
    stagnant_generations = jnp.sum(convergence_outputs['convergence_risk'])
    max_stagnation = jnp.max(convergence_outputs['stagnation_counter'])
    avg_improvement = jnp.mean(convergence_outputs['fitness_delta_ema'])
    
    print(f"üìä CONVERGENCE ANALYSIS SUMMARY:")
    print(f"   üö® High-risk generations: {stagnant_generations}/{len(generations)} ({stagnant_generations/len(generations)*100:.1f}%)")
    print(f"   ‚è∏Ô∏è  Maximum consecutive stagnation: {max_stagnation} generations")
    print(f"   üìà Average improvement rate (EMA): {avg_improvement:.4f}")
    print(f"   üéØ Final fitness improvement: {convergence_outputs['fitness_delta'][-1]:.4f}")

# Analyze the results
plot_convergence_tracking(convergence_outputs)