# 1. Setup and Imports

In [1]:
import sys
import os
import time
# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar


print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

JAX version: 0.7.0
Available devices: [CpuDevice(id=0)]


# 1. BinaryGenome
The BinaryGenome class represents a binary string, which serves as the genetic material for our evolutionary algorithms.

## 1.1 Initialization
You can initialize a BinaryGenome by specifying custom ones like array_size and a probability p and either assigning a specific tensor or allowing manually or random initialization.

In [2]:
from malthusjax.core.genome import BinaryGenome
# Create a valid genome
tensor_1 = jnp.array([0, 1, 1, 0, 1])
genome_init_params = {'array_size': 5, 'p': 0.5}

genome_from_tensor = BinaryGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = BinaryGenome(array_size=5, p=0.5, random_init=True, random_key=my_random_key)
print(f'Random genome:\n\t{random_genome}')

init_fn = BinaryGenome.get_random_initialization_jit(genome_init_params)
random_key, subkey = jar.split(my_random_key)
random_tensor = init_fn(subkey)
print(f'Random tensor from JIT function:\n\t{random_tensor}')

Genome from tensor:
	[False  True  True False  True](size=5, valid=True)
Random genome:
	[False False  True  True  True](size=5, valid=True)
Random tensor from JIT function:
	[False False  True  True  True]


you can transform a genome back to tensor


In [3]:
# you can transform a genome back to tensor
tensor_from_genome = genome_from_tensor.to_tensor()
print(f'Tensor from genome:\n\t{tensor_from_genome}')

Tensor from genome:
	[0 1 1 0 1]


## 1.2 Validation
when creating creating a genome from tensor it will automaticlly clip to 0/1 ensuring all elements are binary, the only way to have an invalid genome is to manually assign a tensor with values other than 0 or 1, or assign a tensor of wrong size.

In [4]:
invalid_tensor_1 = jnp.array([0, 2, 1, 0, 1])
invalid_genome1 = BinaryGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0, -1, 1, 0, 1])
invalid_genome2 = BinaryGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.array([0, 1, 1, 0, 1, 1])

try:
    invalid_genome3 = BinaryGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')

print("-"*50)
print("Manual invalid genome creation and validation:")
print("-"*50)
invalid_genome1 = BinaryGenome(**genome_init_params, random_init=False)
invalid_genome1.genome = invalid_tensor_1
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')

invalid_genome2 = BinaryGenome(**genome_init_params, random_init=False)
invalid_genome2.genome = invalid_tensor_2
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')   
    
invalid_genome3 = BinaryGenome(**genome_init_params, random_init=False)
invalid_genome3.genome = invalid_tensor_3
print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')


Invalid genome1 is valid: True
Invalid genome2 is valid: True
[False  True  True False  True  True] = (5,)
Failed to create invalid genome3: Genome created from tensor [0 1 1 0 1 1] is not valid
--------------------------------------------------
Manual invalid genome creation and validation:
--------------------------------------------------
Invalid genome1 is valid: False
Invalid genome2 is valid: False
[0 1 1 0 1 1] = (5,)
Invalid genome3 is valid: False


to efficently corect invalid genomes you can use .get_autovalidation_jit() that returns a jit compiled function that will correct the genome inplace, this will help later on when using genetic operators that might produce invalid genomes

In [5]:
auto_correction_function = BinaryGenome.get_autocorrection_jit(genome_init_params= genome_init_params)


