## Section 1: Helper functions

- **Neighbor-count helpers** for cellular automata:
  - `neighbor_count_vn(...)` counts *Von Neumann* neighbors (up/down/left/right).
  - `moore_count_radius1(...)` counts *Moore* neighbors in a 3×3 area (8 neighbors).
  - `moore_count_radius2(...)` counts neighbors in a 5×5 area (24 neighbors).

- **Stage bar drawing**:
  - `_draw_stage_bar(...)` draws a vertical black bar on the left side of the grid to represent the stage.

- **Main animation function**:
  - `make_animation_with_exhaustion(...)` creates a 3-panel animation:
    1) Crowd **state** grid (CHILL / BOUNCE / MOSH)
    2) Crowd **exhaustion** heatmap (0–10) with a colorbar and a red mean marker
    3) Time series of **mean exhaustion** and **danger fraction (≥ 5)**

In [None]:
# Imports
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import time

from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import Image, display
from tqdm import tqdm
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch, Rectangle

def neighbor_count_vn(bool_grid):
    """
    Count Von Neumann neighbors (4-neighborhood: up, down, left, right) for each cell.

    Parameters
    ----------
    bool_grid : np.ndarray (bool or 0/1)
        Grid indicating which cells are "occupied" / "true".

    Returns
    -------
    np.ndarray (int)
        For each cell, the number of occupied Von Neumann neighbors.
    """
    # Convert boolean grid to small integers
    A = bool_grid.astype(np.int8)

    # Pad with a 1-cell border of zeros so edge cells don't need special cases
    P = np.pad(A, 1, mode="constant", constant_values=0)

    # Sum the 4 neighbors using shifted slices (left, right, up, down)
    return P[1:-1, 0:-2] + P[1:-1, 2:] + P[0:-2, 1:-1] + P[2:, 1:-1]


def moore_count_radius1(occ):
    """
    Count Moore neighbors in radius 1 (8-neighborhood in a 3x3 box around each cell).

    This returns the number of occupied cells among the 8 surrounding neighbors.
    (It does NOT count the center cell itself.)

    Parameters
    ----------
    occ : np.ndarray (bool or 0/1)
        Occupancy grid.

    Returns
    -------
    np.ndarray (int)
        For each cell, number of occupied Moore neighbors (radius 1).
    """
    A = occ.astype(np.int8)
    P = np.pad(A, 1, mode="constant", constant_values=0)

    # Sum the full 3x3 neighborhood around each cell (including the center)
    S = (
        P[0:-2, 0:-2] + P[0:-2, 1:-1] + P[0:-2, 2:] +
        P[1:-1, 0:-2] + P[1:-1, 1:-1] + P[1:-1, 2:] +
        P[2:  , 0:-2] + P[2:  , 1:-1] + P[2:  , 2:]
    )

    # Subtract the center so we get only the 8 neighbors (not counting itself)
    return S - A


def moore_count_radius2(occ):
    """
    Count Moore neighbors in radius 2 (neighbors within a 5x5 box around each cell).

    This counts all occupied cells in the 5x5 neighborhood, then subtracts the center.
    So the maximum possible is 24 (since 5x5 = 25 minus the center cell).

    Parameters
    ----------
    occ : np.ndarray (bool or 0/1)
        Occupancy grid.

    Returns
    -------
    np.ndarray (int)
        For each cell, number of occupied Moore neighbors (radius 2).
    """
    A = occ.astype(np.int8)

    # Pad by 2 so that shifting a 5x5 window works at the boundaries
    P = np.pad(A, 2, mode="constant", constant_values=0)

    # Use int16 to avoid overflow
    S = np.zeros_like(A, dtype=np.int16)

    # Sum all 25 offsets in the 5x5 neighborhood
    for dr in range(5):
        for dc in range(5):
            S += P[dr:dr+A.shape[0], dc:dc+A.shape[1]]

    # Remove the center cell contribution so only neighbors remain
    return S - A


def _draw_stage_bar(ax, n, stage_bar_cols=6, label="STAGE"):
    """
    Draw a vertical "stage" bar to the left of the grid.

    This has no functional value. It makes the "stage side" obvious
    and gives the animation a venue-like layout.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        Axis to draw onto.
    n : int
        Grid size (n x n).
    stage_bar_cols : int
        Width (in "cell columns") of the stage bar.
    label : str
        Text label to draw inside the bar.
    """
    stage_bar_cols = int(stage_bar_cols)

    # Expand x-limits to include extra space to the left for the stage bar
    ax.set_xlim(-stage_bar_cols - 0.5, n - 0.5)
    ax.set_ylim(n - 0.5, -0.5)

    # Draw the stage rectangle on the left side
    rect = Rectangle(
        (-stage_bar_cols - 0.5, -0.5),
        stage_bar_cols,
        n,
        facecolor="black",
        edgecolor="black",
        linewidth=0
    )
    ax.add_patch(rect)

    # Add vertical text label centered in the bar
    ax.text(
        -stage_bar_cols / 2 - 0.5,
        (n - 1) / 2,
        label,
        color="white",
        ha="center",
        va="center",
        fontsize=14,
        fontweight="bold",
        rotation=90,
        clip_on=False
    )


def make_animation_with_exhaustion(
    sim,
    total_steps,
    display_stride=1,
    interval=90,
    stage_bar_cols=9,
    show_song_sections=False,
    gif_path=None,
    fps=None,
    dpi=120
):
    """
    Create a 3-panel animation of the simulation:
      (1) Crowd state grid
      (2) Exhaustion heatmap
      (3) Time series: mean exhaustion + danger fraction (≥ 5)

    Parameters
    ----------
    sim : PitCASimulator
        Simulator instance. It has attributes like:
        - sim.grid (state grid)
        - sim.exh (exhaustion grid)
        - sim.exh_mean_hist, sim.exh_danger_hist (time series lists)
        - sim.cmap, sim.legend_handles (for colors/legend)
        - sim.steps_per_level, sim.concert_sequence
        and methods like initialize(), update(), _energy_level(), _song_index().
    total_steps : int
        Total number of simulation steps to run.
    display_stride : int
        How many simulation steps to advance per animation frame.
    interval : int (ms)
        Milliseconds between frames (animation playback speed).
    stage_bar_cols : int
        Width of stage bar (visual only).
    show_song_sections : bool
        If True and sim.concert_sequence exists, shade the time-series plot
        by song sections and label songs.

    Returns
    -------
    IPython.display.Image
        GIF saved to disk (and displayed inline).
    """
    # Convert inputs and compute number of frames
    total_steps = int(total_steps)
    display_stride = max(1, int(display_stride))
    total_frames = int(np.ceil(total_steps / display_stride))

    # 3 panels: state, exhaustion, time-series
    fig = plt.figure(figsize=(18, 5))
    ax_state = fig.add_subplot(1, 3, 1)
    ax_exh   = fig.add_subplot(1, 3, 2)
    ax_line  = fig.add_subplot(1, 3, 3)

    # Remove tick labels for the image panels
    for ax in [ax_state, ax_exh]:
        ax.set_xticks([]); ax.set_yticks([])

    # Initialize sim
    sim.initialize(ax=ax_state)

    #######################
    # Panel 1: state grid #
    #######################
    im_state = ax_state.imshow(
        sim.grid,
        vmin=0,
        vmax=3,
        cmap=sim.cmap,
        interpolation="nearest"
    )
    _draw_stage_bar(ax_state, sim.n, stage_bar_cols=stage_bar_cols, label="STAGE")
    ax_state.legend(handles=sim.legend_handles, loc="upper right", framealpha=0.95)

    ###############################
    # Panel 2: exhaustion heatmap #
    ###############################
    # Mask exhaustion where there is no person (empty cells)
    occ = sim.grid > 0
    exh_masked = np.ma.array(sim.exh, mask=~occ)
    im_exh = ax_exh.imshow(exh_masked, vmin=0, vmax=10, interpolation="nearest")
    _draw_stage_bar(ax_exh, sim.n, stage_bar_cols=stage_bar_cols, label="STAGE")
    ax_exh.set_title("Crowd Exhaustion Level (0–10)")

    # Add colorbar for exhaustion (0–10)
    cbar = fig.colorbar(im_exh, ax=ax_exh, fraction=0.046, pad=0.04)

    # Red marker showing current mean exhaustion on the colorbar
    init_mean = float(sim.exh_mean_hist[-1]) if len(sim.exh_mean_hist) else 0.0
    mean_bar = cbar.ax.axhline(init_mean, color="red", linewidth=3.0, alpha=1.0, zorder=50)

    ##############################
    # Panel 3: time-series chart #
    ##############################
    ax_line.set_title("Crowd Exhaustion Level Over Time")
    ax_line.set_xlabel("Time Step")
    ax_line.set_ylabel("Crowd Exhaustion Level (0–10)")
    ax_line.set_ylim(0, 10)

    # Second y-axis for danger fraction to show two scales at once
    ax_danger = ax_line.twinx()
    ax_danger.set_ylabel("Danger Fraction (≥ 7)")
    ax_danger.set_ylim(0, 1)

    # Start with whatever history already exists after initialization
    x = np.arange(len(sim.exh_mean_hist))
    l_mean, = ax_line.plot(x, sim.exh_mean_hist, label="Mean", color="black")
    l_dang, = ax_danger.plot(x, sim.exh_danger_hist, label="Danger fraction", color="red")

    # Concert shading by song sections
    # These lists store artists/patches so we can return them for animation updates
    song_lines = []
    song_labels = []
    song_spans = []
    energy_legend = []

    if show_song_sections and (sim.concert_sequence is not None):
        # Each song lasts L steps in the simulation
        L = int(sim.steps_per_level)

        # Colors by song energy level (0 low, 1 medium, 2 high)
        span_color = {0: "#6baed6", 1: "#fdae6b", 2: "#fb6a4a"}
        span_alpha = 0.22
        bar_gray = "0.55"

        for i, E in enumerate(sim.concert_sequence):
            start = i * L
            end = (i + 1) * L

            # Shade the background for that song section on both y-axes
            sp1 = ax_line.axvspan(start, end, color=span_color[int(E)], alpha=span_alpha, zorder=0)
            sp2 = ax_danger.axvspan(start, end, color=span_color[int(E)], alpha=span_alpha, zorder=0)
            song_spans.extend([sp1, sp2])

            # Draw vertical bars to show song boundaries
            ln1 = ax_line.axvline(start, color=bar_gray, linewidth=1, zorder=1)
            ln2 = ax_danger.axvline(start, color=bar_gray, linewidth=1, zorder=1)
            song_lines.extend([ln1, ln2])

            # Label each song in the time-series panel
            mid = start + 0.5 * L
            txt = ax_line.text(
                mid, 0.15,
                f"song {i+1}",
                ha="center", va="bottom",
                fontsize=8,
                rotation=90,
                color="black",
                alpha=1.0,
                zorder=50,
                clip_on=True
            )
            song_labels.append(txt)

        # Give the y-limits a bit of margin so labels have space
        ax_line.margins(y=0.10)

        # Legend entries explaining the shaded song sections
        energy_legend = [
            Patch(facecolor=span_color[0], edgecolor="none", alpha=1.0, label="Low energy song"),
            Patch(facecolor=span_color[1], edgecolor="none", alpha=1.0, label="Medium energy song"),
            Patch(facecolor=span_color[2], edgecolor="none", alpha=1.0, label="High energy song"),
        ]

    # legend for time-series panel
    handles = [l_mean, l_dang] + energy_legend
    labels  = [h.get_label() for h in handles]

    # Make room on the right side so the legend can sit outside the plot
    fig.tight_layout(rect=[0, 0, 0.80, 1])

    # Place legend outside plot so it doesn't cover lines
    leg = ax_line.legend(
        handles=handles,
        labels=labels,
        loc="center left",
        bbox_to_anchor=(1.5, 0.5),
        framealpha=1.0
    )
    leg.get_frame().set_facecolor("white")
    leg.get_frame().set_edgecolor("0.2")
    leg.set_zorder(100)

    # Progress bar for rendering frames
    pbar = tqdm(total=total_frames)

    def _title_for(sim):
        """
        Build a title describing the current song energy + index.
        """
        E = sim._energy_level()
        name = {0: "Low energy song", 1: "Medium energy song", 2: "High energy song"}[E]
        if sim.concert_sequence is not None:
            song_i = min(sim._song_index() + 1, len(sim.concert_sequence))
            return f"{name} | song {song_i}/{len(sim.concert_sequence)} | t={sim.t}"
        else:
            return f"{name} | t={sim.t}"

    # Set initial title before animation starts
    ax_state.set_title(_title_for(sim))

    def _update(frame_idx):
        """
        Update function called by FuncAnimation for each frame.
        It advances the simulation and then updates all 3 panels.
        """
        # Figure out which simulation steps correspond to this frame
        start_step = frame_idx * display_stride
        remaining = total_steps - start_step
        steps_now = min(display_stride, max(0, remaining))

        # Advance simulation by `steps_now` steps
        for _ in range(steps_now):
            sim.update()

        # Panel 1 update: state grid
        im_state.set_data(sim.grid)
        ax_state.set_title(_title_for(sim))

        # Panel 2 update: exhaustion heatmap
        # Mask out empty cells again so only the crowd shows exhaustion
        occ = sim.grid > 0
        im_exh.set_data(np.ma.array(sim.exh, mask=~occ))

        # Move the red mean marker on the colorbar to current mean exhaustion
        m = float(sim.exh_mean_hist[-1]) if len(sim.exh_mean_hist) else 0.0
        mean_bar.set_ydata([m, m])

        # Panel 3 update: time series
        x = np.arange(len(sim.exh_mean_hist))
        l_mean.set_data(x, sim.exh_mean_hist)
        l_dang.set_data(x, sim.exh_danger_hist)

        # Keep x-limits sensible depending on whether we are showing song sections
        if show_song_sections and (sim.concert_sequence is not None):
            ax_line.set_xlim(0, total_steps)
            ax_danger.set_xlim(0, total_steps)
        else:
            ax_line.set_xlim(0, max(10, len(x) - 1))
            ax_danger.set_xlim(0, max(10, len(x) - 1))

        # Update the progress bar once per frame
        pbar.update(1)

        # Return artists (useful if blitting, even though blit=False here)
        return [im_state, im_exh, mean_bar, l_mean, l_dang] + song_spans + song_lines + song_labels

    # Build animation object
    anim = FuncAnimation(fig, _update, frames=total_frames, interval=interval, blit=False)

    # Save as GIF instead of inline JSHTML
    import os, re

    # Default FPS derived from interval if not provided
    if fps is None:
        fps = max(1, int(round(1000 / interval)))

    # Default output path
    if gif_path is None:
        # Build a readable filename based on whether this is a concert run
        base = 'concert' if (sim.concert_sequence is not None) else {0: 'low', 1: 'medium', 2: 'high'}[sim._energy_level()]
        gif_path = f"assets/{base}_animation.gif"

    os.makedirs(os.path.dirname(gif_path) or '.', exist_ok=True)

    writer = PillowWriter(fps=fps)
    anim.save(gif_path, writer=writer, dpi=dpi)
    pbar.close()

    # Close figure so it doesn't also display as a static image
    plt.close(fig)
    return Image(filename=gif_path)


