## 1. Setup and Imports

In [1]:
import sys
import os

# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar
import matplotlib.pyplot as plt

from malthusjax.core.genome import BinaryGenome
from malthusjax.core.solution.binary_solution import BinarySolution
from malthusjax.core.solution.base import FitnessTransforms
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.core.population.population import Population
from malthusjax.operators.mutation.binary import BitFlipMutation, ScrambleMutation
from malthusjax.operators.crossover.binary import UniformCrossover, CycleCrossover
from malthusjax.operators.selection.tournament import TournamentSelection
from malthusjax.operators.selection.roulette import RouletteSelection
from malthusjax.core.base import SerializationContext

print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

JAX version: 0.6.2
Available devices: [CpuDevice(id=0)]


## 2. BitFlipMutation: Introducing Genetic Variation
The BitFlipMutation operator randomly flips bits (0 to 1, or 1 to 0) in a binary genome with a specified probability. This introduces new genetic material into the population, preventing premature convergence.

#### Initialization
You can initialize BitFlipMutation with a default mutation rate (0.01) or a custom rate

In [2]:
# Default initialization
mutation_default = BitFlipMutation()
print(f"Default Mutation Rate: {mutation_default.mutation_rate}") 

# Custom initialization
custom_mutation_rate = 0.05
mutation_custom = BitFlipMutation(mutation_rate=custom_mutation_rate)
print(f"Custom Mutation Rate: {mutation_custom.mutation_rate}")

context = SerializationContext(
    genome_class=BinaryGenome,
    genome_init_params = {'array_size': 10, 'p': 0.5},
    solution_class=BinarySolution,
    fitness_evaluator=BinarySumFitnessEvaluator,
    fitness_transform=FitnessTransforms.minimize,
    population_class=Population)

pop_size = 5
genome_init_params = {'array_size': 10, 'p': 0.5}
population = Population(
    solution_class=BinarySolution,
    max_size=pop_size,
    random_init=True,
    random_key=jar.PRNGKey(10),
    genome_init_params=genome_init_params,
    fitness_transform = FitnessTransforms.minimize,

)
key = jar.PRNGKey(0)
# make a key array of the same size as the population
key = jar.split(key, pop_size)


Default Mutation Rate: 0.01
Custom Mutation Rate: 0.05


### Demonstration of Mutation
Let's see BitFlipMutation in action on a sample genome and then on a population.

In [3]:
# Apply mutation with a high rate (e.g., 0.5) to see clear changes
mutation_demo = BitFlipMutation(mutation_rate=0.5)

print("\nPopulation before mutation:")
for i, sol in enumerate(population.get_solutions()):
    print(f"  Solution {i}: {sol.genome.to_tensor()}")
    
# Apply mutation
mutation_fn = mutation_demo.build(population)
mutated_genomes = mutation_fn(population.to_stack(), key)

print("\nPopulation after mutation:")
for i, genome in enumerate(mutated_genomes):
    print(f"  Mutated Genome {i}: {genome}")   



Population before mutation:
  Solution 0: [1 1 1 0 0 1 0 0 0 0]
  Solution 1: [0 1 1 1 1 1 1 0 0 0]
  Solution 2: [1 1 0 0 1 1 1 1 1 0]
  Solution 3: [0 1 0 1 1 0 1 0 0 1]
  Solution 4: [0 0 1 1 1 1 0 1 0 1]

Population after mutation:
  Mutated Genome 0: [1 0 0 1 1 1 0 1 0 1]
  Mutated Genome 1: [1 0 1 0 0 0 0 0 1 0]
  Mutated Genome 2: [1 1 1 1 1 1 1 0 1 1]
  Mutated Genome 3: [1 1 0 0 0 0 1 1 0 1]
  Mutated Genome 4: [1 0 0 0 0 0 0 1 0 0]


In [4]:
# Apply mutation with a high rate (e.g., 0.5) to see clear changes
mutation_demo = ScrambleMutation(mutation_rate=0.5)


print("\nPopulation before mutation:")
for i, sol in enumerate(population.get_solutions()):
    print(f"  Solution {i}: {sol.genome.to_tensor()}")
    
# Apply mutation
mutation_fn = mutation_demo.build(population)
mutated_genomes = mutation_fn(population.to_stack(), key)

print("\nPopulation after mutation:")
for i, genome in enumerate(mutated_genomes):
    print(f"  Mutated Genome {i}: {genome}")  


Population before mutation:
  Solution 0: [1 1 1 0 0 1 0 0 0 0]
  Solution 1: [0 1 1 1 1 1 1 0 0 0]
  Solution 2: [1 1 0 0 1 1 1 1 1 0]
  Solution 3: [0 1 0 1 1 0 1 0 0 1]
  Solution 4: [0 0 1 1 1 1 0 1 0 1]