invalid_genome1.genome = auto_correction_function(invalid_genome1.to_tensor())
print(f'After auto-correction, invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2.genome = auto_correction_function(invalid_genome2.to_tensor())
print(f'After auto-correction, invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3.genome = auto_correction_function(invalid_genome3.to_tensor())
print(f'After auto-correction, invalid genome3 is valid: {invalid_genome3._validate()}')

After auto-correction, invalid genome1 is valid: True
After auto-correction, invalid genome2 is valid: True
After auto-correction, invalid genome3 is valid: True


## 1.3 Distance Calculation (Hamming Distance)
The framework supports calculating the Hamming distance between two binary genomes, representing the number of positions at which the corresponding bits are different.

In [6]:
genome1 = BinaryGenome(array_size=10000, p=0.5,random_init=True, random_key=jar.PRNGKey(42))
genome2 = BinaryGenome(array_size=10000, p=0.5,random_init=True, random_key=jar.PRNGKey(43))


distance_fn = BinaryGenome.get_distance_jit()
# Time the JIT version
start = time.time()
print(f"JIT Distance between genome1 and genome2: {distance_fn(genome1.genome, genome2.genome)}")
jit_time = time.time() - start

# Time the regular version
start = time.time()
print(f"Distance between genome1 and genome2: {genome1.distance(genome2)}")
regular_time = time.time() - start

print(f"\nTiming comparison:")
print(f"JIT version: {jit_time:.6f} seconds")
print(f"Regular version: {regular_time:.6f} seconds")

JIT Distance between genome1 and genome2: 5059
Distance between genome1 and genome2: 5059.0

Timing comparison:
JIT version: 0.011806 seconds
Regular version: 0.033880 seconds


# 2. CategoricalGenome
The CategoricalGenome class represents a categorical array, which serves as the genetic material for our evolutionary algorithms.

## 2.1 Initialization
You can initialize a CategoricalGenome by specifying custom ones like array_size and a number of categories n_categories and either assigning a specific tensor or allowing manually or random initialization.

In [7]:
from malthusjax.core.genome.categorical import CategoricalGenome

# Create a valid genome
tensor_1 = jnp.array([0, 2, 1, 0, 1])
genome_init_params = {'array_size': 5, 'n_categories': 3}
genome_from_tensor = CategoricalGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = CategoricalGenome(**genome_init_params, random_init=True, random_key=my_random_key)

#extract the init function
init_fn = CategoricalGenome.get_random_initialization_jit(genome_init_params)
random_key, subkey = jar.split(my_random_key)
random_tensor = init_fn(subkey)
print(f'Random tensor from JIT function:\n\t{random_tensor}')

print(f'Random genome:\n\t{random_genome}')

# you can transform a genome back to tensor
tensor_from_genome = genome_from_tensor.to_tensor()

Genome from tensor:
	CategoricalGenome(genome=[0 1 1 0 1], valid=True)
Random tensor from JIT function:
	[1 0 0 0 1]
Random genome:
	CategoricalGenome(genome=[1 0 0 0 1], valid=True)


## 2.2 Validation
when creating creating a genome from tensor it will automaticlly clip to 0 or (n_classes - 1) ensuring all elements are within boundaries, the only way to have an invalid genome is to manually assign a tensor with values outisde the specific range or assign a tensor of wrong size.

In [8]:
invalid_tensor_1 = jnp.array([0, 5, 1, 0, 1])
invalid_genome1 = CategoricalGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0, -1, 1, 0, 1])
invalid_genome2 = CategoricalGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.array([0, 1, 1, 0, 1, 1])
try:
    invalid_genome3 = CategoricalGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')
    
    
print("-"*50)
print("Manual invalid genome creation and validation:")
print("-"*50)
invalid_genome1 = CategoricalGenome(**genome_init_params, random_init=False)
invalid_genome1.genome = invalid_tensor_1
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2 = CategoricalGenome(**genome_init_params, random_init=False)
invalid_genome2.genome = invalid_tensor_2
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3 = CategoricalGenome(**genome_init_params, random_init=False)
invalid_genome3.genome = invalid_tensor_3
print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')


Invalid genome1 is valid: True
Invalid genome2 is valid: True
[0 1 1 0 1 1] != (5,)
[0 1 1 0 1 1] != (5,)
Invalid genome3 is valid: False
--------------------------------------------------
Manual invalid genome creation and validation:
--------------------------------------------------
Genome values [0 5 1 0 1] out of range [0, 2)
Invalid genome1 is valid: False
Genome values [ 0 -1  1  0  1] out of range [0, 2)
Invalid genome2 is valid: False
[0 1 1 0 1 1] != (5,)
Invalid genome3 is valid: False


In [9]:
auto_correction_function = CategoricalGenome.get_autocorrection_jit(genome_init_params= genome_init_params)

