In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
import chex
from typing import Any, Type, TypeVar, Generic, ClassVar, Union, Iterator, List, Optional, Tuple

# --- Types & Constants ---
G = TypeVar("G", bound="BaseGenome")

class DistanceMetric:
    HAMMING = "hamming"
    EUCLIDEAN = "euclidean"
    MANHATTAN = "manhattan"

# ==========================================
# 1. CONFIGURATION
# ==========================================
@struct.dataclass
class LinearGenomeConfig:
    length: int           # L: Number of instructions
    num_inputs: int       # N: Number of input features
    num_ops: int          # Number of available functions
    max_arity: int = 2    # Arguments per instruction

# ==========================================
# 2. BASE GENOME (Abstract Individual)
# ==========================================
@struct.dataclass
class BaseGenome:
    """
    Abstract base class for a Single Individual.
    """
    # --- Abstract Interface ---
    @classmethod
    def random_init(cls: Type[G], key: chex.PRNGKey, config: Any) -> G:
        """Create ONE random genome."""
        raise NotImplementedError

    def distance(self, other: "BaseGenome", metric: str) -> float:
        raise NotImplementedError

    def autocorrect(self, config: Any) -> "BaseGenome":
        raise NotImplementedError

    @property
    def size(self) -> int:
        raise NotImplementedError

    # --- Shared Logic (Vectorization) ---
    @classmethod
    def create_population(cls: Type[G], key: chex.PRNGKey, config: Any, pop_size: int) -> G:
        """
        Vectorizes random_init to create a raw batch of genomes.
        Returns object with shape (Pop, ...).
        """
        keys = jax.random.split(key, pop_size)
        return jax.vmap(cls.random_init, in_axes=(0, None))(keys, config)

# ==========================================
# 3. CONCRETE GENOME (Linear Implementation)
# ==========================================
@struct.dataclass
class LinearGenome(BaseGenome):
    ops: chex.Array   # Shape (L,)
    args: chex.Array  # Shape (L, Arity)

    @classmethod
    def random_init(cls, key: chex.PRNGKey, config: LinearGenomeConfig) -> "LinearGenome":
        k_ops, k_args = jax.random.split(key)
        
        # 1. Random Opcodes
        ops = jax.random.randint(k_ops, (config.length,), 0, config.num_ops)
        
        # 2. Topological Arguments (Row 'i' can only see 0..N+i-1)
        row_limits = jnp.arange(config.num_inputs, config.num_inputs + config.length)
        
        def gen_row(rk, climit):
            return jax.random.randint(rk, (config.max_arity,), 0, climit)
            
        row_keys = jax.random.split(k_args, config.length)
        args = jax.vmap(gen_row)(row_keys, row_limits)
        
        return cls(ops=ops, args=args)

    def autocorrect(self, config: LinearGenomeConfig) -> "LinearGenome":
        """Fixes invalid references to ensure Topological DAG structure."""
        valid_ops = jnp.clip(self.ops, 0, config.num_ops - 1)
        
        # Re-calculate limits
        row_limits = jnp.arange(config.num_inputs, config.num_inputs + config.length)
        # Max valid index is limit - 1
        max_indices = row_limits[:, None] - 1
        valid_args = jnp.clip(self.args, 0, max_indices)
        
        return self.replace(ops=valid_ops, args=valid_args)

    def distance(self, other: "LinearGenome", metric: str = "hamming") -> float:
        d_ops = jnp.sum(self.ops != other.ops)
        d_args = jnp.sum(self.args != other.args)
        return (d_ops + d_args).astype(jnp.float32)

    @property
    def size(self) -> int:
        return self.ops.shape[-1]

    def render(self, config: LinearGenomeConfig, op_names: Optional[List[str]] = None) -> str:
        """Human-Readable String Representation."""
        ops_cpu = np.array(self.ops)
        args_cpu = np.array(self.args)
        lines = [f"{'Row':<4} | {'Expression':<30} | {'Raw'}"]
        lines.append("-" * 50)
        
        for i in range(config.length):
            op_idx = int(ops_cpu[i])
            op_str = op_names[op_idx] if op_names and op_idx < len(op_names) else f"OP_{op_idx}"
            
            decoded_args = []
            for arg_idx in args_cpu[i]:
                if arg_idx < config.num_inputs:
                    decoded_args.append(f"x_{arg_idx}")
                else:
                    decoded_args.append(f"v_{arg_idx - config.num_inputs}")
            
            expr = f"v_{i} = {op_str}({', '.join(decoded_args)})"
            lines.append(f"{i:<4} | {expr:<30} | {args_cpu[i]}")
            
        return "\n".join(lines)

    def __repr__(self):
        return f"<LinearGenome(L={self.ops.shape[-1]})>"

