# MalthusJAX Level 1 Showcase: Genomes & Fitness

Welcome to the **MalthusJAX Level 1 Showcase**! This notebook highlights the core components of MalthusJAX at Level 1:

### Key Components:
- **AbstractFitnessEvaluator**: Learn how to define JIT-compatible and auto-vectorized fitness functions.
- **AbstractGenome**: Explore JAX-native genome representations (Pytrees) with:
    - **Static Configuration**: Fixed parameters defining the genome structure.
    - **Dynamic Tensor Data**: Flexible data for runtime operations.

Dive in to understand how these components work together to enable efficient and scalable computations!

In [1]:
import sys 
import os
import time
# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.core.base import JAXTensorizable
from malthusjax.core.fitness.base import AbstractFitnessEvaluator
from malthusjax.core.genome.base import AbstractGenome, AbstractGenomeConfig



In [2]:
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
import functools
from typing import Callable, List, Dict, Any, Tuple
import sys
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

# Fitness Functions
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator, KnapsackFitnessEvaluator
from malthusjax.core.fitness.real import SphereFitnessEvaluator, RastriginFitnessEvaluator

# Genomes
from malthusjax.core.genome.binary import BinaryGenome, BinaryGenomeConfig
from malthusjax.core.genome.real import RealGenome, RealGenomeConfig
from malthusjax.core.genome.categorical import CategoricalGenome, CategoricalGenomeConfig

# --- JAX Setup ---
print(f"JAX running on: {jax.default_backend()}")
print(f"Available devices: {jax.devices()}")

# Master PRNG Key
key = random.PRNGKey(42)

JAX running on: cpu
Available devices: [CpuDevice(id=0)]


## 2. The AbstractFitnessEvaluator

The **AbstractFitnessEvaluator** is designed to define the logic for evaluating a single genome tensor. With MalthusJAX, batch evaluation (via `vmap`) is automatically handled, simplifying the process.

### Key Concept:
You need to implement the following method:



In [3]:
# 1. Instantiate the evaluator
sum_evaluator = BinarySumFitnessEvaluator()

# 2. Get the JIT-compatible tensor function
# This is the function we'll use inside our GA loop
tensor_fn = sum_evaluator.get_tensor_fitness_function()

# 3. JIT-compile it
jit_tensor_fn = jit(tensor_fn)

# 4. Test on a single genome tensor
key, subkey = random.split(key)
dummy_genome = random.bernoulli(subkey, p=0.5, shape=(10,)).astype(jnp.float32)

print(f"Dummy Genome: {dummy_genome}")
fitness = jit_tensor_fn(dummy_genome)
print(f"Fitness (JITted): {fitness}")
print(f"Fitness (should match): {jnp.sum(dummy_genome)}")

Dummy Genome: [0. 0. 1. 1. 1. 1. 1. 1. 1. 1.]
Fitness (JITted): 8.0
Fitness (should match): 8.0


### Auto-Vectorization with `vmap`

One of the most powerful features of MalthusJAX is **auto-vectorization**. The `evaluate_batch` method in the base class leverages `jax.vmap` to seamlessly create a batched version of the tensor function.

#### Why is this important?
- **Efficiency**: Evaluate an entire population in parallel on the accelerator.
- **Simplicity**: A single method call handles the batch evaluation for you.

With this approach, you can process large populations effortlessly, unlocking the full potential of JAX's high-performance capabilities.

In [4]:
# 1. Create a BATCH of genomes
pop_size = 5
key, subkey = random.split(key)
population_tensors = random.bernoulli(subkey, p=0.5, shape=(pop_size, 10)).astype(jnp.float32)
print(f"Population shape: {population_tensors.shape}\n")
print(population_tensors)

# 2. Evaluate the whole batch
# This call uses jax.vmap and jax.jit under the hood
fitness_values = sum_evaluator.evaluate_batch(population_tensors)

print(f"\nFitness values: {fitness_values}")

# 3. Verify
manual_sums = jnp.sum(population_tensors, axis=1)
print(f"Manual sums: {manual_sums}")

Population shape: (5, 10)

[[1. 1. 1. 1. 0. 1. 0. 0. 0. 1.]
 [0. 1. 0. 0. 1. 1. 0. 1. 0. 1.]
 [1. 0. 0. 0. 0. 0. 0. 1. 1. 0.]
 [0. 1. 1. 1. 0. 0. 1. 0. 1. 1.]
 [1. 0. 1. 0. 1. 1. 1. 1. 0. 1.]]

Fitness values: [6.0, 5.0, 3.0, 6.0, 7.0]
Manual sums: [6. 5. 3. 6. 7.]


Example 2: KnapsackFitnessEvaluator
This evaluator shows how to pass static arguments (weights, values, limits) into the JIT-compiled function using functools.partial. The get_tensor_fitness_function returns a closure that already knows about these static values.

