# MalthusJAX Level 2 Showcase: Operators
This notebook demonstrates the core **Level 2** components of MalthusJAX:

- **AbstractSelectionOperator**
- **AbstractMutationOperator**
- **AbstractCrossoverOperator**

### Key Design Pattern: The Factory Pattern

1. **Instantiation**: Create an Operator Class with static configuration (e.g., `mutation_rate`).
2. **Compilation**: Call `.get_compiled_function()` to obtain a pure, JIT-compiled JAX function.
3. **Population-Wide Execution**: Use `jax.vmap` to apply this pure function across entire populations (JAX arrays).

---

### Workflow Overview

#### 1: Introduction & Setup
- Import necessary components:
    - `jax` and `jax.numpy`
    - Level 1 components (Genomes and Fitness functions) for demo population creation
    - Refactored Level 2 Operator classes


In [1]:
import sys 
import os
import time
# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar
from jax import random, jit, vmap

from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.core.base import JAXTensorizable
from malthusjax.core.fitness.base import AbstractFitnessEvaluator
from malthusjax.core.genome.base import AbstractGenome, AbstractGenomeConfig



In [2]:
# --- Imports ---
import functools
from typing import Callable, List, Dict, Any, Tuple

print(f"JAX running on: {jax.default_backend()}")

# --- JAX Setup ---
# Master PRNG Key
key = random.PRNGKey(42)

# --- Level 1 Imports (from our previous work) ---
# Genomes
from malthusjax.core.genome.binary import BinaryGenome, BinaryGenomeConfig
from malthusjax.core.genome.real import RealGenome, RealGenomeConfig
#from malthusjax.core.genome.permutation import PermutationGenome # Assuming this exists
from malthusjax.core.genome.categorical import CategoricalGenome, CategoricalGenomeConfig

# Fitness
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.core.fitness.real import SphereFitnessEvaluator

# --- Level 2 Imports (NEW) ---
# Base classes (for type hints)
from malthusjax.operators.base import AbstractGeneticOperator
from malthusjax.operators.selection.base import AbstractSelectionOperator
from malthusjax.operators.mutation.base import AbstractMutation
from malthusjax.operators.crossover.base import AbstractCrossover

# Selection Operators
from malthusjax.operators.selection.tournament import TournamentSelection
from malthusjax.operators.selection.roulette import RouletteSelection

# Binary Operators
from malthusjax.operators.mutation.binary import BitFlipMutation, SwapMutation as BinarySwap
from malthusjax.operators.crossover.binary import UniformCrossover as BinaryUniform, SinglePointCrossover as BinarySinglePoint

# Real Operators
from malthusjax.operators.mutation.real import BallMutation
from malthusjax.operators.crossover.real import UniformCrossover as RealUniform, SinglePointCrossover as RealSinglePoint, AverageCrossover

# Permutation Operators
from malthusjax.operators.mutation.permutation import ScrambleMutation, SwapMutation as PermutationSwap
from malthusjax.operators.crossover.permutation import CycleCrossover, PillarCrossover

# Categorical Operators
from malthusjax.operators.mutation.categorical import CategoricalFlipMutation
from malthusjax.operators.crossover.categorical import UniformCrossover as CatUniform, SinglePointCrossover as CatSinglePoint


JAX running on: cpu


## 2. The AbstractFitnessEvaluator

The **AbstractFitnessEvaluator** is designed to define the logic for evaluating a single genome tensor. With MalthusJAX, batch evaluation (via `vmap`) is automatically handled, simplifying the process.

### Key Concept:
You need to implement the following method:



In [3]:
# --- 1. Define Population and Problem ---
POP_SIZE = 100
GENOME_LENGTH = 20

key, init_key, fitness_key = random.split(key, 3)

# --- 2. Create Population ---
# Get the JIT-able init function from the Genome class
bin_config = BinaryGenomeConfig(array_shape=(GENOME_LENGTH,), p=0.5)
init_fn = BinaryGenome.get_random_initialization_compilable_from_config(bin_config)

