## Predator-Prey Cellular Automaton Model

Based on: Cattaneo, Dennunzio, and Farina (2006) - "A Full Cellular Automaton to Simulate Predator-Prey Systems"

This implementation extends the original model to support:
- 1 prey species
- Up to 2 predator species (user-configurable)
- Fully local cellular automaton rules (no Monte Carlo steps)
- Configurable initial populations

Rajnil Mukherjee
\
MS21213
\
Department of Physical Sciences

### The Framework

In [3]:
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
%matplotlib notebook

from enum import IntEnum
from dataclasses import dataclass
from typing import Tuple, List, Optional
import seaborn as sns 

from scipy.optimize import minimize
from typing import Optional
from datetime import datetime
import os

In [4]:
# ============================================================================
# CONFIGURATION AND CONSTANTS
# ============================================================================

class CellState(IntEnum):
    """
    Cell states in the automaton:
    - EMPTY: No organism
    - PREY: Contains a prey organism
    - PREDATOR1: Contains a predator of type 1
    - PREDATOR2: Contains a predator of type 2
    - EMPTY_AFTER_ATTACK: Temporary state after prey is eaten
    - PREDATOR1_FED: Temporary state for predator 1 that successfully hunted
    - PREDATOR2_FED: Temporary state for predator 2 that successfully hunted
    """
    EMPTY = 0
    PREY = 1
    PREDATOR1 = 2
    PREDATOR2 = 3
    EMPTY_AFTER_ATTACK = 4
    PREDATOR1_FED = 5
    PREDATOR2_FED = 6


@dataclass
class SpeciesParameters:
    """
    Parameters for each species in the simulation.
    
    Attributes:
        birth_prob: Probability of reproduction (0 to 1)
        death_prob: Probability of natural death (0 to 1)
        hunt_success_prob: Probability of successful hunt per prey in neighborhood
    """
    birth_prob: float
    death_prob: float
    hunt_success_prob: float = 0.0  # Only for predators


@dataclass
class SimulationConfig:
    """
    Configuration for the entire simulation.
    
    Attributes:
        grid_size: Tuple of (rows, cols) for the lattice
        num_predator_types: Number of predator types (1 or 2)
        initial_prey: Initial number of prey organisms
        initial_predator1: Initial number of predator type 1
        initial_predator2: Initial number of predator type 2
        prey_params: Parameters for prey species
        predator1_params: Parameters for predator type 1
        predator2_params: Parameters for predator type 2
        movement_radius: Radius for Moore neighborhood in movement phase
        use_enhanced_model: Whether to use the enhanced model with oscillations
    """
    grid_size: Tuple[int, int]
    num_predator_types: int
    initial_prey: int
    initial_predator1: int
    initial_predator2: int
    prey_params: SpeciesParameters
    predator1_params: SpeciesParameters
    predator2_params: SpeciesParameters
    movement_radius: int = 2
    use_enhanced_model: bool = True
    enhancement_function: str = "cosine"  # "cosine" or "exponential"


# ============================================================================
# CELLULAR AUTOMATON CLASS
# ============================================================================