Population after mutation:
  Mutated Genome 0: [1 1 1 0 0 1 0 0 0 0]
  Mutated Genome 1: [0 1 0 1 1 0 1 1 1 0]
  Mutated Genome 2: [1 1 0 0 1 1 1 1 1 0]
  Mutated Genome 3: [0 0 1 1 0 1 1 0 1 0]
  Mutated Genome 4: [1 1 1 1 0 0 0 1 0 1]


## crossovers

In [5]:
# Example usage of UniformCrossover
crossover_operator = UniformCrossover(crossover_rate=0.7)
print(f"Uniform Crossover Rate: {crossover_operator.crossover_rate}")
# Example usage of CycleCrossover
cycle_crossover_operator = CycleCrossover(crossover_rate=0.7)
print(f"Cycle Crossover Rate: {cycle_crossover_operator.crossover_rate}")   


Uniform Crossover Rate: 0.7
Cycle Crossover Rate: 0.7


In [6]:
crossover_demo = UniformCrossover(crossover_rate=0.5)

print("\nPopulation before crossover:")
for i, sol in enumerate(population.get_solutions()):
    print(f"  Solution {i}: {sol.genome.to_tensor()}")
# Apply crossover
crossover_fn = crossover_demo.build(population) 
crossovered_genomes = crossover_fn(population.to_stack(), key)

print("\nPopulation after crossover:")
for i, genome in enumerate(crossovered_genomes):
    print(f"  Crossovered Genome {i}: {genome}")
  


Population before crossover:
  Solution 0: [1 1 1 0 0 1 0 0 0 0]
  Solution 1: [0 1 1 1 1 1 1 0 0 0]
  Solution 2: [1 1 0 0 1 1 1 1 1 0]
  Solution 3: [0 1 0 1 1 0 1 0 0 1]
  Solution 4: [0 0 1 1 1 1 0 1 0 1]

Population after crossover:
  Crossovered Genome 0: [1 1 1 0 0 1 1 0 1 0]
  Crossovered Genome 1: [0 1 1 1 1 1 1 1 0 1]
  Crossovered Genome 2: [1 1 0 0 1 1 1 1 1 0]
  Crossovered Genome 3: [0 1 1 1 1 1 1 0 1 1]
  Crossovered Genome 4: [0 1 1 1 1 1 1 1 1 1]


In [7]:
crossover_demo = CycleCrossover(crossover_rate=0.5)

print("\nPopulation before crossover:")
for i, sol in enumerate(population.get_solutions()):
    print(f"  Solution {i}: {sol.genome.to_tensor()}")
# Apply crossover
crossover_fn = crossover_demo.build(population) 
crossovered_genomes = crossover_fn(population.to_stack(), key)

print("\nPopulation after crossover:")
for i, genome in enumerate(crossovered_genomes):
    print(f"  Crossovered Genome {i}: {genome}")
  
evaluator = BinarySumFitnessEvaluator()
print("\nEvaluating fitness of crossovered genomes:")
fitness_values = evaluator.evaluate_population_stack(
    crossovered_genomes)
for i, fitness in enumerate(fitness_values):
    print(f"  Fitness of Crossovered Genome {i}: {fitness}")    
# Create a new population from the crossovered genomes
new_population = population.from_stack(
    crossovered_genomes,
    context=context,
    fitness_values=fitness_values
)

print("\nNew Population Statistics:")

new_population.get_statistics()

print("\nPopulation best solution:")
best_solution = new_population.get_best_solution()
print(f"  Best Solution Genome: {best_solution.genome.to_tensor()}")
print(f"  Best Solution Fitness: {best_solution.fitness}")
print(f"  Best Solution raw Fitness: {best_solution.raw_fitness}")
print(f"  Best Solution raw Fitness: {best_solution._raw_fitness}")




Population before crossover:
  Solution 0: [1 1 1 0 0 1 0 0 0 0]
  Solution 1: [0 1 1 1 1 1 1 0 0 0]
  Solution 2: [1 1 0 0 1 1 1 1 1 0]
  Solution 3: [0 1 0 1 1 0 1 0 0 1]
  Solution 4: [0 0 1 1 1 1 0 1 0 1]

Population after crossover:
  Crossovered Genome 0: [1 0 0 1 1 1 0 1 0 1]
  Crossovered Genome 1: [1 0 1 0 0 0 0 0 1 0]
  Crossovered Genome 2: [1 1 1 1 1 1 1 0 1 1]
  Crossovered Genome 3: [1 1 0 0 0 0 1 1 0 1]
  Crossovered Genome 4: [1 0 0 0 0 0 0 1 0 0]

