# 1. Setup and Imports

In [1]:
import sys
import os

# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar

from malthusjax.core.genome import BinaryGenome
from malthusjax.core.fitness.binary_ones import KnapsackFitnessEvaluator, BinarySumFitnessEvaluator
from malthusjax.core.solution.binary_solution import BinarySolution
from malthusjax.core.population.population import Population
from malthusjax.core.solution.base import FitnessTransforms
from malthusjax.operators.selection.tournament import TournamentSelection


print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

JAX version: 0.6.2
Available devices: [CpuDevice(id=0)]


# 2. BinaryGenome: The Building Block of Solutions
The BinaryGenome class represents a binary string, which serves as the genetic material for our evolutionary algorithms.

## Initialization
You can initialize a BinaryGenome with default parameters or specify custom ones like array_size and a probability p for random initialization

In [2]:
# Create a valid genome
valid_genome = BinaryGenome(array_size=5, p=0.5, random_init=True)
valid_genome.genome = jnp.array([0, 1, 1, 0, 1])
print(f"Valid Genome is valid: {valid_genome._validate()}")

# Create an invalid genome (contains a '2')
invalid_genome = BinaryGenome(array_size=5, p=0.5, random_init=False)
invalid_genome.genome = jnp.array([0, 2, 1, 0, 1])
print(f"Invalid Genome is valid: {invalid_genome._validate()}") 

Valid Genome is valid: True
Invalid Genome is valid: False


In [3]:
# Create and use binary genome
genome = BinaryGenome(array_size=10, p=0.3, random_init=True)

# Clone without importing copy
genome_copy = genome.clone()

# Convert to/from tensors with context
tensor = genome.to_tensor()
reconstructed = BinaryGenome.from_tensor(tensor=tensor, genome_init_params={'array_size': 10, 'p':.3})

# Use with JAX transformations
genomes = [BinaryGenome(array_size=5, p=0.5, random_init=True) for _ in range(10)]
jax.tree_util.tree_map(lambda x: x.shape, genomes)  # Using tree_util.tree_map

{'array_size': 10, 'p': 0.3}


[BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state),
 BinaryGenome(array_size=5, invalid_state)]

## Validation
Genomes can be validated to ensure they contain only binary values (0s and 1s).

In [4]:
# Create a valid genome
valid_genome = BinaryGenome(array_size=5, p=0.5, random_init=False)
valid_genome.genome = jnp.array([0, 1, 1, 0, 1])
print(f"Valid Genome is valid: {valid_genome._validate()}")

# Create an invalid genome (contains a '2')
invalid_genome = BinaryGenome(array_size=5, p=0.5, random_init=False)
invalid_genome.genome = jnp.array([0, 2, 1, 0, 1])
print(f"Invalid Genome is valid: {invalid_genome._validate()}")

Valid Genome is valid: True
Invalid Genome is valid: False


## Distance Calculation (Hamming Distance)
The framework supports calculating the Hamming distance between two binary genomes, representing the number of positions at which the corresponding bits are different.

In [5]:
genome1 = BinaryGenome(array_size=5, p=0.5,random_init=False)
genome1.genome = jnp.array([0, 1, 1, 0, 1])

genome2 = BinaryGenome(array_size=5, p=0.5,random_init=False)
genome2.genome = jnp.array([1, 1, 0, 0, 1])

# Hamming distance should be 2 (positions 0 and 2 are different)
print(f"Distance between genome1 and genome2: {genome1.distance(genome2)}")

Distance between genome1 and genome2: 2.0


# 3. BinarySolution: Adding Fitness to Genomes
A BinarySolution encapsulates a BinaryGenome and its associated fitness value.

## Initialization and Fitness Assignment
Solutions can be initialized with a genome and their fitness can be set and retrieved.

In [7]:
# Create solution with genome
from malthusjax.core.genome.binary import BinaryGenome
from malthusjax.core.solution.base import AbstractSolution
genome = BinaryGenome(array_size=10, p=0.5, random_init=True)
solution = AbstractSolution(genome=genome)

# Create solution from class
solution2 = AbstractSolution(
    genome_cls=BinaryGenome,
    random_init=True,
    genome_init_params={'array_size': 10, 'p': 0.5},
)