class PredatorPreyCA:
    """
    Main Cellular Automaton class for predator-prey simulation.
    
    Implements the full CA model with:
    1. Reaction phase (attack and reproduction)
    2. Movement phase (fully local, no Monte Carlo)
    """
    
    def __init__(self, config: SimulationConfig):
        """
        Initialize the cellular automaton with given configuration.
        
        Args:
            config: SimulationConfig object with all parameters
        """
        self.config = config
        self.rows, self.cols = config.grid_size
        self.grid = np.zeros((self.rows, self.cols), dtype=int)
        
        # Population tracking
        self.prey_history = []
        self.predator1_history = []
        self.predator2_history = []
        self.time_step = 0
        
        # Initialize the grid with random positions
        self._initialize_populations()
        
    def _initialize_populations(self):
        """
        Randomly place initial populations on the grid.
        Ensures no overlap between organisms (exclusion principle).
        """
        total_cells = self.rows * self.cols
        total_organisms = (self.config.initial_prey + 
                          self.config.initial_predator1 + 
                          self.config.initial_predator2)
        
        if total_organisms > total_cells:
            raise ValueError("Total organisms exceed available cells!")
        
        # Get random positions without replacement
        positions = np.random.choice(total_cells, total_organisms, replace=False)
        
        # Place prey
        prey_positions = positions[:self.config.initial_prey]
        for pos in prey_positions:
            row, col = divmod(pos, self.cols)
            self.grid[row, col] = CellState.PREY
        
        # Place predator 1
        pred1_start = self.config.initial_prey
        pred1_end = pred1_start + self.config.initial_predator1
        pred1_positions = positions[pred1_start:pred1_end]
        for pos in pred1_positions:
            row, col = divmod(pos, self.cols)
            self.grid[row, col] = CellState.PREDATOR1
        
        # Place predator 2 if enabled
        if self.config.num_predator_types == 2:
            pred2_positions = positions[pred1_end:]
            for pos in pred2_positions:
                row, col = divmod(pos, self.cols)
                self.grid[row, col] = CellState.PREDATOR2
    
    # ========================================================================
    # NEIGHBORHOOD FUNCTIONS
    # ========================================================================
    
    def _get_von_neumann_neighbors(self, row: int, col: int) -> List[Tuple[int, int]]:
        """
        Get Von Neumann neighborhood (4 nearest neighbors) with periodic boundaries.
        
        Args:
            row, col: Cell coordinates
            
        Returns:
            List of (row, col) tuples for neighbors
        """
        neighbors = [
            ((row - 1) % self.rows, col),  # North
            ((row + 1) % self.rows, col),  # South
            (row, (col - 1) % self.cols),  # West
            (row, (col + 1) % self.cols),  # East
        ]
        return neighbors
    
    def _get_moore_neighborhood_quadrants(self, row: int, col: int, radius: int) \
            -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """
        Get the four quadrants of Moore neighborhood (North, South, East, West).
        Used for determining movement direction.
        
        Args:
            row, col: Cell coordinates
            radius: Radius of Moore neighborhood
            
        Returns:
            Tuple of 4 arrays containing cell states in each quadrant
        """
        north_cells = []
        south_cells = []
        east_cells = []
        west_cells = []
        
        for i in range(1, radius + 1):
            for j in range(-i, i + 1):
                # North quadrant (above current row)
                r_north = (row - i) % self.rows
                c_north = (col + j) % self.cols
                north_cells.append(self.grid[r_north, c_north])
                
                # South quadrant (below current row)
                r_south = (row + i) % self.rows
                c_south = (col + j) % self.cols
                south_cells.append(self.grid[r_south, c_south])
                
                # East quadrant (right of current col)
                r_east = (row + j) % self.rows
                c_east = (col + i) % self.cols
                east_cells.append(self.grid[r_east, c_east])
                
                # West quadrant (left of current col)
                r_west = (row + j) % self.rows
                c_west = (col - i) % self.cols
                west_cells.append(self.grid[r_west, c_west])
        
        return (np.array(north_cells), np.array(south_cells), 
                np.array(east_cells), np.array(west_cells))
    
    def _count_neighbors(self, row: int, col: int, state: CellState, 
                        neighborhood: str = "von_neumann") -> int:
        """
        Count neighbors of a specific state in the neighborhood.
        
        Args:
            row, col: Cell coordinates
            state: CellState to count
            neighborhood: "von_neumann" or "moore"
            
        Returns:
            Count of neighbors with specified state
        """
        if neighborhood == "von_neumann":
            neighbors = self._get_von_neumann_neighbors(row, col)
            return sum(1 for r, c in neighbors if self.grid[r, c] == state)
        else:
            # Full Moore neighborhood
            count = 0
            for dr in [-1, 0, 1]:
                for dc in [-1, 0, 1]:
                    if dr == 0 and dc == 0:
                        continue
                    r = (row + dr) % self.rows
                    c = (col + dc) % self.cols
                    if self.grid[r, c] == state:
                        count += 1
            return count
    
    # ========================================================================
    # REACTION PHASE
    # ========================================================================
    
    def _attack_phase(self) -> np.ndarray:
        """
        Execute the attack sub-phase of the reaction step.
        
        Implements predator-prey interactions:
        - Prey can be killed by predators
        - Predators hunt prey based on neighborhood
        
        Returns:
            New grid state after attacks
        """
        new_grid = self.grid.copy()
        
        for row in range(self.rows):
            for col in range(self.cols):
                state = self.grid[row, col]
                
                # Process prey cells
                if state == CellState.PREY:
                    new_grid[row, col] = self._process_prey_attack(row, col)
                
                # Process predator 1 cells
                elif state == CellState.PREDATOR1:
                    new_grid[row, col] = self._process_predator_attack(
                        row, col, CellState.PREDATOR1,
                        self.config.predator1_params.hunt_success_prob
                    )
                
                # Process predator 2 cells
                elif state == CellState.PREDATOR2:
                    new_grid[row, col] = self._process_predator_attack(
                        row, col, CellState.PREDATOR2,
                        self.config.predator2_params.hunt_success_prob
                    )
        
        return new_grid
    
    def _process_prey_attack(self, row: int, col: int) -> CellState:
        """
        Process a prey cell during attack phase.
        """
        # Count predators in Von Neumann neighborhood
        num_pred1 = self._count_neighbors(row, col, CellState.PREDATOR1)
        num_pred2 = self._count_neighbors(row, col, CellState.PREDATOR2)
        total_predators = num_pred1 + num_pred2
        
        # Case 1: No predators nearby
        if total_predators == 0:
            # Enhanced model: check for overcrowding
            if self.config.use_enhanced_model:
                return self._process_prey_overcrowding(row, col)
            return CellState.PREY
        
        # Case 2: Predators present - check if prey survives attack
        # Prey survival probability decreases with more predators
        # Formula from paper: (1 - dp)^npt
        survival_prob = (1 - self.config.predator1_params.hunt_success_prob) ** num_pred1
        survival_prob *= (1 - self.config.predator2_params.hunt_success_prob) ** num_pred2
        
        if np.random.random() < survival_prob:
            # Prey survived the attack
            # Also check overcrowding even if survived attack!
            if self.config.use_enhanced_model:
                # Check if prey dies from overcrowding despite surviving predation
                return self._process_prey_overcrowding(row, col)
            return CellState.PREY
        else:
            # Prey killed by predator
            return CellState.EMPTY_AFTER_ATTACK
    
    def _process_prey_overcrowding(self, row: int, col: int) -> CellState:
        """
        Enhanced model: prey can die from overcrowding when no predators present.
        This helps create oscillatory dynamics similar to logistic equation.
        
        Args:
            row, col: Prey cell coordinates
            
        Returns:
            New state for the cell
        """
        # Count prey in Moore neighborhood of radius r
        radius = self.config.movement_radius
        total_prey = 0
        total_cells = (2 * radius + 1) ** 2
        
        for dr in range(-radius, radius + 1):
            for dc in range(-radius, radius + 1):
                r = (row + dr) % self.rows
                c = (col + dc) % self.cols
                if self.grid[r, c] == CellState.PREY:
                    total_prey += 1
        
        # Apply enhancement function
        bp = self.config.prey_params.birth_prob
        
        if self.config.enhancement_function == "cosine":
            f_bp = 1 - np.cos(np.pi / 2 * bp)
        else:  # exponential
            f_bp = 1 - np.exp(-np.e * bp)
        
        # Death probability increases with prey density
        death_prob = total_prey * f_bp / total_cells
        
        if np.random.random() < death_prob:
            return CellState.EMPTY_AFTER_ATTACK
        
        return CellState.PREY
    
    def _process_predator_attack(self, row: int, col: int, 
                                 predator_type: CellState,
                                 hunt_prob: float) -> CellState:
        """
        Process a predator cell during attack phase.
        
        Args:
            row, col: Predator cell coordinates
            predator_type: Type of predator (PREDATOR1 or PREDATOR2)
            hunt_prob: Hunting success probability
            
        Returns:
            New state for the cell
        """
        # Count prey in Von Neumann neighborhood
        num_prey = self._count_neighbors(row, col, CellState.PREY)
        
        if num_prey == 0:
            return predator_type
        
        # Hunt success probability increases with more prey
        hunt_success_prob = 1 - (1 - hunt_prob) ** num_prey
        
        if np.random.random() < hunt_success_prob:
            # Successfully hunted - mark as fed
            if predator_type == CellState.PREDATOR1:
                return CellState.PREDATOR1_FED
            else:
                return CellState.PREDATOR2_FED
        
        return predator_type
    
    def _reproduction_phase(self, grid: np.ndarray) -> np.ndarray:
        """
        Execute the reproduction sub-phase of the reaction step.
        
        Handles:
        - Prey reproduction in empty cells
        - Predator reproduction after successful hunt
        - Natural death of predators
        
        Args:
            grid: Grid state after attack phase
            
        Returns:
            New grid state after reproduction
        """
        new_grid = grid.copy()
        
        for row in range(self.rows):
            for col in range(self.cols):
                state = grid[row, col]
                
                # Prey cells remain prey
                if state == CellState.PREY:
                    new_grid[row, col] = CellState.PREY
                
                # Predator 1 cells
                elif state in [CellState.PREDATOR1, CellState.PREDATOR1_FED]:
                    new_grid[row, col] = self._process_predator_reproduction(
                        row, col, CellState.PREDATOR1,
                        self.config.predator1_params.death_prob
                    )
                
                # Predator 2 cells
                elif state in [CellState.PREDATOR2, CellState.PREDATOR2_FED]:
                    new_grid[row, col] = self._process_predator_reproduction(
                        row, col, CellState.PREDATOR2,
                        self.config.predator2_params.death_prob
                    )
                
                # Empty cells and cells emptied by attack
                elif state in [CellState.EMPTY, CellState.EMPTY_AFTER_ATTACK]:
                    new_grid[row, col] = self._process_empty_cell(row, col, grid)
        
        return new_grid
    
    def _process_predator_reproduction(self, row: int, col: int,
                                      predator_type: CellState,
                                      death_prob: float) -> CellState:
        """
        Process predator reproduction and death.
        
        Args:
            row, col: Cell coordinates
            predator_type: Type of predator
            death_prob: Natural death probability
            
        Returns:
            New state for the cell
        """
        # Predator can die naturally
        if np.random.random() < death_prob:
            return CellState.EMPTY
        
        return predator_type
    
    def _process_empty_cell(self, row: int, col: int, grid: np.ndarray) -> CellState:
        """
        Process empty cell for potential reproduction.
        
        Args:
            row, col: Cell coordinates
            grid: Current grid state
            
        Returns:
            New state for the cell
        """
        state = grid[row, col]
        neighbors = self._get_von_neumann_neighbors(row, col)
        
        # Count different types of neighbors
        num_prey = sum(1 for r, c in neighbors if grid[r, c] == CellState.PREY)
        num_pred1 = sum(1 for r, c in neighbors 
                       if grid[r, c] in [CellState.PREDATOR1, CellState.PREDATOR1_FED])
        num_pred2 = sum(1 for r, c in neighbors 
                       if grid[r, c] in [CellState.PREDATOR2, CellState.PREDATOR2_FED])
        
        # If cell was emptied by attack (prey was eaten)
        if state == CellState.EMPTY_AFTER_ATTACK:
            # Count fed predators in neighborhood
            num_fed_pred1 = sum(1 for r, c in neighbors 
                              if grid[r, c] == CellState.PREDATOR1_FED)
            num_fed_pred2 = sum(1 for r, c in neighbors 
                              if grid[r, c] == CellState.PREDATOR2_FED)
            total_fed = num_fed_pred1 + num_fed_pred2
            
            if total_fed == 0:
                return CellState.EMPTY
            
            # Predator reproduction based on successful hunts
            # Predator 1 has priority if both types fed
            if num_fed_pred1 > 0:
                birth_prob = self.config.predator1_params.birth_prob
                if np.random.random() > (1 - birth_prob) ** num_fed_pred1:
                    return CellState.PREDATOR1
            
            if num_fed_pred2 > 0:
                birth_prob = self.config.predator2_params.birth_prob
                if np.random.random() > (1 - birth_prob) ** num_fed_pred2:
                    return CellState.PREDATOR2
            
            return CellState.EMPTY
        
        # Regular empty cell - can have prey reproduction
        if num_pred1 + num_pred2 > 0 or num_prey == 0:
            return CellState.EMPTY
        
        # Prey reproduction
        birth_prob = self.config.prey_params.birth_prob
        if np.random.random() > (1 - birth_prob) ** num_prey:
            return CellState.PREY
        
        return CellState.EMPTY
    
    # ========================================================================
    # MOVEMENT PHASE
    # ========================================================================
    
    def _movement_phase(self, grid: np.ndarray) -> np.ndarray:
        """
        Execute the movement phase - fully local, no Monte Carlo.
        
        Movement rules:
        - Prey move away from predators
        - Predators move toward prey
        - Movement is to adjacent Von Neumann neighbors only
        
        Args:
            grid: Grid state after reaction phase
            
        Returns:
            New grid state after movement
        """
        # Normalize fed predator states
        normalized_grid = grid.copy()
        normalized_grid[grid == CellState.PREDATOR1_FED] = CellState.PREDATOR1
        normalized_grid[grid == CellState.PREDATOR2_FED] = CellState.PREDATOR2
        normalized_grid[grid == CellState.EMPTY_AFTER_ATTACK] = CellState.EMPTY
        
        # Calculate movement intentions for all cells
        movement_intentions = {}
        
        for row in range(self.rows):
            for col in range(self.cols):
                state = normalized_grid[row, col]
                
                if state == CellState.EMPTY:
                    continue
                
                # Determine preferred direction
                direction = self._get_preferred_direction(row, col, state, normalized_grid)
                
                if direction is not None:
                    movement_intentions[(row, col)] = direction
        
        # Execute movements (resolve conflicts)
        new_grid = self._execute_movements(normalized_grid, movement_intentions)
        
        return new_grid
    
    def _get_preferred_direction(self, row: int, col: int, 
                                 state: CellState, grid: np.ndarray) -> Optional[Tuple[int, int]]:
        """
        Determine preferred movement direction for an organism.
        
        Args:
            row, col: Cell coordinates
            state: Current cell state
            grid: Current grid
            
        Returns:
            Target (row, col) or None if no movement
        """
        radius = self.config.movement_radius
        
        # Get quadrant counts
        north, south, east, west = self._get_moore_neighborhood_quadrants(row, col, radius)
        
        if state == CellState.PREY:
            # Prey moves away from predators
            pred_counts = {
                'north': np.sum((north == CellState.PREDATOR1) | (north == CellState.PREDATOR2)),
                'south': np.sum((south == CellState.PREDATOR1) | (south == CellState.PREDATOR2)),
                'east': np.sum((east == CellState.PREDATOR1) | (east == CellState.PREDATOR2)),
                'west': np.sum((west == CellState.PREDATOR1) | (west == CellState.PREDATOR2)),
            }
            
            # No predators nearby - no movement
            if sum(pred_counts.values()) == 0:
                return None
            
            # Move to quadrant with fewest predators
            min_pred = min(pred_counts.values())
            candidates = [d for d, c in pred_counts.items() if c == min_pred]
            
        else:  # Predator
            # Predators move toward prey
            prey_counts = {
                'north': np.sum(north == CellState.PREY),
                'south': np.sum(south == CellState.PREY),
                'east': np.sum(east == CellState.PREY),
                'west': np.sum(west == CellState.PREY),
            }
            
            # No prey nearby - no movement
            if sum(prey_counts.values()) == 0:
                return None
            
            # Move to quadrant with most prey
            max_prey = max(prey_counts.values())
            candidates = [d for d, c in prey_counts.items() if c == max_prey]
        
        # Choose random direction among candidates
        chosen_direction = np.random.choice(candidates)
        
        # Map direction to actual neighbor cell
        direction_map = {
            'north': ((row - 1) % self.rows, col),
            'south': ((row + 1) % self.rows, col),
            'east': (row, (col + 1) % self.cols),
            'west': (row, (col - 1) % self.cols),
        }
        
        return direction_map[chosen_direction]
    
    def _execute_movements(self, grid: np.ndarray, 
                          intentions: dict) -> np.ndarray:
        """
        Execute all movements, resolving conflicts.
        
        If multiple organisms want to move to the same cell,
        one is chosen randomly.
        
        Args:
            grid: Current grid state
            intentions: Dictionary mapping (row, col) -> target (row, col)
            
        Returns:
            New grid after movements
        """
        new_grid = np.zeros_like(grid)
        
        # Group intentions by target cell
        target_to_sources = {}
        for source, target in intentions.items():
            # Only move to empty cells
            if grid[target] == CellState.EMPTY:
                if target not in target_to_sources:
                    target_to_sources[target] = []
                target_to_sources[target].append(source)
        
        # Track which cells have moved
        moved = set()
        
        # Resolve movements
        for target, sources in target_to_sources.items():
            # Choose random organism if multiple want same cell
            chosen_source = sources[np.random.randint(len(sources))]
            new_grid[target] = grid[chosen_source]
            moved.add(chosen_source)
        
        # Copy non-moving organisms
        for row in range(self.rows):
            for col in range(self.cols):
                if (row, col) not in moved and grid[row, col] != CellState.EMPTY:
                    if new_grid[row, col] == CellState.EMPTY:  # Cell not already occupied
                        new_grid[row, col] = grid[row, col]
        
        return new_grid
    
    # ========================================================================
    # SIMULATION CONTROL
    # ========================================================================
    
    def step(self):
        """
        Execute one complete time step of the CA.
        
        Consists of:
        1. Attack phase
        2. Reproduction phase
        3. Movement phase
        """
        # Reaction phase
        grid_after_attack = self._attack_phase()
        grid_after_reproduction = self._reproduction_phase(grid_after_attack)
        
        # Movement phase
        self.grid = self._movement_phase(grid_after_reproduction)
        
        # Update statistics
        self._update_statistics()
        self.time_step += 1
    
    def _update_statistics(self):
        """
        Update population statistics for current time step.
        """
        prey_count = np.sum(self.grid == CellState.PREY)
        pred1_count = np.sum(self.grid == CellState.PREDATOR1)
        pred2_count = np.sum(self.grid == CellState.PREDATOR2)
        
        self.prey_history.append(prey_count)
        self.predator1_history.append(pred1_count)
        self.predator2_history.append(pred2_count)
    
    def run(self, num_steps: int, verbose: bool = True):
        """
        Run the simulation for a specified number of steps.
        
        Args:
            num_steps: Number of time steps to simulate
            verbose: Whether to print progress
        """
        for step in range(num_steps):
            self.step()
            
            if verbose and (step + 1) % 10 == 0:
                print(f"Step {step + 1}/{num_steps}: "
                      f"Prey={self.prey_history[-1]}, "
                      f"Pred1={self.predator1_history[-1]}, "
                      f"Pred2={self.predator2_history[-1]}")
    
    def get_population_stats(self) -> dict:
        """
        Get current population statistics.
        
        Returns:
            Dictionary with population counts
        """
        return {
            'prey': self.prey_history[-1] if self.prey_history else 0,
            'predator1': self.predator1_history[-1] if self.predator1_history else 0,
            'predator2': self.predator2_history[-1] if self.predator2_history else 0,
            'time_step': self.time_step
        }
    