Evaluating fitness of crossovered genomes:
  Fitness of Crossovered Genome 0: 6
  Fitness of Crossovered Genome 1: 3
  Fitness of Crossovered Genome 2: 9
  Fitness of Crossovered Genome 3: 5
  Fitness of Crossovered Genome 4: 2

New Population Statistics:

Population best solution:
  Best Solution Genome: [1 1 1 1 1 1 1 0 1 1]
  Best Solution Fitness: 9
  Best Solution raw Fitness: 9
  Best Solution raw Fitness: 9


## stack evaluation

In [8]:
fitness_evaluator = BinarySumFitnessEvaluator()
population = Population(
    solution_class=BinarySolution,
    max_size=10,
    random_init=True,
    random_key=jar.PRNGKey(10),
    genome_init_params={'array_size': 10, 'p': 0.5},
    fitness_transform=FitnessTransforms.minimize,
)

evaluator.evaluate_population(population)
selection_operator = TournamentSelection(number_of_tournaments=100, tournament_size=3)
selection_operator2 = RouletteSelection(number_choices=100)

In [9]:
len(population.get_fitness_values())

10

In [10]:
fn = selection_operator.build(population)
fn2 = selection_operator2.build(population)
tournament_selection = fn(population.get_fitness_values(), jar.PRNGKey(0))
tournament_selection2 = fn2(population.get_fitness_values(), jar.PRNGKey(0))
print("Tournament Selection Indices:", tournament_selection)
print("Roulette Selection Indices:", tournament_selection2)

new_population = population.from_list_of_indexes(tournament_selection2)

Tournament Selection Indices: [9 3 3 6 0 0 3 6 7 5 9 7 9 6 3 3 6 6 6 9 9 9 0 6 7 4 9 0 7 5 0 6 6 6 5 0 3
 9 9 1 9 0 0 6 5 4 9 6 6 7 9 5 4 0 6 5 7 9 9 9 8 9 0 0 0 9 8 9 0 8 9 9 6 6
 6 6 9 0 5 6 6 9 8 9 7 6 0 9 9 8 3 3 6 0 5 9 7 8 3 8]
Roulette Selection Indices: [0 0 6 4 3 8 6 2 2 8 0 9 3 4 1 0 1 2 4 9 9 0 4 2 4 3 1 2 7 7 9 3 4 4 5 4 2
 7 4 5 7 7 3 2 4 6 6 9 1 2 4 4 8 3 2 2 7 1 7 6 5 6 9 2 3 8 8 1 7 2 1 3 2 1
 1 1 6 1 4 2 8 0 1 1 3 4 1 3 5 6 8 3 1 7 1 5 5 7 1 7]


In [11]:
new_population.get_statistics()

{'pop_size': 100,
 'max_fitness': -4.0,
 'min_fitness': -7.0,
 'avg_fitness': -5.389999866485596,
 'fitness_std': 0.9580709934234619,
 '25th_percentile': -6.0,
 '50th_percentile': -5.0,
 '75th_percentile': -5.0}

In [12]:
population.get_statistics()

{'pop_size': 10,
 'max_fitness': -4.0,
 'min_fitness': -7.0,
 'avg_fitness': -5.099999904632568,
 'fitness_std': 0.9433980584144592,
 '25th_percentile': -5.75,
 '50th_percentile': -5.0,
 '75th_percentile': -4.250000476837158}

In [13]:
random_key = jar.PRNGKey(0)
n_iterations = 100
pop_size = 10


crossover_demo = UniformCrossover(crossover_rate=0.5)
mutation_demo = BitFlipMutation(mutation_rate=0.5)
#selection_demo = TournamentSelection(number_of_tournaments=pop_size, tournament_size=20)
selection_demo = RouletteSelection(number_choices=pop_size)
demo_evaluator = BinarySumFitnessEvaluator()
mutation_layer_fn = mutation_demo.build(population)
crossover_layer_fn = crossover_demo.build(population)
demo_selection_layer_fn = selection_demo.build(population)

demo_population = Population(
    solution_class=BinarySolution,
    max_size=pop_size,
    random_init=True,
    random_key=random_key,
    genome_init_params={'array_size': 10, 'p': 0.5},
    fitness_transform=None,
)
demo_population_stack = demo_population.to_stack()

