# Genetic Programming Tree Evolution with JAX

This notebook demonstrates a genetic programming approach for evolving mathematical expression trees using JAX. The trees combine dataset features through mathematical operations to create new derived features.

## 1. Setup and Imports

**Dependencies**:
- JAX for high-performance numerical computing and automatic differentiation
- scikit-learn for synthetic dataset generation
- MalthusJAX for genetic algorithm components

In [None]:
import sys
import os
import time

# Add the src directory to the path so we can import malthusjax
sys.path.append('/Users/leonardodicaterina/Documents/GitHub/MalthusJAX/src')

import jax
import jax.numpy as jnp
import jax.random as jar
from jax import lax

from sklearn.datasets import make_classification
from malthusjax.core.genome.categorical import CategoricalGenome, CategoricalGenomeConfig

print("JAX version:", jax.__version__)
print("Available devices:", jax.devices())

## 2. Dataset Generation

**Purpose**: Create a synthetic classification dataset for testing tree evolution.

**Configuration Schema**:
```
n_features_dataset: 5     # Number of input features
depth_instructions: 20    # Number of tree computation nodes
n_operations: 5          # Number of available mathematical operations
```

**Dataset Structure**:
```
X: (100, 5) - Input features matrix
y: (100,)   - Binary classification targets
```


In [None]:
# Configuration
n_features_dataset = 5
depth_instructions = 20
n_operations = 5

# Generate random dataset
X, y = make_classification(
    n_samples=100, 
    n_features=n_features_dataset, 
    n_informative=3, 
    n_redundant=2, 
    random_state=42
)
# Convert to jax arrays
X = jnp.array(X)
y = jnp.array(y)

print(f"Dataset shape: {X.shape}")
print(f"Target shape: {y.shape}")


## 3. Tree Operations Definition

**Purpose**: Define the mathematical operations available for tree nodes.

**Operation Schema**:
```
Index | Operation | Formula
------|-----------|--------
  0   | Addition  | x1 + x2
  1   | Subtraction | x1 - x2
  2   | Multiplication | x1 * x2
  3   | Division  | x1 / (x2 + ε)
  4   | Maximum   | max(x1, x2)
```

**Implementation**: Uses `jax.lax.switch` for efficient operation selection based on genome indices.



In [None]:
def operation_function(x1_value, x2_value, operation_index):
    """Define mathematical operations for tree nodes"""
    return jax.lax.switch(operation_index, [
        lambda: x1_value + x2_value,        # 0: Addition
        lambda: x1_value - x2_value,        # 1: Subtraction
        lambda: x1_value * x2_value,        # 2: Multiplication
        lambda: x1_value / (x2_value + 1e-8),  # 3: Division (with small epsilon)
        lambda: jnp.maximum(x1_value, x2_value)  # 4: Maximum
    ])

# Test the operations
print("Testing operations:")
x1, x2 = 10.0, 5.0
for op_index in range(n_operations):
    result = operation_function(x1, x2, op_index)
    print(f"Operation {op_index}: {x1} op {x2} = {result}")

## 4. Genome Initialization

**Purpose**: Create genetic representations for tree structure and operations.

**Genome Architecture**:
```
Tree Representation:
┌─────────────────┬──────────────────────────┬─────────────────┐
│   Input Layer   │  Computation             │   Operations    │
│   (Features)    │    Nodes                 │    Genome       │
├─────────────────┼──────────────────────────┼─────────────────┤
│ X[0], X[1], ... │ Node[n_features], Node[1]│ Op[0], Op[1],   │
│ X[n_features-1] │ ... Node[max_lines]      │...Op[max_lines] │
└─────────────────┴──────────────────────────┴─────────────────┘
```

**Index Genome Schema**:
- Shape: `(20, 2)` - Each row represents a tree node with 2 input indices
- Values: Range from `0` to `(n_features + current_depth - 1)`
- Constraint: Ensures nodes can only reference earlier computed values

**Operations Genome Schema**:
- Shape: `(20,)` - One operation per tree node
- Values: Range from `0` to `4` (operation indices)


In [None]:
from malthusjax.core.genome.categorical import CategoricalGenome, CategoricalGenomeConfig

In [None]:
random_genome = CategoricalGenome(array_shape=(6,2),
                                  num_categories=5,
                                  random_init=True,
                                  random_key=jar.PRNGKey(0))