# vmap it to create a population
pop_init_fn = jit(vmap(init_fn))

# Create population
init_keys = random.split(init_key, POP_SIZE)
population_tensors = pop_init_fn(init_keys)

print(f"Created population tensors with shape: {population_tensors.shape}")
print(f"Data type: {population_tensors.dtype}")


# --- 3. Evaluate Fitness ---
# Get the JIT-able fitness function
sum_evaluator = BinarySumFitnessEvaluator()
fitness_fn = sum_evaluator.get_tensor_fitness_function()

# vmap it to evaluate the population
# (Note: sum_evaluator.evaluate_batch() does this automatically,
# but we do it manually here to show the pure-function pattern)
pop_fitness_fn = jit(vmap(fitness_fn))

fitness_values = pop_fitness_fn(population_tensors)

print(f"\nCalculated fitness values (shape {fitness_values.shape}):")
print(fitness_values)


Compiling random initialization function with array_shape=(20,), p=0.5
Created population tensors with shape: (100, 20)
Data type: bool

Calculated fitness values (shape (100,)):
[ 8 14  8  9  8 14 11 12 10 15  8  8  9 11  9  7 11  9 10  9 11 12 11 13
 10  9 10 10  9  9 11  7 11  8 11  9  9 11  9 13 13 10 10  6  8 13  9  9
 11 13  8  8 12 14 13  7 10 14  9  9 10 10 10 12 16  6  7  9 11  9 10  6
  6 13  9 10  9  9 11  9 10  9  7 13 10  7 13 12 11  9 12 12 12 15 11 13
 11 12  9  9]


### Auto-Vectorization with `vmap`

One of the most powerful features of MalthusJAX is **auto-vectorization**. The `evaluate_batch` method in the base class leverages `jax.vmap` to seamlessly create a batched version of the tensor function.

#### Why is this important?
- **Efficiency**: Evaluate an entire population in parallel on the accelerator.
- **Simplicity**: A single method call handles the batch evaluation for you.

With this approach, you can process large populations effortlessly, unlocking the full potential of JAX's high-performance capabilities.

In [4]:
# --- 1. Setup Selection ---
print("--- Tournament Selection ---")
key, select_key = random.split(key)

# Ensure POP_SIZE is a static Python integer
pop_size = int(POP_SIZE)

# We want to select POP_SIZE individuals to create the next generation
# (allowing duplicates, as tournament selection does)
tourn_op_factory = TournamentSelection(
    number_of_choices=pop_size,
    tournament_size=4
)

# --- 2. Get the JIT-compiled function ---
# This is a pure function we can use in our main loop

select_fn = tourn_op_factory.get_compiled_function()
# --- 3. Run Selection ---
# Pass in the key and the fitness values

selected_indices = select_fn(select_key, fitness_values)

print(f"Original fitnesses (top 10): {jnp.sort(fitness_values)[-10:]}")
print(f"\nSelected indices (shape {selected_indices.shape}):\n{selected_indices}")

# Prove that selection worked:
selected_fitnesses = fitness_values[selected_indices]
print(f"\nFitnesses of selected (top 10): {jnp.sort(selected_fitnesses)[-10:]}")
print(f"Average fitness BEFORE selection: {jnp.mean(fitness_values)}")
print(f"Average fitness AFTER selection: {jnp.mean(selected_fitnesses)}")

# --- 4. Get the Parent Population ---
# In the main GA loop, we'd use these indices to get the new population
parent_population = population_tensors[selected_indices]
print(f"\nShape of new parent population: {parent_population.shape}")

--- Tournament Selection ---
Original fitnesses (top 10): [13 13 13 14 14 14 14 15 15 16]

Selected indices (shape (100,)):
[53 37 83  9 49 49 90 30 93 94 34  6 13 62 64 42 88 64 86 18 92 49  9 23
  5 88 83 54 86 23 86 73 86 21 74 30 94 61 21 40 83 57 40 21 94 62  9 64
 83 53 68 48 63 75 86 87 27 54 90  1  5 73 54 23  7  6 23 52 57 57 13 64
 97 63 83  1 18  1 90 63 80 45 49 48 75 40 40 21 95  8 36 83 64  5 61  5
 45 87 64 23]

