In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
import chex
from typing import Any, Type, TypeVar, Generic, ClassVar, Union, Iterator, List, Optional, Tuple

# --- Types & Constants ---
G = TypeVar("G", bound="BaseGenome")

class DistanceMetric:
    HAMMING = "hamming"
    EUCLIDEAN = "euclidean"
    MANHATTAN = "manhattan"

# ==========================================
# 1. CONFIGURATION
# ==========================================
@struct.dataclass
class LinearGenomeConfig:
    length: int           # L: Number of instructions
    num_inputs: int       # N: Number of input features
    num_ops: int          # Number of available functions
    max_arity: int = 2    # Arguments per instruction

# ==========================================
# 2. BASE GENOME (Abstract Individual)
# ==========================================
@struct.dataclass
class BaseGenome:
    """
    Abstract base class for a Single Individual.
    """
    # --- Abstract Interface ---
    @classmethod
    def random_init(cls: Type[G], key: chex.PRNGKey, config: Any) -> G:
        """Create ONE random genome."""
        raise NotImplementedError

    def distance(self, other: "BaseGenome", metric: str) -> float:
        raise NotImplementedError

    def autocorrect(self, config: Any) -> "BaseGenome":
        raise NotImplementedError

    @property
    def size(self) -> int:
        raise NotImplementedError

    # --- Shared Logic (Vectorization) ---
    @classmethod
    def create_population(cls: Type[G], key: chex.PRNGKey, config: Any, pop_size: int) -> G:
        """
        Vectorizes random_init to create a raw batch of genomes.
        Returns object with shape (Pop, ...).
        """
        keys = jax.random.split(key, pop_size)
        return jax.vmap(cls.random_init, in_axes=(0, None))(keys, config)


# ==========================================
# 4. BASE POPULATION (Abstract Container)
# ==========================================
@struct.dataclass
class BasePopulation(Generic[G]):
    """
    Abstract Container. 
    Implements List-like behavior and automated Vectorization.
    """
    genes: G
    fitness: chex.Array
    
    GENOME_CLS: ClassVar[Type[G]]

    # --- The "Kebab" Interface (List Behavior) ---
    def __len__(self) -> int:
        return int(self.fitness.shape[0])

    def __getitem__(self, key: Union[int, slice, chex.Array]) -> Union[G, "BasePopulation[G]"]:
        # Slice the genes Pytree automatically
        sliced_genes = jax.tree_util.tree_map(lambda x: x[key], self.genes)
        
        if isinstance(key, int):
            # Return Single Genome
            return sliced_genes
        else:
            # Return Sliced Population
            return self.replace(genes=sliced_genes, fitness=self.fitness[key])

    def __iter__(self) -> Iterator[G]:
        for i in range(len(self)):
            yield self[i]

    # --- Automated Logic (Vectorized) ---
    def autocorrect(self, config: Any) -> "BasePopulation[G]":
        """Autocorrects entire batch."""
        new_genes = jax.vmap(lambda g: g.autocorrect(config))(self.genes)
        return self.replace(genes=new_genes)

    def distance_matrix(self, metric: str = "hamming") -> chex.Array:
        """Computes N x N distance matrix."""
        pair_fn = lambda g1, g2: g1.distance(g2, metric)
        return jax.vmap(jax.vmap(pair_fn, in_axes=(None, 0)), in_axes=(0, None))(self.genes, self.genes)



In [2]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from functools import partial
from typing import Any, Tuple, Callable, TypeVar, Generic

# --- B. The Abstract Evaluator ---
G = TypeVar("G", bound="BaseGenome")
C = TypeVar("C") # Config Type
D = TypeVar("D") # Data Type

@struct.dataclass
class BaseEvaluator(Generic[G, C, D]):
    config: C

    def evaluate(self, genome: G, data: D) -> float:
        """Abstract: Returns scalar fitness for one genome."""
        raise NotImplementedError

    def evaluate_population(self, population: "BasePopulation[G]", data: D) -> "BasePopulation[G]":
        """Auto-Vectorization: Pop x Data -> Pop (with Fitness)"""
        # vmap over genes (axis 0), keep self/data constant
        fitness_scores = jax.vmap(self.evaluate, in_axes=(0, None))(population.genes, data)
        return population.replace(fitness=fitness_scores)

# --- C. The Concrete Linear Evaluator ---
# Data Type: Tuple(Inputs, Targets)


In [3]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Any, TypeVar, Generic

# Generic Types
G = TypeVar("G", bound="BaseGenome")
C = TypeVar("C")