# ==========================================
# 4. BASE POPULATION (Abstract Container)
# ==========================================
@struct.dataclass
class BasePopulation(Generic[G]):
    """
    Abstract Container. 
    Implements List-like behavior and automated Vectorization.
    """
    genes: G
    fitness: chex.Array
    
    GENOME_CLS: ClassVar[Type[G]]

    # --- The "Kebab" Interface (List Behavior) ---
    def __len__(self) -> int:
        return int(self.fitness.shape[0])

    def __getitem__(self, key: Union[int, slice, chex.Array]) -> Union[G, "BasePopulation[G]"]:
        # Slice the genes Pytree automatically
        sliced_genes = jax.tree_util.tree_map(lambda x: x[key], self.genes)
        
        if isinstance(key, int):
            # Return Single Genome
            return sliced_genes
        else:
            # Return Sliced Population
            return self.replace(genes=sliced_genes, fitness=self.fitness[key])

    def __iter__(self) -> Iterator[G]:
        for i in range(len(self)):
            yield self[i]

    # --- Automated Logic (Vectorized) ---
    def autocorrect(self, config: Any) -> "BasePopulation[G]":
        """Autocorrects entire batch."""
        new_genes = jax.vmap(lambda g: g.autocorrect(config))(self.genes)
        return self.replace(genes=new_genes)

    def distance_matrix(self, metric: str = "hamming") -> chex.Array:
        """Computes N x N distance matrix."""
        pair_fn = lambda g1, g2: g1.distance(g2, metric)
        return jax.vmap(jax.vmap(pair_fn, in_axes=(None, 0)), in_axes=(0, None))(self.genes, self.genes)

# ==========================================
# 5. CONCRETE POPULATION (Linear Implementation)
# ==========================================
@struct.dataclass
class LinearPopulation(BasePopulation[LinearGenome]):
    genes: LinearGenome
    fitness: chex.Array
    
    GENOME_CLS: ClassVar[Type[LinearGenome]] = LinearGenome

    @classmethod
    def init_random(cls, key: chex.PRNGKey, config: LinearGenomeConfig, size: int) -> "LinearPopulation":
        # 1. Use Genome Factory (Returns Batch)
        batched_genes = LinearGenome.create_population(key, config, size)
        # 2. Init Fitness
        initial_fitness = jnp.full((size,), -jnp.inf)
        return cls(genes=batched_genes, fitness=initial_fitness)
    
    
# 1. Config
config = LinearGenomeConfig(length=5, num_inputs=2, num_ops=4)
key = jax.random.PRNGKey(42)
op_names = ["ADD", "SUB", "MUL", "SIN"]

# 2. Create Population (The Easy Way)
pop = LinearPopulation.init_random(key, config, size=10)

# 3. Test "Kebab" Indexing
print(f"Population Size: {len(pop)}")

first_guy = pop[0] # Returns LinearGenome
print("\n--- First Individual ---")
print(first_guy.render(config, op_names))

# 4. Test Slicing
subset = pop[:3] # Returns LinearPopulation
print(f"\nSubset Size: {len(subset)}")

# 5. Test Automated Vectorization
matrix = pop.distance_matrix()
print(f"\nDistance Matrix Shape: {matrix.shape}")

Population Size: 10

--- First Individual ---
Row  | Expression                     | Raw
--------------------------------------------------
0    | v_0 = SIN(x_0, x_1)            | [0 1]
1    | v_1 = SUB(x_0, x_0)            | [0 0]
2    | v_2 = SUB(v_1, v_1)            | [3 3]
3    | v_3 = SIN(v_2, v_0)            | [4 2]
4    | v_4 = SIN(x_1, x_1)            | [1 1]

Subset Size: 3

Distance Matrix Shape: (10, 10)


In [2]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from functools import partial
from typing import Any, Tuple, Callable, TypeVar, Generic

# --- A. Define Robust Primitives (The Instruction Set) ---
def op_add(x, y): return x + y
def op_sub(x, y): return x - y
def op_mul(x, y): return x * y
def op_div(x, y): return jnp.where(jnp.abs(y) < 0.001, 1.0, x / y) # Protected
def op_sin(x, y): return jnp.sin(x)
def op_cos(x, y): return jnp.cos(x)