Fitnesses of selected (top 10): [15 15 15 15 16 16 16 16 16 16]
Average fitness BEFORE selection: 10.179999351501465
Average fitness AFTER selection: 12.460000038146973

Shape of new parent population: (100, 20)


Example 2: KnapsackFitnessEvaluator
This evaluator shows how to pass static arguments (weights, values, limits) into the JIT-compiled function using functools.partial. The get_tensor_fitness_function returns a closure that already knows about these static values.

In [5]:
# --- 1. Setup Mutation ---
print("--- BitFlip Mutation ---")
key, mutate_key = random.split(key)

# Instantiate the "factory" with static config
bitflip_op_factory = BitFlipMutation(mutation_rate=0.1)
   
# --- 2. Get the JIT-compiled function ---
# This function mutates *one* genome
mutate_one_fn = bitflip_op_factory.get_compiled_function()

#mutate just one genome for demonstration
sample_genome = parent_population[0]
mutated_genome = mutate_one_fn(sample_genome, mutate_key)
print(f"Original genome: {sample_genome}")
print(f"Mutated genome:  {mutated_genome}")

# --- 3. vmap it to create a population-wide mutator ---
# We map over axis 0 of keys and axis 0 of genomes
mutate_pop_fn = jit(vmap(mutate_one_fn, in_axes=(0, 0)))

# --- 4. Run Mutation ---
# We'll mutate the `parent_population` we just selected
print(f"Original parent (index 0):\n{parent_population.shape}")


# Create a key for each individual
pop_size = parent_population.shape[0]
master_key = jar.PRNGKey(0)
mutate_keys = jar.split(master_key, pop_size)   # shape (pop_size, 2)
mutated_population = mutate_pop_fn(parent_population, mutate_keys)

print(f"\nMutated parent (index 0):\n{mutated_population[0]}")

# See the difference
diff = (parent_population != mutated_population).astype(jnp.int32)
print(f"\nTotal bits flipped: {jnp.sum(diff)} out of {diff.size} bits ({100 * jnp.sum(diff) / diff.size:.2f}%)")

--- BitFlip Mutation ---
mutation_mask shape: (1, 20)
Original genome: [ True False  True False  True False  True False  True  True False  True
  True  True  True  True  True  True  True False]
Mutated genome:  [[ True False  True  True  True False  True False  True  True False  True
   True  True  True  True  True  True  True  True]]
Original parent (index 0):
(100, 20)
mutation_mask shape: (1, 20)

Mutated parent (index 0):
[[ True False  True False  True False  True False  True False  True  True
   True  True  True  True  True  True  True False]]

Total bits flipped: 94358 out of 200000 bits (47.18%)


### 3. Selection Operators

**Goal**: Select `N` individuals to be parents.

**Pattern**:
1. Instantiate a **Selection** class (e.g., `TournamentSelection`) with static parameters:
    - `number_of_choices`
    - `tournament_size`
2. Call `get_compiled_function()` to obtain the pure JAX function.
3. The pure function signature:
    ```python
    (key, fitness_values) -> selected_indices
    ```

In [6]:
# --- 1. Setup Selection ---
print("--- Tournament Selection ---")
key, select_key = random.split(key)

# We want to select POP_SIZE individuals to create the next generation
# (allowing duplicates, as tournament selection does)
tourn_op_factory = TournamentSelection(
    number_of_choices=POP_SIZE,
    tournament_size=4
)

# --- 2. Get the JIT-compiled function ---
# This is a pure function we can use in our main loop
select_fn = jit(tourn_op_factory.get_compiled_function())

# --- 3. Run Selection ---
# Pass in the key and the fitness values
selected_indices = select_fn(select_key, fitness_values)

print(f"Original fitnesses (top 10): {jnp.sort(fitness_values)[-10:]}")
print(f"\nSelected indices (shape {selected_indices.shape}):\n{selected_indices}")