In [5]:
# 1. Define problem statics
item_weights = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
item_values = jnp.array([5.0, 8.0, 1.0, 10.0, 7.0])
weight_limit = 10.0

# 2. Instantiate evaluator
knapsack_evaluator = KnapsackFitnessEvaluator(
    weights=item_weights,
    values=item_values,
    weight_limit=weight_limit
)

# 3. Create a test population
key, subkey = random.split(key)
knapsack_pop = random.bernoulli(subkey, p=0.4, shape=(4, 5)).astype(jnp.float32)
print("Knapsack Population:\n", knapsack_pop)

# 4. Evaluate batch
knapsack_fitness = knapsack_evaluator.evaluate_batch(knapsack_pop)

# 5. Check results
for i in range(4):
    genome = knapsack_pop[i]
    w = jnp.sum(genome * item_weights)
    v = jnp.sum(genome * item_values)
    fit = knapsack_fitness[i]
    print(f"Genome {i}: Weight={w}, Value={v}, Fitness={fit}")


Knapsack Population:
 [[1. 1. 1. 1. 0.]
 [0. 1. 0. 0. 0.]
 [1. 0. 0. 0. 1.]
 [1. 0. 0. 1. 1.]]
Genome 0: Weight=10.0, Value=24.0, Fitness=24.0
Genome 1: Weight=2.0, Value=8.0, Fitness=8.0
Genome 2: Weight=6.0, Value=12.0, Fitness=12.0
Genome 3: Weight=10.0, Value=22.0, Fitness=22.0


## 3. The AbstractGenome

Genomes in MalthusJAX are designed as **Pytrees**, a fundamental concept in JAX.

### What is a Pytree?
A **Pytree** in JAX is a structure that separates its data into two categories:
- **Dynamic Data (Children):** JAX arrays, such as the genome tensor.
- **Static Data (Auxiliary):** Everything else, like `array_shape`, `p`, `min_val`, etc.

### Key Components:
- **Static Data:** Encapsulated in `...Config` objects, which are ideal for holding genome configurations.
- **Dynamic Data:** Represented as JAX arrays, enabling efficient computation.

### The Contract:
1. Define a `@dataclass(frozen=True)` for your `...GenomeConfig`.
2. Register the `...GenomeConfig` as a Pytree (already done for you).
3. Implement the method:
    ```python
    get_random_initialization_pure_from_config(cls, config)
    ```

This design ensures seamless integration with JAX's functional programming paradigm, enabling efficient and scalable genome operations.

### Example 1: BinaryGenome + BinarySumFitnessEvaluator

In this example, we demonstrate how to obtain a **JIT-compilable initialization function** directly from the `BinaryGenome` class.

#### Highlights:
- **BinaryGenome**: Represents the genome structure.
- **BinarySumFitnessEvaluator**: Evaluates the fitness by summing binary values.

Follow along to see how these components work together seamlessly!

In [6]:
# 1. Create a static configuration
bin_config = BinaryGenomeConfig(array_shape=(20,), p=0.5)
print(f"Config: {bin_config}")

# 2. Get the JIT-compilable init function *from the class*
# This is a partial function that "knows" the config
init_fn_pure = BinaryGenome.get_random_initialization_pure_from_config(bin_config)

# 3. JIT-compile it
jit_init_fn = jit(init_fn_pure)

# 4. Create a new genome tensor
key, subkey = random.split(key)
new_genome_tensor = jit_init_fn(subkey)
print(f"\nNew genome tensor (shape {new_genome_tensor.shape}):\n{new_genome_tensor}")

Config: BinaryGenomeConfig(array_shape=(20,), p=0.5)

New genome tensor (shape (20,)):
[False  True  True False False  True  True  True  True  True False False
 False  True  True False  True  True False False]


### vmap + jit for Population Creation

Harness the power of **vmap** and **jit** to efficiently create an entire population in one step. This approach combines the elegance of JAX's functional programming with the speed of just-in-time compilation.

In [7]:
# 1. vmap our JITted init function
pop_init_fn = jit(vmap(jit_init_fn))

# 2. Create a batch of keys
pop_size = 10
key, *keys_batch = random.split(key, pop_size + 1)
keys_array = jnp.stack(keys_batch)
print(f"Keys shape: {keys_array.shape}")

# 3. Create the whole population
population = pop_init_fn(keys_array)
print(f"\nPopulation shape: {population.shape}")

# 4. Evaluate the new population
fitnesses = sum_evaluator.evaluate_batch(population)
print(f"\nFitnesses: {fitnesses}")

Keys shape: (10, 2)

Population shape: (10, 20)

Fitnesses: [6.0, 10.0, 12.0, 13.0, 11.0, 13.0, 14.0, 11.0, 9.0, 11.0]


### Example 2: RealGenome + SphereFitnessEvaluator

