## 1. Setup and Imports

In [None]:
import sys
import os

# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar
import matplotlib.pyplot as plt

from malthusjax.core.genome import BinaryGenome
from malthusjax.core.fitness.binary_ones import BinarySumFitnessEvaluator
from malthusjax.core.population.population import AbstractPopulation
from malthusjax.operators.selection.tournament import TournamentSelection
from malthusjax.operators.selection.roulette import RouletteSelection

In [2]:
population_size = 10
random_key = jar.PRNGKey(42)
genome_init_params = {'array_size': 20, 'p': 0.5}
population = AbstractPopulation(
    genome_cls=BinaryGenome,
    pop_size=population_size,
    random_init=True,
    random_key=random_key,
    genome_init_params=genome_init_params,
    fitness_transform=None,
)

evaluator = BinarySumFitnessEvaluator()
fitness_values = evaluator(population) 
print("Initial population:")
for i, sol in enumerate(population):
    print(f"  Solution {i}: {sol.astype(jnp.int32)} Fitness: {fitness_values[i]}")

Initial population:
  Solution 0: [0 1 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 1 0] Fitness: 6.0
  Solution 1: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 2: [0 0 1 1 1 1 1 1 1 1 0 0 0 1 1 1 0 1 1 0] Fitness: 13.0
  Solution 3: [1 1 1 0 1 0 0 1 0 1 0 1 0 1 0 0 0 1 0 1] Fitness: 10.0
  Solution 4: [0 0 0 0 0 0 0 1 0 1 0 0 0 1 1 1 1 0 0 0] Fitness: 6.0
  Solution 5: [1 1 1 1 0 1 1 0 1 0 1 0 0 1 0 1 1 1 1 1] Fitness: 14.0
  Solution 6: [1 1 1 0 0 0 0 0 1 1 1 0 1 1 1 1 1 0 0 1] Fitness: 12.0
  Solution 7: [0 1 0 0 0 0 1 0 1 1 0 1 0 0 1 1 1 1 0 1] Fitness: 10.0
  Solution 8: [0 1 1 0 1 1 0 1 0 0 1 0 1 1 1 1 0 1 0 1] Fitness: 12.0
  Solution 9: [1 0 1 0 1 0 1 1 1 1 1 0 0 0 0 0 1 0 1 1] Fitness: 11.0


In [5]:
tournament_selector = TournamentSelection(number_of_tournaments=population_size, tournament_size=3)
selected_population = tournament_selector(population, fitness_values, jar.PRNGKey(42))
for i, sol in enumerate(selected_population):
    print(f"  Solution {i}: {sol.astype(jnp.int32)} Fitness: {evaluator(selected_population)[i]}")


random_key inside selection call: <class 'jaxlib._jax.ArrayImpl'>
  Solution 0: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 1: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 2: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 3: [0 1 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 1 0] Fitness: 6.0
  Solution 4: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 5: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 6: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 7: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 8: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 9: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0


In [4]:
selection_jax = tournament_selector.get_compiled_function()
selection_jiit = jax.jit(selection_jax)
fitness_array = jnp.array(fitness_values)
selected_indices_jax = selection_jax(fitness_array, jar.PRNGKey(42))
selected_indices_jiit = selection_jiit(fitness_array, jar.PRNGKey(42))
print(f"Selected indices (jax): {selected_indices_jax}")
print(f"Selected indices (jiit): {selected_indices_jiit}")

Selected indices (jax): [1 9 7 2 5 3 5 8 1 1]
Selected indices (jiit): [1 9 7 2 5 3 5 8 1 1]


In [7]:
roulette_selector = RouletteSelection(number_choices=population_size)
selected_population_roulette = roulette_selector(population, fitness_values, jar.PRNGKey(42))
print("\nPopulation after tournament selection:")
for i, sol in enumerate(selected_population):
    print(f"  Solution {i}: {sol.astype(jnp.int32)} Fitness: {evaluator(selected_population)[i]}")
print("\nPopulation after roulette selection:")
for i, sol in enumerate(selected_population_roulette):
    print(f"  Solution {i}: {sol.astype(jnp.int32)} Fitness: {evaluator(selected_population_roulette)[i]}")

random_key inside selection call: <class 'jaxlib._jax.ArrayImpl'>

Population after tournament selection:
  Solution 0: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 1: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 2: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 3: [0 1 0 0 0 0 0 1 1 0 0 0 0 0 1 1 0 0 1 0] Fitness: 6.0
  Solution 4: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 5: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 6: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 7: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 8: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 9: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0

Population after roulette selection:
  Solution 0: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 1: [0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1] Fitness: 14.0
  Solution 2: [0 