# Prove that selection worked:
selected_fitnesses = fitness_values[selected_indices]
print(f"\nFitnesses of selected (top 10): {jnp.sort(selected_fitnesses)[-10:]}")
print(f"Average fitness BEFORE selection: {jnp.mean(fitness_values)}")
print(f"Average fitness AFTER selection: {jnp.mean(selected_fitnesses)}")

# --- 4. Get the Parent Population ---
# In the main GA loop, we'd use these indices to get the new population
parent_population = population_tensors[selected_indices]
print(f"\nShape of new parent population: {parent_population.shape}")

--- Tournament Selection ---
Original fitnesses (top 10): [13 13 13 14 14 14 14 15 15 16]

Selected indices (shape (100,)):
[ 6 90 23 87 95 73 95 26 54 90 83  7 39 54 92 39  9 16 45 20 93 57 54 96
 24 21 87 42 57  9 48 39  1 32  1  6 86 54 41 91  7 49  5 88 42 70 38 86
  6 53  8  8  1 90 95 39  1  6 83 95 93  5 20 21  1 49  1 94 86  1 95 63
 83 95 30 36 87 97 16 70 56 40 53 57  5 91 90 64  5 93 92 20 62 53  5 86
 87 77 57 73]

Fitnesses of selected (top 10): [14 14 14 14 15 15 15 15 15 16]
Average fitness BEFORE selection: 10.179999351501465
Average fitness AFTER selection: 12.389999389648438

Shape of new parent population: (100, 20)


### 4. Mutation Operators

**Goal**: Apply random variation to individuals.

**Pattern**:
1. **Instantiate** a Mutation class (e.g., `BitFlipMutation`) with static parameters (`mutation_rate`).
2. **Compile** the mutation function using `get_compiled_function()` to obtain a pure JAX function.
3. **Function Signature**:
    ```python
    (key, genome_tensor) -> mutated_genome_tensor
    ```
4. **Vectorization**: Use `vmap` to apply the mutation function across an entire population.

In [7]:
# --- 1. Setup Mutation ---
print("--- BitFlip Mutation ---")
key, mutate_key = random.split(key)

# Instantiate the "factory" with static config
bitflip_op_factory = BitFlipMutation(mutation_rate=0.1)

# --- 2. Get the JIT-compiled function ---
# This function mutates *one* genome
mutate_one_fn = jit(bitflip_op_factory.get_compiled_function())

# --- 3. vmap it to create a population-wide mutator ---
# We map over axis 0 of keys and axis 0 of genomes
mutate_pop_fn = jit(vmap(mutate_one_fn, in_axes=(0, 0)))

# --- 4. Run Mutation ---
# We'll mutate the `parent_population` we just selected
print(f"Original parent (index 0):\n{parent_population[0]}")

# Create a key for each individual
mutate_keys = random.split(mutate_key, POP_SIZE)
mutated_population = mutate_pop_fn( parent_population, mutate_keys)

print(f"\nMutated parent (index 0):\n{mutated_population[0]}")

# See the difference
diff = (parent_population != mutated_population).astype(jnp.int32)
print(f"\nTotal bits flipped: {jnp.sum(diff)}")

--- BitFlip Mutation ---
Original parent (index 0):
[False  True False False  True  True  True False False False False  True
  True  True  True False False  True  True  True]
mutation_mask shape: (1, 20)

Mutated parent (index 0):
[[False  True False False  True  True  True False  True False False  True
   True  True  True False False False  True  True]]

Total bits flipped: 93474


### Example 2: BallMutation (Real-Valued)

The same pattern works seamlessly across different genome types.

In [8]:
# --- 1. Setup Real Population ---
print("--- Ball Mutation ---")
key, init_key_real, mutate_key_real = random.split(key, 3)

real_config = RealGenomeConfig(array_shape=(10,), min_val=-1.0, max_val=1.0)
real_init_fn = jit(vmap(RealGenome.get_random_initialization_compilable_from_config(real_config)))
real_keys = random.split(init_key_real, POP_SIZE)
real_population = real_init_fn(real_keys)