The **RealGenome** seamlessly integrates with the **SphereFitnessEvaluator**, showcasing the versatility and elegance of this abstraction. The workflow remains consistent, highlighting the power of reusable and composable components.

In [8]:
# 1. Create config and evaluator
real_config = RealGenomeConfig(array_shape=(5,), min_val=-5.0, max_val=5.0)
sphere_evaluator = SphereFitnessEvaluator()

print(f"Config: {real_config}")

# 2. Get and JIT the init function
real_init_fn = RealGenome.get_random_initialization_pure_from_config(real_config)
jit_real_init_fn = jit(real_init_fn)

# 3. Create a population
key, *keys_batch = random.split(key, pop_size + 1)
real_population = jit(vmap(jit_real_init_fn))(jnp.stack(keys_batch))

print(f"\nReal Population (shape {real_population.shape}):\n{real_population[0]}...")

# 4. Evaluate
real_fitnesses = sphere_evaluator.evaluate_batch(real_population)
print(f"\nSphere Fitnesses:\n{real_fitnesses}")


Config: RealGenomeConfig(array_shape=(5,), min_val=-5.0, max_val=5.0)

Real Population (shape (10, 5)):
[ 2.9283488  -3.5115457  -0.37373662  4.0129232  -2.9601252 ]...

Sphere Fitnesses:
[-45.9117546081543, -37.0478515625, -42.4464225769043, -35.13422775268555, -38.68138122558594, -36.138999938964844, -59.504310607910156, -48.80213165283203, -55.42524337768555, -25.40340232849121]


### Example 3: Autocorrection

One of the standout features of the **AbstractGenome** is the `get_autocorrection_pure_from_config` method. This method provides a **JIT-compatible function** to automatically "fix" invalid genomes, ensuring they remain within the defined bounds (e.g., after mutation or crossover).

#### Key Highlights:
- **Functionality**: For `RealGenome`, this is implemented as a simple `jnp.clip`.
- **Efficiency**: The function is fully JIT-compatible, enabling seamless integration into high-performance workflows.

This feature ensures that your evolutionary algorithms maintain valid genome representations at all times, enhancing robustness and reliability.

In [9]:
# 1. Get the autocorrection function
autocorrect_fn = RealGenome.get_autocorrection_pure_from_config(real_config)
jit_autocorrect_fn = jit(autocorrect_fn)

# 2. Create an "invalid" genome tensor (outside the bounds)
invalid_genome = jnp.array([-10.0, 1.0, 5.0, 9.9, -7.2])
print(f"Invalid genome: {invalid_genome}")

# 3. Fix it
corrected_genome = jit_autocorrect_fn(invalid_genome)
print(f"Corrected genome: {corrected_genome}")

Invalid genome: [-10.    1.    5.    9.9  -7.2]
Corrected genome: [-5.  1.  5.  5. -5.]


## 4. Teaser: A Fully JIT-Compiled Evolutionary Step

Experience the power of JAX with a fully **JIT-compiled evolutionary step**! This streamlined design showcases the seamless composition of functions, paving the way for:

- **Level 2: Operators** - Advanced genetic operators for selection, crossover, and mutation.
- **Level 3: Engine** - A robust and scalable evolutionary engine.

Stay tuned for the next levels, where we unlock the full potential of evolutionary algorithms with JAX!

In [10]:
# 1. Define the "step" function
# We'll use the real_config and sphere_evaluator from above
fitnesses_fn = sphere_evaluator.get_batch_fitness_function()

def run_one_generation_step(key, population):
    # (In Level 2, we would add selection, crossover, mutation here)
    
    # For now, just evaluate
    
    fitnesses = fitnesses_fn(population)
    # Find the best
    best_fitness = jnp.min(fitnesses) # Sphere is minimization
    best_genome = population[jnp.argmin(fitnesses)]
    
    # Create a new population (just for demo)
    key, *keys_batch = random.split(key, pop_size + 1)
    new_population = jit(vmap(jit_real_init_fn))(jnp.stack(keys_batch))
    
    return new_population, best_fitness, best_genome

# 2. JIT the entire step
jit_step = jit(run_one_generation_step)

# 3. Run the step
key, subkey = random.split(key)
final_pop, best_fit, best_ind = jit_step(subkey, real_population)

print(f"Best fitness from first pop: {best_fit}")
print(f"Best genome: {best_ind}")
print(f"\nNew population created, e.g.:\n{final_pop[0]}")

# 4. Run it again (will be fast, already compiled)
key, subkey = random.split(key)
_, best_fit_2, _ = jit_step(subkey, final_pop)
print(f"\nBest fitness from second pop: {best_fit_2}")


Best fitness from first pop: -59.504310607910156
Best genome: [-4.264735   4.818348   1.7367709 -1.9551146 -3.355745 ]

New population created, e.g.:
[ 0.90730906  0.7500577  -0.7754016   0.05871177 -0.16430616]

Best fitness from second pop: -63.92742156982422