## Section 2: PitCASimulator

**Big picture**
- The grid is `n × n`. Each cell is either **empty** or contains one person.
- Each person has:
  - a **crowd state** (`CHILL`, `BOUNCE`, `MOSH`)
  - a **personality type** (`RESERVED`, `REGULAR`, `ACTIVE`)
  - an **exhaustion level** `exh` from **0 to 10**
  - a **personal space preference** radius (1 or 2)
  - an optional **instigator** flag (helps trigger energy in high-energy songs)

**Concert / songs**
- The concert is represented as a sequence of song energy levels (`0=low`, `1=medium`, `2=high`).
- Inside each song, the model also generates a smooth **within-song intensity** `I(t)` (like verse/chorus/breakdown).
- The simulation uses both:
  - the **song base energy** (the label 0/1/2)
  - an **effective energy** (can temporarily “dip” or “spike” within the song)

**Main mechanisms**
1) **State update**: people become more active when neighbors are active, with extra pressure near the stage.
2) **Movement**: people move locally using weighted choices:
   - stage attraction/repulsion (depends on state)
   - crowd avoidance / seeking (depends on state)
   - personal-space penalties
   - a “the front must be denser than the back” constraint
   - extra MOSH pushing + swirl effects
3) **Exhaustion update**: fatigue increases with activity and crowding, recovery increases when calm.
4) **Circles / aftershock** (only in high-energy songs): temporary “holes” and ring behavior that collapses inward.