random_genome.genome

In [None]:
def init_atomic_tree(random_key, max_value, num_elements=2):
    """Initialize a single tree node with random indices"""
    return jar.randint(random_key, shape=(num_elements,), minval=0, maxval=max_value)
# Create index genome with proper constraints for each depth level
keys = jar.split(jar.PRNGKey(0), num=depth_instructions)
depth_trees = jnp.arange(n_features_dataset, n_features_dataset + depth_instructions)
init_individual_tree = jax.vmap(lambda key, depth: init_atomic_tree(key, depth, 2))
index_genome = init_individual_tree(keys, depth_trees)

# Create operations genome
operations_genome = CategoricalGenome(
    array_shape=(depth_instructions,), 
    num_categories=n_operations, 
    random_init=True, 
    random_key=jar.PRNGKey(123)
)
init_individual_operations = operations_genome.get_random_initialization_pure()
print(f"Index genome shape: {index_genome.shape}")
print(f"Operations genome shape: {operations_genome.genome.shape}")
print(f"Index genome:\n{index_genome}")

In [None]:
print(f"Operations genome:\n{operations_genome.genome}")

## 5. Tree Evaluation Functions

**Purpose**: Implement the core tree evaluation logic for computing derived features.

**Evaluation Flow**:
```
Input Sample → Tree Evaluation → Derived Features

Step-by-step Process:
1. Initialize results_history array
2. For each tree node (depth 0 to 19):
   a. Get input indices from index_genome
   b. Retrieve values (features or computed nodes)
   c. Apply operation from operations_genome
   d. Store result in results_history
3. Return all computed values
```

**Value Resolution Schema**:
```
Index Range        | Source
0 to 4            | Original features X[0:5]
5 to 24           | Computed nodes results_history[0:20]
```

**Vectorization**: Uses `jax.vmap` to process multiple samples in parallel.



In [None]:

def evaluate_single_row(X_sample, genome_tuple):
    """Evaluate the tree for a single sample"""
    
    
    operations_genome, indexes_genome = genome_tuple
    
    def body_fn(carry, input_element):
        results_history = carry
        
        index_0 = indexes_genome[input_element, 0]
        index_1 = indexes_genome[input_element, 1]
        
        def get_value(index, X_sample, results_history):
            """Get value from either original features or computed results"""
            return jax.lax.cond(
                index < len(X_sample),
                lambda: X_sample[index],  # Feature value
                lambda: results_history[index - len(X_sample)]  # Computed node value
            )
        
        value0 = get_value(index_0, X_sample, results_history)
        value1 = get_value(index_1, X_sample, results_history)
        
        atomic_result = operation_function(value0, value1, operations_genome[input_element])
        results_history = results_history.at[input_element].set(atomic_result)
        
        return results_history, atomic_result
    
    # Initialize and execute scan
    initial_results_history = jnp.zeros((depth_instructions,))
    input_data = jnp.arange(depth_instructions-5)
    
    final_results_history, all_results = lax.scan(
        f=body_fn,
        init=initial_results_history,
        xs=input_data
    )
    
    return all_results

# Vectorize for multiple samples
evaluate_multiple_samples = jax.vmap(
    evaluate_single_row, 
    in_axes=(0, None)
)

In [None]:
import pandas as pd
from malthusjax.core.fitness.linear_gp import TreeGPEvaluator

LinearGPEvaluator_instance = TreeGPEvaluator(
    dataframe = X[0],
    n_instructions_in_genome = depth_instructions
)

fitness_function = LinearGPEvaluator_instance.get_pure_fitness_function()

In [None]:
fitness_function(operations_genome.genome, index_genome)

## 6. Testing and Validation

**Purpose**: Validate tree evaluation functionality with single and multiple samples.

**Test Cases**:
1. **Single Sample Test**: Verify correct computation for one data point
2. **Multiple Sample Test**: Ensure vectorization works across batch of samples

**Expected Output Schema**:
```
Single Sample: (20,) - 20 derived features for one sample
Multiple Samples: (batch_size, 20) - 20 derived features per sample
```



In [None]:
# Test single sample
print("\n=== Single Sample Test ===")
single_sample = X[0]
single_result = evaluate_single_row(single_sample, (operations_genome.genome, index_genome))
print(f"Single sample result shape: {single_result.shape}")
print(f"First 5 results: {single_result[:5]}")