# The registry used by lax.switch (Index matches OpCode)
OP_FUNCTIONS = (op_add, op_sub, op_mul, op_div, op_sin, op_cos)
OP_NAMES = ["ADD", "SUB", "MUL", "DIV", "SIN", "COS"]

# --- B. The Abstract Evaluator ---
G = TypeVar("G", bound="BaseGenome")
C = TypeVar("C") # Config Type
D = TypeVar("D") # Data Type

@struct.dataclass
class BaseEvaluator(Generic[G, C, D]):
    config: C

    def evaluate(self, genome: G, data: D) -> float:
        """Abstract: Returns scalar fitness for one genome."""
        raise NotImplementedError

    def evaluate_population(self, population: "BasePopulation[G]", data: D) -> "BasePopulation[G]":
        """Auto-Vectorization: Pop x Data -> Pop (with Fitness)"""
        # vmap over genes (axis 0), keep self/data constant
        fitness_scores = jax.vmap(self.evaluate, in_axes=(0, None))(population.genes, data)
        return population.replace(fitness=fitness_scores)

# --- C. The Concrete Linear Evaluator ---
# Data Type: Tuple(Inputs, Targets)
RegressionData = Tuple[chex.Array, chex.Array]

@struct.dataclass
class LinearGPEvaluator(BaseEvaluator[LinearGenome, LinearGenomeConfig, RegressionData]):
    
    def predict_one(self, genome: LinearGenome, x_input: chex.Array) -> chex.Array:
        """
        Interprets one genome on one input vector.
        Returns: Shape (L,) -> The output of EVERY instruction row.
        """
        # 1. Initialize Memory
        total_mem = self.config.num_inputs + self.config.length
        memory = jnp.zeros(total_mem)
        memory = memory.at[:self.config.num_inputs].set(x_input)
        
        # 2. Execution Loop
        def step(current_mem, inputs):
            mem, write_idx = current_mem
            op_code, arg_indices = inputs
            
            # Fetch & Execute
            args_val = jnp.take(mem, arg_indices)
            res = jax.lax.switch(op_code, OP_FUNCTIONS, args_val[0], args_val[1])
            res = jnp.nan_to_num(res, nan=0.0, posinf=1.0, neginf=-1.0)
            
            # Write to register
            new_mem = mem.at[write_idx].set(res)
            
            # CRITICAL CHANGE: 
            # We return 'res' as the history item, NOT 'new_mem'.
            # 'res' is the output of the current Atomic Tree.
            return (new_mem, write_idx + 1), res

        init_val = (memory, self.config.num_inputs)
        
        # history will have shape (Length,)
        # It contains [Output_Row_0, Output_Row_1, ..., Output_Row_L]
        (_, _), history = jax.lax.scan(step, init_val, (genome.ops, genome.args))
        
        return history

    def evaluate(self, genome: LinearGenome, data: RegressionData) -> float:
        X, Y = data
        
        # 1. Vectorize Prediction (Data Parallelism)
        # predict_one returns shape (Length,)
        # vmap(predict_one) over X returns shape (Batch_Size, Length)
        all_preds = jax.vmap(self.predict_one, in_axes=(None, 0))(genome, X)
        
        # all_preds shape: (1000 samples, 20 instructions)
        # Y shape:         (1000 samples,)
        
        # 2. Calculate MSE for EVERY instruction column
        # We broadcast Y to match shape (1000, 20)
        Y_bcast = Y[:, None]
        
        # Squared Error: (1000, 20)
        squared_errors = (all_preds - Y_bcast) ** 2
        
        # Mean over Data Axis (0) -> Result shape (20,)
        # This gives us the MSE for Atomic Tree 0, Atomic Tree 1...
        mse_per_tree = jnp.mean(squared_errors, axis=0)
        
        # 3. The "Symbiotic" Selection
        # The fitness of the Genome is the fitness of its BEST Atomic Tree.
        # We minimize MSE, so we take the minimum error found.
        best_mse = jnp.min(mse_per_tree)
        
        return -mse_per_tree

In [3]:
import numpy as np
from sklearn.datasets import make_regression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# 1. PREPARE DATA (Sklearn)
# Generate a simple regression problem: y = combination of 5 features
X_raw, y_raw = make_regression(n_samples=1000, n_features=5, noise=0.1, random_state=42)