# Set fitness and compare
solution.raw_fitness = 0.85
solution2.raw_fitness = 0.90

print(solution < solution2)  # True (lower fitness)

# Tensor operations with context
tensor = solution.to_tensor()
context = solution.get_serialization_context()
reconstructed = AbstractSolution.from_tensor(tensor, context=context)

True
{'array_size': 10, 'p': 0.5}


In [8]:
# Simple creation
genome_init_params = {'array_size': 10, 'p': 0.3}

solution = BinarySolution(genome_init_params=genome_init_params, random_init=True)

# Tensor operations (context handled automatically)
tensor = solution.to_tensor()
reconstructed = BinarySolution.from_tensor(tensor, 
                                            genome_init_params=genome_init_params)


print(solution.genome_init_params)
# Convenience operations
neighbor = solution.flip_bit(5)
#distance = solution.hamming_distance(neighbor)

# Direct from array
binary_data = jnp.array([1, 0, 1, 1, 0, 1, 0, 0, 1, 1])
solution2 = BinarySolution.from_binary_array(binary_data,genome_init_params=genome_init_params)

# JAX operations work seamlessly
solutions = [BinarySolution(genome_init_params = genome_init_params) for _ in range(10)]
#jax.tree_util.tree_map(lambda x: x.shape, solutions) 

{'array_size': 10, 'p': 0.3}
{'array_size': 10, 'p': 0.3}




In [9]:
solution.genome # Should be BinaryGenome

BinaryGenome(array_size=10, p=0.3, valid=True, semantic_key='dabc061d27...')

## Fitness Transformations
Fitness values can be transformed (e.g., for minimization problems).

In [10]:
genome_init_params = {'array_size': 5, 'p': 0.5}
my_ft = lambda x: x**2 # or try FitnessTransforms.minimize
new_solution = BinarySolution(genome_init_params= genome_init_params, random_init=True, fitness_transform=my_ft)
new_solution._raw_fitness = 15
print(f"Solution Transformed Fitness (minimizing): {new_solution.fitness}") 

Solution Transformed Fitness (minimizing): 225


## Cloning Solutions
Solutions can be cloned to create independent copies.

In [11]:
genome_init_params = {'array_size': 10, 'p': 0.3}

original_solution = BinarySolution(genome_init_params = genome_init_params, random_init=True, random_key=jar.PRNGKey(0))
original_solution.raw_fitness = 100.0
print(f"Original Solution Genome: {original_solution.genome}, Fitness: {original_solution.raw_fitness}")
cloned_solution = original_solution.clone()
print(f"Original Solution: {original_solution.genome}, Fitness: {original_solution.raw_fitness}") 
print(f"Cloned Solution: {cloned_solution.genome}, Fitness: {cloned_solution.raw_fitness}") 
print(f"Are they the same object? {original_solution is cloned_solution}")

Original Solution Genome: BinaryGenome(size=10, valid=True), Fitness: 100.0
Original Solution: BinaryGenome(size=10, valid=True), Fitness: 100.0
Cloned Solution: BinaryGenome(size=10, valid=True), Fitness: 100.0
Are they the same object? False


# 4. Fitness Evaluators: Quantifying Solution Quality
MalthusJAX provides various fitness evaluators to assign a fitness score to solutions.

## BinarySumFitnessEvaluator
This evaluator calculates fitness as the sum of '1's in a binary genome.


In [12]:
binary_evaluator = BinarySumFitnessEvaluator()

# Create a population of binary solutions
population_size = 5
genome_init_params = {'array_size': 10, 'p': 0.3}
binary_solutions = [
    BinarySolution( genome_init_params, random_init=True, random_key=jar.PRNGKey(i))
    for i in range(population_size)
]

print("Binary Solutions before evaluation:")
for i, sol in enumerate(binary_solutions):
    print(f"  Solution {i}: {sol.genome.to_tensor()}, Fitness: {sol.raw_fitness}")

# Evaluate the solutions
binary_evaluator.evaluate_solutions(binary_solutions) 

print("\nBinary Solutions after evaluation (BinarySumFitnessEvaluator):")
for i, sol in enumerate(binary_solutions):
    print(f"  Solution {i}: {sol.genome.to_tensor()}, Fitness: {sol.raw_fitness}")

