# 1. Setup and Imports

In [1]:
import sys
import os
import time
# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar


print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

JAX version: 0.7.0
Available devices: [CpuDevice(id=0)]


In [2]:
from typing import Callable


# 1. BinaryGenome
The BinaryGenome class represents a binary string, which serves as the genetic material for our evolutionary algorithms.

## 1.1 Initialization
You can initialize a BinaryGenome by specifying custom ones like array_size and a probability p and either assigning a specific tensor or allowing manually or random initialization.

In [3]:
from malthusjax.core.genome import BinaryGenome
# Create a valid genome
tensor_1 = jnp.array([0, 1, 1, 0, 1])
genome_init_params = {'array_size': 5, 'p': 0.5}

genome_from_tensor = BinaryGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = BinaryGenome(array_size=5, p=0.5, random_init=True, random_key=my_random_key)
print(f'Random genome:\n\t{random_genome}')

init_fn = BinaryGenome.get_random_initialization_jit(genome_init_params)
random_key, subkey = jar.split(my_random_key)
random_tensor = init_fn(subkey)
print(f'Random tensor from JIT function:\n\t{random_tensor}')

Genome from tensor:
	[False  True  True False  True](size=5, valid=True)
Random genome:
	[False False  True  True  True](size=5, valid=True)
Random tensor from JIT function:
	[False False  True  True  True]


you can transform a genome back to tensor


In [4]:
# you can transform a genome back to tensor
tensor_from_genome = genome_from_tensor.to_tensor()
print(f'Tensor from genome:\n\t{tensor_from_genome}')

Tensor from genome:
	[0 1 1 0 1]


## 1.2 Validation
when creating creating a genome from tensor it will automaticlly clip to 0/1 ensuring all elements are binary, the only way to have an invalid genome is to manually assign a tensor with values other than 0 or 1, or assign a tensor of wrong size.

In [5]:
invalid_tensor_1 = jnp.array([0, 2, 1, 0, 1])
invalid_genome1 = BinaryGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0, -1, 1, 0, 1])
invalid_genome2 = BinaryGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.array([0, 1, 1, 0, 1, 1])

try:
    invalid_genome3 = BinaryGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')

print("-"*50)
print("Manual invalid genome creation and validation:")
print("-"*50)
invalid_genome1 = BinaryGenome(**genome_init_params, random_init=False)
invalid_genome1.genome = invalid_tensor_1
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')

invalid_genome2 = BinaryGenome(**genome_init_params, random_init=False)
invalid_genome2.genome = invalid_tensor_2
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')   
    
invalid_genome3 = BinaryGenome(**genome_init_params, random_init=False)
invalid_genome3.genome = invalid_tensor_3
print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')


Invalid genome1 is valid: True
Invalid genome2 is valid: True
[False  True  True False  True  True] = (5,)
Failed to create invalid genome3: Genome created from tensor [0 1 1 0 1 1] is not valid
--------------------------------------------------
Manual invalid genome creation and validation:
--------------------------------------------------
Invalid genome1 is valid: False
Invalid genome2 is valid: False
[0 1 1 0 1 1] = (5,)
Invalid genome3 is valid: False


to efficently corect invalid genomes you can use .get_autovalidation_jit() that returns a jit compiled function that will correct the genome inplace, this will help later on when using genetic operators that might produce invalid genomes

In [6]:
auto_correction_function = BinaryGenome.get_autocorrection_jit(genome_init_params= genome_init_params)