# Important: Scale data for Neural/Evolutionary methods
scaler_x = StandardScaler()
X_scaled = scaler_x.fit_transform(X_raw)
# JAX expects arrays, not numpy
X_jax = jnp.array(X_scaled)
y_jax = jnp.array(y_raw)

print(f"Dataset Shape: X={X_jax.shape}, y={y_jax.shape}")

# 2. CONFIGURE ENGINE
# We have 5 inputs (features)
# We give it 20 instructions (depth) to solve the problem
config = LinearGenomeConfig(
    length=20, 
    num_inputs=5, 
    num_ops=len(OP_FUNCTIONS), # 6 operators
    max_arity=2
)

# 3. INITIALIZE POPULATION
# Let's create a massive batch to see JAX speed
POP_SIZE = 2000
key = jax.random.PRNGKey(0)

print(f"Initializing Population of {POP_SIZE} individuals...")
# Calls the Factory we wrote earlier
population = LinearPopulation.init_random(key, config, size=POP_SIZE)

# 4. RUN EVALUATION
print("Compiling & Evaluating...")
evaluator = LinearGPEvaluator(config=config)

# This triggers the JIT compilation (might take 1-2 sec first time)
# Then it runs 2000 genomes * 1000 samples = 2 Million evaluations instantly
evaluated_pop = evaluator.evaluate_population(population, (X_jax, y_jax))

# 5. INSPECT RESULTS
# Get the winner
best_fitness = jnp.max(evaluated_pop.fitness)

print(f"\nOptimization Complete!")
print(f"Best MSE: {-best_fitness:.4f}") # Remember we minimized negative MSE
print("-" * 30)
print("Best Program Structure:")


Dataset Shape: X=(1000, 5), y=(1000,)
Initializing Population of 2000 individuals...
Compiling & Evaluating...

Optimization Complete!
Best MSE: 3451.2764
------------------------------
Best Program Structure:


In [4]:
best_solution_index = jnp.argmax(evaluated_pop.fitness)
best_genome = evaluated_pop[int(best_solution_index)]
print(best_genome.render(config, OP_NAMES))

Row  | Expression                     | Raw
--------------------------------------------------
0    | v_0 = COS(x_2, x_4)            | [2 4]
1    | v_1 = DIV(x_1, v_0)            | [1 5]
2    | v_2 = MUL(v_1, v_0)            | [6 5]
3    | v_3 = SUB(v_1, x_0)            | [6 0]
4    | v_4 = COS(x_3, v_0)            | [3 5]
5    | v_5 = COS(x_2, v_1)            | [2 6]
6    | v_6 = MUL(v_0, x_0)            | [5 0]
7    | v_7 = SIN(x_0, x_3)            | [0 3]
8    | v_8 = COS(v_4, v_5)            | [ 9 10]
9    | v_9 = SIN(x_0, x_0)            | [0 0]
10   | v_10 = DIV(v_5, x_1)           | [10  1]
11   | v_11 = DIV(v_5, x_3)           | [10  3]
12   | v_12 = DIV(x_1, v_8)           | [ 1 13]
13   | v_13 = COS(v_7, v_10)          | [12 15]
14   | v_14 = SUB(v_6, x_3)           | [11  3]
15   | v_15 = SUB(v_14, v_2)          | [19  7]
16   | v_16 = DIV(v_6, v_7)           | [11 12]
17   | v_17 = DIV(v_8, v_11)          | [13 16]
18   | v_18 = ADD(v_6, x_4)           | [11  4]
19   | v_19

In [5]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Any, TypeVar, Generic

# Generic Types
G = TypeVar("G", bound="BaseGenome")
C = TypeVar("C")

# ==========================================
# 1. ABSTRACT BASE: MUTATION
# ==========================================
@struct.dataclass
class BaseMutation(Generic[G, C]):
    """
    Abstract Mutation Functor.
    """
    # --- STATIC PARAMS (Re-compile if changed) ---
    # Determines the output shape: (N_Offspring, ...)
    num_offspring: int = struct.field(pytree_node=False, default=1)

    def __call__(self, key: chex.PRNGKey, genome: G, config: C) -> G:
        """
        Applies mutation to produce 'num_offspring' children.
        Output Shape: (Num_Offspring, Genome_Size...)
        """
        # 1. Split keys for the static number of children
        keys = jax.random.split(key, self.num_offspring)
        
        # 2. Vectorize the single mutation logic
        # vmap over keys (axis 0), keep genome & config constant
        return jax.vmap(
            lambda k, g, c: self._mutate_one(k, g, c), 
            in_axes=(0, None, None)
        )(keys, genome, config)

    def _mutate_one(self, key: chex.PRNGKey, genome: G, config: C) -> G:
        """Abstract: Logic to produce EXACTLY ONE mutant."""
        raise NotImplementedError