# ==========================================
# 1. ABSTRACT BASE: MUTATION
# ==========================================
@struct.dataclass
class BaseMutation(Generic[G, C]):
    """
    Abstract Mutation Functor.
    """
    # --- STATIC PARAMS (Re-compile if changed) ---
    # Determines the output shape: (N_Offspring, ...)
    num_offspring: int = struct.field(pytree_node=False, default=1)

    def __call__(self, key: chex.PRNGKey, genome: G, config: C) -> G:
        """
        Applies mutation to produce 'num_offspring' children.
        Output Shape: (Num_Offspring, Genome_Size...)
        """
        # 1. Split keys for the static number of children
        keys = jax.random.split(key, self.num_offspring)
        
        # 2. Vectorize the single mutation logic
        # vmap over keys (axis 0), keep genome & config constant
        return jax.vmap(
            lambda k, g, c: self._mutate_one(k, g, c), 
            in_axes=(0, None, None)
        )(keys, genome, config)

    def _mutate_one(self, key: chex.PRNGKey, genome: G, config: C) -> G:
        """Abstract: Logic to produce EXACTLY ONE mutant."""
        raise NotImplementedError


# ==========================================
# 2. ABSTRACT BASE: CROSSOVER
# ==========================================
@struct.dataclass
class BaseCrossover(Generic[G, C]):
    """
    Abstract Crossover Functor.
    """
    # --- STATIC PARAMS (Re-compile if changed) ---
    num_offspring: int = struct.field(pytree_node=False, default=1)

    def __call__(self, key: chex.PRNGKey, p1: G, p2: G, config: C) -> G:
        """
        Combines two parents to produce 'num_offspring' children.
        Output Shape: (Num_Offspring, Genome_Size...)
        """
        keys = jax.random.split(key, self.num_offspring)
        
        # vmap over keys (axis 0), keep parents & config constant
        return jax.vmap(
            lambda k, a, b, c: self._cross_one(k, a, b, c),
            in_axes=(0, None, None, None)
        )(keys, p1, p2, config)

    def _cross_one(self, key: chex.PRNGKey, p1: G, p2: G, config: C) -> G:
        """Abstract: Logic to produce EXACTLY ONE child."""
        raise NotImplementedError


In [4]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Generic

# Generic Genome Type not needed here, strictly math!

@struct.dataclass
class BaseSelection:
    """
    Abstract Selection Operator.
    Operates purely on Fitness Arrays to return Indices.
    """
    # STATIC: How many parents do we want to pick?
    num_selections: int = struct.field(pytree_node=False)

    def __call__(self, key: chex.PRNGKey, fitness: chex.Array) -> chex.Array:
        """
        Args:
            key: RNG Key
            fitness: Shape (Pop_Size, ...) - Can be 1D or 2D (Symbiotic)
            
        Returns:
            Selected Indices: Shape (num_selections,) int32
        """
        raise NotImplementedError

In [5]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Any, Tuple

# --- CONFIGURATION ---
@struct.dataclass
class SantaConfig:
    num_trees: int        # N
    min_pos: float = -100.0
    max_pos: float = 100.0
    
    # Precision scaling (optional, but good for stability)
    scale_factor: float = 1.0 

# --- THE CONTINUOUS GENOME ---
@struct.dataclass
class ContinuousGenome(BaseGenome):
    # Shape: (Num_Trees, 3) -> [x, y, theta]
    values: chex.Array  

    @classmethod
    def random_init(cls, key: chex.PRNGKey, config: SantaConfig) -> "ContinuousGenome":
        k1, k2 = jax.random.split(key)
        
        # 1. Generate X, Y in [-100, 100]
        pos = jax.random.uniform(
            k1, shape=(config.num_trees, 2), 
            minval=config.min_pos, maxval=config.max_pos
        )
        
        # 2. Generate Angle in [0, 360)
        deg = jax.random.uniform(
            k2, shape=(config.num_trees, 1), 
            minval=0.0, maxval=360.0
        )
        
        # 3. Combine
        values = jnp.concatenate([pos, deg], axis=1)
        return cls(values=values)

    def autocorrect(self, config: SantaConfig) -> "ContinuousGenome":
        """
        Enforces constraints:
        1. X, Y must be inside bounding box.
        2. Degrees must be wrapped modulo 360 (Periodic).
        """
        x = self.values[:, 0]
        y = self.values[:, 1]
        deg = self.values[:, 2]
        
        # Clip Position
        x = jnp.clip(x, config.min_pos, config.max_pos)
        y = jnp.clip(y, config.min_pos, config.max_pos)
        
        # Wrap Angle (0 to 360)
        deg = deg % 360.0
        
        # Reconstruct
        new_values = jnp.stack([x, y, deg], axis=1)
        return self.replace(values=new_values)

    def distance(self, other: "ContinuousGenome", metric: str = "euclidean") -> float:
        """Euclidean distance in the 3D configuration space."""
        diff = self.values - other.values
        # Note: Angle distance should technically be circular (min(|a-b|, 360-|a-b|))
        # But simple Euclidean is fine for a rough diversity metric.
        return jnp.linalg.norm(diff)

    @property
    def size(self) -> int:
        return self.values.shape[0]

    def __repr__(self):
        return f"<SantaGenome(N={self.values.shape[0]}, device={self.values.device})>"