invalid_genome1.genome = auto_correction_function(invalid_genome1.to_tensor())
print(f'After auto-correction, invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2.genome = auto_correction_function(invalid_genome2.to_tensor())
print(f'After auto-correction, invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3.genome = auto_correction_function(invalid_genome3.to_tensor())
print(f'After auto-correction, invalid genome3 is valid: {invalid_genome3._validate()}')

After auto-correction, invalid genome1 is valid: True
After auto-correction, invalid genome2 is valid: True
After auto-correction, invalid genome3 is valid: True


## 1.3 Distance Calculation (Hamming Distance)
The framework supports calculating the Hamming distance between two binary genomes, representing the number of positions at which the corresponding bits are different.

In [7]:
genome1 = BinaryGenome(array_size=10000, p=0.5,random_init=True, random_key=jar.PRNGKey(42))
genome2 = BinaryGenome(array_size=10000, p=0.5,random_init=True, random_key=jar.PRNGKey(43))


distance_fn = BinaryGenome.get_distance_jit()
# Time the JIT version
start = time.time()
print(f"JIT Distance between genome1 and genome2: {distance_fn(genome1.genome, genome2.genome)}")
jit_time = time.time() - start

# Time the regular version
start = time.time()
print(f"Distance between genome1 and genome2: {genome1.distance(genome2)}")
regular_time = time.time() - start

print(f"\nTiming comparison:")
print(f"JIT version: {jit_time:.6f} seconds")
print(f"Regular version: {regular_time:.6f} seconds")

JIT Distance between genome1 and genome2: 5059
Distance between genome1 and genome2: 5059.0

Timing comparison:
JIT version: 0.013116 seconds
Regular version: 0.045763 seconds


# 2. CategoricalGenome
The CategoricalGenome class represents a categorical array, which serves as the genetic material for our evolutionary algorithms.

## 2.1 Initialization
You can initialize a CategoricalGenome by specifying custom ones like array_size and a number of categories n_categories and either assigning a specific tensor or allowing manually or random initialization.

In [8]:
from malthusjax.core.genome.categorical import CategoricalGenome

# Create a valid genome
tensor_1 = jnp.array([0, 2, 1, 0, 1])
genome_init_params = {'array_size': 5, 'n_categories': 3}
genome_from_tensor = CategoricalGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = CategoricalGenome(**genome_init_params, random_init=True, random_key=my_random_key)

#extract the init function
init_fn = CategoricalGenome.get_random_initialization_jit(genome_init_params)
random_key, subkey = jar.split(my_random_key)
random_tensor = init_fn(subkey)
print(f'Random tensor from JIT function:\n\t{random_tensor}')

print(f'Random genome:\n\t{random_genome}')

# you can transform a genome back to tensor
tensor_from_genome = genome_from_tensor.to_tensor()

Genome from tensor:
	CategoricalGenome(genome=[0 1 1 0 1], valid=True)
Random tensor from JIT function:
	[1 0 0 0 1]
Random genome:
	CategoricalGenome(genome=[1 0 0 0 1], valid=True)


## 2.2 Validation
when creating creating a genome from tensor it will automaticlly clip to 0 or (n_classes - 1) ensuring all elements are within boundaries, the only way to have an invalid genome is to manually assign a tensor with values outisde the specific range or assign a tensor of wrong size.

In [9]:
invalid_tensor_1 = jnp.array([0, 5, 1, 0, 1])
invalid_genome1 = CategoricalGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0, -1, 1, 0, 1])
invalid_genome2 = CategoricalGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.array([0, 1, 1, 0, 1, 1])
try:
    invalid_genome3 = CategoricalGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')
    
    
print("-"*50)
print("Manual invalid genome creation and validation:")
print("-"*50)
invalid_genome1 = CategoricalGenome(**genome_init_params, random_init=False)
invalid_genome1.genome = invalid_tensor_1
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2 = CategoricalGenome(**genome_init_params, random_init=False)
invalid_genome2.genome = invalid_tensor_2
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3 = CategoricalGenome(**genome_init_params, random_init=False)
invalid_genome3.genome = invalid_tensor_3
print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')


Invalid genome1 is valid: True
Invalid genome2 is valid: True
[0 1 1 0 1 1] != (5,)
[0 1 1 0 1 1] != (5,)
Invalid genome3 is valid: False
--------------------------------------------------
Manual invalid genome creation and validation:
--------------------------------------------------
Genome values [0 5 1 0 1] out of range [0, 2)
Invalid genome1 is valid: False
Genome values [ 0 -1  1  0  1] out of range [0, 2)
Invalid genome2 is valid: False
[0 1 1 0 1 1] != (5,)
Invalid genome3 is valid: False


In [10]:
auto_correction_function = CategoricalGenome.get_autocorrection_jit(genome_init_params= genome_init_params)