print(f"Original real genome (index 0):\n{real_population[0]}")

# --- 2. Setup Mutation ---
ball_op_factory = BallMutation(mutation_rate=0.5, mutation_strength=0.1)
mutate_one_real_fn = jit(ball_op_factory.get_compiled_function())

# --- 3. vmap and Run ---
mutate_pop_real_fn = jit(vmap(mutate_one_real_fn, in_axes=(0, 0)))
real_mutate_keys = random.split(mutate_key_real, POP_SIZE)

mutated_real_population = mutate_pop_real_fn(real_mutate_keys, real_population)

print(f"\nMutated real genome (index 0):\n{mutated_real_population[0]}")

# --- 4. (Demo) Don't forget Autocorrection! ---
# The mutation *could* have pushed values out of bounds.
# We get the JIT-able correction function from the genome config.
autocorrect_fn = jit(RealGenome.get_autocorrection_compilable_from_config(real_config))
autocorrect_pop_fn = jit(vmap(autocorrect_fn))

# Apply correction (in a real GA, you'd do this after mutation)
final_population = autocorrect_pop_fn(mutated_real_population)
print(f"\nFinal corrected genome (index 0):\n{final_population[0]}")
print(f"Min value in pop: {jnp.min(final_population)}, Max value in pop: {jnp.max(final_population)}")


--- Ball Mutation ---
Original real genome (index 0):
[-0.5403094   0.38333678 -0.4683802   0.53857994 -0.712852   -0.8242552
  0.34233332 -0.5783129   0.43055868  0.99890995]

Mutated real genome (index 0):
[-0.5403094   0.29410166 -0.4683802   0.53857994 -0.6486067  -0.79814005
  0.33577687 -0.5783129   0.43055868  1.0857137 ]

Final corrected genome (index 0):
[-0.5403094   0.29410166 -0.4683802   0.53857994 -0.6486067  -0.79814005
  0.33577687 -0.5783129   0.43055868  1.        ]
Min value in pop: -1.0, Max value in pop: 1.0


## 5. Crossover Operators

**Goal**: Create new offspring from pairs of parents.

**Pattern**:
1. **Instantiate** a Crossover class (e.g., `UniformCrossover`) with static parameters:
    - `crossover_rate`
    - `n_outputs`
2. **Compile** the crossover function using `get_compiled_function()` to obtain the pure JAX function.
3. **Function Signature**:
    ```python
    (key, parent1_tensor, parent2_tensor) -> offspring_batch
    ```
4. **Output**:
    - The `offspring_batch` has shape `(n_outputs, ...genome_shape)`.
5. **Vectorization**:
    - Use `vmap` to apply the crossover function across the entire population.

In [9]:
# --- 1. Setup Crossover ---
print("--- Binary Uniform Crossover ---")
key, cross_key = random.split(key)

# Instantiate the "factory"
# We'll create 2 offspring for every 1 pair of parents
cross_op_factory = BinaryUniform(crossover_rate=0.5, n_outputs=2)

# --- 2. Get the JIT-compiled function ---
# This function crosses *one pair* of parents
cross_one_pair_fn = jit(cross_op_factory.get_compiled_function())

# --- 3. Prepare Parent Pairs ---
# We need two populations of parents. We'll shuffle our
# parent_population to get two different-ordered lists.
key, shuffle_key1, shuffle_key2 = random.split(key, 3)

parents_1 = random.permutation(shuffle_key1, parent_population, axis=0)
parents_2 = random.permutation(shuffle_key2, parent_population, axis=0)

print(f"Parent 1 (index 0):\n{parents_1[0]}")
print(f"Parent 2 (index 0):\n{parents_2[0]}")

# --- 4. vmap the Crossover Function ---
# We map over axis 0 of keys, axis 0 of parents_1, and axis 0 of parents_2
cross_pop_fn = jit(vmap(cross_one_pair_fn, in_axes=(0, 0, 0)))