In [6]:
@struct.dataclass
class GaussianMutation(BaseMutation[ContinuousGenome, SantaConfig]):
    # STATIC
    num_offspring: int = struct.field(pytree_node=False)
    
    # DYNAMIC (The Sigma / Step Size)
    pos_sigma: float = 1.0   # Move trees by ~1.0 unit
    deg_sigma: float = 5.0   # Rotate trees by ~5.0 degrees
    prob: float = 0.5        # Probability to move a specific tree

    def _mutate_one(self, key: chex.PRNGKey, genome: ContinuousGenome, config: SantaConfig) -> ContinuousGenome:
        k_mask, k_pos, k_deg = jax.random.split(key, 3)
        
        # 1. Mask: Which trees do we move?
        # Shape: (N,) broadcasted to (N, 1)
        mask = jax.random.bernoulli(k_mask, self.prob, (config.num_trees, 1))
        
        # 2. Noise Generation
        noise_pos = jax.random.normal(k_pos, (config.num_trees, 2)) * self.pos_sigma
        noise_deg = jax.random.normal(k_deg, (config.num_trees, 1)) * self.deg_sigma
        
        full_noise = jnp.concatenate([noise_pos, noise_deg], axis=1)
        
        # 3. Apply
        # New = Old + (Mask * Noise)
        new_values = genome.values + (mask * full_noise)
        
        # 4. Repair (Clip bounds, Wrap angles)
        return genome.replace(values=new_values).autocorrect(config)

In [7]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Tuple, Any

# ==========================================
# 1. CONFIGURATION
# ==========================================
@struct.dataclass
class SantaConfig:
    num_trees: int
    min_pos: float = -100.0
    max_pos: float = 100.0
    scale_factor: float = 1.0

