In [1]:
import malthusjax as mjx
import jax
import jax.numpy as jnp
import jax.random as jar

# Verify we are running on the intended backend
print(f"MalthusJAX Version: {mjx.__version__}")
print(f"JAX Backend: {jax.devices()[0]}")

# Master Random Key
key = jar.PRNGKey(42)

MalthusJAX Version: 0.2.0
JAX Backend: TFRT_CPU_0


In [2]:
# 1. Define the Problem (100-bit string)
genome_config = mjx.BinaryGenomeConfig(length=100)

# 2. Define Engine Parameters (Static)
# We want 1000 individuals, running for 50 generations
# We preserve the top 5 individuals (Elitism)
params = mjx.StandardEngineParams(
    pop_size=1000,
    num_generations=50,
    elitism=5
)

# 3. Initialize Random Population (Batch Creation)
key, k_pop = jar.split(key)
initial_pop = mjx.BinaryPopulation.init_random(k_pop, genome_config, params.pop_size)

print(f"Population Shape: {initial_pop.genes.bits.shape}")

Population Shape: (1000, 100)


In [3]:
# 4. Assemble the Engine
# We use the clean 'mjx' namespace we built
engine = mjx.StandardGeneticEngine(
    evaluator=mjx.BinarySumEvaluator(mjx.BinarySumConfig(maximize=True)),
    selection=mjx.selection.Tournament(num_selections=params.pop_size, tournament_size=3),
    crossover=mjx.crossover.Uniform(num_offspring=2, crossover_rate=0.8),
    mutation=mjx.mutation.BitFlip(num_offspring=1, mutation_rate=0.01)
)

# 5. Create Initial State
# This runs one evaluation to populate the 'best_fitness' and 'fitness' fields
key, k_init = jar.split(key)
state = engine.init_state(k_init, initial_pop)

print(f"Initial Best Fitness: {state.best_fitness}")

Initial Best Fitness: 65.0


In [7]:
# 6. Run Evolution (JIT Compiled)
print("Compiling & Running...")

# The .run() method handles the jax.lax.scan loop automatically
final_state, history, elapsed = engine.run(
    state, 
    params, 
    time_it=True, 
    verbose=True
)

# 7. Analyze Results
print(f"\nOptimization Complete in {elapsed:.4f}s")
print(f"Final Best Fitness: {final_state.best_fitness}/{genome_config.length}")

# Check if we solved it (Should be close to 100)
best_genes = final_state.best_genome.genes.bits  # Index into population to get first genome
print(f"Best Genome Sample: {best_genes}")

Compiling & Running...
Running evolution (JIT compilation automatic)

Optimization Complete in 0.3909s
Final Best Fitness: 100.0/100
Best Genome Sample: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