# Test multiple samples
print("\n=== Multiple Samples Test ===")
X_batch = X[:10]
batch_results = evaluate_multiple_samples(X_batch,(operations_genome.genome, index_genome))
print(f"Batch results shape: {batch_results.shape}")
print(f"Sample 0 results: {batch_results[0][:5]}")
print(f"Sample 1 results: {batch_results[1][:5]}")

## 7. Performance Testing

**Purpose**: Benchmark JAX JIT compilation performance for tree evaluation.

**Performance Optimization**:
- JIT compilation for faster execution
- Warm-up run to trigger compilation
- Timing measurement with `jax.block_until_ready()`

**Benchmarking Process**:
1. Compile function with JIT
2. Warm-up run (compilation overhead)
3. Timed evaluation of full dataset
4. Report performance metrics


In [None]:
# JIT compile for performance
jitted_evaluate_multiple = jax.jit(evaluate_multiple_samples)

# Warm up JIT
_ = jitted_evaluate_multiple(X_batch,(operations_genome.genome, index_genome))
start_time = time.time()
jitted_results = jitted_evaluate_multiple(X, (operations_genome.genome, index_genome))
# wait until computation is done
jax.block_until_ready(jitted_results)
end_time = time.time()

print(f"\n=== Performance Test ===")
print(f"Evaluated {X.shape[0]} samples in {end_time - start_time:.4f} seconds")
print(f"Final results shape: {jitted_results.shape}")

## 8. Population Initialization

**Purpose**: Create a population of genetic individuals for evolutionary optimization.

**Population Structure**:
```
Population Architecture:
┌─────────────────┬─────────────────┬─────────────────┐
│   Individual 0  │   Individual 1  │   Individual N  │
├─────────────────┼─────────────────┼─────────────────┤
│ Index Genome    │ Index Genome    │ Index Genome    │
│ (20, 2)         │ (20, 2)         │ (20, 2)         │
├─────────────────┼─────────────────┼─────────────────┤
│ Operations      │ Operations      │ Operations      │
│ Genome (20,)    │ Genome (20,)    │ Genome (20,)    │
└─────────────────┴─────────────────┴─────────────────┘
```

**Vectorization Schema**:
```
Target Shapes:
├── Population Index Genomes: (population_size, depth_instructions, 2)
├── Population Operations: (population_size, depth_instructions)
└── Evaluation Results: (population_size, n_samples, depth_instructions)
```

**Computational Complexity**:
```
Time Complexity: O(p × n × d)
├── p = population_size (50)
├── n = n_samples (100) 
└── d = depth_instructions (20)

Total Operations: 50 × 100 × 20 = 100,000
```

**Memory Complexity**: 
```
Space Requirements: O(p × n × d)
├── Population storage: ~400KB for 50 individuals
├── Evaluation results: ~800KB for full population
└── Intermediate computations: ~200KB
```

**Implementation Strategy**:
- Use `jax.vmap` to vectorize genome initialization across population
- Leverage JAX's parallel processing for efficient population evaluation
- JIT compile population operations for maximum performance

In [None]:
'''population_size = 50

keys = jar.split(jar.PRNGKey(42), num=population_size * depth_instructions)
keys = keys.reshape((population_size, depth_instructions, 2))
# now I need to stack the arange arrays for the index genomes
depth_trees = jnp.arange(n_features_dataset, n_features_dataset + depth_instructions)
depth_trees = jnp.tile(depth_trees, (population_size, 1))
print(f"Keys shape: {keys.shape}")
print(f"Depth trees shape: {depth_trees.shape}")
# now I can vmap the init_individual_tree function over the population
init_population_trees = jax.vmap(init_individual_tree)
init_population_operations = jax.vmap(init_individual_operations)
population_index_genomes = init_population_trees(keys, depth_trees)
papulation_operations_genomes = init_population_operations(jar.split(jar.PRNGKey(123), population_size))
print(f"Population index genomes shape: {population_index_genomes.shape}")
print(f"Population operations genomes shape: {papulation_operations_genomes.shape}")
# now I can evaluate the whole population
evaluate_population = jax.vmap(evaluate_multiple_samples, in_axes=(None, 0, 0))
population_results = evaluate_population(X, population_index_genomes, papulation_operations_genomes)
print(f"Population results shape: {population_results.shape}")'''