# ==========================================
# 2. THE GEOMETRIC ENGINE
# ==========================================
@struct.dataclass
class JaxSantaEvaluator(BaseEvaluator[ContinuousGenome, SantaConfig, Any]):
    """
    A Differentiable Geometric Evaluator for JAX.
    Replaces Shapely with Vectorized Linear Algebra.
    """
    # STATIC: The Local Coordinates of the Tree Polygon (V, 2)
    # We bake these into the class so they are constant constants.
    base_vertices: chex.Array = struct.field(pytree_node=True)

    @classmethod
    def create(cls, config: SantaConfig) -> "JaxSantaEvaluator":
        """Factory to initialize the tree shape."""
        # Define the exact vertices from the Santa 2025 code
        # (Normalized, without the massive 1e18 scale for now)
        trunk_w, trunk_h = 0.15, 0.2
        base_w, mid_w, top_w = 0.7, 0.4, 0.25
        tip_y, tier1_y, tier2_y, base_y = 0.8, 0.5, 0.25, 0.0
        trunk_btm = -trunk_h

        # Shape definition (Counter-Clockwise winding)
        verts = [
            [0.0, tip_y],                          # Tip
            [top_w/2, tier1_y], [top_w/4, tier1_y], # Top Tier Right
            [mid_w/2, tier2_y], [mid_w/4, tier2_y], # Mid Tier Right
            [base_w/2, base_y],                    # Base Right
            [trunk_w/2, base_y], [trunk_w/2, trunk_btm], # Trunk Right
            [-trunk_w/2, trunk_btm], [-trunk_w/2, base_y], # Trunk Left
            [-base_w/2, base_y],                   # Base Left
            [-mid_w/4, tier2_y], [-mid_w/2, tier2_y], # Mid Tier Left
            [-top_w/4, tier1_y], [-top_w/2, tier1_y]  # Top Tier Left
        ]
        return cls(config=config, base_vertices=jnp.array(verts))

    # --- CORE MATH (Differentiable) ---

    def transform_vertices(self, genome: ContinuousGenome) -> chex.Array:
        """
        Transforms local tree vertices to global world coordinates.
        Output: (Num_Trees, Num_Vertices, 2)
        """
        # 1. Unpack Genome
        # pos: (N, 2), theta: (N, 1) - degrees
        pos = genome.values[:, :2]
        theta_deg = genome.values[:, 2:3]
        theta_rad = jnp.deg2rad(theta_deg)

        # 2. Build Rotation Matrices (N, 2, 2)
        # [[cos, -sin], [sin, cos]]
        c = jnp.cos(theta_rad)
        s = jnp.sin(theta_rad)
        # Construct rotation matrix for each tree
        # Shape manipulation to get (N, 2, 2)
        row1 = jnp.concatenate([c, -s], axis=1) # (N, 2)
        row2 = jnp.concatenate([s, c], axis=1)  # (N, 2)
        rot_mats = jnp.stack([row1, row2], axis=1) 

        # 3. Apply Rotation
        # Vertices (V, 2) -> (2, V) for matmul
        v_T = self.base_vertices.T 
        # (N, 2, 2) @ (2, V) -> (N, 2, V)
        rotated_T = rot_mats @ v_T
        # Transpose back -> (N, V, 2)
        rotated = jnp.transpose(rotated_T, (0, 2, 1))

        # 4. Apply Translation
        # Add position (N, 1, 2) broadcasted to (N, V, 2)
        global_verts = rotated + pos[:, None, :]
        
        return global_verts

    def calculate_bounding_score(self, vertices: chex.Array) -> float:
        """
        Score = (Max_Dim_Side^2) / N
        """
        # Flatten all vertices to find global min/max
        # vertices shape: (N, V, 2) -> (N*V, 2)
        all_points = vertices.reshape(-1, 2)
        
        min_xy = jnp.min(all_points, axis=0)
        max_xy = jnp.max(all_points, axis=0)
        
        diffs = max_xy - min_xy
        side_x, side_y = diffs[0], diffs[1]
        
        # Max dimension squared (Bounding Square Area)
        max_side = jnp.maximum(side_x, side_y)
        area = max_side ** 2
        
        return area / self.config.num_trees

    def check_collisions_simple(self, genome: ContinuousGenome) -> float:
        """
        A fast heuristic collision check using Bounding Circles.
        Returns: Penalty score (0.0 = clean, >0 = collision)
        """
        # Simple Logic: 
        # If distance between centers < (Radius A + Radius B), overlap likely.
        # Tree approximate radius ~ 0.5 (Base width 0.7, Height ~1.0)
        TREE_RADIUS = 0.45 
        
        pos = genome.values[:, :2] # (N, 2)
        
        # Calculate pairwise distances (N, N)
        # We use squared distance to avoid Sqrt derivative issues at 0
        diff = pos[:, None, :] - pos[None, :, :]
        dist_sq = jnp.sum(diff**2, axis=-1)
        
        # Threshold: (2 * Radius)^2
        min_dist_sq = (2 * TREE_RADIUS) ** 2
        
        # Mask out self-collision (diagonal)
        # Add a large number to diagonal so it's always > min_dist
        eye = jnp.eye(self.config.num_trees) * 1e6
        dist_sq_masked = dist_sq + eye
        
        # Violation: Max(0, Threshold - Actual)
        # If Actual < Threshold, we add penalty
        penalties = jax.nn.relu(min_dist_sq - dist_sq_masked)
        
        return jnp.sum(penalties)

    # --- IMPLEMENT ABSTRACT EVALUATE ---
    
    def evaluate(self, genome: ContinuousGenome, data: Any = None) -> float:
        """
        Combined Objective Function.
        Fitness = - (Score + Collision_Penalty * 1000)
        """
        # 1. Transform Geometry
        verts = self.transform_vertices(genome)
        
        # 2. Objective: Smallest Box
        score = self.calculate_bounding_score(verts)
        
        # 3. Constraint: No Collisions
        # Note: Ideally replace 'simple' with SAT for final version
        penalty = self.check_collisions_simple(genome)
        
        # We minimize score, so fitness is negative
        # Heavy weight on penalty to force valid configurations
        total_loss = score + (penalty * 1000.0)
        
        return -total_loss

In [8]:
@struct.dataclass
class SantaPopulation(BasePopulation[ContinuousGenome]):
    """
    Concrete container for the Santa 2025 Optimization.
    Wraps a batch of ContinuousGenomes (Tree Configurations).
    """
    # 1. The Data
    # Shape inside genes: (Pop_Size, Num_Trees, 3)
    genes: ContinuousGenome
    
    # 2. The Scores
    # Shape: (Pop_Size,)
    fitness: chex.Array
    
    # 3. Link to the Genome Class (For the Factory)
    GENOME_CLS: ClassVar[Type[ContinuousGenome]] = ContinuousGenome

    @classmethod
    def init_random(cls, key: chex.PRNGKey, config: SantaConfig, size: int) -> "SantaPopulation":
        """
        Creates a population of random tree configurations.
        """
        # 1. Use the Base Factory (Vectorized)
        # Calls ContinuousGenome.create_population -> vmap(random_init)
        batched_genes = ContinuousGenome.create_population(key, config, size)
        
        # 2. Initialize Fitness (Negative Infinity)
        initial_fitness = jnp.full((size,), -jnp.inf)
        
        return cls(genes=batched_genes, fitness=initial_fitness)