invalid_genome1.genome = auto_correction_function(invalid_genome1.to_tensor())
print(f'After auto-correction, invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2.genome = auto_correction_function(invalid_genome2.to_tensor())
print(f'After auto-correction, invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3.genome = auto_correction_function(invalid_genome3.to_tensor())
print(f'After auto-correction, invalid genome3 is valid: {invalid_genome3._validate()}')

After auto-correction, invalid genome1 is valid: True
After auto-correction, invalid genome2 is valid: True
After auto-correction, invalid genome3 is valid: True


## 2.3 Distance Calculation (Hamming Distance)
The framework supports calculating the Hamming distance between two binary genomes, representing the number of positions at which the corresponding bits are different.

In [10]:
genome1 = CategoricalGenome(array_size=10000, num_categories=5, random_init=True, random_key=jar.PRNGKey(42))
genome2 = CategoricalGenome(array_size=10000, num_categories=5, random_init=True, random_key=jar.PRNGKey(43))


distance_fn = CategoricalGenome.get_distance_jit()
# Time the JIT version
start = time.time()
print(f"JIT Distance between genome1 and genome2: {distance_fn(genome1.genome, genome2.genome)}")
jit_time = time.time() - start

# Time the regular version
start = time.time()
print(f"Distance between genome1 and genome2: {genome1.distance(genome2)}")
regular_time = time.time() - start

print(f"\nTiming comparison:")
print(f"JIT version: {jit_time:.6f} seconds")
print(f"Regular version: {regular_time:.6f} seconds")

JIT Distance between genome1 and genome2: 7919
Distance between genome1 and genome2: 7919

Timing comparison:
JIT version: 0.012116 seconds
Regular version: 0.010109 seconds


# 3. PermutationGenome
The PermutationGenome class represents a permutation array, which serves as the genetic material for our evolutionary algorithms.
## 1.1 Initialization
You can initialize a PermutationGenome by specifying custom parameters: permutation_start and permutation_end, and either assigning a specific tensor or allowing manually or random initialization.

In [11]:
jnp.arange(4,10)

Array([4, 5, 6, 7, 8, 9], dtype=int32)

In [12]:
from malthusjax.core.genome.permutation import PermutationGenome

# Create a valid genome
tensor_1 = jnp.array([3, 0, 2, 1, 4])
genome_init_params = {'permutation_start': 0, 'permutation_end': 5}
genome_from_tensor = PermutationGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = PermutationGenome(**genome_init_params, random_init=True, random_key=my_random_key)
print(f'Random genome:\n\t{random_genome}')
# you can transform a genome back to tensor
tensor_from_genome = genome_from_tensor.to_tensor()
print(f'Tensor from genome:\n\t{tensor_from_genome}')

Genome from tensor:
	PermutationGenome(permutation_start=0, permutation_end=5 
Random genome:
	PermutationGenome(permutation_start=0, permutation_end=5 
Tensor from genome:
	[3 0 2 1 4]


## 3.2 Validation
when creating creating a genome from tensor it will automaticlly clip to 0 or (n_classes - 1) ensuring all elements are within boundaries, the only way to have an invalid genome is to manually assign a tensor with values outisde the specific range or assign a tensor of wrong size.

In [13]:
invalid_tensor_1 = jnp.array([0, 1, 1, 4, 3])
invalid_genome1 = PermutationGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0, -1, 2, 3, 4])
invalid_genome2 = PermutationGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.arange(6)
try:
    invalid_genome3 = PermutationGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')


Genome contains duplicates: [0 1 1 4 3]
Genome contains duplicates: [0 1 1 4 3]
Invalid genome1 is valid: False
Genome values [ 0 -1  2  3  4] out of range [0, 5]
Genome values [ 0 -1  2  3  4] out of range [0, 5]
Invalid genome2 is valid: False
[0 1 2 3 4 5] = (5,)
[0 1 2 3 4 5] = (5,)
Invalid genome3 is valid: False


In [14]:
# Permutations are tricky to auto-correct, so we will skip that part

# 4 RealGenome
The RealGenome class represents a real-valued array, which serves as the genetic material for our evolutionary algorithms.
## 4.1 Initialization
You can initialize a RealGenome by specifying custom ones like array_size and bounds (a tuple defining the minimum and maximum values) and either assigning a specific tensor or allowing manually or random initialization.

In [15]:
from malthusjax.core.genome.real import RealGenome


# Create a valid genome
genome_init_params = {'minval': -1.0, 'maxval': 1.0, 'array_size': 5}
tensor_1 = jnp.array([0.5, -0.2, 0.1, 0.0, -0.9])
genome_from_tensor = RealGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = RealGenome(**genome_init_params, random_init= True, random_key= my_random_key)
print(f'Random genome:\n\t{random_genome}')

Genome from tensor:
	RealGenome(size=5, valid=True)
Random genome:
	RealGenome(size=5, valid=True)


In [16]:
invalid_tensor_1 = jnp.array([0.5, -1.2, 0.1, 0.0, -0.9])
invalid_genome1 = RealGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0.5, 1.2, 0.1, 0.0, -0.9])
invalid_genome2 = RealGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.array([0.5, -0.2, 0.1, 0.0, -0.9, 0.3])
try:
    invalid_genome3 = RealGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')
    
    