In [None]:
class PitCASimulator:
    """
    Cellular automaton for crowd dynamics in a concert pit.

    Grid representation
    -------------------
    - The pit is a 2D grid of size n x n.
    - Each cell is either EMPTY or contains one person in a crowd "state".

    Person properties stored in grids
    ---------------------------------
    - grid[r,c]         : crowd state (EMPTY / CHILL / BOUNCE / MOSH)
    - personality[r,c]  : personality type (NONE / RESERVED / REGULAR / ACTIVE)
    - ps_radius[r,c]    : personal-space radius preference (1 or 2)
    - instigator[r,c]   : whether this person is an instigator (used in high energy songs)
    - exh[r,c]          : exhaustion level in [0, 10]

    Concert logic
    -------------
    - Each song has a base energy level E_song in {0,1,2}.
    - Within a song, there is a smooth intensity multiplier I(t) around 1.0
      to mimic verse/chorus/breakdown variation.
    - The model converts (E_song, I(t)) into an "effective energy" E_eff
      used for state transitions inside a song.
    """

    # Crowd state labels (what each occupied cell is "doing")
    EMPTY, CHILL, BOUNCE, MOSH = 0, 1, 2, 3

    # Personality labels (what kind of person it is)
    NONE, RESERVED, REGULAR, ACTIVE = 0, 1, 2, 3

    def __init__(
        self,
        n=100,
        density=0.8,

        # emptiness gradient
        front_dense_frac=0.35,
        front_empty_weight=15.0,
        back_empty_boost=1000.0,
        back_ramp_power=1.0,

        # preserve gradient during motion
        profile_strength=0.12,
        profile_strength_mosh=0.06,

        # personality mix
        personality_probs=(0.20, 0.60, 0.20),

        # personal space
        stage_relax_cols=6,
        space_weight=1.0,
        desired_empty_r1=1,
        desired_empty_r2=3,

        # non-mosh movement
        beta_stage=(0.6, 0.9, 1.2),
        crowd_avoid=0.18,

        # MOSH movement/pressure toward stage
        mosh_abs_stage=12.0,
        mosh_left_boost=150.0,
        mosh_edge_cols=2,
        mosh_edge_boost=320.0,
        mosh_seek_crowd=2.4,
        mosh_push_prob=0.90,
        mosh_swirl=1.05,

        # movement noise
        min_move=0.25,
        swap_prob=0.22,
        move_noise=0.26,
        stay_bias=0.15,
        random_step_prob=0.08,

        # instigators (E=2 only)
        instigator_count=(3, 8),
        instigator_period=7,
        instigator_influence_r1=True,

        # circles (E=2 only)
        circle_enable=True,
        circle_start_prob=0.040,
        circle_max_count=2,
        circle_cooldown_steps=100,
        circle_min_sep=12,

        circle_open_steps=16,
        circle_hold_steps=26,
        circle_close_steps=16,

        circle_r_max=9,
        circle_ring_width=3,
        circle_ring_strength=3.5,
        circle_collapse_inward=3.0,

        # aftershock (post-collapse center MOSH)
        aftershock_steps=22,
        aftershock_r_frac=0.45,
        aftershock_pull=2.4,

        circle_center_col_frac=(0.22, 0.66),
        circle_center_row_frac=(0.14, 0.86),

        # song length + schedule
        steps_per_level=75,
        fixed_energy=None,
        concert_sequence=None,

        seed=None
    ):
        """
        Initialize model parameters.

        Parameters (high level)
        -----------------------
        n, density
            Grid size and overall crowd density.

        emptiness gradient parameters
            These control how the crowd is denser near the stage and emptier in the back.

        movement parameters
            These control how likely people move and what directions/locations they prefer.

        instigators / circles
            Special events/mechanisms that only activate in high-energy songs (E=2).

        concert_sequence, steps_per_level
            Defines the concert: each song lasts steps_per_level steps,
            and each song has a base energy in {0,1,2}.
        """
        # basic setup
        self.n = int(n)
        self.density = float(density)

        # initial occupancy gradient (stage side is left, back is right)
        self.front_dense_frac = float(front_dense_frac)
        self.front_empty_weight = float(front_empty_weight)
        self.back_empty_boost = float(back_empty_boost)
        self.back_ramp_power = float(back_ramp_power)

        # how strongly movement tries to preserve the target density-by-column profile
        self.profile_strength = float(profile_strength)
        self.profile_strength_mosh = float(profile_strength_mosh)

        # personality probabilities: (RESERVED, REGULAR, ACTIVE)
        self.personality_probs = np.array(personality_probs, dtype=float)
        self.personality_probs = self.personality_probs / self.personality_probs.sum()

        # personal space behavior
        self.stage_relax_cols = int(stage_relax_cols)  # near-stage zone where personal space is relaxed
        self.space_weight = float(space_weight)
        self.desired_empty_r1 = int(desired_empty_r1)  # desired empty neighbors in radius-1
        self.desired_empty_r2 = int(desired_empty_r2)  # desired empty neighbors in radius-2

        # "stage attraction" for non-mosh movement, depends on state
        self.beta_stage = {self.CHILL: float(beta_stage[0]),
                           self.BOUNCE: float(beta_stage[1]),
                           self.MOSH: float(beta_stage[2])}
        self.crowd_avoid = float(crowd_avoid)

        # MOSH movement: stronger stage pull, pushing, and edge effects
        self.mosh_abs_stage = float(mosh_abs_stage)
        self.mosh_left_boost = float(mosh_left_boost)
        self.mosh_edge_cols = int(mosh_edge_cols)
        self.mosh_edge_boost = float(mosh_edge_boost)
        self.mosh_seek_crowd = float(mosh_seek_crowd)
        self.mosh_push_prob = float(mosh_push_prob)
        self.mosh_swirl = float(mosh_swirl)

        # stochastic movement controls
        self.min_move = float(min_move)
        self.swap_prob = float(swap_prob)
        self.move_noise = float(move_noise)
        self.stay_bias = float(stay_bias)
        self.random_step_prob = float(random_step_prob)

        # instigators (only matter during high energy songs)
        self.instigator_count = instigator_count
        self.instigator_period = int(instigator_period)
        self.instigator_influence_r1 = bool(instigator_influence_r1)

        # circle pit settings (only in high energy songs)
        self.circle_enable = bool(circle_enable)
        self.circle_start_prob = float(circle_start_prob)
        self.circle_max_count = int(circle_max_count)
        self.circle_cooldown_steps = int(circle_cooldown_steps)
        self.circle_min_sep = float(circle_min_sep)

        self.circle_open_steps = int(circle_open_steps)
        self.circle_hold_steps = int(circle_hold_steps)
        self.circle_close_steps = int(circle_close_steps)

        self.circle_r_max = float(circle_r_max)
        self.circle_ring_width = float(circle_ring_width)
        self.circle_ring_strength = float(circle_ring_strength)
        self.circle_collapse_inward = float(circle_collapse_inward)

        # aftershock: temporary central MOSH after circle collapse
        self.aftershock_steps = int(aftershock_steps)
        self.aftershock_r_frac = float(aftershock_r_frac)
        self.aftershock_pull = float(aftershock_pull)

        # allowed region for circle centers (fractions of grid size)
        self.circle_center_col_frac = circle_center_col_frac
        self.circle_center_row_frac = circle_center_row_frac

        # song schedule
        self.steps_per_level = int(steps_per_level)
        self.fixed_energy = fixed_energy if fixed_energy is not None else None
        self.concert_sequence = None if concert_sequence is None else [int(x) for x in concert_sequence]

        # random number generator for ALL randomness in the sim
        self.rng = np.random.default_rng(seed)

        # Base movement tendency by crowd state (later modulated by within-song intensity)
        self.p_move = {self.CHILL: 0.30, self.BOUNCE: 0.60, self.MOSH: 0.88}

        # Colors for visualization (state grid)
        self.cmap = ListedColormap(["white", "#6baed6", "#fdae6b", "#fb6a4a"])
        self.legend_handles = [
            Patch(facecolor="#6baed6", edgecolor="k", label="CHILL"),
            Patch(facecolor="#fdae6b", edgecolor="k", label="BOUNCE"),
            Patch(facecolor="#fb6a4a", edgecolor="k", label="MOSH"),
        ]

        # GLOBAL stage MOSH behavior
        # These constants shape how strongly effects fade with distance from stage
        self._STAGE_DECAY_COLS = 0.22 * self.n
        self._BARRICADE_DECAY_COLS = 0.06 * self.n

        # Center-of-pit shaping for MOSH (stronger in the center rows)
        self._CENTER_SIGMA_FRAC = 0.18
        self._CENTER_BOOST = 0.90
        self._CENTER_PULL = 1.6

        # Minimum neighbor "energetic" count needed to allow crowd-driven MOSH
        self._MOSH_MIN_ENERGETIC = 3

        # Base per-song-level probabilities (later modulated by intensity I(t))
        # - _SEED_PROB : chance MOSH "seeds" near stage
        # - _CROWD_PROB: chance of switching to MOSH when surrounded by energetic neighbors
        self._SEED_PROB  = {0: 0.0012, 1: 0.010, 2: 0.017}
        self._CROWD_PROB = {0: 0.008,  1: 0.14,  2: 0.21}

        # Barricade effects in energetic songs: "stick" to MOSH and "force" MOSH near stage
        self._BARR_FORCE = {0: 0.00, 1: 0.45, 2: 0.58}
        self._BARR_STICK = {0: 0.00, 1: 0.91, 2: 0.94}

        # In high-energy songs, prevent calm states in the front fraction of columns
        self._E2_NO_CALM_COL_FRAC = 0.50

        # Circles can ONLY start after this many steps into a high-energy song
        self._CIRCLE_MIN_START_IN_SONG = 15

        # within-song intensity variation
        # I(t) is a multiplier around 1.0 that changes each step within a song:
        # - Low energy song:   I(t) ~ 0.75 .. 1.05  (mild verse/chorus)
        # - Medium energy song:I(t) ~ 0.80 .. 1.20  (more swing)
        # - High energy song:  I(t) ~ 0.85 .. 1.35  (breakdown spikes)
        self._INTRASONG_RANGE = {0: (0.75, 1.05), 1: (0.80, 1.20), 2: (0.85, 1.35)}
        self._INTRASONG_KNOTS = (6, 9)  # how many "sections" to stitch (verse/chorus/etc)
        self._INTRASONG_NOISE_SD = {0: 0.012, 1: 0.017, 2: 0.022}  # small jitter
        self._INTRASONG_SPIKES_COUNT = {0: (0, 1), 1: (1, 2), 2: (2, 3)}  # chorus/breakdown peaks
        self._INTRASONG_SPIKE_AMP = {0: (0.03, 0.08), 1: (0.06, 0.14), 2: (0.08, 0.20)}
        self._song_profiles = {}  # song_idx -> np.array(length steps_per_level)

        # circle roughness
        # adds an irregular “organic” boundary to circle holes/rings
        self._CIRCLE_ROUGH_AMP = 1.2
        self._CIRCLE_ROUGH_SMOOTH_ITERS = 4
        self.circle_rough = None

        # CROWD EXHAUSTION LEVEL model (0–10)
        # Exhaustion increases with activity + crowding and decreases with recovery when calm
        self._EXH_BASE = {self.CHILL: 0.0006, self.BOUNCE: 0.0075, self.MOSH: 0.0180}
        self._EXH_K_OCC   = 0.006
        self._EXH_K_HIGH  = 0.010
        self._EXH_RECOVER = 0.075
        self._EXH_NOISE   = 0.0015
        self._EXH_DANGER  = 5.0

        # Per-song-level multipliers (global):
        # - E2: still tires faster than E1
        # - E0: recovers faster than E1
        self._EXH_FATIGUE_MULT_BY_E = {0: 0.2, 1: 0.4, 2: 1.5}
        self._EXH_RECOVER_MULT_BY_E = {0: 5, 1: 3, 2: 1.5}

        # runtime: circles (these are updated while sim runs)
        self.circle_state = 0
        self.circle_k = 0
        self.circle_centers = []
        self.forbidden = None
        self.ring = None
        self.min_dist2 = None
        self.circle_cooldown = 0

        # runtime: aftershock (also updated during run)
        self.aftershock_k = 0
        self.aftershock_centers = []
        self.aftershock_min_dist2 = None

    # song helpers
    def _song_index(self):
        """Which song number we are currently in (0-based)."""
        return self.t // max(1, self.steps_per_level)

    def _song_phase(self):
        """Step index within the current song (0 .. steps_per_level-1)."""
        return self.t % max(1, self.steps_per_level)

    def _energy_for_song_idx(self, idx):
        """
        Get the base energy label (0/1/2) for a given song index.

        Priority order:
        1) fixed_energy (if set)
        2) concert_sequence (if provided)
        3) fallback: cycle 0,1,2 repeating
        """
        if self.fixed_energy is not None:
            return int(self.fixed_energy)
        if self.concert_sequence is not None and len(self.concert_sequence) > 0:
            if idx >= len(self.concert_sequence):
                return int(self.concert_sequence[-1])
            return int(self.concert_sequence[idx])
        return int(idx % 3)

    def _generate_song_profile(self, E_song):
        """
        Generate a smooth within-song intensity profile I(t) for ONE song.

        Intuition:
        - Start with a piecewise-linear curve with random “section” values.
        - Add a few gaussian spikes (chorus/breakdown peaks).
        - Smooth the profile a bit so it feels structured.
        - Add small jitter noise.
        - Clip to the allowed range for that song energy.

        Returns
        -------
        np.ndarray length steps_per_level
            Values roughly around 1.0, with song-dependent range.
        """
        L = max(1, self.steps_per_level)
        lo, hi = self._INTRASONG_RANGE[int(E_song)]

        k0, k1 = self._INTRASONG_KNOTS
        K = int(self.rng.integers(k0, k1 + 1))

        xs = np.linspace(0, L - 1, K)
        ys = self.rng.uniform(lo, hi, size=K)

        prof = np.interp(np.arange(L), xs, ys)

        # add a few "chorus/breakdown" spikes
        c0, c1 = self._INTRASONG_SPIKES_COUNT[int(E_song)]
        n_spikes = int(self.rng.integers(c0, c1 + 1))
        a0, a1 = self._INTRASONG_SPIKE_AMP[int(E_song)]
        for _ in range(n_spikes):
            center = int(self.rng.integers(int(0.18 * L), max(int(0.82 * L), int(0.18 * L) + 1)))
            width = max(2, int(0.06 * L))
            amp = float(self.rng.uniform(a0, a1))
            x = np.arange(L)
            bump = amp * np.exp(-0.5 * ((x - center) / width) ** 2)
            prof = prof + bump

        # smooth a bit so it feels like sections
        for _ in range(3):
            P = np.pad(prof, 1, mode="edge")
            prof = (P[0:-2] + P[1:-1] + P[2:]) / 3.0

        # tiny jitter
        prof = prof * (1.0 + self.rng.normal(0.0, self._INTRASONG_NOISE_SD[int(E_song)], size=L))

        return np.clip(prof, lo, hi)

    def _ensure_profile(self, song_idx):
        """
        Make sure we have already generated and cached the intensity profile
        for this song index.
        """
        if song_idx in self._song_profiles:
            return self._song_profiles[song_idx]
        E_song = self._energy_for_song_idx(song_idx)
        prof = self._generate_song_profile(E_song)
        self._song_profiles[song_idx] = prof
        return prof

    def _energy_level(self):
        """Song base energy level (0/1/2) for the current song."""
        return self._energy_for_song_idx(self._song_index())

    def _song_intensity(self):
        """Within-song intensity I(t) (around 1.0) for the current step."""
        idx = self._song_index()
        phase = self._song_phase()
        prof = self._ensure_profile(idx)
        return float(prof[phase])

    def _effective_energy(self):
        """
        Convert within-song intensity into a temporary "effective" energy E_eff.

        Even inside one song, crowds can calm down or spike up.
        This creates more realistic within-song micro-dynamics.

        Returns
        -------
        int in {0,1,2}
            Temporary energy used for choosing default crowd states.
        """
        E_song = self._energy_level()
        I = self._song_intensity()

        E_eff = E_song
        if E_song == 0:
            if I > 1.02:
                E_eff = 1
        elif E_song == 1:
            if I < 0.90:
                E_eff = 0
            elif I > 1.10:
                E_eff = 2
        else:  # E_song == 2
            if I < 0.95:
                E_eff = 1  # high-energy songs can "dip" but not fully calm
        return int(E_eff)

    def _default_state_from_energy(self, E_eff):
        """Default state for REGULAR people given effective energy."""
        return self.CHILL if E_eff == 0 else self.BOUNCE

    def _record_exhaustion(self):
        """
        Append current exhaustion summary statistics to history lists:
        - mean exhaustion
        - 90th percentile exhaustion
        - danger fraction (exh >= 5)
        """
        occ = self.grid > 0
        if not occ.any():
            self.exh_mean_hist.append(0.0)
            self.exh_p90_hist.append(0.0)
            self.exh_danger_hist.append(0.0)
            return
        vals = self.exh[occ]
        self.exh_mean_hist.append(float(vals.mean()))
        self.exh_p90_hist.append(float(np.percentile(vals, 90)))
        self.exh_danger_hist.append(float((vals >= self._EXH_DANGER).mean()))

    def initialize(self, ax=None):
        """
        Initialize / reset the simulation state for a fresh run.

        This method:
        - builds the initial occupancy using an emptiness gradient (front denser)
        - assigns personalities and personal-space radii
        - chooses instigators
        - initializes exhaustion
        - resets circle/aftershock state
        - performs one initial update pass so histories start non-empty
        """
        # Time step counter (global time across songs)
        self.t = 0

        # Reset song profiles so each run gets fresh within-song variation
        self._song_profiles = {}

        # Core grids (state + traits)
        self.grid = np.zeros((self.n, self.n), dtype=np.int8)
        self.personality = np.zeros((self.n, self.n), dtype=np.int8)
        self.ps_radius = np.zeros((self.n, self.n), dtype=np.int8)
        self.instigator = np.zeros((self.n, self.n), dtype=bool)

        # Precompute row/col index grids for fast vector math
        self.R = np.tile(np.arange(self.n)[:, None], (1, self.n))
        self.C = np.tile(np.arange(self.n)[None, :], (self.n, 1))
        self.col_idx = self.C.copy()
        self.cy = (self.n - 1) / 2.0  # "center row" used in center-weighting

        # Total target number of people
        total = self.n * self.n
        self.N_people = int(round(self.density * total))
        N_empty = total - self.N_people

        # Build emptiness gradient by column (more empty in the back)
        front_cut = int(self.n * self.front_dense_frac)

        # Higher weight => more empty cells assigned to that column
        w_col = np.full(self.n, self.front_empty_weight, dtype=float)
        if front_cut < self.n:
            x = np.linspace(0.0, 1.0, self.n - front_cut)
            w_col[front_cut:] = self.front_empty_weight + 1.0 + self.back_empty_boost * (x ** self.back_ramp_power)

        # Convert weights into probabilities
        p_col = w_col / w_col.sum()

        # Target occupancy per column
        exp_empty = N_empty * p_col
        self.target_occ_per_col = np.clip(self.n - exp_empty, 0.0, float(self.n))

        # Sample how many empties go into each column
        empty_counts = self.rng.multinomial(N_empty, p_col)

        # Enforce per-column cap (can't have more than n empty cells in a column)
        cap = self.n
        empty_counts = np.minimum(empty_counts, cap)

        # If multinomial + cap lost some empties, redistribute leftovers safely
        leftover = N_empty - int(empty_counts.sum())
        guard = 0
        while leftover > 0 and guard < 200:
            avail = np.where(empty_counts < cap)[0]
            if len(avail) == 0:
                break
            p = w_col[avail].astype(float); p = p / p.sum()
            add = self.rng.multinomial(leftover, p)
            room = cap - empty_counts[avail]
            add = np.minimum(add, room)
            empty_counts[avail] += add
            leftover = N_empty - int(empty_counts.sum())
            guard += 1

        # Build an explicit empty mask by sampling rows inside each column
        empty_mask = np.zeros((self.n, self.n), dtype=bool)
        for c in range(self.n):
            k = int(empty_counts[c])
            if k <= 0:
                continue
            rows = self.rng.choice(self.n, size=k, replace=False)
            empty_mask[rows, c] = True

        # All positions that are NOT empty become occupied
        occ_positions = np.argwhere(~empty_mask)
        rr, cc = occ_positions[:, 0], occ_positions[:, 1]

        # Assign personality to each occupied cell
        pers_vals = self.rng.choice(
            [self.RESERVED, self.REGULAR, self.ACTIVE],
            size=len(rr),
            p=self.personality_probs
        )
        self.personality[rr, cc] = pers_vals.astype(np.int8)

        # Assign personal-space radius
          # RESERVED: usually wants more space (radius 2)
          # ACTIVE: usually okay with less space (radius 1)
          # REGULAR: mostly radius 1, sometimes 2
        ps = np.ones(len(rr), dtype=np.int8)
        ps[pers_vals == self.RESERVED] = 2
        ps[pers_vals == self.ACTIVE] = 1
        reg = np.where(pers_vals == self.REGULAR)[0]
        ps[reg] = self.rng.choice([1, 2], size=len(reg), p=[0.75, 0.25])
        self.ps_radius[rr, cc] = ps

        # Initialize all occupied as CHILL (state rules will adjust right after)
        self.grid[rr, cc] = self.CHILL

        # Choose instigators (subset of occupied)
        lo, hi = int(self.instigator_count[0]), int(self.instigator_count[1])
        n_inst = int(self.rng.integers(lo, hi + 1))
        n_inst = min(n_inst, len(rr))
        pick = self.rng.choice(len(rr), size=n_inst, replace=False)
        self.instigator[rr[pick], cc[pick]] = True

        # Exhaustion grid
        self.exh = np.zeros((self.n, self.n), dtype=np.float64)
        self.exh[rr, cc] = self.rng.uniform(0.0, 0.6, size=len(rr))  # start low but nonzero

        # History lists for plots/animation
        self.exh_mean_hist = []
        self.exh_p90_hist = []
        self.exh_danger_hist = []

        # Reset circle-related runtime state
        self.circle_state = 0
        self.circle_k = 0
        self.circle_centers = []
        self.circle_rough = None
        self.forbidden = None
        self.ring = None
        self.min_dist2 = None
        self.circle_cooldown = 0

        # Reset aftershock-related runtime state
        self.aftershock_k = 0
        self.aftershock_centers = []
        self.aftershock_min_dist2 = None

        # Ensure first song profile exists so intensity calls won't fail
        self._ensure_profile(0)

        # Do one initial update pass so states/exhaustion/histories are consistent at t=0
        self._apply_state_rules()
        self._update_exhaustion()
        self._record_exhaustion()

        # If we pass an axis (for animation), store it and hide ticks
        self.ax = ax
        if self.ax is not None:
            self.ax.set_xticks([]); self.ax.set_yticks([])

    def _apply_state_rules(self):
        """
        Update the crowd "state" (CHILL/BOUNCE/MOSH) based on:
        - personality defaults
        - effective energy inside the song
        - neighbor contagion (energetic neighbors)
        - stage/barricade seeding and forcing of MOSH
        - instigator influence (only in high-energy songs)
        - aftershock forcing (post-circle collapse)
        """
        occ = self.grid > 0
        if not occ.any():
            return

        # E_song: base energy label for the song (0/1/2)
        # E_eff : effective energy inside song (can dip/spike)
        # I     : within-song intensity multiplier (~0.75..1.35)
        E_song = self._energy_level()
        E_eff  = self._effective_energy()
        I      = self._song_intensity()

        # Default state for REGULAR people depends on effective energy
        default_state = self._default_state_from_energy(E_eff)

        # Base state assignment by personality:
        # RESERVED => CHILL
        # REGULAR  => default_state
        # ACTIVE   => depends on effective energy (CHILL/BOUNCE/MOSH)
        base = np.zeros_like(self.grid)
        base[(occ) & (self.personality == self.RESERVED)] = self.CHILL
        base[(occ) & (self.personality == self.REGULAR)]  = default_state

        if E_eff == 0:
            base[(occ) & (self.personality == self.ACTIVE)] = self.CHILL
        elif E_eff == 1:
            base[(occ) & (self.personality == self.ACTIVE)] = self.BOUNCE
        else:
            base[(occ) & (self.personality == self.ACTIVE)] = self.MOSH

        # Neighbor contagion based on Von Neumann neighbors (4-neighborhood)
        nb_bounce = neighbor_count_vn(base == self.BOUNCE)
        nb_mosh   = neighbor_count_vn(base == self.MOSH)
        energetic = nb_bounce + nb_mosh

        # Start from base then upgrade people if surrounded by energy
        new_state = base.copy()
        new_state[occ & (nb_mosh >= 2)] = self.MOSH
        new_state[occ & (nb_mosh < 2) & (energetic >= 2)] = np.maximum(
            new_state[occ & (nb_mosh < 2) & (energetic >= 2)],
            self.BOUNCE
        )

        # In high-energy songs: front half should not be calm (even if intensity dips)
        if E_song == 2:
            front_cols = self.col_idx < int(self._E2_NO_CALM_COL_FRAC * self.n)
            new_state[occ & front_cols] = np.maximum(new_state[occ & front_cols], self.BOUNCE)

        # Stage weight and barricade weight decay with distance from stage (left to right)
        stage_w = np.exp(-self.col_idx / max(1.0, self._STAGE_DECAY_COLS))
        barr_w  = np.exp(-self.col_idx / max(1.0, self._BARRICADE_DECAY_COLS))

        # Row-based "center" weighting: MOSH is more likely near middle rows
        sigma = max(1.0, self._CENTER_SIGMA_FRAC * self.n)
        row_w = np.exp(-0.5 * ((self.R - self.cy) / sigma) ** 2)
        shape = (1.0 - self._CENTER_BOOST) + self._CENTER_BOOST * row_w

        # within-song modulation:
        # - when I(t) > 1: more aggressive MOSH propagation
        # - when I(t) < 1: calmer passage
        moshing_gain = float(I)

        # Barricade effects only when effective energy is energetic enough
        if E_eff >= 1:
            stick_p = np.clip(self._BARR_STICK[E_song] * moshing_gain, 0.0, 0.999)
            force_p = np.clip(self._BARR_FORCE[E_song] * moshing_gain, 0.0, 1.0)

            # "stick": MOSH tends to persist near barricade in energetic contexts
            stick = (self.grid == self.MOSH) & (self.rng.random((self.n, self.n)) < (stick_p * barr_w * shape))
            new_state[occ & stick] = self.MOSH

            # "force": sometimes force MOSH near barricade (crowd pressure)
            force = (self.rng.random((self.n, self.n)) < (force_p * barr_w * shape))
            new_state[occ & force] = self.MOSH

        # Stage seeding: occasionally seed MOSH near the stage
        seed_p = np.clip(self._SEED_PROB[E_song] * moshing_gain, 0.0, 1.0)
        seed = self.rng.random((self.n, self.n)) < (seed_p * stage_w * shape)
        new_state[occ & seed] = self.MOSH

        # Crowd-driven flipping: if energetic neighbors exist, flip into MOSH with some probability
        crowd_p = np.clip(self._CROWD_PROB[E_song] * moshing_gain, 0.0, 1.0)
        crowd = (energetic >= self._MOSH_MIN_ENERGETIC)
        flip = self.rng.random((self.n, self.n)) < (crowd_p * stage_w * shape)
        new_state[occ & crowd & flip] = self.MOSH

        # Instigators only in high-energy SONGS (not just effective energy)
        if E_song == 2 and (self.instigator_period > 0) and (self.t % self.instigator_period == 0):
            inst_mask = self.instigator & occ
            inst_nb = moore_count_radius1(inst_mask) if self.instigator_influence_r1 else moore_count_radius2(inst_mask)
            new_state[occ & (inst_nb >= 1)] = np.maximum(new_state[occ & (inst_nb >= 1)], self.BOUNCE)
            new_state[occ & (inst_nb >= 2)] = np.maximum(new_state[occ & (inst_nb >= 2)], self.MOSH)

        # Aftershock zone forces MOSH near previous circle centers for a while
        if E_song == 2 and self.aftershock_k > 0 and self.aftershock_min_dist2 is not None:
            r_shock = self.aftershock_r_frac * self.circle_r_max
            shock_zone = occ & (self.aftershock_min_dist2 <= (r_shock ** 2))
            new_state[shock_zone] = self.MOSH

        # Apply updated states back to the grid (occupied cells only)
        self.grid[occ] = new_state[occ]

    def _update_exhaustion(self):
        """
        Update exhaustion values for occupied cells.

        Exhaustion increases due to:
        - base fatigue from current state (CHILL/BOUNCE/MOSH)
        - crowding (occupied neighbors)
        - energetic neighbors (BOUNCE/MOSH neighbors)

        Exhaustion decreases due to:
        - recovery when CHILL, especially when not too crowded

        Also:
        - within-song intensity I(t) scales fatigue and recovery.
        """
        occ = self.grid > 0
        if not occ.any():
            return

        E_song = self._energy_level()
        I = self._song_intensity()

        # within-song: more intense sections => faster fatigue, weaker recovery
        # (and calmer sections => slower fatigue, stronger recovery)
        fatigue_mult = float(self._EXH_FATIGUE_MULT_BY_E.get(E_song, 1.0)) * float(I)
        recover_mult = float(self._EXH_RECOVER_MULT_BY_E.get(E_song, 1.0)) * float(max(1.00, 1.6 - 0.6*I))

        # Neighborhood counts for crowding and energetic neighbors
        n_occ = moore_count_radius1(occ)
        high = (self.grid == self.BOUNCE) | (self.grid == self.MOSH)
        n_high = moore_count_radius1(high)

        # Base fatigue per state
        base = np.zeros((self.n, self.n), dtype=np.float64)
        base[self.grid == self.CHILL]  = self._EXH_BASE[self.CHILL]
        base[self.grid == self.BOUNCE] = self._EXH_BASE[self.BOUNCE]
        base[self.grid == self.MOSH]   = self._EXH_BASE[self.MOSH]

        # Extra fatigue from crowding and energetic neighbors (scaled by neighborhood size)
        crowd_term = self._EXH_K_OCC  * (n_occ / 8.0)
        high_term  = self._EXH_K_HIGH * (n_high / 8.0)

        # Recovery only when CHILL, and reduced when crowded
        recover = np.zeros((self.n, self.n), dtype=np.float64)
        chill = (self.grid == self.CHILL) & occ
        recover[chill] = self._EXH_RECOVER * (1.0 - (n_occ[chill] / 8.0))

        # Small per-cell noise (makes exhaustion evolve less deterministically)
        noise = self.rng.normal(0.0, self._EXH_NOISE, size=(self.n, self.n))

        # Combine into total delta (fatigue + noise - recovery)
        pos = (base + crowd_term + high_term) * fatigue_mult
        rec = recover * recover_mult

        delta = pos + noise - rec
        self.exh[occ] = np.clip(self.exh[occ] + delta[occ], 0.0, 10.0)

    # circle helpers
    def _make_circle_rough(self):
        """
        Make a smooth random field used to "roughen" the circle boundary.
        This makes circle holes/rings less perfectly circular.
        """
        z = self.rng.normal(0.0, 1.0, size=(self.n, self.n))
        for _ in range(self._CIRCLE_ROUGH_SMOOTH_ITERS):
            P = np.pad(z, 1, mode="edge")
            z = (
                P[0:-2, 0:-2] + P[0:-2, 1:-1] + P[0:-2, 2:] +
                P[1:-1, 0:-2] + P[1:-1, 1:-1] + P[1:-1, 2:] +
                P[2:  , 0:-2] + P[2:  , 1:-1] + P[2:  , 2:]
            ) / 9.0
        m = np.max(np.abs(z))
        if m > 1e-12:
            z = z / m
        return z

    def _choose_circle_centers(self, K):
        """
        Choose K circle centers.

        Strategy:
        1) Prefer instigator positions inside an allowed region.
        2) If not enough, sample random positions in the allowed region.
        3) Enforce a minimum separation between centers.
        """
        occ = self.grid > 0
        inst_pos = np.argwhere(self.instigator & occ)

        c0 = int(self.circle_center_col_frac[0] * (self.n - 1))
        c1 = int(self.circle_center_col_frac[1] * (self.n - 1))
        r0 = int(self.circle_center_row_frac[0] * (self.n - 1))
        r1 = int(self.circle_center_row_frac[1] * (self.n - 1))

        centers = []
        min_sep2 = float(self.circle_min_sep) ** 2

        # First: try instigators inside the allowed rectangle
        if len(inst_pos) > 0:
            mask = (inst_pos[:, 1] >= c0) & (inst_pos[:, 1] <= c1) & (inst_pos[:, 0] >= r0) & (inst_pos[:, 0] <= r1)
            candidates = inst_pos[mask]
            self.rng.shuffle(candidates)
            for rr, cc in candidates:
                if all(((rr - pr) ** 2 + (cc - pc) ** 2) >= min_sep2 for pr, pc in centers):
                    centers.append((int(rr), int(cc)))
                    if len(centers) >= K:
                        return centers

        # Fallback: random sampling in allowed region
        tries = 0
        while len(centers) < K and tries < 500:
            rr = int(self.rng.integers(r0, r1 + 1))
            cc = int(self.rng.integers(c0, c1 + 1))
            if all(((rr - pr) ** 2 + (cc - pc) ** 2) >= min_sep2 for pr, pc in centers):
                centers.append((rr, cc))
            tries += 1
        return centers

    def _compute_min_dist2_to_centers(self, centers):
        """Compute per-cell minimum squared distance to the nearest center."""
        if len(centers) == 0:
            return None
        min_d2 = np.full((self.n, self.n), np.inf, dtype=np.float64)
        for cr, cc in centers:
            d2 = (self.R - cr) ** 2 + (self.C - cc) ** 2
            min_d2 = np.minimum(min_d2, d2)
        return min_d2

    def _update_circle_masks(self, r_hole):
        """
        Update boolean masks for:
        - forbidden: the empty "hole" region (should be evacuated)
        - ring     : the ring region around the hole

        r_hole controls the current radius of the hole (changes over time).
        """
        if len(self.circle_centers) == 0:
            self.forbidden = None
            self.ring = None
            self.min_dist2 = None
            return

        # Precompute min distances to centers (used everywhere)
        self.min_dist2 = self._compute_min_dist2_to_centers(self.circle_centers)

        # Inner hole radius r0, outer ring radius r1
        r0 = float(r_hole)
        r1 = float(r_hole + self.circle_ring_width)

        # Optionally add roughness so boundaries are irregular
        if self.circle_rough is not None:
            r0_eff = np.clip(r0 + self._CIRCLE_ROUGH_AMP * self.circle_rough, 0.0, None)
            r1_eff = np.clip(r1 + 0.6 * self._CIRCLE_ROUGH_AMP * self.circle_rough, 0.0, None)
            r1_eff = np.maximum(r1_eff, r0_eff + 0.6)
        else:
            r0_eff = r0
            r1_eff = r1

        # forbidden: inside the hole
        self.forbidden = self.min_dist2 <= (r0_eff ** 2)

        # ring: between hole radius and outer ring radius
        self.ring = (self.min_dist2 > (r0_eff ** 2)) & (self.min_dist2 <= (r1_eff ** 2))

    def _evacuate_holes(self):
        """
        Move people out of the forbidden (hole) region into empty spots outside.
        This keeps the hole empty while the circle is forming/holding/closing.
        """
        if self.forbidden is None:
            return
        occ = self.grid > 0
        inside = np.argwhere(occ & self.forbidden)
        if len(inside) == 0:
            return
        outside_empties = np.argwhere((~occ) & (~self.forbidden))
        if len(outside_empties) == 0:
            return

        # Shuffle so moves are random
        self.rng.shuffle(inside)
        self.rng.shuffle(outside_empties)
        m = min(len(inside), len(outside_empties))
        inside = inside[:m]
        outside_empties = outside_empties[:m]

        # Move each inside person to an outside empty cell (copy traits and clear original)
        for (sr, sc), (tr, tc) in zip(inside, outside_empties):
            self.grid[tr, tc] = self.grid[sr, sc]
            self.personality[tr, tc] = self.personality[sr, sc]
            self.ps_radius[tr, tc] = self.ps_radius[sr, sc]
            self.instigator[tr, tc] = self.instigator[sr, sc]
            self.exh[tr, tc] = self.exh[sr, sc]

            self.grid[sr, sc] = self.EMPTY
            self.personality[sr, sc] = self.NONE
            self.ps_radius[sr, sc] = 0
            self.instigator[sr, sc] = False
            self.exh[sr, sc] = 0.0

    def _circle_step(self):
        """
        Advance circle pit logic by 1 step.

        Circle state machine:
        - 0: inactive
        - 1: opening (hole radius grows)
        - 2: hold (hole stays open)
        - 3: closing (hole radius shrinks, then aftershock triggers)
        """
        if not self.circle_enable:
            self.circle_state = 0
            self.circle_k = 0
            self.circle_centers = []
            self.circle_rough = None
            self._update_circle_masks(0.0)
            return

        E_song = self._energy_level()
        phase = self._song_phase()

        # Only HIGH-energy SONGS can have circles, and not in the first 15 steps
        if E_song != 2:
            self.circle_state = 0
            self.circle_k = 0
            self.circle_centers = []
            self.circle_rough = None
            self._update_circle_masks(0.0)
            return

        # Start circle with some probability, only if we have instigators present
        if self.circle_state == 0 and self.circle_cooldown == 0:
            if phase >= self._CIRCLE_MIN_START_IN_SONG:
                if self.rng.random() < self.circle_start_prob and (self.instigator & (self.grid > 0)).any():
                    K = int(self.rng.integers(1, self.circle_max_count + 1))
                    self.circle_centers = self._choose_circle_centers(K)
                    self.circle_rough = self._make_circle_rough()
                    self.circle_state = 1
                    self.circle_k = 0

        if self.circle_state == 1:  # opening
            frac = self.circle_k / max(1, self.circle_open_steps)
            r = self.circle_r_max * frac
            self._update_circle_masks(r)
            self._evacuate_holes()
            self.circle_k += 1
            if self.circle_k >= self.circle_open_steps:
                self.circle_state = 2
                self.circle_k = 0

        elif self.circle_state == 2:  # hold
            r = self.circle_r_max
            self._update_circle_masks(r)
            self._evacuate_holes()
            self.circle_k += 1
            if self.circle_k >= self.circle_hold_steps:
                self.circle_state = 3
                self.circle_k = 0

        elif self.circle_state == 3:  # closing
            frac = self.circle_k / max(1, self.circle_close_steps)
            r = self.circle_r_max * (1.0 - frac)
            self._update_circle_masks(r)
            self._evacuate_holes()
            self.circle_k += 1
            if self.circle_k >= self.circle_close_steps:
                # Trigger aftershock near the circle centers (temporary)
                self.aftershock_k = self.aftershock_steps
                self.aftershock_centers = list(self.circle_centers)
                self.aftershock_min_dist2 = self._compute_min_dist2_to_centers(self.aftershock_centers)

                # Reset circle
                self.circle_state = 0
                self.circle_k = 0
                self.circle_centers = []
                self.circle_rough = None
                self._update_circle_masks(0.0)
                self.circle_cooldown = self.circle_cooldown_steps

    def update(self):
        """
        Perform one full simulation time step.

        High-level order:
        1) Ensure current song intensity profile exists
        2) Update circle / aftershock state machine
        3) Apply state rules (CHILL/BOUNCE/MOSH updates)
        4) Apply ring forcing (ring becomes MOSH during E=2)
        5) Move people cell-by-cell (stochastic weighted movement)
        6) Evacuate forbidden holes again if needed
        7) Update exhaustion + record time-series stats
        8) Increment time
        """
        # ensure a profile exists for current song
        self._ensure_profile(self._song_index())

        E_song = self._energy_level()

        # circles/aftershock only active for high-energy SONGS
        if E_song != 2:
            self.circle_state = 0
            self.circle_k = 0
            self.circle_centers = []
            self.circle_rough = None
            self._update_circle_masks(0.0)
            self.circle_cooldown = 0

            self.aftershock_k = 0
            self.aftershock_centers = []
            self.aftershock_min_dist2 = None
        else:
            if self.circle_cooldown > 0:
                self.circle_cooldown -= 1

            if self.aftershock_k > 0:
                self.aftershock_k -= 1
                if self.aftershock_k == 0:
                    self.aftershock_centers = []
                    self.aftershock_min_dist2 = None

            self._circle_step()

        # Update crowd states based on rules + neighbors + stage effects
        self._apply_state_rules()

        # If in high-energy song and ring exists, force ring area into MOSH
        if E_song == 2 and self.ring is not None:
            occ = self.grid > 0
            self.grid[occ & self.ring] = self.MOSH

        # Precompute occupancy and neighborhood counts for movement weighting
        occ = self.grid > 0
        crowd_r1 = moore_count_radius1(occ)
        crowd_r2 = moore_count_radius2(occ)

        # Per-column occupancy count (used to preserve density profile)
        col_occ = occ.sum(axis=0).astype(np.int16)

        # Iterate people in random order each step (avoids directional bias)
        positions = np.argwhere(occ)
        self.rng.shuffle(positions)

        # Distances to circle/aftershock centers for collapse/pull effects
        d_curr_circle = np.sqrt(self.min_dist2) if (E_song == 2 and self.circle_state == 3 and self.min_dist2 is not None) else None
        d_curr_shock  = np.sqrt(self.aftershock_min_dist2) if (E_song == 2 and self.aftershock_k > 0 and self.aftershock_min_dist2 is not None) else None

        # Stage weight for center-pull shaping in MOSH movement
        stage_w_move = np.exp(-self.col_idx / max(1.0, self._STAGE_DECAY_COLS))

        # within-song: movement gets a bit more active in intense sections
        I = self._song_intensity()
        move_gain = float(np.clip(0.85 + 0.30 * I, 0.75, 1.25))

        # Main movement loop: one person at a time
        for (r, c) in positions:
            state = int(self.grid[r, c])
            if state == self.EMPTY:
                continue

            # Don't allow movement inside forbidden circle holes
            if (self.forbidden is not None) and self.forbidden[r, c]:
                continue

            # Effective move probability for this person this step
            p_eff = max(self.p_move[state] * move_gain, self.min_move)
            if self.rng.random() > p_eff:
                continue

            # Candidate lists for movement:
            # - empty_cands: empty neighbor cells to move into
            # - swap_cands : occupied neighbor cells to swap with (pushing)
            empty_cands, swap_cands = [], []

            def add_candidate(rr, cc):
                # Skip forbidden cells (circle hole)
                if self.forbidden is not None and self.forbidden[rr, cc]:
                    return
                if self.grid[rr, cc] == self.EMPTY:
                    empty_cands.append((rr, cc))
                else:
                    swap_cands.append((rr, cc))

            # Consider Von Neumann neighbors as possible movement targets
            if c > 0: add_candidate(r, c-1)
            if c < self.n - 1: add_candidate(r, c+1)
            if r > 0: add_candidate(r-1, c)
            if r < self.n - 1: add_candidate(r+1, c)

            # Always allow "stay in place" candidate
            candidates = [(r, c)]

            # MOSH: can push into occupied cells with high probability (mosh_push_prob)
            if state == self.MOSH:
                candidates += empty_cands
                if swap_cands and (self.rng.random() < self.mosh_push_prob):
                    candidates += swap_cands
            else:
                # Non-mosh: prefer empty spots, otherwise may swap with some probability
                if empty_cands:
                    candidates += empty_cands
                else:
                    if swap_cands and (self.rng.random() < self.swap_prob):
                        candidates += swap_cands
                    else:
                        continue

            # If only "stay" is available, nothing to do
            if len(candidates) == 1:
                continue

            # Start weights uniformly, then multiply by different preference factors
            weights = np.ones(len(candidates), dtype=np.float64)

            # Staying gets a bias multiplier (so people don't jitter too much)
            weights[0] *= self.stay_bias

            beta = self.beta_stage[state]
            rps = int(self.ps_radius[r, c])

            # Score each candidate cell
            for i, (rr, cc) in enumerate(candidates):
                # stage attraction / pushing
                if state == self.MOSH:
                    # MOSH strongly prefers moving toward stage (smaller column index)
                    weights[i] *= np.exp(np.clip(-self.mosh_abs_stage * (cc / (self.n - 1)), -50, 50))

                    # Extra boost for left move (toward stage)
                    if cc == c - 1:
                        weights[i] *= self.mosh_left_boost

                    # Edge boost near stage edge columns (pile-up effect)
                    if cc <= self.mosh_edge_cols:
                        weights[i] *= self.mosh_edge_boost
                else:
                    # CHILL/BOUNCE stage tendency based on beta_stage
                    weights[i] *= np.exp(np.clip(-beta * (cc - c), -50, 50))

                # crowd seeking / avoidance
                if state == self.MOSH:
                    # MOSH seeks crowd more strongly near stage
                    stage_factor = max(0.0, 1.0 - (cc / (self.n - 1)))
                    weights[i] *= np.exp(np.clip(+self.mosh_seek_crowd * stage_factor * crowd_r1[rr, cc], -50, 50))
                else:
                    # Non-mosh tends to avoid crowded neighborhoods
                    weights[i] *= np.exp(np.clip(-self.crowd_avoid * crowd_r1[rr, cc], -50, 50))

                # personal space penalty (outside stage-relax zone)
                if state != self.MOSH and cc > self.stage_relax_cols:
                    if rps == 1:
                        allowed_occ = 8 - self.desired_empty_r1
                        overcrowd = max(0, int(crowd_r1[rr, cc]) - allowed_occ)
                    else:
                        allowed_occ = 24 - self.desired_empty_r2
                        overcrowd = max(0, int(crowd_r2[rr, cc]) - allowed_occ)
                    weights[i] *= np.exp(np.clip(-self.space_weight * overcrowd, -50, 50))

                # preserve the desired density gradient by columns
                # Encourage moves that bring each column occupancy closer to its target.
                if cc != c:
                    delta_to = float(col_occ[cc]) - float(self.target_occ_per_col[cc])
                    delta_fr = float(col_occ[c])  - float(self.target_occ_per_col[c])
                    k = (self.profile_strength_mosh * (cc / (self.n - 1))) if state == self.MOSH else self.profile_strength
                    weights[i] *= np.exp(np.clip(-k * (delta_to - delta_fr), -50, 50))

                # center pull during MOSH (keeps motion near center rows)
                if state == self.MOSH:
                    w = float(stage_w_move[r, c])
                    weights[i] *= np.exp(np.clip(-self._CENTER_PULL * w * (abs(rr - self.cy) - abs(r - self.cy)), -50, 50))

                # circle collapse inward force (during closing state)
                if d_curr_circle is not None:
                    weights[i] *= np.exp(np.clip(-self.circle_collapse_inward * (float(d_curr_circle[rr, cc]) - float(d_curr_circle[r, c])), -50, 50))

                # aftershock pull toward aftershock centers
                if d_curr_shock is not None:
                    weights[i] *= np.exp(np.clip(-self.aftershock_pull * (float(d_curr_shock[rr, cc]) - float(d_curr_shock[r, c])), -50, 50))

            # Add extra randomness / noise to weights
            weights *= (1.0 + self.move_noise * self.rng.random(len(weights)))

            # Occasionally do a random step (ignoring weights), for unpredictability
            if len(candidates) > 1 and (self.rng.random() < self.random_step_prob):
                idx = int(self.rng.integers(1, len(candidates)))
            else:
                s = weights.sum()
                idx = 0 if (s <= 0 or not np.isfinite(s)) else int(self.rng.choice(len(candidates), p=weights / s))

            tr, tc = candidates[idx]
            if (tr, tc) == (r, c):
                continue

            # Apply the move
            if self.grid[tr, tc] == self.EMPTY:
                # Move into empty cell: copy attributes, clear source cell
                self.grid[tr, tc] = self.grid[r, c]
                self.personality[tr, tc] = self.personality[r, c]
                self.ps_radius[tr, tc] = self.ps_radius[r, c]
                self.instigator[tr, tc] = self.instigator[r, c]
                self.exh[tr, tc] = self.exh[r, c]

                self.grid[r, c] = self.EMPTY
                self.personality[r, c] = self.NONE
                self.ps_radius[r, c] = 0
                self.instigator[r, c] = False
                self.exh[r, c] = 0.0

                # Maintain col_occ counts if the column changed
                if tc != c:
                    col_occ[tc] += 1
                    col_occ[c] -= 1
            else:
                # Swap with an occupied cell (pushing / crowd displacement)
                self.grid[tr, tc], self.grid[r, c] = self.grid[r, c], self.grid[tr, tc]
                self.personality[tr, tc], self.personality[r, c] = self.personality[r, c], self.personality[tr, tc]
                self.ps_radius[tr, tc], self.ps_radius[r, c] = self.ps_radius[r, c], self.ps_radius[tr, tc]
                self.instigator[tr, tc], self.instigator[r, c] = self.instigator[r, c], self.instigator[tr, tc]
                self.exh[tr, tc], self.exh[r, c] = self.exh[r, c], self.exh[tr, tc]

        # After movement, enforce hole evacuation again if needed
        if self.forbidden is not None:
            self._evacuate_holes()

        # Exhaustion update and record time-series stats
        self._update_exhaustion()
        self._record_exhaustion()

        # Advance global time
        self.t += 1