In [9]:
# 1. Define the Challenge
config = SantaConfig(
    num_trees=50,       # We want to pack 50 trees
    min_pos=-100.0,     # The 100x100 meter bounds
    max_pos=100.0
)

# 2. Setup Randomness
master_key = jax.random.PRNGKey(2025)

# 3. Create the Instance
# This creates 10,000 distinct configurations of 50 trees each.
# Total Matrix Size: (10000, 50, 3) floats
xmas_tree_population = SantaPopulation.init_random(
    key=master_key, 
    config=config, 
    size=10000
)

# 4. Verify
print(f"Population Size: {len(xmas_tree_population)}")
print(f"Gene Shape: {xmas_tree_population.genes.values.shape}")
# Output: (10000, 50, 3)

Population Size: 10000
Gene Shape: (10000, 50, 3)


In [10]:
@struct.dataclass
class GaussianMutation(BaseMutation[ContinuousGenome, SantaConfig]):
    """
    Applies Gaussian noise to tree positions and rotations.
    """
    # STATIC
    num_offspring: int = struct.field(pytree_node=False, default=1)
    
    # DYNAMIC (Tunable Noise Levels)
    prob: float = 0.5        # Probability to nudge a specific tree
    pos_sigma: float = 2.0   # Standard deviation for X/Y (e.g., 2 meters)
    deg_sigma: float = 10.0  # Standard deviation for Angle (e.g., 10 degrees)

    def _mutate_one(self, key: chex.PRNGKey, genome: ContinuousGenome, config: SantaConfig) -> ContinuousGenome:
        k_mask, k_noise = jax.random.split(key)
        
        # 1. Mask: Which trees get nudged?
        # Shape: (N,) -> Broadcast to (N, 1) to cover x,y,theta columns
        mask = jax.random.bernoulli(k_mask, self.prob, (config.num_trees, 1))
        
        # 2. Noise: Generate different scales for Pos vs Angle
        # Noise Shape: (N, 3)
        # We split the key again to handle the noise generation
        noise_vals = jax.random.normal(k_noise, (config.num_trees, 3))
        
        # Scale columns independently: [x*pos_sigma, y*pos_sigma, deg*deg_sigma]
        # Create a scale vector: (1, 3)
        scales = jnp.array([[self.pos_sigma, self.pos_sigma, self.deg_sigma]])
        scaled_noise = noise_vals * scales
        
        # 3. Apply
        # New = Old + (Mask * Noise)
        new_values = genome.values + (mask * scaled_noise)
        
        # 4. Repair (Clip to bounds, wrap 360)
        return genome.replace(values=new_values).autocorrect(config)


@struct.dataclass
class ContinuousCrossover(BaseCrossover[ContinuousGenome, SantaConfig]):
    """
    Uniform Crossover for Continuous Genomes.
    Swaps entire trees (x, y, theta) between parents.
    """
    # STATIC
    num_offspring: int = struct.field(pytree_node=False, default=1)
    
    # DYNAMIC
    mixing_ratio: float = 0.5

    def _cross_one(self, key: chex.PRNGKey, p1: ContinuousGenome, p2: ContinuousGenome, config: SantaConfig) -> ContinuousGenome:
        # 1. Mask: Decide per tree (row)
        mask = jax.random.bernoulli(key, self.mixing_ratio, (config.num_trees,))
        
        # 2. Broadcast mask to columns (N,) -> (N, 3)
        mask_bcast = mask[:, None]
        
        # 3. Swap Rows
        child_values = jnp.where(mask_bcast, p2.values, p1.values)
        
        return p1.replace(values=child_values)