for i in range(n_iterations):
    print(f"\nIteration {i+1}/{n_iterations}")
    # Create a new key for this iteration
    iteration_key = jar.fold_in(random_key, i)
    
    # Evaluate fitness
    fitness_values = demo_evaluator.evaluate_population_stack(demo_population_stack)

    # Select parents
    selection_key = jar.fold_in(iteration_key, 0)
    selected_indices = demo_selection_layer_fn(fitness_values, selection_key)
    selected_population = demo_population.from_list_of_indexes(selected_indices)
    
    # Crossover
    crossover_key = jar.fold_in(iteration_key, 1)
    crossover_keys = jar.split(crossover_key, pop_size)
    crossovered_genomes = crossover_layer_fn(selected_population.to_stack(), crossover_keys)

    # Mutation
    mutation_key = jar.fold_in(iteration_key, 2)
    mutation_keys = jar.split(mutation_key, pop_size)
    mutated_genomes = mutation_layer_fn(crossovered_genomes, mutation_keys)

    # Create new population
    demo_population_stack = mutated_genomes


Iteration 1/100

Iteration 2/100

Iteration 3/100

Iteration 4/100

Iteration 5/100

Iteration 6/100

Iteration 7/100

Iteration 8/100

Iteration 9/100

Iteration 10/100

Iteration 11/100

Iteration 12/100

Iteration 13/100

Iteration 14/100

Iteration 15/100

Iteration 16/100

Iteration 17/100

Iteration 18/100

Iteration 19/100

Iteration 20/100

Iteration 21/100

Iteration 22/100

Iteration 23/100

Iteration 24/100

Iteration 25/100

Iteration 26/100

Iteration 27/100

Iteration 28/100

Iteration 29/100

Iteration 30/100

Iteration 31/100

Iteration 32/100

Iteration 33/100

Iteration 34/100

Iteration 35/100

Iteration 36/100

Iteration 37/100

Iteration 38/100

Iteration 39/100

Iteration 40/100

Iteration 41/100

Iteration 42/100

Iteration 43/100

Iteration 44/100

Iteration 45/100

Iteration 46/100

Iteration 47/100

Iteration 48/100

Iteration 49/100

Iteration 50/100

Iteration 51/100

Iteration 52/100

Iteration 53/100

Iteration 54/100

Iteration 55/100

Iteration 56/100



In [14]:
demo_evaluator.evaluate_population(demo_population)
new_population = demo_population.from_stack(demo_population_stack)
demo_evaluator.evaluate_population(new_population)

In [15]:
demo_population.get_statistics()


{'pop_size': 10,
 'max_fitness': 7.0,
 'min_fitness': 2.0,
 'avg_fitness': 5.0,
 'fitness_std': 1.2649110555648804,
 '25th_percentile': 5.0,
 '50th_percentile': 5.0,
 '75th_percentile': 5.749999523162842}

In [16]:
new_population.get_statistics()

{'pop_size': 10,
 'max_fitness': 7.0,
 'min_fitness': 2.0,
 'avg_fitness': 4.099999904632568,
 'fitness_std': 1.6401219367980957,
 '25th_percentile': 3.0,
 '50th_percentile': 4.0,
 '75th_percentile': 5.499999046325684}

In [17]:
# Example usage of BinarySumFitnessEvaluator
fitness_evaluator = BinarySumFitnessEvaluator()
print(f"Fitness Evaluator: {fitness_evaluator}")
# Evaluate fitness of the new population
fitness_scores = fitness_evaluator.evaluate_population_stack(crossovered_genomes)
print("\nFitness Scores of the new population:")
fitness_scores

# ok I need now to define a function that will tournament selection with the fitness scores that will return the indices of the selected genomes,
# the function must be jttable

tournament_size = 4
number_of_tournaments = 10
# random vectors of integers between 1 and pop_size
tournament_indices = jax.random.randint(jar.PRNGKey(0), (number_of_tournaments, tournament_size), 0, population.size)
# fitness matrix is a matrix with the same shape as tournament_indices but with the fitness scores of the genomes indexed by the tournament_indices
fitness_matrix = jnp.take(fitness_scores, tournament_indices)
print(fitness_matrix)
# tournament_winners are the indices of the maximum fitness scores in each row of the fitness matrix
tournament_winners = jnp.argmax(fitness_matrix, axis=1)
print("\nTournament Winners Indices:")
print(tournament_winners)

# Select the winning genomes using the indices
winning_genomes = jnp.take(crossovered_genomes, tournament_winners, axis=0)

print("\nWinning Genomes:")
for i, genome in enumerate(winning_genomes):    
    print(f"  Winning Genome {i}: {genome}")

#now_let's make it a jax.jit function that takes the fitness scores and tournament indices and returns the winning genomes
# make the static parameters tournament_size and number_of_tournaments
from functools import partial