Binary Solutions before evaluation:
  Solution 0: [1 1 0 0 1 1 1 0 1 0], Fitness: None
  Solution 1: [0 0 1 1 0 0 1 1 0 0], Fitness: None
  Solution 2: [0 0 1 0 1 0 0 0 0 1], Fitness: None
  Solution 3: [1 0 0 0 0 1 1 1 0 0], Fitness: None
  Solution 4: [0 0 0 0 1 0 1 1 1 0], Fitness: None

Binary Solutions after evaluation (BinarySumFitnessEvaluator):
  Solution 0: [1 1 0 0 1 1 1 0 1 0], Fitness: 6.0
  Solution 1: [0 0 1 1 0 0 1 1 0 0], Fitness: 4.0
  Solution 2: [0 0 1 0 1 0 0 0 0 1], Fitness: 3.0
  Solution 3: [1 0 0 0 0 1 1 1 0 0], Fitness: 4.0
  Solution 4: [0 0 0 0 1 0 1 1 1 0], Fitness: 4.0


## KnapsackFitnessEvaluator
This evaluator solves a binary knapsack problem, where each item has a weight and a value. Solutions are penalized if their total weight exceeds a limit.

In [13]:

knapsack_weights = jnp.array([2, 3, 4, 5, 6])
knapsack_values = jnp.array([6, 5, 4, 3, 2])
weight_limit = 10.0

knapsack_evaluator = KnapsackFitnessEvaluator(
    weights=knapsack_weights,
    values=knapsack_values,
    weight_limit=weight_limit
)



# Create some solutions for the knapsack problem
knapsack_solutions_data = [
    jnp.array([1, 1, 0, 0, 0]),  # weight=5, value=11 (under limit)
    jnp.array([0, 0, 0, 0, 0]),  # weight=0, value=0
    jnp.array([0, 0, 1, 1, 0]),  # weight=9, value=7 (under limit)
    jnp.array([1, 1, 1, 1, 1]),  # weight=20, exceeding limit (default penalty -1.0)
]

knapsack_solutions = []
genome_init_params = {'array_size': 5, 'p': 0.3}
for i, data in enumerate(knapsack_solutions_data):
    solution = BinarySolution(genome_init_params = genome_init_params, random_init=False)
    solution.genome.genome = data
    knapsack_solutions.append(solution)
print("\nKnapsack Solutions before evaluation:")
for i, sol in enumerate(knapsack_solutions):
    print(f"  Solution {i}: {sol.genome.to_tensor()}, Fitness: {sol.raw_fitness}")

# Evaluate the knapsack solutions
binary_evaluator.evaluate_solutions(knapsack_solutions) 

print("\nKnapsack Solutions after evaluation (KnapsackFitnessEvaluator):")
for i, sol in enumerate(knapsack_solutions):
    print(f"  Solution {i}: {sol.genome.to_tensor()}, Fitness: {sol.raw_fitness}")

# Demonstrate custom penalty for exceeding weight limit
custom_penalty_evaluator = KnapsackFitnessEvaluator(
    weights=knapsack_weights,
    values=knapsack_values,
    weight_limit=weight_limit,
    default_exceding_weight_penalization=-99.0
)
invalid_solution = BinarySolution.from_tensor(tensor=jnp.array([1, 1, 1, 1, 1]), genome_init_params=genome_init_params)

custom_penalty_evaluator.evaluate_single_solution(invalid_solution)
print(f"\nSolution exceeding limit with custom penalty (-99.0): {invalid_solution.raw_fitness}")


Knapsack Solutions before evaluation:
  Solution 0: [1 1 0 0 0], Fitness: None
  Solution 1: [0 0 0 0 0], Fitness: None
  Solution 2: [0 0 1 1 0], Fitness: None
  Solution 3: [1 1 1 1 1], Fitness: None

Knapsack Solutions after evaluation (KnapsackFitnessEvaluator):
  Solution 0: [1 1 0 0 0], Fitness: 2.0
  Solution 1: [0 0 0 0 0], Fitness: 0.0
  Solution 2: [0 0 1 1 0], Fitness: 2.0
  Solution 3: [1 1 1 1 1], Fitness: 5.0
{'array_size': 5, 'p': 0.3}

Solution exceeding limit with custom penalty (-99.0): -99.0
