In [None]:
import jax
import jax.numpy as jnp
from flax import struct
from functools import partial
import chex
from typing import Any, Tuple, List, Type, TypeVar,Generic

# 1. Define the Metric Enum
class DistanceMetric:
    HAMMING = "hamming"
    EUCLIDEAN = "euclidean"
    MANHATTAN = "manhattan"
    
T = TypeVar("T", bound="BaseGenome")

# 2. The Abstract Base Class
@struct.dataclass
class BaseGenome:
    """
    All genomes must inherit from this.
    It guarantees they are valid Pytrees.
    """
    
    # FIX: Swap args to (rng_key, config) to match vmap below
    @classmethod
    def random_init(cls: Type[T], rng_key: chex.PRNGKey, config: Any) -> T: 
        """Generates a random genome based on configuration."""
        raise NotImplementedError
    
    def distance(self, other: "BaseGenome", metric: str = DistanceMetric.HAMMING) -> float:
        """Calculates distance between this genome and another."""
        raise NotImplementedError

    def autocorrect(self, config: Any) -> "BaseGenome":
        """Returns a valid version of itself based on constraints."""
        raise NotImplementedError
        
    @property
    def size(self) -> int:
        raise NotImplementedError
    
    @property
    def shape(self) -> Tuple[int, ...]:
        raise NotImplementedError
    
    @property
    def len(self) -> int:
        return self.size
    
# --- Concrete Shared Methods (Inherited Logic) ---
    @classmethod
    def create_population(cls: Type[T], rng_key: chex.PRNGKey, config: Any, pop_size: int) -> T:
        """Generates a population of random genomes."""
        rng_keys = jax.random.split(rng_key, pop_size)
        # Now this works: rng_keys (axis 0) -> rng_key, config (axis None) -> config
        return jax.vmap(cls.random_init, in_axes=(0, None))(rng_keys, config)
    
    @classmethod
    def distance_matrix(cls: Type[T], population: T, metric: str = DistanceMetric.HAMMING) -> chex.Array:
        """Computes pairwise distance matrix for a population."""
        # Logic is correct: (N, N) matrix
        def pairwise_distance(g1: T, g2: T) -> float:
            return g1.distance(g2, metric)
        
        v_pairwise_distance = jax.vmap(
            jax.vmap(pairwise_distance, in_axes=(None, 0)), 
            in_axes=(0, None)
        )
        return v_pairwise_distance(population, population)

In [None]:
# 3. BASE POPULATION (Abstract Interface)
@struct.dataclass
class BasePopulation(Generic[T]):
    """Abstract container for Batch + Fitness."""
    
    @property
    def fitness(self) -> chex.Array:
        raise NotImplementedError

    @property
    def size(self) -> int:
        raise NotImplementedError

    @classmethod
    def init_random(cls, key: chex.PRNGKey, config: Any, size: int) -> "BasePopulation[T]":
        """Factory: Creates the Population object (Genes + Initial Fitness)."""
        raise NotImplementedError

In [3]:
from typing import Optional
import numpy as np

@struct.dataclass
class LinearGenomeConfig:
    length: int           # L: Number of Atomic Trees (Instructions)
    num_inputs: int       # N: Number of Features in Database
    num_ops: int          # Number of available functions (ADD, SUB...)
    max_arity: int = 2    # Arguments per instruction
    
    # Note: Addressable memory grows dynamically.
    # Max addressable index by the LAST instruction is (num_inputs + length - 1)