## Section 3: Demo configuration (one-song per energy level + full concert)

1) **Defines a baseline parameter set** (`base_kwargs`) for the PitCASimulator. These values control density, movement “jitter,” circle behavior, exhaustion dynamics, etc.

2) **Runs animations** (renders the sim).  
   - `display_stride` and `interval_ms` change *how the animation is displayed* (speed / sampling).

Outputs:
- Three short demos: **one low**, **one medium**, **one high** energy song.
- One full concert animation using a **20-song energy sequence** (`concert`), with song sections shaded on the time-series plot.

---



⏱️ This section takes approximately 40 minutes to run in Google Colab.

In [None]:
base_kwargs = dict(
    n=100,  # grid size: n x n

    # Occupancy / sparseness
    density=0.8,             # fraction of cells that start occupied
    front_dense_frac=0.35,   # left portion (near stage) that stays denser
    front_empty_weight=15,   # baseline emptiness weight for stage-side columns
    back_empty_boost=1000.0, # how strongly emptiness increases toward the back
    back_ramp_power=1,       # shape of the “emptier toward the back” ramp

    # How strongly movement tries to preserve the target column density profile
    profile_strength=0.12,        # non-mosh: how hard to maintain the density gradient
    profile_strength_mosh=0.06,   # mosh: weaker preservation (mosh is more chaotic)

    # Movement randomness / “jitter”
    min_move=0.25,           # minimum per-step chance a person attempts to move
    swap_prob=0.22,          # if no empty neighbors: chance of swapping with an occupied neighbor
    move_noise=0.26,         # multiplicative random noise applied to movement weights
    stay_bias=0.15,          # extra weight for staying in place (prevents too much jitter)
    random_step_prob=0.08,   # sometimes ignore weights and take a random non-stay step

    # Instigators (only meaningful in E=2 songs)
    instigator_count=(3, 8), # number of instigators drawn uniformly in this range at init
    instigator_period=7,     # instigators “trigger” influence every N steps

    # Circle pits (E=2 only)
    circle_enable=True,          # turn circle mechanic on/off
    circle_start_prob=0.040,     # chance per step to start a circle (when allowed)
    circle_max_count=2,          # max simultaneous circles per event
    circle_cooldown_steps=40,    # cooldown after a full circle event (prevents nonstop circles)
    circle_min_sep=12,           # minimum separation between circle centers

    # Timing of a circle event
    circle_open_steps=3,
    circle_hold_steps=4,
    circle_close_steps=3,

    # Circle geometry / strength
    circle_r_max=9,              # max hole radius
    circle_ring_width=3,         # ring thickness around the hole
    circle_ring_strength=3.5,    # ring “forcing” effect (used elsewhere in code)
    circle_collapse_inward=3.0,  # how strongly closing pulls inward toward center

    # Aftershock = post-collapse MOSH near the center(s)
    aftershock_steps=2,
    aftershock_r_frac=0.45,      # aftershock radius as a fraction of circle_r_max
    aftershock_pull=2.4,         # extra inward pull during aftershock

    # Where circles are allowed to form (fractions of grid extents)
    circle_center_col_frac=(0.22, 0.66),
    circle_center_row_frac=(0.14, 0.86),

    # MOSH movement bias toward the stage (left side)
    mosh_abs_stage=12.0,      # overall exponential pull toward stage
    mosh_left_boost=150.0,    # special boost for stepping left (toward stage)
    mosh_edge_cols=2,         # near-stage edge zone columns
    mosh_edge_boost=320.0,    # extra boost near the very front/edge columns
    mosh_seek_crowd=2.4,      # mosh likes moving into dense areas near stage
    mosh_push_prob=0.90,      # chance mosh can push/swap into occupied cells
    mosh_swirl=1.05,          # swirl-ish parameter

    # Song length
    steps_per_level=75,       # how many CA updates per song (per energy level)

    # Seed controls randomness
    seed=7
)