In [11]:
def run_continuous_variation_demo():
    print("=== Santa 2025 Operator Demo ===\n")
    
    # 1. SETUP (Tiny Config for readability)
    # 5 Trees, strict bounds [-10, 10] to test clipping easily
    config = SantaConfig(num_trees=5, min_pos=-10.0, max_pos=10.0)
    key = jax.random.PRNGKey(2025)

    # 2. INIT PARENT
    print("1. Creating Parent...")
    parent = ContinuousGenome.random_init(key, config)
    print(f"   Parent Values (First 2 trees):\n{parent.values[:2]}")

    # 3. TEST MUTATION (High Noise to force clipping)
    print("\n2. Applying High-Noise Mutation...")
    # pos_sigma=20.0 means noise can easily push values outside [-10, 10]
    mutator = GaussianMutation(num_offspring=1, prob=1.0, pos_sigma=20.0, deg_sigma=45.0)
    
    @jax.jit
    def run_mut(k, g): return mutator(k, g, config)
    
    # We use slice [0] because num_offspring=1 returns a batch of 1
    mutant = jax.tree_util.tree_map(lambda x: x[0], run_mut(jax.random.PRNGKey(1), parent))
    
    print(f"   Mutant Values:\n{mutant.values[:2]}")
    
    # VERIFY BOUNDS
    max_val = jnp.max(jnp.abs(mutant.values[:, :2]))
    print(f"   Max Position Absolute Value: {max_val:.2f} (Should be <= 10.0)")
    assert max_val <= 10.001, "Autocorrect Failed! Bounds violated."
    print("   [PASS] Autocorrect enforced bounds.")

    # 4. TEST CROSSOVER
    print("\n3. Applying Crossover...")
    parent2 = ContinuousGenome.random_init(jax.random.PRNGKey(2), config)
    crossover = ContinuousCrossover(num_offspring=1, mixing_ratio=0.5)
    
    @jax.jit
    def run_cross(k, a, b): return crossover(k, a, b, config)
    
    child = jax.tree_util.tree_map(lambda x: x[0], run_cross(jax.random.PRNGKey(3), parent, parent2))
    
    # Check if we have a mix
    # We count how many rows match P1 exactly
    matches_p1 = jnp.sum(jnp.all(child.values == parent.values, axis=1))
    print(f"   Rows inherited from Parent 1: {matches_p1} / 5")
    print("   [PASS] Crossover mixed parents.")

if __name__ == "__main__":
    run_continuous_variation_demo()

=== Santa 2025 Operator Demo ===

1. Creating Parent...
   Parent Values (First 2 trees):
[[-2.1969557e-01 -5.7836461e+00  3.3885193e+02]
 [-6.0277677e+00  3.3743358e+00  1.6650845e+02]]

2. Applying High-Noise Mutation...
   Mutant Values:
[[ 10.       -10.       300.43704 ]
 [-10.        -9.035362 209.87344 ]]
   Max Position Absolute Value: 10.00 (Should be <= 10.0)
   [PASS] Autocorrect enforced bounds.

3. Applying Crossover...
   Rows inherited from Parent 1: 1 / 5
   [PASS] Crossover mixed parents.


In [12]:
@struct.dataclass
class TournamentSelection(BaseSelection):
    """
    Standard Tournament Selection.
    Selects 'num_selections' parents by holding 'num_selections' tournaments.
    """
    # STATIC CONFIG
    tournament_size: int = struct.field(pytree_node=False, default=3)

    def __call__(self, key: chex.PRNGKey, fitness: chex.Array) -> chex.Array:
        """
        Args:
            key: RNG Key
            fitness: Shape (Pop_Size,) - Scalar fitness per individual
            
        Returns:
            indices: Shape (num_selections,) - Indices of the winning GENOMES.
        """
        pop_size = fitness.shape[0]
        
        # 1. Generate Random Contenders
        # Shape: (Num_Selections, Tournament_Size)
        # We need 'num_selections' tournaments, each with 'k' participants
        k_tourn = jax.random.split(key)[0]
        contestants = jax.random.randint(
            k_tourn, 
            shape=(self.num_selections, self.tournament_size), 
            minval=0, 
            maxval=pop_size
        )
        
        # 2. Look up Fitness
        # Gather the scores of every contestant
        # Shape: (Num_Selections, Tournament_Size)
        scores = jnp.take(fitness, contestants)
        
        # 3. Find Winners
        # Argmax gives us the LOCAL index (0..k-1) of the winner in each tournament
        winner_local_idx = jnp.argmax(scores, axis=1)
        
        # 4. Map back to Global Index
        # We select the actual Population Index of the winner
        winner_indices = jax.vmap(
            lambda row, idx: row[idx]
        )(contestants, winner_local_idx)
        
        return winner_indices

In [13]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Type, TypeVar, Generic, Any

# Generic variable for the Genome (Linear, Continuous, Tree...)
G = TypeVar("G", bound="BaseGenome")