# ==========================================
# 2. ABSTRACT BASE: CROSSOVER
# ==========================================
@struct.dataclass
class BaseCrossover(Generic[G, C]):
    """
    Abstract Crossover Functor.
    """
    # --- STATIC PARAMS (Re-compile if changed) ---
    num_offspring: int = struct.field(pytree_node=False, default=1)

    def __call__(self, key: chex.PRNGKey, p1: G, p2: G, config: C) -> G:
        """
        Combines two parents to produce 'num_offspring' children.
        Output Shape: (Num_Offspring, Genome_Size...)
        """
        keys = jax.random.split(key, self.num_offspring)
        
        # vmap over keys (axis 0), keep parents & config constant
        return jax.vmap(
            lambda k, a, b, c: self._cross_one(k, a, b, c),
            in_axes=(0, None, None, None)
        )(keys, p1, p2, config)

    def _cross_one(self, key: chex.PRNGKey, p1: G, p2: G, config: C) -> G:
        """Abstract: Logic to produce EXACTLY ONE child."""
        raise NotImplementedError


# ==========================================
# 3. CONCRETE: LINEAR MUTATION
# ==========================================
@struct.dataclass
class LinearMutation(BaseMutation[LinearGenome, LinearGenomeConfig]):
    """
    Standard Linear GP Mutation.
    """
    # --- DYNAMIC PARAMS (Annealable/Mutable) ---
    # These are treated as inputs to the kernel.
    op_rate: float  = 0.1   # Prob to flip Opcode
    arg_rate: float = 0.1   # Prob to flip Argument

    def _mutate_one(self, key: chex.PRNGKey, genome: LinearGenome, config: LinearGenomeConfig) -> LinearGenome:
        k_op, k_arg, k_noise = jax.random.split(key, 3)

        # 1. Generate Boolean Masks (Where to mutate?)
        # Uses self.op_rate / self.arg_rate dynamically
        mask_ops = jax.random.bernoulli(k_op, self.op_rate, genome.ops.shape)
        mask_args = jax.random.bernoulli(k_arg, self.arg_rate, genome.args.shape)

        # 2. Generate Random Noise (New values)
        # Note: We reuse random generation logic inline for speed
        noise_ops = jax.random.randint(k_noise, genome.ops.shape, 0, config.num_ops)
        
        # Args can point up to (Inputs + Length)
        # Autocorrect will fix forward-references later
        max_mem = config.num_inputs + config.length
        # Re-split k_noise for args
        k_noise_args = jax.random.split(k_noise)[0]
        noise_args = jax.random.randint(k_noise_args, genome.args.shape, 0, max_mem)

        # 3. Apply Changes (If mask is True, take Noise, else take Old)
        new_ops = jnp.where(mask_ops, noise_ops, genome.ops)
        new_args = jnp.where(mask_args, noise_args, genome.args)

        # 4. Construct & Repair
        # Enforce Topological Validity (DAG)
        return genome.replace(ops=new_ops, args=new_args).autocorrect(config)


# ==========================================
# 4. CONCRETE: LINEAR CROSSOVER
# ==========================================
@struct.dataclass
class LinearCrossover(BaseCrossover[LinearGenome, LinearGenomeConfig]):
    """
    Uniform Crossover (Coin Flip Mixing).
    """
    # --- DYNAMIC PARAMS ---
    mixing_ratio: float = 0.5  # 0.5 = Balanced mix

    def _cross_one(self, key: chex.PRNGKey, p1: LinearGenome, p2: LinearGenome, config: LinearGenomeConfig) -> LinearGenome:
        # 1. Generate Mixing Mask
        # True = From P1, False = From P2
        mask = jax.random.bernoulli(key, self.mixing_ratio, p1.ops.shape)

        # 2. Mix Opcodes
        child_ops = jnp.where(mask, p1.ops, p2.ops)

        # 3. Mix Arguments
        # Broadcast mask (L,) -> (L, Arity) to keep args coupled
        mask_expanded = mask[:, None]
        child_args = jnp.where(mask_expanded, p1.args, p2.args)

        # 4. Return
        # Uniform crossover between two valid topological parents 
        # is guaranteed to be valid, so no autocorrect needed.
        return p1.replace(ops=child_ops, args=child_args)