### Visualisation Functions

In [6]:
"""
Enhanced Visualization Module for Predator-Prey Cellular Automaton
Includes live animation with GIF export capability
"""

# ============================================================================
# ENHANCED VISUALIZATION FUNCTIONS
# ============================================================================

def create_live_animation_with_gif(ca: PredatorPreyCA, 
                                   num_frames: int = 300,
                                   interval: int = 50,
                                   save_gif: bool = True,
                                   gif_filename: Optional[str] = None,
                                   dpi: int = 100) -> FuncAnimation:
    """
    Create a comprehensive live animation of the CA evolution and save as GIF.
    
    Shows:
    1. Spatial grid evolution (left panel)
    2. Population time series (top right)
    3. Phase space trajectory (bottom right)
    
    Args:
        ca: PredatorPreyCA instance (will be reset)
        num_frames: Number of frames in animation
        interval: Milliseconds between frames
        save_gif: Whether to save animation as GIF
        gif_filename: Custom filename for GIF (auto-generated if None)
        dpi: Resolution of the animation
        
    Returns:
        FuncAnimation object
    """
    print("\n" + "="*70)
    print("CREATING LIVE ANIMATION")
    print("="*70)
    
    # Reset CA to initial state
    initial_config = ca.config
    ca.__init__(initial_config)
    
    # Create output directory if saving
    if save_gif:
        os.makedirs("output", exist_ok=True)
        if gif_filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            gif_filename = f"output/predator_prey_animation_{timestamp}.gif"
        print(f"Animation will be saved to: {gif_filename}")
    
    # Setup figure with custom layout
    fig = plt.figure(figsize=(18, 9))
    gs = GridSpec(2, 2, figure=fig, width_ratios=[1.2, 1], height_ratios=[1, 1],
                  hspace=0.3, wspace=0.3)
    
    # Create subplots
    ax_grid = fig.add_subplot(gs[:, 0])  # Left: full height for grid
    ax_timeseries = fig.add_subplot(gs[0, 1])  # Top right: time series
    ax_phase = fig.add_subplot(gs[1, 1])  # Bottom right: phase space
    
    # ========================================================================
    # GRID VISUALIZATION SETUP
    # ========================================================================
    
    # Enhanced color map
    colors = ['#F0F0F0',  # Empty - light gray
              '#2ECC71',  # Prey - green
              '#E74C3C',  # Predator 1 - red
              '#F39C12']  # Predator 2 - orange
    
    cmap = plt.cm.colors.ListedColormap(colors)
    bounds = [0, 1, 2, 3, 4]
    norm = plt.cm.colors.BoundaryNorm(bounds, cmap.N)
    
    # Initial grid display
    display_grid = _prepare_display_grid(ca.grid)
    im = ax_grid.imshow(display_grid, cmap=cmap, norm=norm, 
                       interpolation='nearest', aspect='equal')
    
    # Grid styling
    ax_grid.set_title('Spatial Distribution', fontsize=14, fontweight='bold', pad=10)
    ax_grid.set_xlabel('X Position', fontsize=11)
    ax_grid.set_ylabel('Y Position', fontsize=11)
    
    # Add custom legend
    legend_elements = [
        mpatches.Patch(facecolor=colors[0], edgecolor='black', label='Empty'),
        mpatches.Patch(facecolor=colors[1], edgecolor='black', label='Prey'),
        mpatches.Patch(facecolor=colors[2], edgecolor='black', label='Predator 1')
    ]
    if ca.config.num_predator_types == 2:
        legend_elements.append(
            mpatches.Patch(facecolor=colors[3], edgecolor='black', label='Predator 2')
        )
    
    ax_grid.legend(handles=legend_elements, loc='upper left', 
                  bbox_to_anchor=(0, -0.05), ncol=4, frameon=True, fontsize=10)
    
    # Add info text box
    info_text = ax_grid.text(0.02, 0.98, '', transform=ax_grid.transAxes,
                            fontsize=10, verticalalignment='top',
                            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    # ========================================================================
    # TIME SERIES SETUP
    # ========================================================================
    
    line_prey, = ax_timeseries.plot([], [], 'o-', color='#2ECC71', 
                                    label='Prey', linewidth=2, markersize=3)
    line_pred1, = ax_timeseries.plot([], [], 's-', color='#E74C3C', 
                                     label='Predator 1', linewidth=2, markersize=3)
    line_pred2, = ax_timeseries.plot([], [], '^-', color='#F39C12', 
                                     label='Predator 2', linewidth=2, markersize=3)
    
    ax_timeseries.set_xlabel('Time Step', fontsize=11, fontweight='bold')
    ax_timeseries.set_ylabel('Population', fontsize=11, fontweight='bold')
    ax_timeseries.set_title('Population Dynamics', fontsize=14, fontweight='bold')
    ax_timeseries.legend(loc='best', fontsize=10, framealpha=0.9)
    ax_timeseries.grid(True, alpha=0.3, linestyle='--')
    
    # ========================================================================
    # PHASE SPACE SETUP
    # ========================================================================
    
    phase_line, = ax_phase.plot([], [], '-', color='#3498DB', 
                                linewidth=1.5, alpha=0.6)
    phase_start = ax_phase.scatter([], [], c='green', s=150, 
                                   marker='o', label='Start', 
                                   zorder=5, edgecolors='black', linewidths=2)
    phase_current = ax_phase.scatter([], [], c='red', s=150, 
                                     marker='*', label='Current', 
                                     zorder=6, edgecolors='black', linewidths=2)
    
    ax_phase.set_xlabel('Prey Population', fontsize=11, fontweight='bold')
    ax_phase.set_ylabel('Predator 1 Population', fontsize=11, fontweight='bold')
    ax_phase.set_title('Phase Space (Prey vs Predator 1)', 
                      fontsize=14, fontweight='bold')
    ax_phase.legend(loc='best', fontsize=10, framealpha=0.9)
    ax_phase.grid(True, alpha=0.3, linestyle='--')
    
    # ========================================================================
    # ANIMATION FUNCTIONS
    # ========================================================================
    
    def init():
        """Initialize animation."""
        # Set axis limits
        max_pop = max(ca.config.initial_prey, 
                     ca.config.initial_predator1,
                     ca.config.initial_predator2) * 1.5
        
        ax_timeseries.set_xlim(0, num_frames)
        ax_timeseries.set_ylim(0, max_pop)
        
        ax_phase.set_xlim(0, max_pop)
        ax_phase.set_ylim(0, max_pop)
        
        return (im, line_prey, line_pred1, line_pred2, 
                phase_line, phase_start, phase_current, info_text)
    
    def update(frame):
        """Update animation for each frame."""
        # Perform CA step (skip for first frame)
        if frame > 0:
            ca.step()
        
        # Update grid visualization
        display_grid = _prepare_display_grid(ca.grid)
        im.set_array(display_grid)
        
        # Update info text
        prey_count = ca.prey_history[-1] if ca.prey_history else 0
        pred1_count = ca.predator1_history[-1] if ca.predator1_history else 0
        pred2_count = ca.predator2_history[-1] if ca.predator2_history else 0
        
        info_str = (f'Step: {ca.time_step}\n'
                   f'Prey: {prey_count}\n'
                   f'Pred1: {pred1_count}')
        if ca.config.num_predator_types == 2:
            info_str += f'\nPred2: {pred2_count}'
        
        info_text.set_text(info_str)
        
        # Update time series
        time_steps = list(range(len(ca.prey_history)))
        line_prey.set_data(time_steps, ca.prey_history)
        line_pred1.set_data(time_steps, ca.predator1_history)
        if ca.config.num_predator_types == 2:
            line_pred2.set_data(time_steps, ca.predator2_history)
        
        # Auto-scale y-axis for time series if needed
        if ca.prey_history:
            max_current = max(max(ca.prey_history), 
                            max(ca.predator1_history))
            if ca.config.num_predator_types == 2:
                max_current = max(max_current, max(ca.predator2_history))
            
            current_ylim = ax_timeseries.get_ylim()[1]
            if max_current > current_ylim * 0.9:
                ax_timeseries.set_ylim(0, max_current * 1.2)
        
        # Update phase space
        if len(ca.prey_history) > 0:
            phase_line.set_data(ca.prey_history, ca.predator1_history)
            
            # Update start marker (only once)
            if frame == 0:
                phase_start.set_offsets([[ca.prey_history[0], 
                                        ca.predator1_history[0]]])
            
            # Update current position marker
            phase_current.set_offsets([[ca.prey_history[-1], 
                                      ca.predator1_history[-1]]])
            
            # Auto-scale phase space if needed
            max_prey = max(ca.prey_history)
            max_pred1 = max(ca.predator1_history)
            
            if max_prey > ax_phase.get_xlim()[1] * 0.9:
                ax_phase.set_xlim(0, max_prey * 1.2)
            if max_pred1 > ax_phase.get_ylim()[1] * 0.9:
                ax_phase.set_ylim(0, max_pred1 * 1.2)
        
        # Progress indicator
        if frame % 10 == 0:
            print(f"Frame {frame}/{num_frames} - "
                  f"Prey: {prey_count}, Pred1: {pred1_count}, Pred2: {pred2_count}")
        
        return (im, line_prey, line_pred1, line_pred2, 
                phase_line, phase_start, phase_current, info_text)
    
    # ========================================================================
    # CREATE ANIMATION
    # ========================================================================
    
    print(f"Generating {num_frames} frames...")
    
    anim = FuncAnimation(fig, update, frames=num_frames, 
                        init_func=init, interval=interval,
                        blit=True, repeat=True)
    
    # Save as GIF
    if save_gif:
        print(f"\nSaving animation as GIF (this may take a while)...")
        writer = PillowWriter(fps=1000//interval)
        anim.save(gif_filename, writer=writer, dpi=dpi)
        print(f"✓ Animation saved successfully to: {gif_filename}")
        
        # Get file size
        file_size = os.path.getsize(gif_filename) / (1024 * 1024)  # Convert to MB
        print(f"  File size: {file_size:.2f} MB")
    
    # Adjust layout
    plt.suptitle('Predator-Prey Cellular Automaton Evolution', 
                fontsize=16, fontweight='bold', y=0.98)
    

    
    return anim


def _prepare_display_grid(grid: np.ndarray) -> np.ndarray:
    """
    Prepare grid for display by mapping all states to display values.
    
    Args:
        grid: Raw CA grid
        
    Returns:
        Grid with states mapped to [0, 1, 2, 3]
    """
    display_grid = grid.copy()
    
    # Map predator 2 to display value 3
    display_grid[grid == CellState.PREDATOR2] = 3
    
    # Map temporary states back to their base states
    display_grid[grid == CellState.EMPTY_AFTER_ATTACK] = 0
    display_grid[grid == CellState.PREDATOR1_FED] = 2
    display_grid[grid == CellState.PREDATOR2_FED] = 3
    
    return display_grid


def create_snapshot_sequence(ca: PredatorPreyCA,
                            num_snapshots: int = 6,
                            steps_between: int = 50,
                            save_path: Optional[str] = None):
    """
    Create a sequence of snapshots showing evolution at different time points.
    
    Args:
        ca: PredatorPreyCA instance (will be reset)
        num_snapshots: Number of snapshots to take
        steps_between: Steps between each snapshot
        save_path: Path to save the figure
    """
    # Reset CA
    initial_config = ca.config
    ca.__init__(initial_config)
    
    # Setup figure
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    # Color map
    colors = ['#F0F0F0', '#2ECC71', '#E74C3C', '#F39C12']
    cmap = plt.cm.colors.ListedColormap(colors)
    bounds = [0, 1, 2, 3, 4]
    norm = plt.cm.colors.BoundaryNorm(bounds, cmap.N)
    
    for idx in range(num_snapshots):
        # Take snapshot
        display_grid = _prepare_display_grid(ca.grid)
        
        axes[idx].imshow(display_grid, cmap=cmap, norm=norm, interpolation='nearest')
        axes[idx].set_title(f'Time Step {ca.time_step}\n'
                          f'Prey: {ca.prey_history[-1] if ca.prey_history else 0}, '
                          f'Pred1: {ca.predator1_history[-1] if ca.predator1_history else 0}',
                          fontsize=11, fontweight='bold')
        axes[idx].axis('off')
        
        # Run steps for next snapshot
        if idx < num_snapshots - 1:
            for _ in range(steps_between):
                ca.step()
    
    plt.suptitle('Evolution Sequence', fontsize=16, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Snapshot sequence saved to: {save_path}")
    
    plt.show()


def create_comparison_gif(configs: list, 
                         num_frames: int = 200,
                         interval: int = 50,
                         gif_filename: Optional[str] = None,
                         dpi: int = 80):
    """
    Create a comparison GIF showing multiple configurations side by side.
    
    Args:
        configs: List of SimulationConfig objects to compare
        num_frames: Number of frames
        interval: Milliseconds between frames
        gif_filename: Output filename
        dpi: Resolution
    """
    num_configs = len(configs)
    
    if gif_filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        gif_filename = f"output/comparison_{timestamp}.gif"
    
    os.makedirs("output", exist_ok=True)
    
    # Create CAs
    cas = [PredatorPreyCA(config) for config in configs]
    
    # Setup figure
    fig, axes = plt.subplots(1, num_configs, figsize=(6*num_configs, 6))
    if num_configs == 1:
        axes = [axes]
    
    # Color map
    colors = ['#F0F0F0', '#2ECC71', '#E74C3C', '#F39C12']
    cmap = plt.cm.colors.ListedColormap(colors)
    bounds = [0, 1, 2, 3, 4]
    norm = plt.cm.colors.BoundaryNorm(bounds, cmap.N)
    
    # Initialize images
    images = []
    info_texts = []
    
    for idx, (ax, ca) in enumerate(zip(axes, cas)):
        display_grid = _prepare_display_grid(ca.grid)
        im = ax.imshow(display_grid, cmap=cmap, norm=norm, interpolation='nearest')
        images.append(im)
        
        ax.set_title(f'Configuration {idx+1}', fontsize=12, fontweight='bold')
        ax.axis('off')
        
        info_text = ax.text(0.02, 0.98, '', transform=ax.transAxes,
                          fontsize=9, verticalalignment='top',
                          bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        info_texts.append(info_text)
    
    def init():
        return images + info_texts
    
    def update(frame):
        if frame > 0:
            for ca in cas:
                ca.step()
        
        for idx, (ca, im, info_text) in enumerate(zip(cas, images, info_texts)):
            display_grid = _prepare_display_grid(ca.grid)
            im.set_array(display_grid)
            
            prey = ca.prey_history[-1] if ca.prey_history else 0
            pred1 = ca.predator1_history[-1] if ca.predator1_history else 0
            
            info_text.set_text(f'Step: {ca.time_step}\nPrey: {prey}\nPred1: {pred1}')
        
        if frame % 10 == 0:
            print(f"Frame {frame}/{num_frames}")
        
        return images + info_texts
    
    print(f"Creating comparison GIF with {num_configs} configurations...")
    
    anim = FuncAnimation(fig, update, frames=num_frames, 
                        init_func=init, interval=interval, blit=True)
    
    writer = PillowWriter(fps=1000//interval)
    anim.save(gif_filename, writer=writer, dpi=dpi)
    
    plt.close()
    
    file_size = os.path.getsize(gif_filename) / (1024 * 1024)
    print(f"✓ Comparison GIF saved to: {gif_filename}")
    print(f"  File size: {file_size:.2f} MB")
    
    return anim


# ============================================================================
# UPDATED MAIN FUNCTION WITH ENHANCED VISUALIZATION
# ============================================================================

def run_simulation():
    """
    Run a simulation with enhanced visualization options.
    """
    print("=" * 70)
    print("PREDATOR-PREY CELLULAR AUTOMATON SIMULATION")
    print("Based on Cattaneo, Dennunzio, and Farina (2006)")
    print("=" * 70)
    
    # Get user configuration
    print("\n--- SIMULATION CONFIGURATION ---")
    
    grid_size = int(input("Grid size (e.g., 100 for 100x100) [default=100]: ") or "100")
    
    num_predators = int(input("Number of predator types (1 or 2) [default=2]: ") or "2")
    
    if num_predators not in [1, 2]:
        print("Invalid input. Using 2 predator types.")
        num_predators = 2
    
    initial_prey = int(input(f"Initial prey population [default={grid_size**2 // 4}]: ") 
                      or str(grid_size**2 // 4))
    initial_pred1 = int(input(f"Initial predator 1 population [default={grid_size**2 // 10}]: ") 
                       or str(grid_size**2 // 10))
    
    if num_predators == 2:
        initial_pred2 = int(input(f"Initial predator 2 population [default={grid_size**2 // 20}]: ") 
                           or str(grid_size**2 // 20))
    else:
        initial_pred2 = 0
    
    num_steps = int(input("Number of simulation steps [default=200]: ") or "200")
    
    # Animation options
    print("\n--- ANIMATION OPTIONS ---")

    gif_frames = num_steps
    gif_interval = 50
    
    gif_frames = int(input(f"Number of frames for GIF [default={num_steps}]: ") 
                    or str(num_steps))
    gif_interval = int(input("Interval between frames in ms [default=50]: ") or "50")

    
    # Create configuration
    config = SimulationConfig(
        grid_size=(grid_size, grid_size),
        num_predator_types=num_predators,
        initial_prey=initial_prey,
        initial_predator1=initial_pred1,
        initial_predator2=initial_pred2,
        prey_params=SpeciesParameters(
            birth_prob=0.65,
            death_prob=0.0 #not considering prey natural death
        ),
        predator1_params=SpeciesParameters(
            birth_prob=0.35,
            death_prob=0.35,
            hunt_success_prob=0.65 #dp parameter
        ),
        predator2_params=SpeciesParameters(
            birth_prob=0.3,
            death_prob=0.4,
            hunt_success_prob=0.6
        ),
        movement_radius=2,
        use_enhanced_model=True,
        enhancement_function="cosine"
    )
    
    # Create CA
    ca = PredatorPreyCA(config)
    

    create_live_animation_with_gif(
        ca, 
        num_frames=gif_frames,
        interval=gif_interval,
        save_gif=True,
        dpi=100
    )

    
    # Display final statistics
    print("\n--- FINAL STATISTICS ---")
    stats = ca.get_population_stats()
    print(f"Prey: {stats['prey']}")
    print(f"Predator 1: {stats['predator1']}")
    if num_predators == 2:
        print(f"Predator 2: {stats['predator2']}")
    

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    create_snapshot_sequence(ca, num_snapshots=6, steps_between=50,
                            save_path=f"output/snapshots_{timestamp}.png")

    return ca


if __name__ == "__main__":
    # Run example simulation
    ca = run_simulation()
    
    print("\n" + "="*70)
    print("SIMULATION COMPLETE")
    print("="*70)

PREDATOR-PREY CELLULAR AUTOMATON SIMULATION
Based on Cattaneo, Dennunzio, and Farina (2006)

--- SIMULATION CONFIGURATION ---

--- ANIMATION OPTIONS ---

CREATING LIVE ANIMATION
Animation will be saved to: output/predator_prey_animation_20260219_051721.gif


<IPython.core.display.Javascript object>

Generating 200 frames...

Saving animation as GIF (this may take a while)...
Frame 0/200 - Prey: 0, Pred1: 0, Pred2: 0
Frame 10/200 - Prey: 36243, Pred1: 1483, Pred2: 231
Frame 20/200 - Prey: 40537, Pred1: 697, Pred2: 14
Frame 30/200 - Prey: 41669, Pred1: 363, Pred2: 0
Frame 40/200 - Prey: 41925, Pred1: 191, Pred2: 0


KeyboardInterrupt: 