# Final multi-genome


In [None]:
from malthusjax.core.genome.categorical import CategoricalGenome, CategoricalGenomeConfig
from malthusjax.core.genome.tree_operand import TreeOperandGenome, TreeOperandGenomeConfig
from malthusjax.core.multigenome.base import AbstractMultiGenome


list_genome_init_params = {
    "operations_genome": {'array_shape': (depth_instructions,), 'num_categories': n_operations},
    "operand_genome": {
        'n_operands_per_node': 2,
        'maximum_depth': depth_instructions,
        'n_features_dataset': n_features_dataset
        }
    
    }

genome_types_dict = {
    'operations_genome': CategoricalGenome,
    'operand_genome': TreeOperandGenome
}

multi_genome = AbstractMultiGenome.from_config_tuple(
    config_tuple=(list_genome_init_params, genome_types_dict), random_init = True, random_key=jar.PRNGKey(0))

multi_genome.to_tensors()


# Ensure the output of to_tensors() is a dictionary
single_evaluation = evaluate_single_row(X[0], multi_genome.to_tensors())
print(f"Single evaluation shape: {single_evaluation.shape}")
multiple_evaluation = evaluate_multiple_samples(X[:10], multi_genome.to_tensors())
print(f"Multiple evaluation shape: {multiple_evaluation.shape}")

In [None]:
single_init_fn = multi_genome.get_random_initialization_pure()
jit_single_init_fn = jax.jit(single_init_fn)
jit_single_init_fn(jar.PRNGKey(0))


pop_size = 20
keys = jar.split(jar.PRNGKey(0), num=10)
population = jax.vmap(jit_single_init_fn)(keys)

In [None]:
print(f"type {type(population)}")
print(f"element [0] type: {type(population[0])}")
print(f"element [0] shape: {population[0].shape}")
print(f"element [1]: type {type(population[1])}")
print(f"element [1] shape: {population[1].shape}")

In [None]:
result = jax.vmap(evaluate_multiple_samples, in_axes=(None, 0))(X, population)
print(f"Result shape: {result.shape}")

In [None]:
import functools

def get_full_mutation_function(n_features_dataset, depth_instructions):
  
  def get_mutate_columns(_n_features_dataset, _depth_instructions):
    
    def mutate_element_add_noise(element, keys, max_value: int):
      """
      Applies noise to a single scalar element.
      """
      return jax.lax.cond(
        jar.uniform(keys[0]) < 0.2,
        lambda: jar.randint(keys[1], shape=(), minval=0, maxval=max_value),
        lambda: element
      )
    def mutation_fn(element, key, max_value):
      return  jax.vmap(mutate_element_add_noise, in_axes=(0,0,0))(element, key, max_value)
    return functools.partial(mutation_fn, max_value=jnp.arange(_n_features_dataset, _n_features_dataset + _depth_instructions))
  culum_mutation_function = get_mutate_columns(n_features_dataset, depth_instructions)
  return jax.vmap(culum_mutation_function, in_axes=(1, 1), out_axes=1)

genome_to_mutate = multi_genome._genome_list[1].genome

keys_flat = jar.split(jar.PRNGKey(0), genome_to_mutate.shape[0] * genome_to_mutate.shape[1]*2) # Shape: (R*C*2, 2)
keys_3d = keys_flat.reshape(genome_to_mutate.shape + (2,)+(2,)) # Shape: (R, C, 2)
mutation_fn = get_full_mutation_function(n_features_dataset, depth_instructions)
mutation_fn = jax.jit(mutation_fn)
mutated_genome_all_columns = mutation_fn(genome_to_mutate, keys_3d)

print("Original genome:")
print(genome_to_mutate - mutated_genome_all_columns)
print("Mutated genome (all columns):")
print(mutated_genome_all_columns)

In [None]:
def crossover_function(genome_a, genome_b, crossover_point):
    """Perform crossover between two genomes at the specified point."""
    return jnp.where(
        jnp.arange(genome_a.shape[0]).reshape((-1, 1)) < crossover_point,
        genome_a,
        genome_b
    )
    
crossover_point = 10
genome_a = multi_genome._genome_list[1].genome
genome_b = mutated_genome_all_columns
crossover_genome = crossover_function(genome_a, genome_b, crossover_point)
crossover_genome