# --- 5. Run Crossover ---
# We only need POP_SIZE / n_outputs keys, since each key generates n_outputs offspring
num_pairs = POP_SIZE // 2
cross_keys = random.split(cross_key, num_pairs)

# Run crossover on the first 50 pairs
offspring_batches = cross_pop_fn(
    cross_keys,
    parents_1[:num_pairs],
    parents_2[:num_pairs]
)

print(f"\nShape of offspring batches: {offspring_batches.shape}")
print(f"(num_pairs, n_outputs_per_pair, genome_length)")

# --- 6. Flatten into the new population ---
new_population = offspring_batches.reshape((POP_SIZE, GENOME_LENGTH))
print(f"\nFinal new population shape: {new_population.shape}")

print(f"\nOffspring 0 (from pair 0):\n{new_population[0]}")
print(f"Offspring 1 (from pair 0):\n{new_population[1]}")

--- Binary Uniform Crossover ---
Parent 1 (index 0):
[ True False  True  True False  True  True False  True  True  True False
  True  True False  True False  True  True False]
Parent 2 (index 0):
[ True False  True  True False  True  True  True  True False False False
 False False  True False  True  True  True  True]

Shape of offspring batches: (50, 2, 20)
(num_pairs, n_outputs_per_pair, genome_length)

Final new population shape: (100, 20)

Offspring 0 (from pair 0):
[ True False  True  True False  True  True False  True False False False
  True False False  True False  True  True False]
Offspring 1 (from pair 0):
[ True False  True  True False  True  True  True  True  True  True False
 False  True  True False  True  True  True  True]


### Example 2: AverageCrossover (Real-Valued)

The same `vmap`-based pattern applies seamlessly to real-valued genomes.

In [10]:
# --- 1. Setup Real Crossover ---
print("--- Real Average Crossover ---")
key, cross_key_real, shuffle_key_real = random.split(key, 3)

# We'll use the 'real_population' from the mutation demo
parents_real_1 = real_population
parents_real_2 = random.permutation(shuffle_key_real, real_population, axis=0)

print(f"Real Parent 1 (index 0):\n{parents_real_1[0].round(2)}")
print(f"Real Parent 2 (index 0):\n{parents_real_2[0].round(2)}")

# --- 2. Get JIT function (with blend_rate=0.5) ---
avg_cross_factory = AverageCrossover(blend_rate=0.5, n_outputs=2)
cross_one_avg_fn = jit(avg_cross_factory.get_compiled_function())

# --- 3. vmap and Run ---
cross_pop_avg_fn = jit(vmap(cross_one_avg_fn, in_axes=(0, 0, 0)))
avg_cross_keys = random.split(cross_key_real, num_pairs) # Use 50 keys

avg_offspring_batches = cross_pop_avg_fn(
    avg_cross_keys,
    parents_real_1[:num_pairs],
    parents_real_2[:num_pairs]
)

# --- 4. Check results ---
new_real_population = avg_offspring_batches.reshape((POP_SIZE, 10))
print(f"\nOffspring 0 (should be avg of P1[0] and P2[0]):\n{new_real_population[0].round(2)}")

# Verification
manual_avg = (parents_real_1[0] + parents_real_2[0]) / 2.0
print(f"Manual average:\n{manual_avg.round(2)}")

--- Real Average Crossover ---
Real Parent 1 (index 0):
[-0.53999996  0.38       -0.47        0.53999996 -0.71       -0.82
  0.34       -0.58        0.42999998  1.        ]
Real Parent 2 (index 0):
[ 0.29999998  0.98999995  0.79999995  0.68        0.45        0.66999996
 -0.90999997  0.32999998 -0.48        0.42      ]

Offspring 0 (should be avg of P1[0] and P2[0]):
[-0.12  0.69  0.17  0.61 -0.13 -0.08 -0.28 -0.12 -0.03  0.71]
Manual average:
[-0.12  0.69  0.17  0.61 -0.13 -0.08 -0.28 -0.12 -0.03  0.71]