@partial(jax.jit, static_argnums=(1, 2))  # Make tournament_size and number_of_tournaments static
def tournament_selection(fitness_scores, tournament_size=4, number_of_tournaments=3, key=jar.PRNGKey(0)):
    """
    Perform tournament selection on fitness scores.
    
    Args:
        fitness_scores: Array of fitness values for each individual
        tournament_size: Number of individuals in each tournament (static)
        number_of_tournaments: Number of tournaments to run (static)
        key: JAX random key for reproducibility
        
    Returns:
        Array of indices of tournament winners
    """
        
    pop_size = fitness_scores.shape[0]
    
    # Generate random tournament indices
    tournament_indices = jax.random.randint(
        key, 
        (number_of_tournaments, tournament_size), 
        0, 
        pop_size
    )
    
    # Get fitness values for tournament participants
    fitness_matrix = jnp.take(fitness_scores, tournament_indices)
    
    # Find winners (indices with maximum fitness in each tournament)
    local_winners = jnp.argmax(fitness_matrix, axis=1)
    
    # Convert local indices back to global population indices
    tournament_winners = tournament_indices[jnp.arange(number_of_tournaments), local_winners]
    
    return tournament_winners
# Now we can call the tournament_selection function with the fitness scores
tournament_key = jar.PRNGKey(42)  # Use a different key for reproducibility
tournament_winners = tournament_selection(
    fitness_scores, 
    tournament_size=4, 
    number_of_tournaments=10, 
    key=tournament_key
)
winning_genomes = jnp.take(crossovered_genomes, tournament_winners, axis=0)

print("\nWinning Genomes from JIT Tournament Selection:")
for i, genome in enumerate(winning_genomes):
    print(f"  Winning Genome {i}: {genome}")
    
from functools import partial


@partial(jax.jit, static_argnums=(2,))  # Make number_of_tournaments static
def roulette_selection(key, fitness_scores, number_of_tournaments=3):
    pop_size = fitness_scores.shape[0]
    # calculate probabilities
    probabilities = fitness_scores / jnp.sum(fitness_scores)
    # choose random numbers based on the probabilities
    selected_indices = jax.random.choice(key, jnp.arange(pop_size), 
                                       shape=(number_of_tournaments,), p=probabilities)
    # select the genomes using the indices
    selected_genomes = jnp.take(crossovered_genomes, selected_indices, axis=0)
    return selected_genomes

# Now we can call the roulette_selection function with the fitness scores and a key
roulette_key = jar.PRNGKey(0)
selected_genomes = roulette_selection(roulette_key, fitness_scores, number_of_tournaments=5)
print("\nSelected Genomes from Roulette Selection:")
for i, genome in enumerate(selected_genomes):   
    print(f"  Selected Genome {i}: {genome}")
    


Fitness Evaluator: <malthusjax.core.fitness.binary_ones.BinarySumFitnessEvaluator object at 0x151d61dc0>

Fitness Scores of the new population:
[[5 7 7 9]
 [6 9 7 9]
 [9 9 7 9]
 [7 5 5 7]
 [8 7 7 6]
 [9 7 7 7]
 [9 7 8 7]
 [8 8 5 9]
 [5 9 9 9]
 [9 8 5 8]]

Tournament Winners Indices:
[3 1 0 0 0 0 0 3 1 0]

Winning Genomes:
  Winning Genome 0: [1 1 1 1 1 1 1 0 1 1]
  Winning Genome 1: [0 0 1 1 0 1 1 1 1 0]
  Winning Genome 2: [1 0 1 1 0 0 1 1 1 1]
  Winning Genome 3: [1 0 1 1 0 0 1 1 1 1]
  Winning Genome 4: [1 0 1 1 0 0 1 1 1 1]
  Winning Genome 5: [1 0 1 1 0 0 1 1 1 1]
  Winning Genome 6: [1 0 1 1 0 0 1 1 1 1]
  Winning Genome 7: [1 1 1 1 1 1 1 0 1 1]
  Winning Genome 8: [0 0 1 1 0 1 1 1 1 0]
  Winning Genome 9: [1 0 1 1 0 0 1 1 1 1]

Winning Genomes from JIT Tournament Selection:
  Winning Genome 0: [1 1 1 1 1 1 0 1 1 1]
  Winning Genome 1: [1 1 1 0 1 1 1 1 1 1]
  Winning Genome 2: [1 1 1 1 1 1 0 1 1 1]
  Winning Genome 3: [1 1 1 1 1 0 1 1 0 0]
  Winning Genome 4: [1 1 1 1 1 1 1 0 1 1