@struct.dataclass
class LinearGenome(BaseGenome):
    ops: chex.Array   # Shape (L,)
    args: chex.Array  # Shape (L, Arity)

    @classmethod
    def random_init(cls, key: chex.PRNGKey, config: LinearGenomeConfig) -> "LinearGenome":
        k_ops, k_args = jax.random.split(key, 2)
        
        # 1. Random Opcodes
        # Standard uniform distribution [0, num_ops)
        ops = jax.random.randint(
            k_ops, shape=(config.length,), 
            minval=0, maxval=config.num_ops
        )
        
        # 2. Topological Arguments
        # We need to determine the 'limit' (exclusive max) for every row.
        # Row 0 (Address N)   can see: 0 to N-1          (Limit = N)
        # Row 1 (Address N+1) can see: 0 to N            (Limit = N+1)
        # Row i (Address N+i) can see: 0 to N+i-1        (Limit = N+i)
        
        # Create limits array: [N, N+1, ..., N+L-1]
        row_limits = jnp.arange(config.num_inputs, config.num_inputs + config.length)
        
        # Helper to generate one row's args given a key and a limit
        def generate_row_args(r_key, r_limit):
            return jax.random.randint(
                r_key, shape=(config.max_arity,), minval=0, maxval=r_limit
            )
            
        # Split keys for each row
        row_keys = jax.random.split(k_args, config.length)
        
        # Vectorize the generation
        args = jax.vmap(generate_row_args)(row_keys, row_limits)
        
        return cls(ops=ops, args=args)

    def autocorrect(self, config: LinearGenomeConfig) -> "LinearGenome":
        """
        Fixes the genome to ensure topological validity.
        Ensures Row 'i' never references itself or future rows.
        """
        # 1. Fix Opcodes
        valid_ops = jnp.clip(self.ops, 0, config.num_ops - 1)
        
        # 2. Fix Arguments based on Topology
        # Re-create the limits: [N, N+1, ...]
        row_limits = jnp.arange(config.num_inputs, config.num_inputs + config.length)
        
        # Broadcast limits for comparison against args: (L,) -> (L, 1)
        row_limits_bcast = row_limits[:, None]
        
        # The maximum allowed INDEX is (limit - 1)
        max_valid_indices = row_limits_bcast - 1
        
        # Clip args to be within [0, max_valid_index]
        # This handles cases where mutation might set Row 0 to point to index 99
        valid_args = jnp.clip(self.args, 0, max_valid_indices)
        
        return self.replace(ops=valid_ops, args=valid_args)

    def distance(self, other: "LinearGenome", metric: str = "hamming") -> float:
        if metric == "hamming":
            d_ops = jnp.sum(self.ops != other.ops)
            d_args = jnp.sum(self.args != other.args)
            return (d_ops + d_args).astype(jnp.float32)
        elif metric == "euclidean":
            d_ops = (self.ops - other.ops).astype(jnp.float32)
            d_args = (self.args - other.args).astype(jnp.float32)
            return jnp.sqrt(jnp.sum(d_ops**2) + jnp.sum(d_args**2))
        return 0.0
    
    def render(self, config: LinearGenomeConfig, op_names: Optional[List[str]] = None) -> str:
        """
        Returns a human-readable string representation of the genome.
        Decodes the Topological Indices into 'Inputs' vs 'Registers'.
        
        Format:
        Row 0 | v0 = ADD(x0, x1)
        Row 1 | v1 = SUB(v0, x2)
        """
        # 1. Pull data from Accelerator to CPU for printing
        # We use numpy conversion to avoid JAX tracer errors during printing
        ops_cpu = np.array(self.ops)
        args_cpu = np.array(self.args)
        
        lines = []
        lines.append(f"{'Row':<4} | {'Expression':<30} | {'Raw (Arg | Op)'}")
        lines.append("-" * 60)

        for i in range(config.length):
            # 1. Decode The Operator
            op_idx = int(ops_cpu[i])
            op_str = op_names[op_idx] if op_names and op_idx < len(op_names) else f"OP_{op_idx}"
            
            # 2. Decode The Arguments (The Topological Logic)
            # Memory Layout: [0...num_inputs-1] are Inputs (x)
            #                [num_inputs... ]   are Previous Rows (v)
            decoded_args = []
            for arg_idx in args_cpu[i]:
                arg_val = int(arg_idx)
                if arg_val < config.num_inputs:
                    # It points to an Input Feature
                    decoded_args.append(f"x_{arg_val}")
                else:
                    # It points to a previous Atomic Tree (Register)
                    # We map index back to relative row: index - num_inputs
                    # E.g. If num_inputs=2, index 2 is actually v0 (Row 0)
                    rel_idx = arg_val - config.num_inputs
                    decoded_args.append(f"v_{rel_idx}")
            
            # 3. Format the String
            # e.g., "v_0 = ADD(x_0, x_1)"
            args_str = ", ".join(decoded_args)
            
            # The definition of this row's output variable
            current_var = f"v_{i}" 
            
            expression = f"{current_var} = {op_str}({args_str})"
            raw_fmt = f"{args_cpu[i]} | {op_str}"
            
            lines.append(f"{i:<4} | {expression:<30} | {raw_fmt}")

        return "\n".join(lines)

    def __repr__(self):
        """
        Default Python representation. 
        Keeps it simple to avoid cluttering logs with massive tables.
        """
        return f"<LinearGenome(L={self.ops.shape[0]}, device={self.ops.device})>"