In [6]:
# 1. Setup
config = LinearGenomeConfig(length=10, num_inputs=2, num_ops=5)
key = jax.random.PRNGKey(42)
parent = LinearGenome.random_init(key, config)

# -----------------------------------------------------------
# SCENARIO A: Standard Evolution (1 Parent -> 1 Child)
# -----------------------------------------------------------
mut_standard = LinearMutation(num_offspring=1, op_rate=0.1, arg_rate=0.1)
child_batch = mut_standard(key, parent, config)

print(f"Standard Output Shape: {child_batch.ops.shape}")
# Output: (1, 10) -> Batch of size 1


# -----------------------------------------------------------
# SCENARIO B: "Explosive" Evolution (1 Parent -> 10 Mutants)
# (Useful for Evolution Strategies / ES)
# -----------------------------------------------------------
mut_explosive = LinearMutation(num_offspring=10, op_rate=0.2, arg_rate=0.2)
cloud_batch = mut_explosive(key, parent, config)

print(f"Explosive Output Shape: {cloud_batch.ops.shape}")
# Output: (10, 10) -> Batch of size 10 generated instantly


# -----------------------------------------------------------
# SCENARIO C: Twin Crossover (2 Parents -> 2 Children)
# -----------------------------------------------------------
cross_twin = LinearCrossover(num_offspring=2, mixing_ratio=0.5)
p1 = LinearGenome.random_init(key, config)
p2 = LinearGenome.random_init(key, config)

twins = cross_twin(key, p1, p2, config)
print(f"Twins Shape: {twins.ops.shape}")
# Output: (2, 10)

Standard Output Shape: (1, 10)
Explosive Output Shape: (10, 10)
Twins Shape: (2, 10)


In [7]:
import jax
import jax.numpy as jnp
import chex


def run_variation_demo():
    print("=== MalthusJax Variation Demo ===\n")
    
    # 1. SETUP
    # Length 5, 2 Inputs (x0, x1), 4 Ops (ADD, SUB, MUL, SIN)
    config = LinearGenomeConfig(length=5, num_inputs=2, num_ops=4)
    key = jax.random.PRNGKey(42)
    op_names = ["ADD", "SUB", "MUL", "SIN"]

    # 2. CREATE PARENTS
    k1, k2, k3 = jax.random.split(key, 3)
    p1 = LinearGenome.random_init(k1, config)
    p2 = LinearGenome.random_init(k2, config)

    print("--- Parent 1 ---")
    print(p1.render(config, op_names))
    
    # 3. TEST CROSSOVER (Static: 2 Children, Dynamic: 50% Mix)
    print("\n\n>>> Applying Twin Crossover (2 Offspring)...")
    crossover_op = LinearCrossover(num_offspring=2, mixing_ratio=0.5)
    
    # Run JIT-compiled
    @jax.jit
    def run_cross(k, a, b):
        return crossover_op(k, a, b, config)

    twins = run_cross(k3, p1, p2)
    
    print(f"Output Shape: {twins.ops.shape}") # Should be (2, 5)
    
    # Slice the batch to see Child 0
    child_0 = jax.tree_util.tree_map(lambda x: x[0], twins)
    print("\n--- Child 0 (Mixed Logic) ---")
    print(child_0.render(config, op_names))

    # 4. TEST MUTATION (Static: 3 Mutants, Dynamic: High Noise)
    print("\n\n>>> Applying Explosive Mutation (3 Mutants, High Rate)...")
    # High mutation rate to force changes
    mutation_op = LinearMutation(num_offspring=3, op_rate=0.5, arg_rate=0.5)
    
    @jax.jit
    def run_mut(k, g):
        return mutation_op(k, g, config)

    k_mut = jax.random.split(key)[0]
    mutants = run_mut(k_mut, p1)
    
    print(f"Output Shape: {mutants.ops.shape}") # Should be (3, 5)
    
    # Slice to see Mutant 0
    mutant_0 = jax.tree_util.tree_map(lambda x: x[0], mutants)
    print("\n--- Mutant 0 (Noisy & Autocorrected) ---")
    print(mutant_0.render(config, op_names))

# Run it
run_variation_demo()

