# MalthusJAX Level 3 Showcase: Engines & Complete Evolution


### Key Components:
- **AbstractEvolutionEngine**: The main orchestrator that combines all components
- **Complete GA Pipeline**: Selection → Crossover → Mutation → Evaluation
- **Performance Benchmarks**: Speed comparisons and scalability demonstrations
- **Multiple Problem Types**: Binary, Real-valued, and complex optimization problems

Experience the full power of JAX-accelerated evolutionary computation

In [1]:
import sys 
import os
import time
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax # type: ignore
import jax.numpy as jnp # type: ignore
import jax.random as jar # type: ignore
from jax import random, jit, vmap # type: ignore

print(f"JAX running on: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")

# Level 1: Genomes & Fitness
from malthusjax.core.genome.binary import BinaryGenome, BinaryGenomeConfig
from malthusjax.core.genome.real import RealGenome, RealGenomeConfig
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.core.fitness.real import SphereFitnessEvaluator, RastriginFitnessEvaluator

# Level 2: Operators
from malthusjax.operators.selection.tournament import TournamentSelection
from malthusjax.operators.mutation.binary import BitFlipMutation
from malthusjax.operators.mutation.real import BallMutation
from malthusjax.operators.crossover.binary import UniformCrossover as BinaryUniform
from malthusjax.operators.crossover.real import AverageCrossover

# Level 3: Engines (NEW)
from malthusjax.engine.BasicMalthusEngine import BasicMalthusEngine
#from malthusjax.engine.genetic import GeneticAlgorithmEngine
#from malthusjax.engine.config import EvolutionConfig

key = jar.PRNGKey(42)

JAX running on: cpu
Available devices: [CpuDevice(id=0)]


In [2]:
# --- Imports ---
import functools
from typing import Callable, List, Dict, Any, Tuple

print(f"JAX running on: {jax.default_backend()}")

# --- JAX Setup ---
# Master PRNG Key
key = jar.PRNGKey(42)

# --- Level 1 Imports (from our previous work) ---
# Genomes
from malthusjax.core.genome.binary import BinaryGenome, BinaryGenomeConfig
from malthusjax.core.genome.real import RealGenome, RealGenomeConfig
#from malthusjax.core.genome.permutation import PermutationGenome # Assuming this exists
from malthusjax.core.genome.categorical import CategoricalGenome, CategoricalGenomeConfig

# Fitness
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.core.fitness.real import SphereFitnessEvaluator

# --- Level 2 Imports (NEW) ---
# Base classes (for type hints)
from malthusjax.operators.base import AbstractGeneticOperator
from malthusjax.operators.selection.base import AbstractSelectionOperator
from malthusjax.operators.mutation.base import AbstractMutation
from malthusjax.operators.crossover.base import AbstractCrossover

# Selection Operators
from malthusjax.operators.selection.tournament import TournamentSelection
from malthusjax.operators.selection.roulette import RouletteSelection

# Binary Operators
from malthusjax.operators.mutation.binary import BitFlipMutation, SwapMutation as BinarySwap
from malthusjax.operators.crossover.binary import UniformCrossover as BinaryUniform, SinglePointCrossover as BinarySinglePoint

# Real Operators
from malthusjax.operators.mutation.real import BallMutation
from malthusjax.operators.crossover.real import UniformCrossover as RealUniform, SinglePointCrossover as RealSinglePoint, AverageCrossover

# Permutation Operators
from malthusjax.operators.mutation.permutation import ScrambleMutation, SwapMutation as PermutationSwap
from malthusjax.operators.crossover.permutation import CycleCrossover, PillarCrossover

# Categorical Operators
from malthusjax.operators.mutation.categorical import CategoricalFlipMutation
from malthusjax.operators.crossover.categorical import UniformCrossover as CatUniform, SinglePointCrossover as CatSinglePoint


JAX running on: cpu


## 1. The AbstractEvolutionEngine

The **AbstractEvolutionEngine** is the heart of Level 3, orchestrating the complete evolutionary process:

### Key Responsibilities:
- **Population Management**: Initialize and maintain populations
- **Operator Coordination**: Sequence selection, crossover, mutation in the correct order
- **Fitness Evaluation**: Batch evaluate populations efficiently
- **Evolution Loop**: Run complete generations with JIT compilation
- **Statistics Tracking**: Monitor best fitness, diversity, convergence

### The Engine Pattern:
1. **Configure**: Set up genome config, operators, and evolution parameters
2. **Initialize**: Create initial population and JIT-compile the evolution step
3. **Evolve**: Run the main evolution loop
4. **Results**: Extract best solutions and statistics