# Rendering controls (does NOT change the sim rules)
display_stride = 1   # one-song GIFs
display_stride_concert = 5  # fewer frames for full-concert GIF (smaller/faster)
interval_ms = 90     # delay between frames in the HTML animation (visual speed)

song_steps = base_kwargs["steps_per_level"]  # convenience: number of steps in a single song

# one-song outputs
# These run the SAME model, but force the entire run to be one fixed energy level.
sim_low  = PitCASimulator(**base_kwargs, fixed_energy=0)  # low energy song
sim_med  = PitCASimulator(**base_kwargs, fixed_energy=1)  # medium energy song
sim_high = PitCASimulator(**base_kwargs, fixed_energy=2)  # high energy song

# Build animations for each one-song run (3 panels: state, exhaustion map, time-series)
anim_low  = make_animation_with_exhaustion(sim_low,  total_steps=song_steps, display_stride=display_stride, interval=interval_ms, stage_bar_cols=9, gif_path="assets/low_energy.gif", fps=10, dpi=120)
anim_med  = make_animation_with_exhaustion(sim_med,  total_steps=song_steps, display_stride=display_stride, interval=interval_ms, stage_bar_cols=9, gif_path="assets/medium_energy.gif", fps=10, dpi=120)
anim_high = make_animation_with_exhaustion(sim_high, total_steps=song_steps, display_stride=display_stride, interval=interval_ms, stage_bar_cols=9, gif_path="assets/high_energy.gif", fps=10, dpi=120)