invalid_genome1.genome = auto_correction_function(invalid_genome1.to_tensor())
print(f'After auto-correction, invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2.genome = auto_correction_function(invalid_genome2.to_tensor())
print(f'After auto-correction, invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3.genome = auto_correction_function(invalid_genome3.to_tensor())
print(f'After auto-correction, invalid genome3 is valid: {invalid_genome3._validate()}')

After auto-correction, invalid genome1 is valid: True
After auto-correction, invalid genome2 is valid: True
After auto-correction, invalid genome3 is valid: True


## 2.3 Distance Calculation (Hamming Distance)
The framework supports calculating the Hamming distance between two binary genomes, representing the number of positions at which the corresponding bits are different.

In [11]:
genome1 = CategoricalGenome(array_size=10000, num_categories=5, random_init=True, random_key=jar.PRNGKey(42))
genome2 = CategoricalGenome(array_size=10000, num_categories=5, random_init=True, random_key=jar.PRNGKey(43))


distance_fn = CategoricalGenome.get_distance_jit()
# Time the JIT version
start = time.time()
print(f"JIT Distance between genome1 and genome2: {distance_fn(genome1.genome, genome2.genome)}")
jit_time = time.time() - start

# Time the regular version
start = time.time()
print(f"Distance between genome1 and genome2: {genome1.distance(genome2)}")
regular_time = time.time() - start

print(f"\nTiming comparison:")
print(f"JIT version: {jit_time:.6f} seconds")
print(f"Regular version: {regular_time:.6f} seconds")

JIT Distance between genome1 and genome2: 7919
Distance between genome1 and genome2: 7919

Timing comparison:
JIT version: 0.011787 seconds
Regular version: 0.010426 seconds


# 3. PermutationGenome
The PermutationGenome class represents a permutation array, which serves as the genetic material for our evolutionary algorithms.
## 1.1 Initialization
You can initialize a PermutationGenome by specifying custom parameters: permutation_start and permutation_end, and either assigning a specific tensor or allowing manually or random initialization.

In [12]:
jnp.arange(4,10)

Array([4, 5, 6, 7, 8, 9], dtype=int32)

In [13]:
from malthusjax.core.genome.permutation import PermutationGenome

# Create a valid genome
tensor_1 = jnp.array([3, 0, 2, 1, 4])
genome_init_params = {'permutation_start': 0, 'permutation_end': 5}
genome_from_tensor = PermutationGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = PermutationGenome(**genome_init_params, random_init=True, random_key=my_random_key)
print(f'Random genome:\n\t{random_genome}')
# you can transform a genome back to tensor
tensor_from_genome = genome_from_tensor.to_tensor()
print(f'Tensor from genome:\n\t{tensor_from_genome}')

Genome from tensor:
	PermutationGenome(permutation_start=0, permutation_end=5 
Random genome:
	PermutationGenome(permutation_start=0, permutation_end=5 
Tensor from genome:
	[3 0 2 1 4]


## 3.2 Validation
when creating creating a genome from tensor it will automaticlly clip to 0 or (n_classes - 1) ensuring all elements are within boundaries, the only way to have an invalid genome is to manually assign a tensor with values outisde the specific range or assign a tensor of wrong size.

In [14]:
invalid_tensor_1 = jnp.array([0, 1, 1, 4, 3])
invalid_genome1 = PermutationGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0, -1, 2, 3, 4])
invalid_genome2 = PermutationGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.arange(6)
try:
    invalid_genome3 = PermutationGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')


Genome contains duplicates: [0 1 1 4 3]
Genome contains duplicates: [0 1 1 4 3]
Invalid genome1 is valid: False
Genome values [ 0 -1  2  3  4] out of range [0, 5]
Genome values [ 0 -1  2  3  4] out of range [0, 5]
Invalid genome2 is valid: False
[0 1 2 3 4 5] = (5,)
[0 1 2 3 4 5] = (5,)
Invalid genome3 is valid: False


In [15]:
# Permutations are tricky to auto-correct, so we will skip that part

# 4 RealGenome
The RealGenome class represents a real-valued array, which serves as the genetic material for our evolutionary algorithms.
## 4.1 Initialization
You can initialize a RealGenome by specifying custom ones like array_size and bounds (a tuple defining the minimum and maximum values) and either assigning a specific tensor or allowing manually or random initialization.

In [16]:
from malthusjax.core.genome.real import RealGenome


# Create a valid genome
genome_init_params = {'minval': -1.0, 'maxval': 1.0, 'array_size': 5}
tensor_1 = jnp.array([0.5, -0.2, 0.1, 0.0, -0.9])
genome_from_tensor = RealGenome.from_tensor(tensor_1, genome_init_params= genome_init_params)
print(f'Genome from tensor:\n\t{genome_from_tensor}')

# Create a random genome
my_random_key = jar.PRNGKey(42)
random_genome = RealGenome(**genome_init_params, random_init= True, random_key= my_random_key)
print(f'Random genome:\n\t{random_genome}')

Genome from tensor:
	RealGenome(size=5, valid=True)
Random genome:
	RealGenome(size=5, valid=True)


In [17]:
invalid_tensor_1 = jnp.array([0.5, -1.2, 0.1, 0.0, -0.9])
invalid_genome1 = RealGenome.from_tensor(invalid_tensor_1, genome_init_params= genome_init_params)
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_tensor_2 = jnp.array([0.5, 1.2, 0.1, 0.0, -0.9])
invalid_genome2 = RealGenome.from_tensor(invalid_tensor_2, genome_init_params= genome_init_params)
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_tensor_3 = jnp.array([0.5, -0.2, 0.1, 0.0, -0.9, 0.3])
try:
    invalid_genome3 = RealGenome.from_tensor(invalid_tensor_3, genome_init_params= genome_init_params)
    print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')
except ValueError as e:
    print(f'Failed to create invalid genome3: {e}')
    
    
print("-"*50)
print("Manual invalid genome creation and validation:")
print("-"*50)
invalid_genome1 = RealGenome(**genome_init_params, random_init=False)
invalid_genome1.genome = invalid_tensor_1
print(f'Invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2 = RealGenome(**genome_init_params, random_init=False)
invalid_genome2.genome = invalid_tensor_2
print(f'Invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3 = RealGenome(**genome_init_params, random_init=False)
invalid_genome3.genome = invalid_tensor_3
print(f'Invalid genome3 is valid: {invalid_genome3._validate()}')


Genome values [ 0.5 -1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Genome values [ 0.5 -1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome1 is valid: False
Genome values [ 0.5  1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Genome values [ 0.5  1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome2 is valid: False
[ 0.5 -0.2  0.1  0.  -0.9  0.3] = (5,)
[ 0.5 -0.2  0.1  0.  -0.9  0.3] = (5,)
Invalid genome3 is valid: False
--------------------------------------------------
Manual invalid genome creation and validation:
--------------------------------------------------
Genome values [ 0.5 -1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome1 is valid: False
Genome values [ 0.5  1.2  0.1  0.  -0.9] out of range [-1.0, 1.0]
Invalid genome2 is valid: False
[ 0.5 -0.2  0.1  0.  -0.9  0.3] = (5,)
Invalid genome3 is valid: False


In [18]:
auto_correction_function = RealGenome.get_autocorrection_jit(genome_init_params= genome_init_params)


invalid_genome1.genome = auto_correction_function(invalid_genome1.to_tensor())
print(f'After auto-correction, invalid genome1 is valid: {invalid_genome1._validate()}')
invalid_genome2.genome = auto_correction_function(invalid_genome2.to_tensor())
print(f'After auto-correction, invalid genome2 is valid: {invalid_genome2._validate()}')
invalid_genome3.genome = auto_correction_function(invalid_genome3.to_tensor())
print(f'After auto-correction, invalid genome3 is valid: {invalid_genome3._validate()}')

After auto-correction, invalid genome1 is valid: True
After auto-correction, invalid genome2 is valid: True
After auto-correction, invalid genome3 is valid: True


# multi trial

In [25]:
from malthusjax.core.multigenome import AbstractMultiGenome

genome_classes_dct = {
    'binary': BinaryGenome,
    'binary2': BinaryGenome,
    'categorical': CategoricalGenome,
    'permutation': PermutationGenome,
    'real': RealGenome
}

genome_init_params_dct = {
    'binary': {'array_size': 10, 'p': 0.5},
    'binary2': {'array_size': 10, 'p': 0.5},
    'categorical': {'array_size': 10, 'n_categories': 3},
    'permutation': {'permutation_start': 0, 'permutation_end': 10},
    'real': {'minval': -1.0, 'maxval': 1.0, 'array_size': 10}
}


multigenome1 = AbstractMultiGenome(
    genome_types_dict = genome_classes_dct,
    genome_init_params = genome_init_params_dct,
    random_init = True,
    random_key = jax.random.PRNGKey(42)
)

print(multigenome1)

AbstractMultiGenome(size=(10, 10, 10, 10, 10), valid=True)


In [None]:
multigenome2 = AbstractMultiGenome(
    genome_types_dict = genome_classes_dct,
    genome_init_params = genome_init_params_dct,
    random_init = True,
    random_key = jax.random.PRNGKey(43)
)

In [20]:
distance_fn_list = [
    genome_cls.get_distance_jit() for genome_cls in genome_classes_dct.values()
]
auto_correction_fn_list = [
    genome_cls.get_autocorrection_jit(genome_init_params_dct[genome_type]) for genome_type, genome_cls in genome_classes_dct.items()
]
init_fn_list = [
    genome_cls.get_random_initialization_jit(genome_init_params_dct[genome_type]) for genome_type, genome_cls in genome_classes_dct.items()
]



batch_execute_distance = AbstractMultiGenome.create_batch_executor(distance_fn_list, as_tuple=True)
batch_execute_auto_correction = AbstractMultiGenome.create_batch_executor(auto_correction_fn_list, as_tuple=True)
batch_execute_init = AbstractMultiGenome.create_batch_executor(init_fn_list, as_tuple=True)

batch_distance_inputs = [
    (multigenome1.to_tensors(as_tuple=True)[i], multigenome2.to_tensors(as_tuple=True)[i]) for i in range(len(genome_classes_dct))
]   
batch_distance = batch_execute_distance(batch_distance_inputs)
print("-"*50)
print("type:", type(batch_distance))
print(f"Batch distance results: {batch_distance}")
print("-"*50)
batch_init = batch_execute_init([(jax.random.PRNGKey(i)) for i in range(len(init_fn_list))])
print("-"*50)
print("type:", type(batch_init))
print(f"Batch init results: {batch_init}")
print("-"*50)

batch_auto_correction_inputs = [ (multigenome1.to_tensors(as_tuple=True)[i],) for i in range(len(genome_classes_dct)) ]
batch_auto_correction = batch_execute_auto_correction(batch_auto_correction_inputs)
print("-"*50)
print("type:", type(batch_auto_correction))
print(f"Batch auto-correction results: {batch_auto_correction}")
print("-"*50)

--------------------------------------------------
type: <class 'tuple'>
Batch distance results: (Array(5, dtype=int32), Array(4, dtype=int32), Array(4, dtype=int32), Array(8, dtype=int32), Array(2.3389573, dtype=float32))
--------------------------------------------------
--------------------------------------------------
type: <class 'tuple'>
Batch init results: (Array([False, False,  True,  True, False,  True,  True, False, False,
        True], dtype=bool), Array([ True, False,  True,  True, False, False, False, False,  True,
       False], dtype=bool), Array([1, 0, 1, 1, 1, 0, 0, 1, 0, 1], dtype=int32), Array([0, 5, 6, 7, 4, 2, 1, 3, 9, 8], dtype=int32), Array([ 0.7610934 ,  0.53977776, -0.71992755,  0.26197267,  0.00271249,
        0.18649173, -0.5124204 ,  0.78336215,  0.7290621 , -0.21929836],      dtype=float32))
--------------------------------------------------
--------------------------------------------------
type: <class 'tuple'>
Batch auto-correction results: (Array([0, 

In [21]:
# hard way 
distance_fn_list = [genome.get_distance_jit() for genome in multigenome1._genome_list]
batch_executor = multigenome1.create_batch_executor(distance_fn_list, as_tuple=True)
batch_distance_inputs = [
    (multigenome1.to_tensors(as_tuple=True)[i], multigenome2.to_tensors(as_tuple=True)[i]) for i in range(len(multigenome1._genome_list))
]
batch_exec = batch_executor(batch_distance_inputs)
print("type:", type(batch_exec))
print("Batch distance results:", batch_exec)

# easy way
multigenome1 - multigenome2

type: <class 'tuple'>
Batch distance results: (Array(5, dtype=int32), Array(4, dtype=int32), Array(4, dtype=int32), Array(8, dtype=int32), Array(2.3389573, dtype=float32))


(Array(5, dtype=int32),
 Array(4, dtype=int32),
 Array(4, dtype=int32),
 Array(8, dtype=int32),
 Array(2.3389573, dtype=float32))

In [22]:
# hard way

init_fn_list = [genome.get_random_initialization_jit(params) for genome, params in zip(multigenome1._genome_list, multigenome1.genome_init_params_dict.values())]
batch_executor = multigenome1.create_batch_executor(init_fn_list, as_tuple=True)
batch_init_inputs = [(jax.random.PRNGKey(i),) for i in range(len(init_fn_list))]
batch_exec = batch_executor(batch_init_inputs)
print("type:", type(batch_exec))
print("Batch init results:", batch_exec)

#easy way
multigenome = AbstractMultiGenome(
    genome_types_dict = genome_classes_dct,
    genome_init_params = genome_init_params_dct,
    random_init = True,
    random_key = jax.random.PRNGKey(42)
)
multigenome.to_tensors(as_tuple=True)


type: <class 'tuple'>
Batch init results: (Array([False, False,  True,  True, False,  True,  True, False, False,
        True], dtype=bool), Array([ True, False,  True,  True, False, False, False, False,  True,
       False], dtype=bool), Array([1, 0, 1, 1, 1, 0, 0, 1, 0, 1], dtype=int32), Array([0, 5, 6, 7, 4, 2, 1, 3, 9, 8], dtype=int32), Array([ 0.7610934 ,  0.53977776, -0.71992755,  0.26197267,  0.00271249,
        0.18649173, -0.5124204 ,  0.78336215,  0.7290621 , -0.21929836],      dtype=float32))


(Array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32),
 Array([1, 1, 1, 1, 0, 1, 0, 0, 0, 1], dtype=int32),
 Array([0, 0, 0, 0, 0, 0, 1, 0, 0, 1], dtype=int32),
 Array([3, 4, 5, 6, 1, 9, 0, 8, 2, 7], dtype=int32),
 Array([ 0.7768438 , -0.5759094 ,  0.30274057, -0.77292156,  0.37062407,
         0.19877958,  0.06566715, -0.40798068, -0.25738406,  0.5588684 ],      dtype=float32))

In [23]:
#hard way
auto_correction_fn_list = [genome.get_autocorrection_jit(params) for genome, params in zip(multigenome1._genome_list, multigenome1.genome_init_params_dict.values())]
batch_executor = multigenome1.create_batch_executor(auto_correction_fn_list, as_tuple=True)
batch_auto_correction_inputs = [(multigenome1.to_tensors(as_tuple=True)[i],) for i in range(len(auto_correction_fn_list))]
batch_exec = batch_executor(batch_auto_correction_inputs)
print("type:", type(batch_exec))
print("Batch auto-correction results:", batch_exec)

# easy way
multigenome1.auto_correct()

type: <class 'tuple'>
Batch auto-correction results: (Array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32), Array([1, 1, 1, 1, 0, 1, 0, 0, 0, 1], dtype=int32), Array([0, 0, 0, 0, 0, 0, 1, 0, 0, 1], dtype=int32), Array([3, 4, 5, 6, 1, 9, 0, 8, 2, 7], dtype=int32), Array([ 0.7768438 , -0.5759094 ,  0.30274057, -0.77292156,  0.37062407,
        0.19877958,  0.06566715, -0.40798068, -0.25738406,  0.5588684 ],      dtype=float32))


(Array([0, 0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32),
 Array([1, 1, 1, 1, 0, 1, 0, 0, 0, 1], dtype=int32),
 Array([0, 0, 0, 0, 0, 0, 1, 0, 0, 1], dtype=int32),
 Array([3, 4, 5, 6, 1, 9, 0, 8, 2, 7], dtype=int32),
 Array([ 0.7768438 , -0.5759094 ,  0.30274057, -0.77292156,  0.37062407,
         0.19877958,  0.06566715, -0.40798068, -0.25738406,  0.5588684 ],      dtype=float32))