In [3]:
# --- 1. Configure the Problem ---
print("=== Binary Optimization: Maximize Sum of Bits ===")

# Genome configuration
binary_config = BinaryGenomeConfig(array_shape=(50,), p=0.5)

# Fitness evaluator
fitness_evaluator = BinarySumFitnessEvaluator()

# Operators
selection_op = TournamentSelection(number_of_choices=100, tournament_size=3)
crossover_op = BinaryUniform(crossover_rate=0.8, n_outputs=1)
mutation_op = BitFlipMutation(mutation_rate=0.02)

# Evolution configuration
'''evolution_config = EvolutionConfig(
    population_size=100,
    max_generations=50,
    elitism_size=5
)

print(f"Problem size: {binary_config.array_shape[0]} bits")
print(f"Population size: {evolution_config.population_size}")
print(f"Max generations: {evolution_config.max_generations}")'''

=== Binary Optimization: Maximize Sum of Bits ===


'evolution_config = EvolutionConfig(\n    population_size=100,\n    max_generations=50,\n    elitism_size=5\n)\n\nprint(f"Problem size: {binary_config.array_shape[0]} bits")\nprint(f"Population size: {evolution_config.population_size}")\nprint(f"Max generations: {evolution_config.max_generations}")'

In [4]:
from malthusjax.engine.BasicMalthusEngine import BasicMalthusEngine
key = jar.PRNGKey(42)
key, engine_key = jar.split(key)

genome = BinaryGenome(array_shape=binary_config.array_shape, p=binary_config.p)

binary_engine = BasicMalthusEngine(
    genome_representation=genome,
    fitness_evaluator=fitness_evaluator,
    selection_operator=selection_op,
    crossover_operator=crossover_op,
    mutation_operator=mutation_op,
    elitism=2
)

# --- 3. Run Evolution ---
print("\nRunning binary optimization...")
start_time = time.time()

final_state, history = binary_engine.run(engine_key,
                                         num_generations=50,
                                         pop_size=100)  # Specify the population size
                                         

end_time = time.time()
print(f"Evolution completed in {end_time - start_time:.2f} seconds")

# --- 4. Results ---
best_fitness = final_state.best_fitness
best_genome = final_state.best_genome
generation = final_state.generation

print(f"\nFinal Results:")
print(f"Best fitness: {best_fitness} / {binary_config.array_shape[0]}")
print(f"Best genome: {best_genome}")
print(f"Generations: {generation}")
print(f"Success rate: {best_fitness / binary_config.array_shape[0] * 100:.1f}%")


Running binary optimization...
Evolution completed in 2.84 seconds

Final Results:
Best fitness: 15 / 50
Best genome: [ True False False False  True False False  True False False  True False
 False False False  True  True False  True False False  True  True False
 False  True  True False False False False False False  True  True False
 False False False False False False False False False False  True  True
 False False]
Generations: 51
Success rate: 30.0%


### Example 2: Real-Valued Optimization

Let's tackle a more challenging problem: the **Rastrigin function**, a classic multimodal optimization benchmark with many local optima.

#### Rastrigin Function:
- **Global minimum**: f(0, 0, ..., 0) = 0
- **Search space**: [-5.12, 5.12]^n
- **Difficulty**: Highly multimodal with many local minima

This demonstrates the engine's versatility across different problem types.

In [5]:
# --- 1. Configure Real Problem ---
print("=== Real-Valued Optimization: Rastrigin Function ===")

# Genome configuration (10-dimensional Rastrigin)
real_config = RealGenomeConfig(array_shape=(10,), min_val=-5.12, max_val=5.12)

# Fitness evaluator (Rastrigin is minimization)
rastrigin_evaluator = RastriginFitnessEvaluator()

# Real-valued operators
real_selection = TournamentSelection(number_of_choices=200, tournament_size=4)
real_crossover = AverageCrossover(blend_rate=0.3, n_outputs=1)
real_mutation = BallMutation(mutation_rate=0.8, mutation_strength=0.1)

'''# Evolution configuration (larger population for harder problem)
real_evolution_config = EvolutionConfig(
    population_size=200,
    max_generations=100,
    elitism_size=10
)'''

print(f"Problem dimensions: {real_config.array_shape[0]}")
print(f"Search space: [{real_config.min_val}, {real_config.max_val}]^{real_config.array_shape[0]}")
print(f"Global optimum: f(0, 0, ..., 0) = 0")