# Example Usage
config = LinearGenomeConfig(length=10, num_inputs=5, num_ops=8)
key = jax.random.PRNGKey(42)
genome = LinearGenome.random_init(key, config)
print("Random Genome Ops:\n", genome.ops)
print("Random Genome Args:\n", genome.args)

Random Genome Ops:
 [4 3 3 1 6 3 4 5 4 3]
Random Genome Args:
 [[ 1  3]
 [ 4  1]
 [ 0  4]
 [ 4  3]
 [ 0  8]
 [ 0  1]
 [ 9  2]
 [ 2  9]
 [ 4  5]
 [10  8]]


In [4]:
# 1. Setup
config = LinearGenomeConfig(length=10, num_inputs=2, num_ops=4)
# Let's say our operators are these:
op_names = ["ADD", "SUB", "MUL", "SIN"]

# 2. Create a random genome
key = jax.random.PRNGKey(42)
genome = LinearGenome.random_init(key, config)

# 3. Print the "Human" version
print(genome.render(config, op_names))
# 4. Print the default representation
print("\n\n\n")
print(str(genome))

Row  | Expression                     | Raw (Arg | Op)
------------------------------------------------------------
0    | v_0 = ADD(x_1, x_0)            | [1 0] | ADD
1    | v_1 = SIN(x_1, x_1)            | [1 1] | SIN
2    | v_2 = SIN(v_0, x_0)            | [2 0] | SIN
3    | v_3 = SUB(x_0, x_1)            | [0 1] | SUB
4    | v_4 = MUL(v_1, v_0)            | [3 2] | MUL
5    | v_5 = SIN(v_1, v_0)            | [3 2] | SIN
6    | v_6 = ADD(v_2, v_3)            | [4 5] | ADD
7    | v_7 = SUB(v_6, v_4)            | [8 6] | SUB
8    | v_8 = ADD(v_6, v_0)            | [8 2] | ADD
9    | v_9 = SIN(v_7, v_3)            | [9 5] | SIN




<LinearGenome(L=10, device=TFRT_CPU_0)>


In [6]:
population = LinearGenome.create_population(rng_key=jax.random.PRNGKey(0),
                                            config=config, pop_size=5)


In [None]:
population.args.shape

(5, 10, 2)

In [None]:
# 4. CONCRETE POPULATION (Linear Implementation)
@struct.dataclass
class LinearPopulation(BasePopulation[LinearGenome]):
    genes: LinearGenome     # The batched genome from create_population
    fitness: chex.Array     # (N,) array

    @property
    def size(self) -> int:
        return self.fitness.shape[0]

    @classmethod
    def init_random(cls, key: chex.PRNGKey, config: LinearGenomeConfig, size: int) -> "LinearPopulation":
        # 1. Use the Genome's existing factory to get the batch
        # This calls LinearGenome.create_population -> vmap(random_init)
        batched_genes = LinearGenome.create_population(key, config, size)
        
        # 2. Initialize fitness (e.g., to negative infinity)
        initial_fitness = jnp.full((size,), -jnp.inf)
        
        # 3. Wrap it up
        return cls(genes=batched_genes, fitness=initial_fitness)