# Display the three animations in the notebook output
display(anim_low)
display(anim_med)
display(anim_high)

# full concert output (song sections drawn on the chart)
# Each entry is a SONG energy label:
# 0 = low, 1 = medium, 2 = high
concert = [1, 2, 1, 2, 1, 1, 0, 1, 2, 1, 1, 0, 1, 2, 1, 0, 1, 2, 1, 2]

# The sim switches songs automatically based on concert_sequence
sim_concert = PitCASimulator(**base_kwargs, concert_sequence=concert)

# Total simulation steps = steps per song * number of songs
concert_steps = song_steps * len(concert)

# Make a full-concert animation, with shaded sections for each song on the time-series plot
anim_concert = make_animation_with_exhaustion(
    sim_concert,
    total_steps=concert_steps,
    display_stride=display_stride_concert,
    interval=interval_ms,
    stage_bar_cols=9,
    show_song_sections=True,
    gif_path="assets/concert_baseline.gif",
    fps=10,
    dpi=110
)

display(anim_concert)

## Section 4: Model validation

Before running Monte Carlo comparisons, we run two fast tests:

1) **Basic checks (correctness):**  
   Verifies the simulation produces valid arrays, exhaustion stays in bounds (0–10), and the number of agents stays constant (moves/swaps should not create or destroy people).