=== Real-Valued Optimization: Rastrigin Function ===
Problem dimensions: 10
Search space: [-5.12, 5.12]^10
Global optimum: f(0, 0, ..., 0) = 0


In [6]:
# --- 2. Create and Run Real Engine ---
key, real_engine_key = jar.split(key)

real_engine = BasicMalthusEngine(
    genome_representation=RealGenome(array_shape=real_config.array_shape, min_val=real_config.min_val, max_val=real_config.max_val),
    fitness_evaluator=rastrigin_evaluator,
    selection_operator=real_selection,
    crossover_operator=real_crossover,
    mutation_operator=real_mutation,
    elitism=1
)

print("\nRunning Rastrigin optimization...")
start_time = time.time()

real_final_state, real_history = real_engine.run(real_engine_key,
                                                num_generations=100,
                                                pop_size=200)  # Specify the population size

end_time = time.time()
print(f"Evolution completed in {end_time - start_time:.2f} seconds")

# --- 3. Results ---
best_real_fitness = real_final_state.best_fitness
best_real_genome = real_final_state.best_genome

print(f"\nFinal Results:")
print(f"Best fitness: {best_real_fitness:.6f}")
print(f"Best genome: {best_real_genome}")
print(f"Distance from global optimum: {jnp.linalg.norm(best_real_genome):.6f}")

# Show convergence
print(f"\nConvergence Progress:")
for i in [0, 25, 50, 75, 99]:
    if i < len(real_history):
        print(f"Generation {i:2d}: {real_history[i]:.6f}")


Running Rastrigin optimization...
Evolution completed in 2.42 seconds

Final Results:
Best fitness: 55.539959
Best genome: [ 1.1485001   0.91059196 -1.2185602  -0.03232902 -0.11216843 -2.2040331
  0.8476451   0.3067146   0.15139309 -0.06845435]
Distance from global optimum: 3.056909

Convergence Progress:
Generation  0: 93.886360
Generation 25: 55.539959
Generation 50: 55.539959
Generation 75: 55.539959
Generation 99: 55.539959


## 2. Performance Analysis & JIT Compilation

One of MalthusJAX's key advantages is **JIT compilation** of the entire evolutionary step. Let's benchmark this performance gain:

### What Gets JIT-Compiled:
- **Population Initialization**: Vectorized genome creation
- **Fitness Evaluation**: Batch evaluation of entire populations  
- **Selection**: Tournament selection across the population
- **Crossover**: Vectorized parent pairing and offspring creation
- **Mutation**: Batch mutation of all individuals
- **Complete Generation**: The entire evolutionary step as one function

This results in **orders of magnitude** speedup compared to traditional implementations.

In [7]:
# --- Performance Comparison ---
print("=== Performance Benchmark ===")

def benchmark_engine_performance():
    """Compare first run (compilation) vs subsequent runs (JIT optimized)"""
    
    # Small problem for quick benchmarking
    
    bench_engine = BasicMalthusEngine(
        genome_representation = RealGenome(array_shape=(5,), min_val=-1.0, max_val=1.0),
        fitness_evaluator = SphereFitnessEvaluator(),
        selection_operator = TournamentSelection(50, 3),
        crossover_operator = AverageCrossover(0.8, 1),
        mutation_operator = BallMutation(0.1, 0.1),
        elitism=2
    )
    
    key1, key2, key3 = jar.split(key, 3)
    
    # First run (includes compilation time)
    print("First run (with JIT compilation)...")
    start = time.time()
    _, _ = bench_engine.run(key1, num_generations=50, pop_size=50)
    _.block_until_ready()
    first_time = time.time() - start
    
    # Second run (fully compiled)
    print("Second run (JIT optimized)...")
    start = time.time()
    _, _ = bench_engine.run(key2, num_generations=50, pop_size=50)
    # wait for compilation to finish
    _.block_until_ready()
    second_time = time.time() - start
    
    # Third run (verify consistency)
    start = time.time()
    _, _ = bench_engine.run(key3, num_generations=50, pop_size=50)
    _.block_until_ready()
    third_time = time.time() - start
    
    print(f"\nPerformance Results:")
    print(f"First run (with compilation): {first_time:.3f}s")
    print(f"Second run (JIT optimized):   {second_time:.3f}s")
    print(f"Third run (JIT optimized):    {third_time:.3f}s")
    print(f"Speedup after compilation:    {first_time / second_time:.1f}x")

benchmark_engine_performance()

=== Performance Benchmark ===
First run (with JIT compilation)...
Second run (JIT optimized)...

Performance Results:
First run (with compilation): 2.241s
Second run (JIT optimized):   0.017s
Third run (JIT optimized):    0.015s
Speedup after compilation:    132.4x


In [8]:
# --- Scalability Test ---
print("\n=== Scalability Analysis ===")

def test_population_scaling():
    """Test how performance scales with population size"""
    
    population_sizes = [50, 100, 200, 500,1000,10000]
    times = []
    
    base_config = RealGenomeConfig(array_shape=(100,), min_val=-1.0, max_val=1.0)
    base_fitness = SphereFitnessEvaluator()
    
    for pop_size in population_sizes:
        print(f"Testing population size: {pop_size}")
        engine = BasicMalthusEngine(
            genome_representation = RealGenome(array_shape=base_config.array_shape, min_val=base_config.min_val, max_val=base_config.max_val),
            fitness_evaluator = base_fitness,
            selection_operator = TournamentSelection(100, 3),
            crossover_operator = AverageCrossover(0.8, 1),
            mutation_operator = BallMutation(0.1, 0.1),
            elitism=2
        )
        key1, test_key = jar.split(key)
        start = time.time()
        _, _ = engine.run(test_key, num_generations=50, pop_size=pop_size)
        
        _.block_until_ready()
        elapsed = time.time() - start
        times.append(elapsed)
        print(f"Time taken: {elapsed:.3f}s")    
    print("\nScalability Results:")
    for pop_size, t in zip(population_sizes, times):
        print(f"Population {pop_size:4d}: {t:.3f}s")
test_population_scaling()
    


=== Scalability Analysis ===
Testing population size: 50
Time taken: 1.244s
Testing population size: 100
Time taken: 0.960s
Testing population size: 200
Time taken: 1.054s
Testing population size: 500
Time taken: 1.052s
Testing population size: 1000
Time taken: 1.025s
Testing population size: 10000
Time taken: 1.253s

Scalability Results:
Population   50: 1.244s
Population  100: 0.960s
Population  200: 1.054s
Population  500: 1.052s
Population 1000: 1.025s
Population 10000: 1.253s


In [9]:
# --- Final Performance Demonstration ---
print("=== Final Performance Summary ===")

def comprehensive_benchmark():
    """Run a comprehensive benchmark across different problem types"""
    
    problems = [
        ("Binary (50-bit)", BinaryGenomeConfig((50,), 0.5), BinarySumFitnessEvaluator(), 100),
        ("Real (10D Sphere)", RealGenomeConfig((10,), -5, 5), SphereFitnessEvaluator(), 100),
        ("Real (20D Rastrigin)", RealGenomeConfig((20,), -5.12, 5.12), RastriginFitnessEvaluator(), 200)
    ]
    
    print("Problem Type           | Pop Size | Generations | Time     | Best Fitness")
    print("-" * 75)
    
    for name, genome_config, fitness_eval, pop_size in problems:
        
        # Configure engine
        if "Binary" in name:
            genome_rep = BinaryGenome(array_shape=genome_config.array_shape, p=genome_config.p)
            selection_op = TournamentSelection(pop_size, 3)
            crossover_op = BinaryUniform(0.8, 1)
            mutation_op = BitFlipMutation(0.02)
        else:
            genome_rep = RealGenome(array_shape=genome_config.array_shape, min_val=genome_config.min_val, max_val=genome_config.max_val)
            selection_op = TournamentSelection(pop_size, 4)
            crossover_op = AverageCrossover(0.7, 1)
            mutation_op = BallMutation(0.1, 0.1)
            
            
        engine = BasicMalthusEngine(
            genome_representation = genome_rep,
            fitness_evaluator = fitness_eval,
            selection_operator = selection_op,
            crossover_operator = crossover_op,
            mutation_operator = mutation_op,
            elitism=2
        )
        
        key1, test_key = jar.split(key)
        start = time.time()
        final_state, _ = engine.run(test_key, num_generations=50, pop_size=pop_size)
        _.block_until_ready()
        elapsed = time.time() - start
        
        print(f"{name:22s} | {pop_size:8d} | {50:11d} | {elapsed:.3f}s | {final_state.best_fitness:.6f}")
    
comprehensive_benchmark()


=== Final Performance Summary ===
Problem Type           | Pop Size | Generations | Time     | Best Fitness
---------------------------------------------------------------------------
Binary (50-bit)        |      100 |          50 | 0.654s | 16.000000
Real (10D Sphere)      |      100 |          50 | 1.718s | -143.590576
Real (20D Rastrigin)   |      200 |          50 | 1.200s | 137.804749