print("-"*50)
print("Manual invalid genome creation and validation:")
print("-"*50)
invalid_genome1 = RealGenome(**genome_init_params, random_init=False)
invalid_genome1.genome = invalid_tensor_1
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2 = RealGenome(**genome_init_params, random_init=False)
invalid_genome2.genome = invalid_tensor_2
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3 = RealGenome(**genome_init_params, random_init=False)
invalid_genome3.genome = invalid_tensor_3
print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')


Genome values [ 0.5 -1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Genome values [ 0.5 -1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome1 is valid: False
Genome values [ 0.5  1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Genome values [ 0.5  1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome2 is valid: False
[ 0.5 -0.2  0.1  0.  -0.9  0.3] = (5,)
[ 0.5 -0.2  0.1  0.  -0.9  0.3] = (5,)
Invalid genome3 is valid: False
--------------------------------------------------
Manual invalid genome creation and validation:
--------------------------------------------------
Genome values [ 0.5 -1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome1 is valid: False
Genome values [ 0.5  1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome2 is valid: False
[ 0.5 -0.2  0.1  0.  -0.9  0.3] = (5,)
Invalid genome3 is valid: False


In [17]:
auto_correction_function = RealGenome.get_autocorrection_jit(genome_init_params= genome_init_params)


invalid_genome1.genome = auto_correction_function(invalid_genome1.to_tensor())
print(f'After auto-correction, invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2.genome = auto_correction_function(invalid_genome2.to_tensor())
print(f'After auto-correction, invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3.genome = auto_correction_function(invalid_genome3.to_tensor())
print(f'After auto-correction, invalid genome3 is valid: {invalid_genome3._validate()}')

After auto-correction, invalid genome1 is valid: True
After auto-correction, invalid genome2 is valid: True
After auto-correction, invalid genome3 is valid: True


# demo of heterogeneous real genome

In [5]:
import jax.random as jar
import jax.numpy as jnp
dist = jar.normal(
    key=jar.PRNGKey(42), shape=(5,), dtype=jnp.float32
)

# transform distribution from standard normal to mean = 5 and std = 2
dist = dist * 2 + 5
print(f'Sample from normal distribution with mean 5 and std 2:\n\t{dist}')

Sample from normal distribution with mean 5 and std 2:
	[4.943391  5.9342637 5.591406  5.3070917 4.7519345]


In [None]:
jar.uniform(key=jar.PRNGKey(42), shape=(5,), dtype=jnp.float32)

# tsudent distribution
jar.t(key=jar.PRNGKey(42), df=10, shape=(5,), dtype=jnp.float32)

#exponential distribution
jar.exponential(key=jar.PRNGKey(42), shape=(5,), scalar=None, dtype=jnp.float32)

Array([0.6708175, 1.1388006, 0.95782  , 0.8232925, 0.5990097], dtype=float32)

In [None]:
from malthusjax.core.genome.real_heterogeneous import RealGenome_het

@jax.jit
def normal_dist(key, shape, mean=0.0, std=1.0):
    return jar.normal(key, shape) * std + mean

@jax.jit
def t_dist(key, shape, df=10, mean=0.0, std=1.0):
    return jar.t(key, df, shape) * std + mean


genome_init_params = {
    'array_shape': 5,
    'distributions': [normal_dist, t_dist, jar.exponential, normal_dist, jar.uniform],
    'distribution_params': [
        {'mean': 0.0, 'std': 1.0},          # Normal distribution
        {'df': 10, 'mean': 0.0, 'std': 1.0},# Student's t-distribution
        {'scale': 1.0},                     # Exponential distribution
        {'mean': 5.0, 'std': 2.0},          # Normal distribution with different params
        {'minval': 0.0, 'maxval': 10.0}    # Uniform distribution with different params
    ]
}   
random_genome = RealGenome_het(**genome_init_params, random_init=True, random_key=jar.PRNGKey(42))
print(f'Random heterogeneous real genome:\n\t{random_genome}')

AttributeError: module 'jax.random' has no attribute 'KeyArray'