2) **Repeatability check (determinism):**  
   Runs the exact same setlist twice with the same seed and confirms the metrics match exactly.  
   This matters because Monte Carlo comparisons only make sense if “same inputs → same outputs.”

These tests use a **small grid** and **short songs** so they run in a few seconds.

In [None]:
def basic_checks(sim_kwargs, seed=123, steps=50):
    """
    Quick correctness checks.

    What we check:
    - grid/exhaustion arrays exist and have the expected shape
    - exhaustion values are finite (no NaN/inf) and stay within [0, 10]
    - the number of occupied cells stays constant after an update
      (assuming the model only moves/swaps agents and doesn’t delete/create people)
    """
    # Run a tiny 1-song simulation so this stays fast
    sim = PitCASimulator(**{**sim_kwargs, "seed": int(seed), "concert_sequence": [1]})
    sim.initialize()

    # Step a bit, with a progress bar so it’s obvious it’s running
    for _ in tqdm(range(steps), desc="Basic checks (updates)", unit="step", leave=False):
        sim.update()

    #vBasic invariants about shapes
    assert sim.grid.shape == (sim.n, sim.n), "grid has the wrong shape"
    assert sim.exh.shape == (sim.n, sim.n), "exh has the wrong shape"

    # Exhaustion should always be real + bounded
    assert np.isfinite(sim.exh).all(), "exhaustion contains NaN or inf"
    assert sim.exh.min() >= -1e-6, "exhaustion went below 0 (should be clipped)"
    assert sim.exh.max() <= 10.0 + 1e-6, "exhaustion exceeded 10 (should be clipped)"

    # Occupancy conservation
    # Count how many cells are occupied (grid > 0 means a person is there)
    # Then do one extra update and ensure it does not change.
    occ0 = int((sim.grid > 0).sum())
    sim.update()
    occ1 = int((sim.grid > 0).sum())
    assert occ0 == occ1, f"Occupancy changed: {occ0} -> {occ1}"

    print("✅ Basic checks passed: shapes, exhaustion bounds, and occupancy conservation.")

def pace_setlist_cooldown_after_high(seq, cooldown_level=0):
    """
    Simple pacing policy that keeps the setlist the SAME length.

    Rule:
    - If a song is high energy (2), force the next song to be a cooldown (default 0).

    Why this is useful:
    - It’s a very interpretable policy (“always cool down after a banger”)
    - It creates a clear strategy to compare against the baseline
    """
    seq = list(map(int, seq))
    paced = seq[:]  # copy so we don’t modify the input list
    for i in range(len(paced) - 1):
        if paced[i] == 2:
            paced[i + 1] = int(cooldown_level)
    return paced

def run_trial_metrics(sim_kwargs, concert_sequence, seed, exh_thr=5.0, sample_stride=10):
    """
    Run ONE simulated concert and compute a few simple scalar metrics.

    Metrics we compute:
    - danger_exposure: time-average fraction of people with exhaustion >= exh_thr
    - danger_peak: max fraction of people with exhaustion >= exh_thr at any time
    - mean_exh_exposure: time-average of mean exhaustion
    - mean_exh_peak: max mean exhaustion observed
    - mosh_mean: average MOSH fraction among occupied cells (sampled every sample_stride steps)

    Speed trick:
    - We override sim._record_exhaustion so the sim doesn't waste time calculating percentiles
      and storing long histories we don't need for Monte Carlo.
    """
    sim = PitCASimulator(**{**sim_kwargs, "seed": int(seed), "concert_sequence": list(concert_sequence)})
    sim.initialize()

    # Big speed win: disable percentile/history tracking during the trial
    sim._record_exhaustion = lambda: None

    steps = sim.steps_per_level * len(concert_sequence)

    danger_sum = 0.0
    danger_peak = 0.0
    mean_exh_sum = 0.0
    mean_exh_peak = 0.0

    mosh_sum = 0.0
    mosh_count = 0

    for t in range(steps):
        sim.update()

        # occ = which cells currently contain a person
        occ = (sim.grid > 0)

        # If there are any people, compute exhaustion-based metrics on occupied cells only
        if occ.any():
            vals = sim.exh[occ]

            mean_exh = float(vals.mean())
            mean_exh_sum += mean_exh
            mean_exh_peak = max(mean_exh_peak, mean_exh)

            danger = float((vals >= exh_thr).mean())
            danger_sum += danger
            danger_peak = max(danger_peak, danger)
        else:
            # Very defensive: if the grid were empty for some reason, just add zeros
            mean_exh_sum += 0.0
            danger_sum += 0.0

        # Sample MOSH fraction less frequently to save time
        if (t % sample_stride) == 0:
            if occ.any():
                mosh_frac = float((sim.grid == sim.MOSH).sum() / occ.sum())
            else:
                mosh_frac = 0.0
            mosh_sum += mosh_frac
            mosh_count += 1

    # Convert sums into averages
    return dict(
        danger_exposure=float(danger_sum / steps),
        danger_peak=float(danger_peak),
        mean_exh_exposure=float(mean_exh_sum / steps),
        mean_exh_peak=float(mean_exh_peak),
        mosh_mean=float(mosh_sum / max(1, mosh_count)),
    )

def repeatability_check(sim_kwargs, concert_sequence, seed=777, exh_thr=5.0):
    """
    Repeatability (determinism) test.

    If the simulator is deterministic given a seed, then:
    - same seed + same setlist + same parameters => identical metrics

    This is important because Monte Carlo relies on fair comparisons.
    """
    m1 = run_trial_metrics(sim_kwargs, concert_sequence, seed=seed, exh_thr=exh_thr)
    m2 = run_trial_metrics(sim_kwargs, concert_sequence, seed=seed, exh_thr=exh_thr)

    for k in m1:
        assert abs(m1[k] - m2[k]) < 1e-12, f"Repeatability failed on {k}: {m1[k]} vs {m2[k]}"

    print("✅ Repeatability check passed: same seed → identical metrics.")

# --- Run checks quickly with a small config (cheap + fast) ---
test_kwargs = dict(base_kwargs)

# Make the test intentionally small so it runs in seconds
test_kwargs.update(
    n=25,
    steps_per_level=12
)

# 1) Basic correctness checks
basic_checks(test_kwargs, seed=321, steps=25)

# 2) Repeatability on a short setlist
test_setlist = list(concert[:6])
repeatability_check(test_kwargs, test_setlist, seed=777, exh_thr=5.0)

## Section 5: Monte Carlo comparison

This cell runs a **Monte Carlo experiment** to compare two setlist strategies:

- **Baseline:** the original sequence of song energy levels (0/1/2).
- **Paced:** a pacing policy that edits the same-length setlist (here: after a high-energy song, force a cooldown song next).

For each strategy, we run many randomized simulations (different seeds) and measure outcomes like:
- average and peak fraction of the crowd above a **moderate exhaustion threshold** (≥ EXH_THR)
- mean exhaustion exposure over time
- peak mean exhaustion
- mosh participation rate


---


⏱️ This section takes approximately 30 minutes to run in Google Colab.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from matplotlib.ticker import PercentFormatter
import matplotlib.patheffects as pe
from matplotlib.lines import Line2D