=== MalthusJax Variation Demo ===

--- Parent 1 ---
Row  | Expression                     | Raw
--------------------------------------------------
0    | v_0 = SIN(x_0, x_1)            | [0 1]
1    | v_1 = SUB(x_0, x_0)            | [0 0]
2    | v_2 = SUB(v_1, v_1)            | [3 3]
3    | v_3 = SIN(v_2, v_0)            | [4 2]
4    | v_4 = SIN(x_1, x_1)            | [1 1]


>>> Applying Twin Crossover (2 Offspring)...
Output Shape: (2, 5)

--- Child 0 (Mixed Logic) ---
Row  | Expression                     | Raw
--------------------------------------------------
0    | v_0 = SIN(x_0, x_1)            | [0 1]
1    | v_1 = MUL(x_0, x_1)            | [0 1]
2    | v_2 = SUB(v_1, v_1)            | [3 3]
3    | v_3 = SIN(v_2, v_0)            | [4 2]
4    | v_4 = MUL(v_1, v_2)            | [3 4]


>>> Applying Explosive Mutation (3 Mutants, High Rate)...
Output Shape: (3, 5)

--- Mutant 0 (Noisy & Autocorrected) ---
Row  | Expression                     | Raw
--------------------------------

In [8]:
def run_jit_demo():
    print("=== JIT Compilation Verification ===\n")
    config = LinearGenomeConfig(length=5, num_inputs=2, num_ops=10)
    master_key = jax.random.PRNGKey(42)

    # --- STEP A: JIT-Compiled Initialization ---
    print("1. Compiling Initialization...", end=" ")
    
    @jax.jit
    def create_parent(key):
        return LinearGenome.random_init(key, config)
    
    # Run once to compile
    k1, k2, k3 = jax.random.split(master_key, 3)
    parent_1 = create_parent(k1)
    parent_2 = create_parent(k2)
    print("Done.")
    print(f"   Parent 1 Shape: {parent_1.ops.shape}")
    print(f"   > Row 0: {parent_1.render(config).splitlines()[0]}")

    # --- STEP B: JIT-Compiled Crossover (Twins) ---
    print("\n2. Compiling Twin Crossover (Static=2 Children)...", end=" ")
    
    # Instantiate operator OUTSIDE JIT (Configures static shapes)
    crossover_op = LinearCrossover(num_offspring=2, mixing_ratio=0.5)

    @jax.jit
    def run_crossover(key, p1, p2):
        return crossover_op(key, p1, p2, config)

    # Run execution
    offspring_batch = run_crossover(k3, parent_1, parent_2)
    print("Done.")
    
    # Check Result
    print(f"   Offspring Batch Shape: {offspring_batch.ops.shape} (Should be 2, 5)")
    # Access Child 0 from batch
    child_0 = jax.tree_util.tree_map(lambda x: x[0], offspring_batch)
    print(f"   > Child 0 Row 0: {child_0.render(config).splitlines()[0]}")

    # --- STEP C: JIT-Compiled Mutation (Explosive) ---
    print("\n3. Compiling Explosive Mutation (Static=5 Mutants)...", end=" ")
    
    # Instantiate operator with different Static Config
    mutation_op = LinearMutation(num_offspring=5, op_rate=0.2, arg_rate=0.2)

    @jax.jit
    def run_mutation(key, genome):
        return mutation_op(key, genome, config)

    k_mut = jax.random.split(master_key)[0]
    mutant_batch = run_mutation(k_mut, parent_1)
    print("Done.")
    
    # Check Result
    print(f"   Mutant Batch Shape: {mutant_batch.ops.shape} (Should be 5, 5)")
    
    # Verify Autocorrect worked (Row 0 must imply inputs only)
    mutant_0 = jax.tree_util.tree_map(lambda x: x[0], mutant_batch)
    print(f"   > Mutant 0 Row 0: {mutant_0.render(config).splitlines()[0]}")
    
    print("\n[SUCCESS] All operators compiled and ran on accelerator.")

run_jit_demo()

=== JIT Compilation Verification ===

1. Compiling Initialization... Done.
   Parent 1 Shape: (5,)
   > Row 0: Row  | Expression                     | Raw

2. Compiling Twin Crossover (Static=2 Children)... Done.
   Offspring Batch Shape: (2, 5) (Should be 2, 5)
   > Child 0 Row 0: Row  | Expression                     | Raw