@struct.dataclass
class EvolutionState(Generic[G]):
    """
    The State of the World.
    This Pytree is passed in/out of the JAX scan loop.
    """
    # 1. The Current Generation
    # Contains (N) Genomes and (N) Fitness scores
    population: "BasePopulation[G]"

    # 2. The Hall of Fame (Global Best)
    # We track this separately because the current population might 
    # drift away from the optimum due to high mutation/exploration.
    # This is a SINGLE genome (unbatched).
    best_genome: G
    best_fitness: chex.Array  # Scalar (float32)

    # 3. Randomness State
    # The key for the *next* generation's operations
    key: chex.PRNGKey

    # 4. Metadata
    gen_counter: chex.Array   # Scalar (int32)

    @classmethod
    def create(
        cls, 
        pop_class: Type["BasePopulation[G]"], 
        config: Any, 
        key: chex.PRNGKey, 
        pop_size: int
    ) -> "EvolutionState[G]":
        """
        Factory to initialize the state at Generation 0.
        """
        k_pop, k_state = jax.random.split(key)

        # A. Initialize Random Population
        # We rely on the Abstract Factory we built in BasePopulation
        initial_pop = pop_class.init_random(k_pop, config, pop_size)

        # B. Initialize Global Best (Placeholder)
        # We grab the first individual from the batch to serve as the placeholder "best".
        # We rely on the 'Kebab' __getitem__[0] we implemented earlier.
        first_genome = initial_pop[0]

        return cls(
            population=initial_pop,
            best_genome=first_genome,
            # Start with -inf so the first evaluation will overwrite this
            best_fitness=jnp.array(-jnp.inf, dtype=jnp.float32),
            key=k_state,
            gen_counter=jnp.array(0, dtype=jnp.int32)
        )

In [16]:
import jax
import jax.numpy as jnp
from flax import struct
import chex
from typing import Any, Tuple
import time

# --- ASSUMING PREVIOUS CLASSES ARE IMPORTED ---
# (In a real project, these would be: from core import ...)
# For this script to be standalone copy-paste, I will include the missing 
# GeneticEngine wiring we discussed.

@struct.dataclass
class GeneticEngine:
    # 1. The Logic Units
    evaluator: Any # Typed as BaseEvaluator
    selection: Any # Typed as BaseSelection
    crossover: Any # Typed as BaseCrossover
    mutation: Any  # Typed as BaseMutation
    
    # 2. Hyperparameters
    elitism_rate: float = 0.05

    def init_state(self, config: Any, key: chex.PRNGKey, pop_size: int) -> EvolutionState:
        return EvolutionState.create(SantaPopulation, config, key, pop_size)

    @jax.jit
    def step(self, state: EvolutionState) -> EvolutionState:
        """
        One Generation Step: Select -> Cross -> Mutate -> Evaluate -> Merge
        """
        pop = state.population
        pop_size = pop.fitness.shape[0]
        key = state.key
        
        # 1. Manage Randomness
        k_sel, k_cross, k_mut, k_next = jax.random.split(key, 4)

        # 2. Elitism (Preserve top 5%)
        num_elites = int(pop_size * self.elitism_rate)
        # We need negative fitness for top_k because we want largest values (least negative)
        _, elite_indices = jax.lax.top_k(pop.fitness, num_elites)
        elites = pop[elite_indices]

        # 3. Selection (Tournament)
        # We perform tournament selection to find parents for the next generation
        parent_indices = self.selection(k_sel, pop.fitness)
        parents = pop[parent_indices]

        # 4. Crossover (Mix Parents)
        # Shuffle parents to create pairs
        p1 = parents
        p2_indices = jax.random.permutation(k_cross, jnp.arange(pop_size))
        p2 = parents[p2_indices]
        
        offspring_genes = self.crossover(k_cross, p1.genes, p2.genes, self.evaluator.config)
        # Flatten batch dimension if necessary (here assuming num_offspring=1)
        offspring_genes = jax.tree_util.tree_map(lambda x: x.squeeze(1), offspring_genes)

        # 5. Mutation (Add Noise)
        # Mutate the crossed-over children
        mutant_genes = self.mutation(k_mut, offspring_genes, self.evaluator.config)
        mutant_genes = jax.tree_util.tree_map(lambda x: x.squeeze(1), mutant_genes)

        # 6. Merge & Truncate
        # We have Pop_Size mutants, but we need to make space for Elites
        num_mutants_needed = pop_size - num_elites
        
        # Take the first N mutants
        mutants_to_keep = jax.tree_util.tree_map(lambda x: x[:num_mutants_needed], mutant_genes)
        
        # Concatenate Elites + Mutants
        new_genes = jax.tree_util.tree_map(
            lambda e, m: jnp.concatenate([e, m], axis=0),
            elites.genes.values, # Access inner array
            mutants_to_keep.values
        )
        # Wrap back into Genome object
        new_genome_batch = ContinuousGenome(values=new_genes)
        
        # Create new Population container (Fitness not calculated yet)
        next_gen_pop = SantaPopulation(
            genes=new_genome_batch, 
            fitness=jnp.full((pop_size,), -jnp.inf)
        )

        # 7. Evaluate
        evaluated_pop = self.evaluator.evaluate_population(next_gen_pop, data=None)

        # 8. Update Hall of Fame (Global Best)
        best_idx = jnp.argmax(evaluated_pop.fitness)
        current_best_fit = evaluated_pop.fitness[best_idx]
        current_best_genome = evaluated_pop[best_idx]

        is_new_record = current_best_fit > state.best_fitness
        
        new_best_genome = jax.tree_util.tree_map(
            lambda old, new: jnp.where(is_new_record, new, old),
            state.best_genome,
            current_best_genome
        )
        new_best_fit = jnp.maximum(state.best_fitness, current_best_fit)

        return EvolutionState(
            population=evaluated_pop,
            best_genome=new_best_genome,
            best_fitness=new_best_fit,
            key=k_next,
            gen_counter=state.gen_counter + 1
        )