# User inputs
N_TRIALS = 20          # trials per strategy (baseline + paced)
BASE_SEED = 1000       # master seed start
EXH_THR = 5.0          # "moderate exhaustion" threshold for danger metrics
SAMPLE_STRIDE = 10     # sample MOSH participation every N steps (speed tradeoff)

# Plot tuning
HIST_GAP = 0.12
HIST_BINS = "fd"       # Freedman–Diaconis
USE_DENSITY = False    # False = counts (cleaner for small N)

# MC-only simulation settings
mc_kwargs = dict(base_kwargs)
mc_kwargs.update(
    n=30,              # smaller grid = faster trials
    steps_per_level=30,# shorter songs = faster trials
    density=0.80,      # keep density high to stress-test exhaustion
)

# Setlists (same length)
baseline_seq = list(concert)

paced_seq = pace_setlist_cooldown_after_high(baseline_seq, cooldown_level=0)

print("Baseline sequence:", baseline_seq)
print("Paced sequence   :", paced_seq)
print(f"Using density={mc_kwargs.get('density', 'N/A')} for stress test.")
print(f"Number of Monte Carlo trials per strategy: {N_TRIALS}")

def run_mc(baseline_seq, paced_seq, sim_kwargs, n_trials, base_seed=1000, exh_thr=5.0, sample_stride=10):
    """
    Run Monte Carlo for two strategies:
    - baseline setlist
    - paced setlist

    Returns a dataframe with one row per run:
      strategy, seed, + metrics from run_trial_metrics(...)
    """
    rows = []

    # Two progress bars: one for each strategy
    bar_base = tqdm(total=n_trials, desc="Monte Carlo — Baseline", unit="trial", position=0, leave=True)
    bar_paced = tqdm(total=n_trials, desc="Monte Carlo — Paced", unit="trial", position=1, leave=True)

    try:
        # Baseline runs
        for i in range(n_trials):
            seed = base_seed + i
            m = run_trial_metrics(sim_kwargs, baseline_seq, seed=seed, exh_thr=exh_thr, sample_stride=sample_stride)
            rows.append({"strategy": "baseline", "seed": seed, **m})
            bar_base.update(1)

        # Paced runs
        for i in range(n_trials):
            seed = base_seed + 10_000 + i
            m = run_trial_metrics(sim_kwargs, paced_seq, seed=seed, exh_thr=exh_thr, sample_stride=sample_stride)
            rows.append({"strategy": "paced", "seed": seed, **m})
            bar_paced.update(1)

    finally:
        # Always close bars even if something errors
        bar_base.close()
        bar_paced.close()

    return pd.DataFrame(rows)

def mean_ci95(x):
    """
    Return (mean, lower_95, upper_95) using normal approx: mean ± 1.96 * SE.
    """
    x = np.asarray(x, dtype=float)
    x = x[np.isfinite(x)]
    if len(x) < 2:
        m = float(np.nanmean(x)) if len(x) else np.nan
        return m, np.nan, np.nan

    m = float(x.mean())
    s = float(x.std(ddof=1))
    se = s / np.sqrt(len(x))
    half = 1.96 * se
    return m, m - half, m + half

def summarize_ci(df):
    """
    Produce a long-format summary table:
      strategy | metric | mean | ci95_low | ci95_high | n
    """
    metrics = [c for c in df.columns if c not in ("strategy", "seed")]
    rows = []
    for strat in ["baseline", "paced"]:
        sub = df[df["strategy"] == strat]
        for col in metrics:
            m, lo, hi = mean_ci95(sub[col].values)
            rows.append({
                "strategy": strat,
                "metric": col,
                "mean": m,
                "ci95_low": lo,
                "ci95_high": hi,
                "n": len(sub)
            })
    return pd.DataFrame(rows)

def shared_bins(data, bins="fd"):
    """
    Compute shared histogram bin edges for baseline + paced together.
    This makes the two histograms comparable (same bins).
    """
    data = np.asarray(data, dtype=float)
    data = data[np.isfinite(data)]
    if len(data) == 0:
        return np.linspace(0, 1, 11)

    # Freedman–Diaconis rule (adaptive bins)
    if bins == "fd":
        mn, mx = float(data.min()), float(data.max())
        if np.allclose(mn, mx):
            return np.linspace(mn - 1e-6, mx + 1e-6, 11)

        q75, q25 = np.percentile(data, [75, 25])
        iqr = q75 - q25
        if iqr <= 0:
            nbins = 20
        else:
            bw = 2 * iqr / (len(data) ** (1/3))
            nbins = int(np.ceil((mx - mn) / bw)) if bw > 0 else 20
            nbins = int(np.clip(nbins, 10, 40))
        return np.linspace(mn, mx, nbins + 1)

    # Fixed number of bins
    nbins = int(bins)
    mn, mx = float(data.min()), float(data.max())
    if np.allclose(mn, mx):
        return np.linspace(mn - 1e-6, mx + 1e-6, nbins + 1)
    return np.linspace(mn, mx, nbins + 1)

def plot_hist_compare(df, metric, title, xlabel, bins="fd", as_percent=False, density=False, gap=0.12):
    """
    Baseline vs paced histogram.
    """
    base = df.loc[df["strategy"] == "baseline", metric].dropna().to_numpy()
    paced = df.loc[df["strategy"] == "paced", metric].dropna().to_numpy()

    both = np.concatenate([base, paced])
    edges = shared_bins(both, bins=bins)

    fig, ax = plt.subplots(figsize=(7.7, 4.4), dpi=150)

    # Baseline histogram (filled)
    rwidth = max(0.5, 1.0 - gap)
    _, _, patches = ax.hist(
        base,
        bins=edges,
        density=density,
        alpha=0.45,
        label="Baseline Strategy",
        rwidth=rwidth,
        edgecolor="white",
        linewidth=1.1,
    )

    # Paced histogram (outline only)
    ax.hist(
        paced,
        bins=edges,
        density=density,
        histtype="step",
        linewidth=2.4,
        label="Paced Strategy",
    )

    # Add a white halo so the paced outline stands out on top of bars
    paced_line = ax.lines[-1] if ax.lines else None
    if paced_line is not None:
        paced_line.set_path_effects([
            pe.Stroke(linewidth=paced_line.get_linewidth() + 2.2, foreground="white"),
            pe.Normal()
        ])

    # Mean markers
    base_mean = float(base.mean()) if len(base) else np.nan
    paced_mean = float(paced.mean()) if len(paced) else np.nan

    base_color = patches[0].get_facecolor() if len(patches) else "black"
    paced_color = paced_line.get_color() if paced_line is not None else "black"

    ax.axvline(base_mean, linestyle="--", linewidth=2.0, color=base_color)
    ax.axvline(paced_mean, linestyle="--", linewidth=2.0, color=paced_color)

    # Titles / labels
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Probability Density" if density else "Number of Runs")

    # Format x-axis as percent when the metric is a fraction in [0,1]
    if as_percent:
        ax.xaxis.set_major_formatter(PercentFormatter(xmax=1.0, decimals=0))

    ax.grid(True, alpha=0.22)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    # Add some padding around the data range
    xmin, xmax = float(np.nanmin(both)), float(np.nanmax(both))
    pad = 0.06 * (xmax - xmin + 1e-12)
    ax.set_xlim(xmin - pad, xmax + pad)

    mean_handles = [
        Line2D([0], [0], color=base_color, linestyle="--", linewidth=2.0, label="Baseline Mean"),
        Line2D([0], [0], color=paced_color, linestyle="--", linewidth=2.0, label="Paced Mean"),
    ]
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles + mean_handles, labels + [h.get_label() for h in mean_handles],
              frameon=False, loc="best")

    fig.tight_layout()
    plt.show()

# Run Monte Carlo
df_mc = run_mc(
    baseline_seq=baseline_seq,
    paced_seq=paced_seq,
    sim_kwargs=mc_kwargs,
    n_trials=N_TRIALS,
    base_seed=BASE_SEED,
    exh_thr=EXH_THR,
    sample_stride=SAMPLE_STRIDE
)

# Summary table (means + 95% CI)
summary = summarize_ci(df_mc)
display(summary.sort_values(["metric", "strategy"]))

# Headline results table (baseline vs paced + percent change)
KEY_METRICS = [
    ("danger_exposure",      f"Average fraction with exhaustion ≥ {EXH_THR}"),
    ("danger_peak",          f"Peak fraction with exhaustion ≥ {EXH_THR}"),
    ("mean_exh_exposure",    "Time-averaged mean exhaustion"),
    ("mean_exh_peak",        "Peak mean exhaustion"),
    ("mosh_mean",            "Average mosh participation fraction"),
]

headline_rows = []
for metric, label in KEY_METRICS:
    b = df_mc[df_mc["strategy"] == "baseline"][metric].values
    p = df_mc[df_mc["strategy"] == "paced"][metric].values
    b_m, b_lo, b_hi = mean_ci95(b)
    p_m, p_lo, p_hi = mean_ci95(p)

    # Percent change = (paced - baseline) / baseline
    pct = np.nan
    if np.isfinite(b_m) and abs(b_m) > 1e-12:
        pct = 100.0 * (p_m - b_m) / b_m

    headline_rows.append({
        "Outcome Metric": label,
        "Baseline mean (95% CI)": f"{b_m:.4f} [{b_lo:.4f}, {b_hi:.4f}]",
        "Paced mean (95% CI)":    f"{p_m:.4f} [{p_lo:.4f}, {p_hi:.4f}]",
        "Percent change (paced vs baseline)": f"{pct:.1f}%",
    })

headline = pd.DataFrame(headline_rows)
display(headline)

# Distribution plots for key metrics
PERCENT_METRICS = {"danger_exposure", "danger_peak", "mosh_mean"}

for metric, label in tqdm(KEY_METRICS, desc="Plotting key distributions", unit="plot", leave=False):
    plot_hist_compare(
        df_mc,
        metric=metric,
        title=f"{label}: Baseline vs Paced",
        xlabel=label,
        bins=HIST_BINS,
        as_percent=(metric in PERCENT_METRICS),
        density=USE_DENSITY,
        gap=HIST_GAP,
    )

## Section 6: Paced concert sequence simulation

Uses the same baseline parameter set (`base_kwargs`) from Section 3 but applies them for the **paced** concert sequence.

---

⏱️ This section takes approximately 30 minutes to run in Google Colab.

In [None]:
# Rendering controls
display_stride = 1   # (unused; kept for reference)
display_stride_concert = 5  # fewer frames for full-concert GIF (smaller/faster)
interval_ms = 90     # delay between frames in the HTML animation (visual speed)

song_steps = base_kwargs["steps_per_level"]

# Full PACED concert sequence (0=low, 1=medium, 2=high)
concert = [1, 2, 0, 2, 0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 0, 0, 1, 2, 0, 2]

# The sim switches songs automatically based on concert_sequence
sim_concert = PitCASimulator(**base_kwargs, concert_sequence=concert)

# Total simulation steps = steps per song * number of songs
concert_steps = song_steps * len(concert)

# Make a full-concert animation, with shaded sections for each song on the time-series plot
anim_concert = make_animation_with_exhaustion(
    sim_concert,
    total_steps=concert_steps,
    display_stride=display_stride_concert,
    interval=interval_ms,
    stage_bar_cols=9,
    show_song_sections=True,
    gif_path="assets/concert_paced.gif",
    fps=10,
    dpi=110
)

display(anim_concert)