3. Compiling Explosive Mutation (Static=5 Mutants)... Done.
   Mutant Batch Shape: (5, 5) (Should be 5, 5)
   > Mutant 0 Row 0: Row  | Expression                     | Raw

[SUCCESS] All operators compiled and ran on accelerator.


In [9]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Generic

# Generic Genome Type not needed here, strictly math!

@struct.dataclass
class BaseSelection:
    """
    Abstract Selection Operator.
    Operates purely on Fitness Arrays to return Indices.
    """
    # STATIC: How many parents do we want to pick?
    num_selections: int = struct.field(pytree_node=False)

    def __call__(self, key: chex.PRNGKey, fitness: chex.Array) -> chex.Array:
        """
        Args:
            key: RNG Key
            fitness: Shape (Pop_Size, ...) - Can be 1D or 2D (Symbiotic)
            
        Returns:
            Selected Indices: Shape (num_selections,) int32
        """
        raise NotImplementedError

In [10]:
@struct.dataclass
class SymbioticTournament(BaseSelection):
    # STATIC CONFIG
    tournament_size: int = struct.field(pytree_node=False, default=3)
    symbionts_per_genome: int = struct.field(pytree_node=False, default=3)

    def __call__(self, key: chex.PRNGKey, fitness_matrix: chex.Array) -> chex.Array:
        """
        Performs tournament selection on the 'unrolled' best atomic trees.
        
        Args:
            fitness_matrix: Shape (N, L) - Fitness of every atomic tree.
            
        Returns:
            indices: Shape (num_selections,) - Indices of the winning GENOMES (0..N-1).
        """
        N, L = fitness_matrix.shape
        
        # 1. Filter: Get the Top K Symbionts per Genome
        # We don't want all 50 instructions entering (too much garbage).
        # We pick the elite sub-components.
        # values shape: (N, K)
        top_k_values, _ = jax.lax.top_k(fitness_matrix, self.symbionts_per_genome)
        
        # 2. Flatten: Create the "Symbiont Pool"
        # We treat this as a population of size (N * K)
        pool_fitness = top_k_values.ravel() # Shape (N*K, )
        total_candidates = pool_fitness.shape[0]
        
        # 3. Tournament: Select Winning Symbionts
        
        # A. Pick Random Contenders (Indices into the pool)
        # Shape: (Num_Selections, Tournament_Size)
        k_tourn = jax.random.split(key)[0]
        contestants = jax.random.randint(
            k_tourn, 
            shape=(self.num_selections, self.tournament_size), 
            minval=0, 
            maxval=total_candidates
        )
        
        # B. Get their scores
        # We use 'take' to gather fitness values
        scores = jnp.take(pool_fitness, contestants)
        
        # C. Find the Winner (Argmax)
        # Shape: (Num_Selections,) -> Local index 0..T-1
        winner_local_idx = jnp.argmax(scores, axis=1)
        
        # D. Get the Winner's Pool Index
        # advanced indexing: pick the winner from the contestants row
        # vmap over the rows
        winner_pool_indices = jax.vmap(lambda row, idx: row[idx])(contestants, winner_local_idx)
        
        # 4. Map Back: Symbiont ID -> Genome ID
        # If we have 3 symbionts per genome:
        # Symbiont 0, 1, 2  -> Genome 0
        # Symbiont 3, 4, 5  -> Genome 1
        # Formula: Genome_ID = Symbiont_ID // K
        
        winner_genome_indices = winner_pool_indices // self.symbionts_per_genome
        
        return winner_genome_indices

In [11]:
pop.fitness.shape

(10,)

In [12]:
# 1. Setup
# Assume pop.fitness is (100, 50) -> 100 Genomes, 50 Atomic Trees each
selector = SymbioticTournament(
    num_selections=100,    # We want to select 100 parents
    tournament_size=4,     # Standard pressure
    symbionts_per_genome=5 # Let the top 5 atomic trees represent the genome
)

# 2. THE SELECTION STEP (Pure Math)
# "Here are the scores, who wins?"
# Returns: Array([0, 0, 5, 12, 99...])
selected_indices = selector(key, pop.fitness)

# 3. THE GATHERING STEP (Data Movement)
# "Okay, grab those genomes."
# Your BasePopulation.__getitem__ handles this automatically!
parents = pop[selected_indices]

# Now 'parents' is a LinearPopulation of size 100, ready for Crossover.

ValueError: not enough values to unpack (expected 2, got 1)