# --- EXECUTION SCRIPT ---

print("=== ðŸŽ… Santa 2025: Continuous Optimization Solver ===\n")

# 1. SETUP
# Challenge: Pack 50 trees into the smallest square possible
config = SantaConfig(num_trees=50, min_pos=-100.0, max_pos=100.0)

# Engine Hyperparameters
POP_SIZE = 2000       # 2,000 Parallel Configurations
GENERATIONS = 1000    # 1,000 Iterations

print(f"Configuration: {config.num_trees} Trees")
print(f"Population:    {POP_SIZE}")
print(f"Generations:   {GENERATIONS}")
print("-" * 40)

# 2. INITIALIZE COMPONENTS
# Evaluator (The Geometry Engine)
evaluator = JaxSantaEvaluator.create(config)

# Operators (The Evolution Logic)
# Using 'Explosive' parameters? No, let's use standard evolution first.
selection = TournamentSelection(num_selections=POP_SIZE, tournament_size=5)
crossover = ContinuousCrossover(num_offspring=1, mixing_ratio=0.5)
# Adaptive Mutation: 50% chance to move a tree, sigma=1.0m, rotation=5deg
mutation = GaussianMutation(num_offspring=1, prob=0.5, pos_sigma=1.0, deg_sigma=5.0)

engine = GeneticEngine(
    evaluator=evaluator,
    selection=selection,
    crossover=crossover,
    mutation=mutation,
    elitism_rate=0.05
)

# 3. INITIALIZE STATE
print("Initializing Population on GPU...", end=" ")
master_key = jax.random.PRNGKey(2025)
state = engine.init_state(config, master_key, POP_SIZE)

# Run one evaluation to get initial best
state = state.replace(population=evaluator.evaluate_population(state.population, None))
# Update best manually for gen 0
best_idx = jnp.argmax(state.population.fitness)
state = state.replace(
    best_genome=state.population[best_idx],
    best_fitness=state.population.fitness[best_idx]
)
print(f"Done.\nInitial Score: {-state.best_fitness:.4f}")

# 4. COMPILE THE LOOP
print("Compiling Evolution Loop (XLA)...", end=" ")

# We use scan to run the loop entirely on the accelerator
def scan_body(carry_state, _):
    new_state = engine.step(carry_state)
    # Log metrics (we return them to host)
    return new_state, new_state.best_fitness

# Run compilation
start_compile = time.time()
final_state, history = jax.lax.scan(scan_body, state, None, length=GENERATIONS)
# Block until ready to measure compile time
_ = final_state.best_fitness.block_until_ready()
print(f"Done ({time.time() - start_compile:.2f}s)")

# 5. ANALYSIS
print("-" * 40)
print("Optimization Finished.")

final_score = -final_state.best_fitness
print(f"Final Score: {final_score:.4f}")

# Calculate improvement
initial_score = -state.best_fitness
improvement = ((initial_score - final_score) / initial_score) * 100
print(f"Improvement: {improvement:.2f}%")

# 6. VISUALIZATION (Text Based)
best_genome = final_state.best_genome
print("\nBest Configuration (First 3 Trees):")
print("   X      |    Y     |  Angle")
print("-" * 30)
for i in range(3):
    row = best_genome.values[i]
    print(f"{row[0]:7.2f} | {row[1]:7.2f} | {row[2]:6.1f}Â°")



=== ðŸŽ… Santa 2025: Continuous Optimization Solver ===

Configuration: 50 Trees
Population:    2000
Generations:   1000
----------------------------------------
Initializing Population on GPU... Done.
Initial Score: 621.5767
Compiling Evolution Loop (XLA)... 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[]
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
The error occurred while tracing the function step at /var/folders/n8/08b2nd114jdfnsydb_4mj4fw0000gn/T/ipykernel_38319/581938587.py:27 for jit. This concrete value was not available in Python because it depends on the value of the argument self.elitism_rate.

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError