# BME 590 — Synthetic Microbiome Data Generator v3

gLV with regime switching, mechanistic extensions, realistic sequencing observation model, and full benchmark infrastructure.

### Fixes Applied
- **#9** Absorbing-zero: immigration (propagule rain) in ALL RHS functions
- **#10** Hidden/cumulative trigger heterogeneity: per-community parameter jitter
- **#11** Single-gLV baseline: NLS replaces finite-difference ridge
- **#12** Ground-truth regime parameters exported
- **#13** Per-community switch-time estimation in metadata

### New Functionality
- Environmental drift scenario (non-stationary parameters)
- Global carrying capacity constraint
- Allee effects (cooperative growth at low density)
- Shannon diversity tracking
- Comprehensive validation suite
- 12 registered scenarios

## 0. Imports & Global Configuration

In [None]:
#!/usr/bin/env python3
"""
BME 590 — Synthetic Microbiome Data Generator v3
=================================================
gLV with regime switching, mechanistic extensions, realistic sequencing
observation model, and full benchmark infrastructure.

Fixes applied (from issue list):
  - Absorbing-zero: immigration (propagule rain) in ALL RHS functions
  - Hidden/cumulative trigger heterogeneity: per-community parameter jitter
  - Single-gLV baseline: gradient-matching replaced with proper NLS
  - Ground-truth regime parameters exported
  - Per-community switch-time estimation in metadata
  - Environmental drift scenario added
  - Carrying-capacity constraint added
  - Comprehensive validation suite

Additions beyond checklist:
  - Allee effects (cooperative growth at low density)
  - pH / environmental coupling
  - Lotka-Volterra with functional response (Type II Holling)
  - Compositional (CLR) export for downstream analysis
  - Shannon diversity tracking
  - Pairwise community distance matrix export
"""

## 0. Global Configuration

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §0  IMPORTS & GLOBAL CONFIGURATION

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional, Tuple, Union
from copy import deepcopy
from datetime import datetime
from pathlib import Path
import warnings
import json
import numpy as np
import pandas as pd
from scipy.integrate import solve_ivp
from scipy.optimize import least_squares
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

Number = Union[float, int]


@dataclass
class GeneratorConfig:
    """Central configuration object governing all generator behaviour."""

    # ── Reproducibility ──────────────────────────────────────────────
    seed: int = 42
    dataset_version: str = "v3"

    # ── Core simulation ──────────────────────────────────────────────
    n_species: int = 5
    n_metabolites: int = 1

    t_start: float = 0.0
    t_end: float = 20.0
    n_timepoints: int = 101

    # ── Scenario ─────────────────────────────────────────────────────
    scenario: str = "baseline_glv_stable"

    # ── Ecological matrix structure ──────────────────────────────────
    A_structure: str = "sparse"
    A_structure_kwargs: dict = field(default_factory=lambda: dict(
        p=0.15, offdiag_std=0.03, diag_range=(-0.6, -0.2)))
    A_structure_kwargs2: Optional[dict] = None
    stability_margin: float = 0.03

    # ── Hierarchical variation ───────────────────────────────────────
    enable_hierarchical: bool = True
    sigma_r: float = 0.15
    sigma_A: float = 0.05

    # ── Regime switching ─────────────────────────────────────────────
    regime_distance: float = 1.0
    t_switch: float = 10.0
    epsilon: float = 1.0
    smooth_k: float = 10.0
    per_comm_switch: bool = False
    switch_spread: float = 3.0

    # ── Hidden trigger ───────────────────────────────────────────────
    k_u: float = 0.4
    u0: float = 0.0
    theta: float = 0.5
    # Per-community heterogeneity for hidden/cumulative triggers
    trigger_sigma_theta: float = 0.0    # std of per-community theta jitter
    trigger_sigma_k_u: float = 0.0      # std of per-community k_u jitter

    # ── Cumulative trigger ───────────────────────────────────────────
    c1: float = 1.0
    M_init: float = 0.05
    a0: float = 0.0
    idx_M: Optional[List[int]] = None

    # ── Resource dynamics ────────────────────────────────────────────
    enable_forcing: bool = False
    diet_shocks: Optional[List[Dict]] = None
    export_resources: bool = False

    # ── Bistability ──────────────────────────────────────────────────
    bistable_saddle_dist: float = 0.3
    bistable_comp_strength: float = 0.25
    bistable_fac_strength: float = 0.02

    # ── Dormancy ─────────────────────────────────────────────────────
    dormancy_rate: float = 0.05
    revival_rate: float = 0.10
    dormancy_threshold: float = 0.01

    # ── Antibiotic pulse ─────────────────────────────────────────────
    enable_antibiotic: bool = False
    antibiotic_start: float = 8.0
    antibiotic_duration: float = 3.0
    antibiotic_kill_rates: Optional[List[float]] = None

    # ── Bloom ────────────────────────────────────────────────────────
    enable_bloom: bool = False
    bloom_species: int = 0
    bloom_start: float = 5.0
    bloom_boost: float = 2.0
    bloom_duration: float = 2.0

    # ── Environmental drift ──────────────────────────────────────────
    enable_drift: bool = False
    drift_rate_r: float = 0.01       # magnitude of r drift per unit time
    drift_rate_A: float = 0.005      # magnitude of A drift per unit time

    # ── Carrying capacity ────────────────────────────────────────────
    enable_carrying_cap: bool = False
    carrying_capacity: float = 2.0   # max total biomass

    # ── Allee effects ────────────────────────────────────────────────
    enable_allee: bool = False
    allee_threshold: float = 0.005   # below this, growth is penalised

    # ── Observation model ────────────────────────────────────────────
    observation_mode: str = "continuous"
    library_size_mean: float = 1e4
    library_size_sigma: float = 0.6
    dm_alpha_scale: float = 100.0
    enable_dropout: bool = False
    detection_limit: float = 2.0
    n_replicates: int = 1

    # ── Sampling design ──────────────────────────────────────────────
    irregular_sampling: bool = False
    irregular_keep_frac: float = 0.6
    missing_rate: float = 0.0

    # ── Immigration (propagule rain) ─────────────────────────────────
    immigration_rate: float = 1e-4
    immigration_scale: float = 1.0

    # ── Exports ──────────────────────────────────────────────────────
    export_latent: bool = True
    export_regimes: bool = True
    export_ground_truth_params: bool = True


CONFIG = GeneratorConfig()
print("CONFIG initialised. Scenario:", CONFIG.scenario)

## 1. RNG Hierarchy

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §1  RNG HIERARCHY

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def make_rng(stream_name: str, base_seed: int = None) -> np.random.Generator:
    if base_seed is None:
        base_seed = CONFIG.seed
    offset = abs(hash(stream_name)) % (2**31)
    return np.random.default_rng(base_seed + offset)


RNG = dict(
    params      = make_rng("params"),
    simulation  = make_rng("simulation"),
    observation = make_rng("observation"),
    sampling    = make_rng("sampling"),
    splits      = make_rng("splits"),
)

def get_rng(name: str = "simulation") -> np.random.Generator:
    return RNG[name]

# Determinism check
assert np.allclose(make_rng("test").normal(size=3), make_rng("test").normal(size=3))
print("RNG hierarchy OK. Streams:", list(RNG.keys()))

## 2. Parameter Samplers

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §2  PARAMETER SAMPLERS

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def _to_array(x, shape):
    arr = np.asarray(x)
    if arr.ndim == 0:
        return np.broadcast_to(arr, shape).astype(float)
    if arr.ndim == 1 and arr.shape[0] == shape[0]:
        return np.broadcast_to(
            arr.reshape(shape[0], *([1]*(len(shape)-1))), shape
        ).astype(float)
    return np.broadcast_to(arr, shape).astype(float)


def generate_gaussian_params_with_diag(
    n_species, vec_mean=0., vec_std=1., vec_bounds=(-np.inf, np.inf),
    mat_mean=0., mat_std=1., mat_bounds=(-np.inf, np.inf),
    diag_mean=0., diag_std=1., diag_bounds=(-np.inf, np.inf),
    seed=None,
):
    rng = np.random.default_rng(seed)
    n = n_species
    v = np.clip(
        rng.normal(_to_array(vec_mean, (n,)), _to_array(vec_std, (n,)), n),
        *vec_bounds
    )
    M = np.clip(
        rng.normal(_to_array(mat_mean, (n, n)), _to_array(mat_std, (n, n)), (n, n)),
        *mat_bounds
    )
    d_mean = np.asarray(diag_mean)
    d_std = np.asarray(diag_std)
    dm = d_mean * np.ones(n) if d_mean.ndim == 0 else d_mean
    ds = d_std * np.ones(n) if d_std.ndim == 0 else d_std
    np.fill_diagonal(M, np.clip(rng.normal(dm, ds, n), *diag_bounds))
    return v, M


print("Parameter samplers defined.")

## 3. Structured Matrices & Stable Regime Builder

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §3  STRUCTURED MATRICES & STABLE REGIME BUILDER

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def generate_A_sparse(n, p=0.15, diag_range=(-0.6, -0.15),
                      offdiag_mean=0., offdiag_std=0.03, rng=None):
    rng = np.random.default_rng() if rng is None else rng
    A = np.zeros((n, n))
    A[np.diag_indices(n)] = rng.uniform(*diag_range, size=n)
    mask = rng.random((n, n)) < p
    np.fill_diagonal(mask, False)
    A[mask] = rng.normal(offdiag_mean, offdiag_std, mask.sum())
    return A


def generate_A_modular(n, n_blocks=2, p_in=0.25, p_out=0.05,
                        diag_range=(-0.6, -0.15), offdiag_std=0.03, rng=None):
    rng = np.random.default_rng() if rng is None else rng
    A = np.zeros((n, n))
    A[np.diag_indices(n)] = rng.uniform(*diag_range, size=n)
    blks = np.floor(np.linspace(0, n_blocks, n, endpoint=False)).astype(int)
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            p = p_in if blks[i] == blks[j] else p_out
            if rng.random() < p:
                A[i, j] = rng.normal(0, offdiag_std)
    return A


def generate_A_lowrank_sparse(n, rank=2, p_sparse=0.10,
                               diag_range=(-0.6, -0.15),
                               lowrank_scale=0.02, sparse_std=0.02, rng=None):
    rng = np.random.default_rng() if rng is None else rng
    A = lowrank_scale * (rng.normal(0, 1, (n, rank)) @ rng.normal(0, 1, (rank, n)))
    mask = rng.random((n, n)) < p_sparse
    np.fill_diagonal(mask, False)
    A[mask] += rng.normal(0, sparse_std, mask.sum())
    A[np.diag_indices(n)] = rng.uniform(*diag_range, size=n)
    return A


def construct_stable_regime(n_species, structure="sparse", x_star=None,
                             stability_margin=0.03, rng=None,
                             structure_kwargs=None,
                             eps_shift=1e-3, x_floor=1e-12):
    """
    Build (r, A, x*) such that:
      - r = -A x*  (x* is equilibrium)
      - max Re(eig(J(x*))) < -stability_margin
    Diagonal shift applied if necessary.
    """
    rng = np.random.default_rng() if rng is None else rng
    kw = {} if structure_kwargs is None else dict(structure_kwargs)

    if x_star is None:
        x_star = rng.lognormal(-3.5, 0.7, n_species)
    x_star = np.asarray(x_star, dtype=float)

    builders = {
        "sparse": generate_A_sparse,
        "modular": generate_A_modular,
        "lowrank_sparse": generate_A_lowrank_sparse,
    }
    if structure not in builders:
        raise ValueError(f"Unknown structure '{structure}'")
    A = builders[structure](n_species, rng=rng, **kw)

    r = -(A @ x_star)
    J = np.diag(x_star) @ A
    max_real = float(np.max(np.real(np.linalg.eigvals(J))))

    if max_real >= -stability_margin:
        min_x = max(float(np.min(np.maximum(x_star, x_floor))), x_floor)
        delta = (max_real + stability_margin + eps_shift) / min_x
        A = A - delta * np.eye(n_species)
        r = -(A @ x_star)

    return r, A, x_star


def sample_hierarchical_community_params(r_base, A_base,
                                          sigma_r=0.15, sigma_A=0.05,
                                          rng=None, present_mask=None):
    """Draw community-specific (r, A) around shared base parameters."""
    rng = np.random.default_rng() if rng is None else rng
    r_base = np.asarray(r_base, dtype=float)
    A_base = np.asarray(A_base, dtype=float)
    n = r_base.shape[0]
    r_n = rng.normal(0, sigma_r, n)
    A_n = rng.normal(0, sigma_A, (n, n))
    if present_mask is not None:
        mask = np.asarray(present_mask, dtype=bool)
        r_n[~mask] = 0.
        A_n[~mask, :] = 0.
        A_n[:, ~mask] = 0.
    r_c = r_base + r_n
    A_c = A_base + A_n
    # Ensure self-regulation remains negative
    A_c[np.diag_indices(n)] = -np.abs(np.diag(A_c)) - 1e-6
    return r_c, A_c


print("Stable regime builder and hierarchical sampler defined.")

## 4. Community Table — Sparse Initialisation

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §4  COMMUNITY TABLE — SPARSE INITIALISATION

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def generate_community_dataframe(
    n_species, size_counts, mean_abundance=0.01,
    bounds=(0.001, 0.1), sigma_log=1.0, seed=None,
    time_value=0.0, comm_name_prefix="comm",
):
    """
    Generate initial communities with sparse composition.
    Absent species have abundance EXACTLY 0.0.
    Coverage guarantee: all species appear in at least one community.

    Returns
    -------
    df   : pd.DataFrame   [Comm_name, Time, sp1..spN]
    meta : dict           comm_name -> list of present species indices (0-based)
    """
    rng = np.random.default_rng(seed)
    mu_log = np.log(mean_abundance) - 0.5 * sigma_log**2
    rows, meta = [], {}
    counter = 1

    for size, count in sorted(size_counts, key=lambda x: x[0]):
        size, count = int(size), int(count)
        if not (1 <= size <= n_species):
            raise ValueError(f"Community size {size} not in [1, {n_species}]")

        if size == 1:
            count = max(count, n_species)  # coverage guarantee
            species_seq = list(range(n_species))
            if count > n_species:
                species_seq += rng.integers(0, n_species, count - n_species).tolist()
            for sp_idx in species_seq:
                abund = np.zeros(n_species)
                abund[sp_idx] = float(np.clip(
                    rng.lognormal(mu_log, sigma_log), *bounds
                ))
                name = f"{comm_name_prefix}{counter}"
                rows.append([name, time_value] + abund.tolist())
                meta[name] = [sp_idx]
                counter += 1
        else:
            for _ in range(count):
                present = sorted(
                    rng.choice(n_species, size, replace=False).tolist()
                )
                abund = np.zeros(n_species)
                sampled = np.clip(rng.lognormal(mu_log, sigma_log, size), *bounds)
                for k, sp in enumerate(present):
                    abund[sp] = float(sampled[k])
                name = f"{comm_name_prefix}{counter}"
                rows.append([name, time_value] + abund.tolist())
                meta[name] = present
                counter += 1

    cols = ["Comm_name", "Time"] + [f"sp{i+1}" for i in range(n_species)]
    return pd.DataFrame(rows, columns=cols), meta


print("generate_community_dataframe defined.")

## 5. Core gLV Simulator

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §5  CORE gLV SIMULATOR (with immigration in ALL RHS)

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def _get_sp_cols(df):
    return [c for c in df.columns if c.startswith("sp")]


def _build_immigration_vec(config, r_base):
    """
    Per-species immigration (propagule rain) rates.
    m_i = immigration_rate * |r_i| * immigration_scale

    Biology: even absent taxa arrive at low frequency from the environment.
    This gives every species a strictly positive influx regardless of
    current abundance, fixing the absorbing-zero problem.
    """
    rate  = float(getattr(config, "immigration_rate", 1e-4))
    scale = float(getattr(config, "immigration_scale", 1.0))
    m = rate * scale * np.abs(np.asarray(r_base, dtype=float))
    return m


def _apply_carrying_cap(dxdt, x, config):
    """
    Global carrying capacity: when total biomass approaches K,
    apply logistic suppression to all species proportionally.
    """
    if not getattr(config, "enable_carrying_cap", False):
        return dxdt
    K = float(getattr(config, "carrying_capacity", 2.0))
    total = float(np.sum(np.maximum(x, 0.)))
    if total > 0 and K > 0:
        suppression = max(0.0, 1.0 - total / K)
        # Only suppress positive growth; allow decline
        dxdt = np.where(dxdt > 0, dxdt * suppression, dxdt)
    return dxdt


def _apply_allee(dxdt, x, config):
    """
    Allee effect: species below threshold have reduced growth.
    Models cooperative behaviours (quorum sensing, biofilm formation).
    """
    if not getattr(config, "enable_allee", False):
        return dxdt
    threshold = float(getattr(config, "allee_threshold", 0.005))
    # Smooth Allee: multiply growth by x/(x + threshold)
    allee_factor = x / (x + threshold + 1e-15)
    # Only modulate positive growth rates
    dxdt = np.where(dxdt > 0, dxdt * allee_factor, dxdt)
    return dxdt


def simulate_gLV_dataframe(df_init, r, A, timepoints,
                            species_prefix="sp", atol=1e-8, rtol=1e-6,
                            enforce_nonnegative=True,
                            antibiotic_fn=None, bloom_fn=None, m=None,
                            config=None):
    """
    Simulate gLV for every community in df_init.
    dx_i/dt = x_i*(r_i + Σ A_ij x_j) + m_i  [+modifiers]

    m : per-species immigration (propagule rain) vector.
    """
    r = np.asarray(r, dtype=float)
    A = np.asarray(A, dtype=float)
    n = r.shape[0]
    sp = [f"{species_prefix}{i+1}" for i in range(n)]
    t0, tf = float(timepoints[0]), float(timepoints[-1])
    m_ = np.zeros(n) if m is None else np.asarray(m, dtype=float)

    def rhs(t, x):
        xp = np.maximum(np.nan_to_num(x), 0.)
        re = r.copy()
        if bloom_fn:
            re = re + bloom_fn(t)
        dxdt = xp * (re + A @ xp) + m_
        if antibiotic_fn:
            dxdt -= antibiotic_fn(t) * xp
        if config is not None:
            dxdt = _apply_carrying_cap(dxdt, xp, config)
            dxdt = _apply_allee(dxdt, xp, config)
        return dxdt

    out = []
    for comm in df_init["Comm_name"].unique():
        df_c = df_init[df_init["Comm_name"] == comm]
        init = df_c.sort_values("Time").iloc[0]
        x0 = np.maximum(init[sp].to_numpy(dtype=float), 0.)
        sol = solve_ivp(rhs, (t0, tf), x0, t_eval=timepoints,
                        atol=atol, rtol=rtol, method="LSODA")
        if not sol.success:
            raise RuntimeError(f"ODE failed for {comm}: {sol.message}")
        X = np.maximum(sol.y.T, 0.) if enforce_nonnegative else sol.y.T
        for i, t in enumerate(timepoints):
            out.append([comm, float(t)] + X[i].tolist())

    return (pd.DataFrame(out, columns=["Comm_name", "Time"] + sp)
              .sort_values(["Comm_name", "Time"]).reset_index(drop=True))


def _glv_integrate(r, A, x0, timepoints, atol=1e-8, rtol=1e-6,
                   m=None, config=None):
    """Integrate single community (no obs model). Returns X array (T, S)."""
    m_ = np.zeros(len(r)) if m is None else np.asarray(m, dtype=float)

    def rhs(t, x):
        x = np.maximum(np.nan_to_num(x), 0.)
        dxdt = x * (r + A @ x) + m_
        if config is not None:
            dxdt = _apply_carrying_cap(dxdt, x, config)
            dxdt = _apply_allee(dxdt, x, config)
        return dxdt

    sol = solve_ivp(rhs, (float(timepoints[0]), float(timepoints[-1])),
                    x0, t_eval=timepoints, atol=atol, rtol=rtol, method="LSODA")
    if not sol.success:
        raise RuntimeError(f"ODE solver failed: {sol.message}")
    return np.maximum(sol.y.T, 0.)


print("Core gLV simulator defined (with immigration, carrying cap, Allee).")

## 6. Scenario Registry & Dispatcher

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §6  SCENARIO REGISTRY & DISPATCHER

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

SCENARIO_REGISTRY = {}

def register_scenario(name):
    def wrapper(func):
        SCENARIO_REGISTRY[name] = func
        return func
    return wrapper


def simulate_community(config, df_init, timepoints, comm_meta=None, **kwargs):
    """Master entry: routes to scenario, applies sampling + observation model."""
    sc = config.scenario
    if sc not in SCENARIO_REGISTRY:
        raise ValueError(
            f"Scenario '{sc}' not registered. "
            f"Available: {sorted(SCENARIO_REGISTRY)}"
        )
    df_latent = SCENARIO_REGISTRY[sc](
        config, df_init=df_init, timepoints=timepoints,
        comm_meta=comm_meta, **kwargs
    )
    df_latent = apply_sampling_design(df_latent, config)
    return apply_observation_model(df_latent, config)


def list_scenarios():
    print("Registered scenarios:", sorted(SCENARIO_REGISTRY))


print("Dispatcher ready.")

## 6b. Shared Helpers

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §6b  SHARED HELPERS

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

# Store ground-truth parameters for export
_GROUND_TRUTH_PARAMS = {}

def _build_regime_pair(config, rng=None):
    """
    Returns (r1, A1, r2, A2).
    Regime 2 is interpolated: alpha=regime_distance (0=identical, 1=independent).
    Also stores parameters in _GROUND_TRUTH_PARAMS for export.
    """
    rng = np.random.default_rng(config.seed) if rng is None else rng
    n   = config.n_species
    kw1 = getattr(config, "A_structure_kwargs",
                   dict(p=0.15, offdiag_std=0.03, diag_range=(-0.6, -0.2)))
    kw2 = getattr(config, "A_structure_kwargs2", None) or kw1
    st  = getattr(config, "A_structure", "sparse")
    sm  = getattr(config, "stability_margin", 0.03)

    r1, A1, _ = construct_stable_regime(
        n, structure=st, stability_margin=sm, rng=rng,
        structure_kwargs=kw1
    )
    r2r, A2r, _ = construct_stable_regime(
        n, structure=st, stability_margin=sm,
        rng=np.random.default_rng(config.seed + 10_000),
        structure_kwargs=kw2
    )
    alpha = float(getattr(config, "regime_distance", 1.0))
    r2 = (1 - alpha) * r1 + alpha * r2r
    A2 = (1 - alpha) * A1 + alpha * A2r

    # Store for export (FIX #12: ground-truth regime parameters)
    _GROUND_TRUTH_PARAMS.update({
        "r1": r1.tolist(), "A1": A1.tolist(),
        "r2": r2.tolist(), "A2": A2.tolist(),
    })

    return r1, A1, r2, A2


def _build_antibiotic_fn(config, n_species):
    if not getattr(config, "enable_antibiotic", False):
        return None
    t0 = float(getattr(config, "antibiotic_start", 8.))
    dur = float(getattr(config, "antibiotic_duration", 3.))
    k = getattr(config, "antibiotic_kill_rates", None)
    k = np.full(n_species, 0.5) if k is None else np.asarray(k, float)
    return lambda t: k if (t0 <= t <= t0 + dur) else np.zeros(n_species)


def _build_bloom_fn(config, n_species):
    if not getattr(config, "enable_bloom", False):
        return None
    t0 = float(getattr(config, "bloom_start", 5.))
    dur = float(getattr(config, "bloom_duration", 2.))
    sp = int(getattr(config, "bloom_species", 0))
    boost = float(getattr(config, "bloom_boost", 2.))
    delta = np.zeros(n_species)
    delta[sp] = boost
    return lambda t: delta if (t0 <= t <= t0 + dur) else np.zeros(n_species)


def _get_community_trigger_params(config, comm_name):
    """
    FIX #10: Per-community heterogeneity for hidden/cumulative triggers.
    Jitters theta and k_u per community using comm_name as seed source.
    """
    theta_base = float(getattr(config, "theta", 0.5))
    k_u_base = float(getattr(config, "k_u", 0.4))
    sig_theta = float(getattr(config, "trigger_sigma_theta", 0.0))
    sig_k_u = float(getattr(config, "trigger_sigma_k_u", 0.0))

    if sig_theta == 0.0 and sig_k_u == 0.0:
        return theta_base, k_u_base

    rng = make_rng(f"trigger_{comm_name}", config.seed)
    theta_c = float(np.clip(rng.normal(theta_base, sig_theta), 0.05, 0.95))
    k_u_c = float(np.clip(rng.normal(k_u_base, sig_k_u), 0.05, 2.0))
    return theta_c, k_u_c


print("Shared helpers defined.")

## 7. Scenarios — Baseline gLV

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §7  SCENARIOS — BASELINE gLV

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

@register_scenario("baseline_glv")
def run_baseline_glv(config, df_init, timepoints, comm_meta=None, **kwargs):
    return simulate_gLV_dataframe(df_init=df_init, timepoints=timepoints,
                                   config=config, **kwargs)


@register_scenario("baseline_glv_stable")
def run_baseline_glv_stable(config, df_init, timepoints, comm_meta=None, **kwargs):
    rng = np.random.default_rng(config.seed)
    n = config.n_species
    r_base, A_base, x_star = construct_stable_regime(
        n, structure=getattr(config, "A_structure", "sparse"),
        stability_margin=getattr(config, "stability_margin", 0.03), rng=rng,
        structure_kwargs=getattr(config, "A_structure_kwargs",
                                  dict(p=0.15, offdiag_std=0.03,
                                       diag_range=(-0.6, -0.2))))

    _GROUND_TRUTH_PARAMS.update({
        "r1": r_base.tolist(), "A1": A_base.tolist(),
        "x_star": x_star.tolist(),
    })

    hier = getattr(config, "enable_hierarchical", True)
    sig_r = getattr(config, "sigma_r", 0.15)
    sig_A = getattr(config, "sigma_A", 0.05)
    ab_fn = _build_antibiotic_fn(config, n)
    bl_fn = _build_bloom_fn(config, n)

    out = []
    for comm in sorted(df_init["Comm_name"].unique()):
        df_c = df_init[df_init["Comm_name"] == comm].copy()
        pm = None
        if comm_meta and comm in comm_meta:
            pm = np.zeros(n, dtype=bool)
            pm[comm_meta[comm]] = True
        r_c, A_c = (
            sample_hierarchical_community_params(
                r_base, A_base, sig_r, sig_A, rng, pm)
            if hier else (r_base.copy(), A_base.copy())
        )
        m = _build_immigration_vec(config, r_c)
        out.append(simulate_gLV_dataframe(
            df_c, r_c, A_c, timepoints,
            antibiotic_fn=ab_fn, bloom_fn=bl_fn, m=m, config=config
        ))
    return pd.concat(out, ignore_index=True)


print("Baseline scenarios registered.")

## 8. Scenarios — Regime Switching

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §8  SCENARIOS — REGIME SWITCHING

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

In [None]:
# ─── Soft-switch utilities ───────────────────────────────────────────

def _soft_weight(t, t_switch, epsilon):
    return 1. / (1. + np.exp(-(t - t_switch) / epsilon))


def _soft_rhs(t, x, r1, A1, r2, A2, t_switch, epsilon, m, config=None):
    x = np.maximum(x, 0.)
    w = _soft_weight(t, t_switch, epsilon)
    dxdt = (1 - w) * x * (r1 + A1 @ x) + w * x * (r2 + A2 @ x) + m
    if config is not None:
        dxdt = _apply_carrying_cap(dxdt, x, config)
        dxdt = _apply_allee(dxdt, x, config)
    return dxdt

In [None]:
# ─── TIME_SWITCH ─────────────────────────────────────────────────────

@register_scenario("time_switch")
def run_time_switch(config, df_init, timepoints, comm_meta=None, **kwargs):
    """Hard switch at t_switch. Supports per-community jitter."""
    rng = np.random.default_rng(config.seed)
    r1, A1, r2, A2 = _build_regime_pair(config, rng)
    t0, tf = float(timepoints[0]), float(timepoints[-1])
    t_sw_global = float(getattr(config, "t_switch", 10.))
    per_comm = bool(getattr(config, "per_comm_switch", False))
    spread = float(getattr(config, "switch_spread", 3.))
    m = _build_immigration_vec(config, r1)
    out = []

    for comm in sorted(df_init["Comm_name"].unique()):
        df_c = df_init[df_init["Comm_name"] == comm].copy()
        if per_comm:
            t_sw = float(np.clip(
                make_rng(f"sw_{comm}", config.seed).normal(t_sw_global, spread),
                t0 + 1e-3, tf - 1e-3
            ))
        else:
            t_sw = t_sw_global

        tp1 = timepoints[timepoints <= t_sw]
        if len(tp1) == 0 or tp1[-1] != t_sw:
            tp1 = np.sort(np.unique(np.append(tp1, t_sw)))
        tp2 = timepoints[timepoints >= t_sw]
        if len(tp2) == 0 or tp2[0] != t_sw:
            tp2 = np.sort(np.unique(np.append(t_sw, tp2)))

        seg1 = simulate_gLV_dataframe(df_c, r1, A1, tp1, m=m, config=config)
        sp = _get_sp_cols(seg1)
        init2 = seg1[np.isclose(seg1["Time"].astype(float), t_sw)][
            ["Comm_name", "Time"] + sp
        ].copy()
        init2["Time"] = t_sw
        seg2 = simulate_gLV_dataframe(init2, r2, A2, tp2, m=m, config=config)

        df_out = pd.concat([
            seg1,
            seg2[~np.isclose(seg2["Time"].astype(float), t_sw)]
        ], ignore_index=True)
        df_out["regime"] = (df_out["Time"] > t_sw).astype(int)
        df_out["t_switch_true"] = t_sw
        out.append(df_out)

    return pd.concat(out, ignore_index=True)

In [None]:
# ─── SOFT_SWITCH ─────────────────────────────────────────────────────

@register_scenario("soft_switch")
def run_soft_switch(config, df_init, timepoints, comm_meta=None, **kwargs):
    """Logistic interpolation between two regimes."""
    rng = np.random.default_rng(config.seed)
    r1, A1, r2, A2 = _build_regime_pair(config, rng)
    t_sw = float(getattr(config, "t_switch", 10.))
    eps = float(getattr(config, "epsilon", 1.))
    m = _build_immigration_vec(config, r1)
    sp = _get_sp_cols(df_init)
    out = []

    for comm, sub in df_init.groupby("Comm_name"):
        x0 = np.maximum(sub[sp].iloc[0].values.astype(float), 0.)
        sol = solve_ivp(
            lambda t, x: _soft_rhs(t, x, r1, A1, r2, A2, t_sw, eps, m, config),
            (timepoints[0], timepoints[-1]), x0,
            t_eval=timepoints, rtol=1e-6, atol=1e-9, method="LSODA"
        )
        if not sol.success:
            raise RuntimeError(f"soft_switch failed for {comm}: {sol.message}")
        df_c = pd.DataFrame(np.maximum(sol.y.T, 0.), columns=sp)
        df_c["Time"] = sol.t
        df_c["Comm_name"] = comm
        df_c["w"] = _soft_weight(sol.t, t_sw, eps)
        df_c["regime"] = (df_c["w"] > 0.5).astype(int)
        out.append(df_c)

    return pd.concat(out, ignore_index=True)


print("time_switch and soft_switch registered.")

In [None]:
# ─── HIDDEN TRIGGER ──────────────────────────────────────────────────

def _glv_u_rhs(t, y, r, A, k_u, m, config=None):
    n = len(r)
    x = np.maximum(y[:n], 0.)
    u = y[n]
    dxdt = x * (r + A @ x) + m
    if config is not None:
        dxdt = _apply_carrying_cap(dxdt, x, config)
        dxdt = _apply_allee(dxdt, x, config)
    dudt = k_u * (1. - u)
    return np.concatenate([dxdt, [dudt]])


def _event_u(theta):
    e = lambda t, y: y[-1] - theta
    e.terminal = True
    e.direction = 1
    return e


def _w_from_u(u, theta, epsilon):
    return 1. / (1. + np.exp(-(u - theta) / epsilon))


@register_scenario("hidden_trigger")
def run_hidden_trigger(config, df_init, timepoints, comm_meta=None, **kwargs):
    rng = np.random.default_rng(config.seed)
    r1, A1, r2, A2 = _build_regime_pair(config, rng)
    eps = float(getattr(config, "epsilon", 0.1))
    u0_ = float(getattr(config, "u0", 0.))
    m = _build_immigration_vec(config, r1)
    sp = _get_sp_cols(df_init)
    t0, tf = float(timepoints[0]), float(timepoints[-1])
    out = []

    for comm, sub in df_init.groupby("Comm_name"):
        # FIX #10: per-community trigger heterogeneity
        theta_c, k_u_c = _get_community_trigger_params(config, comm)

        x0 = np.maximum(sub[sp].iloc[0].values.astype(float), 0.)
        y0 = np.concatenate([x0, [u0_]])

        sol1 = solve_ivp(
            lambda t, y: _glv_u_rhs(t, y, r1, A1, k_u_c, m, config),
            (t0, tf), y0,
            events=_event_u(theta_c), dense_output=True,
            rtol=1e-6, atol=1e-9, method="LSODA"
        )
        t_sw = (float(sol1.t_events[0][0])
                if sol1.t_events[0].size > 0 else tf)

        tp1 = timepoints[timepoints <= t_sw]
        if not len(tp1) or tp1[-1] != t_sw:
            tp1 = np.sort(np.unique(np.append(tp1, t_sw)))
        Y1 = sol1.sol(tp1).T
        n_sp = len(sp)
        df1 = pd.DataFrame(np.maximum(Y1[:, :n_sp], 0.), columns=sp)
        df1["u"] = Y1[:, n_sp]
        df1["w"] = _w_from_u(Y1[:, n_sp], theta_c, eps)
        df1["regime"] = 0
        df1["Time"] = tp1
        df1["Comm_name"] = comm

        if t_sw >= tf:
            out.append(df1[df1["Time"] <= tf])
            continue

        y_sw = sol1.sol(t_sw).reshape(-1)
        sol2 = solve_ivp(
            lambda t, y: _glv_u_rhs(t, y, r2, A2, k_u_c, m, config),
            (t_sw, tf), y_sw,
            dense_output=True, rtol=1e-6, atol=1e-9, method="LSODA"
        )
        tp2 = timepoints[timepoints >= t_sw]
        if not len(tp2) or tp2[0] != t_sw:
            tp2 = np.sort(np.unique(np.append(t_sw, tp2)))
        Y2 = sol2.sol(tp2).T
        df2 = pd.DataFrame(np.maximum(Y2[:, :n_sp], 0.), columns=sp)
        df2["u"] = Y2[:, n_sp]
        df2["w"] = _w_from_u(Y2[:, n_sp], theta_c, eps)
        df2["regime"] = 1
        df2["Time"] = tp2
        df2["Comm_name"] = comm
        df2 = df2[df2["Time"] != t_sw].copy()

        df_out = pd.concat([df1, df2], ignore_index=True)
        df_out["t_switch_true"] = t_sw
        out.append(df_out)

    return pd.concat(out, ignore_index=True)


print("hidden_trigger registered.")

In [None]:
# ─── CUMULATIVE TRIGGER ──────────────────────────────────────────────

def _glv_integral_rhs(t, y, r, A, idx_M, m, config=None):
    n = len(r)
    x = np.maximum(y[:n], 0.)
    dxdt = x * (r + A @ x) + m
    if config is not None:
        dxdt = _apply_carrying_cap(dxdt, x, config)
        dxdt = _apply_allee(dxdt, x, config)
    da = float(np.sum(x[idx_M]))
    return np.concatenate([dxdt, [da]])


def _event_integral(c1, M_init):
    e = lambda t, y: c1 * y[-1] - M_init
    e.terminal = True
    e.direction = 1
    return e


def _w_from_integral(a, c1, M_init, epsilon):
    return 1. / (1. + np.exp(-(c1 * a - M_init) / epsilon))


@register_scenario("cumulative_trigger")
def run_cumulative_trigger(config, df_init, timepoints, comm_meta=None, **kwargs):
    rng = np.random.default_rng(config.seed)
    r1, A1, r2, A2 = _build_regime_pair(config, rng)

    c1_val = float(getattr(config, "c1", 1.))
    M_init = float(getattr(config, "M_init", 0.05))
    eps = float(getattr(config, "epsilon", 0.01))
    idx_M = getattr(config, "idx_M", None)
    idx_M = (np.arange(config.n_species, dtype=int)
             if idx_M is None else np.array(list(idx_M), dtype=int))
    a0_ = float(getattr(config, "a0", 0.))
    m = _build_immigration_vec(config, r1)
    sp = _get_sp_cols(df_init)
    t0, tf = float(timepoints[0]), float(timepoints[-1])
    ev = _event_integral(c1_val, M_init)
    out = []

    for comm, sub in df_init.groupby("Comm_name"):
        x0 = np.maximum(sub[sp].iloc[0].values.astype(float), 0.)
        y0 = np.concatenate([x0, [a0_]])

        sol1 = solve_ivp(
            lambda t, y: _glv_integral_rhs(t, y, r1, A1, idx_M, m, config),
            (t0, tf), y0, events=ev, dense_output=True,
            rtol=1e-6, atol=1e-9, method="LSODA"
        )
        t_sw = (float(sol1.t_events[0][0])
                if sol1.t_events[0].size > 0 else tf)

        tp1 = timepoints[timepoints <= t_sw]
        if not len(tp1) or tp1[-1] != t_sw:
            tp1 = np.sort(np.unique(np.append(tp1, t_sw)))
        Y1 = sol1.sol(tp1).T
        n_sp = len(sp)
        df1 = pd.DataFrame(np.maximum(Y1[:, :n_sp], 0.), columns=sp)
        df1["A_M"] = Y1[:, -1]
        df1["w"] = _w_from_integral(Y1[:, -1], c1_val, M_init, eps)
        df1["regime"] = 0
        df1["Time"] = tp1
        df1["Comm_name"] = comm

        if t_sw >= tf:
            out.append(df1[df1["Time"] <= tf])
            continue

        y_sw = sol1.sol(t_sw).reshape(-1)
        sol2 = solve_ivp(
            lambda t, y: _glv_integral_rhs(t, y, r2, A2, idx_M, m, config),
            (t_sw, tf), y_sw, dense_output=True,
            rtol=1e-6, atol=1e-9, method="LSODA"
        )
        tp2 = timepoints[timepoints >= t_sw]
        if not len(tp2) or tp2[0] != t_sw:
            tp2 = np.sort(np.unique(np.append(t_sw, tp2)))
        Y2 = sol2.sol(tp2).T
        df2 = pd.DataFrame(np.maximum(Y2[:, :n_sp], 0.), columns=sp)
        df2["A_M"] = Y2[:, -1]
        df2["w"] = _w_from_integral(Y2[:, -1], c1_val, M_init, eps)
        df2["regime"] = 1
        df2["Time"] = tp2
        df2["Comm_name"] = comm
        df2 = df2[df2["Time"] != t_sw].copy()

        df_out = pd.concat([df1, df2], ignore_index=True)
        df_out["t_switch_true"] = t_sw
        out.append(df_out)

    return pd.concat(out, ignore_index=True)


print("cumulative_trigger registered.")

## 9. Scenarios — Resource / Cross-Feeding

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §9  SCENARIOS — RESOURCE / CROSS-FEEDING DYNAMICS

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def build_inflow_fn(config, base_inflow):
    base_inflow = np.asarray(base_inflow, dtype=float)
    if not getattr(config, "enable_forcing", False):
        return lambda t: base_inflow.copy()
    shocks = getattr(config, "diet_shocks", None) or []
    parsed = [(float(s["start"]), float(s["end"]),
               np.asarray(s["delta"], float)) for s in shocks]
    def fn(t):
        v = base_inflow.copy()
        for s, e, d in parsed:
            if s <= t <= e:
                v = v + d
        return v
    return fn


def simulate_resource_coupled_dataframe(
    df_init, r, A, timepoints, R0, B, consume, inflow,
    delta_R, K_half, config,
    produce=None, clip_nonneg=True, solver_method="LSODA",
):
    """
    Species + resource ODE:
      dx_i/dt = x_i*(r_i + Ax + B@R) + m_i
      dR_k/dt = inflow_k - delta_R_k*R_k
                - sum_i consume_{ik}*x_i*(R_k/(K_k+R_k))
                [+ sum_i produce_{ik}*x_i]
    """
    sp = _get_sp_cols(df_init)
    x0 = df_init.loc[
        df_init["Time"] == df_init["Time"].min(), sp
    ].iloc[0].values.astype(float)
    S, K = len(x0), len(R0)
    r = np.asarray(r, float).reshape(S)
    A = np.asarray(A, float).reshape(S, S)
    R0 = np.asarray(R0, float).reshape(K)
    B = np.asarray(B, float).reshape(S, K)
    consume = np.asarray(consume, float).reshape(S, K)
    delta_R = np.asarray(delta_R, float).reshape(K)
    K_half = np.asarray(K_half, float).reshape(K)
    prod_ = (np.zeros((S, K)) if produce is None
             else np.asarray(produce, float).reshape(S, K))
    inflow_t = build_inflow_fn(config, np.asarray(inflow, float).reshape(K))
    y0 = np.concatenate([x0, R0])
    m_res = _build_immigration_vec(config, r)

    def rhs(t, y):
        x = np.maximum(y[:S], 0.) if clip_nonneg else y[:S]
        R = np.maximum(y[S:], 0.) if clip_nonneg else y[S:]
        dx = x * (r + A @ x + B @ R) + m_res
        dx = _apply_carrying_cap(dx, x, config)
        dx = _apply_allee(dx, x, config)
        uptake = (x[:, None] * consume) * (R / (K_half + R + 1e-12))[None, :]
        dR = (inflow_t(t) - delta_R * R
              - uptake.sum(0) + (x[:, None] * prod_).sum(0))
        return np.concatenate([dx, dR])

    sol = solve_ivp(rhs, (timepoints.min(), timepoints.max()), y0,
                    t_eval=timepoints, method=solver_method)
    if not sol.success:
        raise RuntimeError(f"Resource ODE failed: {sol.message}")
    X, Rmat = sol.y.T[:, :S], sol.y.T[:, S:]
    df_out = pd.DataFrame({
        "Comm_name": df_init["Comm_name"].iloc[0], "Time": sol.t
    })
    for i in range(S):
        df_out[f"sp{i+1}"] = np.maximum(X[:, i], 0.)
    if getattr(config, "export_resources", False):
        for k in range(K):
            df_out[f"R{k+1}"] = np.maximum(Rmat[:, k], 0.)
    return df_out


def _run_resource_scenario(config, df_init, timepoints, crossfeeding):
    rng = np.random.default_rng(getattr(config, "seed", 42))
    n = config.n_species
    K = int(getattr(config, "n_metabolites", 1))
    r_base, A_base, _ = construct_stable_regime(
        n, structure=getattr(config, "A_structure", "sparse"),
        stability_margin=getattr(config, "stability_margin", 0.03), rng=rng,
        structure_kwargs=getattr(config, "A_structure_kwargs",
                                  dict(p=0.15, offdiag_std=0.03,
                                       diag_range=(-0.6, -0.2))))
    hier = getattr(config, "enable_hierarchical", True)
    sig_r = getattr(config, "sigma_r", 0.15)
    sig_A = getattr(config, "sigma_A", 0.05)
    R0 = np.asarray(getattr(config, "R0", np.ones(K)), float).reshape(K)
    B = np.asarray(getattr(config, "B", rng.normal(0, .2, (n, K))),
                   float).reshape(n, K)
    consume = np.asarray(
        getattr(config, "consume", np.abs(rng.normal(0, .2, (n, K)))),
        float).reshape(n, K)
    inflow = np.asarray(
        getattr(config, "inflow", np.ones(K) * .05), float).reshape(K)
    delta_R = np.asarray(
        getattr(config, "delta_R", np.ones(K) * .1), float).reshape(K)
    K_half = np.asarray(
        getattr(config, "K_half", np.ones(K) * .5), float).reshape(K)
    produce = None
    if crossfeeding:
        sc = float(getattr(config, "produce_scale", 0.05))
        produce = np.asarray(
            getattr(config, "produce",
                    np.abs(rng.normal(0, 1, (n, K))) * sc),
            float).reshape(n, K)
    solver = getattr(config, "resource_solver_method", "LSODA")
    clip_nn = bool(getattr(config, "clip_nonneg", True))

    out = []
    for comm in sorted(df_init["Comm_name"].unique()):
        df_c = df_init[df_init["Comm_name"] == comm].copy()
        r_c, A_c = (
            sample_hierarchical_community_params(
                r_base, A_base, sig_r, sig_A, rng)
            if hier else (r_base, A_base)
        )
        out.append(simulate_resource_coupled_dataframe(
            df_c, r_c, A_c, timepoints, R0, B, consume, inflow,
            delta_R, K_half, config,
            produce=produce, clip_nonneg=clip_nn, solver_method=solver
        ))
    return pd.concat(out, ignore_index=True)


@register_scenario("resource_coupled")
def run_resource_coupled(config, df_init, timepoints, comm_meta=None, **kwargs):
    return _run_resource_scenario(config, df_init, timepoints, crossfeeding=False)


@register_scenario("resource_crossfeeding")
def run_resource_crossfeeding(config, df_init, timepoints,
                               comm_meta=None, **kwargs):
    return _run_resource_scenario(config, df_init, timepoints, crossfeeding=True)


print("resource_coupled and resource_crossfeeding registered.")

## 10. Scenarios — Bistability, Dormancy, Drift

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §10  SCENARIOS — BISTABILITY, DORMANCY, ANTIBIOTIC, DRIFT

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

@register_scenario("bistable")
def run_bistable(config, df_init, timepoints, comm_meta=None, **kwargs):
    """Priority-effects bistability: two guilds compete for exclusion."""
    rng = np.random.default_rng(config.seed)
    n = config.n_species
    half = n // 2
    grp1 = list(range(half))
    grp2 = list(range(half, n))
    comp = float(getattr(config, "bistable_comp_strength", 0.25))
    fac = float(getattr(config, "bistable_fac_strength", 0.02))

    A = np.zeros((n, n))
    A[np.diag_indices(n)] = rng.uniform(-0.6, -0.3, n)
    for i in grp1:
        for j in grp2:
            A[i, j] = -comp * (1 + rng.normal(0, .05))
            A[j, i] = -comp * (1 + rng.normal(0, .05))
    for g in [grp1, grp2]:
        for i in g:
            for j in g:
                if i != j and rng.random() < 0.4:
                    A[i, j] = fac * rng.normal(1, .1)

    x_eq1 = np.zeros(n)
    x_eq1[grp1] = rng.lognormal(-3.5, .5, half)
    x_eq2 = np.zeros(n)
    x_eq2[grp2] = rng.lognormal(-3.5, .5, n - half)
    r = -(A @ (0.5 * (x_eq1 + x_eq2)))

    m = _build_immigration_vec(config, r)
    out = []
    for comm in sorted(df_init["Comm_name"].unique()):
        df_c = df_init[df_init["Comm_name"] == comm].copy()
        sp = _get_sp_cols(df_c)
        x0 = df_c[sp].iloc[0].values.astype(float)
        basin = 1 if x0[grp1].sum() >= x0[grp2].sum() else 2
        df_s = simulate_gLV_dataframe(df_c, r, A, timepoints, m=m, config=config)
        df_s["basin"] = basin
        out.append(df_s)
    return pd.concat(out, ignore_index=True)


@register_scenario("dormancy")
def run_dormancy(config, df_init, timepoints, comm_meta=None, **kwargs):
    """Active <-> dormant transitions (stress-triggered)."""
    rng = np.random.default_rng(config.seed)
    n = config.n_species
    r, A, _ = construct_stable_regime(
        n, structure=getattr(config, "A_structure", "sparse"),
        stability_margin=getattr(config, "stability_margin", 0.03), rng=rng,
        structure_kwargs=getattr(config, "A_structure_kwargs",
                                  dict(p=0.15, offdiag_std=0.03,
                                       diag_range=(-0.6, -0.2))))
    gamma = float(getattr(config, "dormancy_rate", 0.05))
    delta = float(getattr(config, "revival_rate", 0.10))
    thr = float(getattr(config, "dormancy_threshold", 0.01))
    sp = _get_sp_cols(df_init)
    m_dorm = _build_immigration_vec(config, r)
    out = []

    for comm, sub in df_init.groupby("Comm_name"):
        x0 = np.maximum(sub[sp].iloc[0].values.astype(float), 0.)
        y0 = np.concatenate([x0, np.zeros(n)])

        def rhs(t, y, _r=r, _A=A, _n=n):
            x = np.maximum(y[:_n], 0.)
            d = np.maximum(y[_n:], 0.)
            sl = (x < thr).astype(float)
            dx = x * (_r + _A @ x) - gamma * x * sl + delta * d + m_dorm
            dd = gamma * x * sl - delta * d
            return np.concatenate([dx, dd])

        sol = solve_ivp(rhs, (timepoints[0], timepoints[-1]), y0,
                        t_eval=timepoints, rtol=1e-6, atol=1e-9,
                        method="LSODA")
        if not sol.success:
            raise RuntimeError(f"dormancy ODE failed for {comm}")
        X = np.maximum(sol.y[:n].T, 0.)
        D = np.maximum(sol.y[n:].T, 0.)
        df_c = pd.DataFrame(X, columns=sp)
        for i in range(n):
            df_c[f"d{i+1}"] = D[:, i]
        df_c["Time"] = sol.t
        df_c["Comm_name"] = comm
        out.append(df_c)
    return pd.concat(out, ignore_index=True)


@register_scenario("antibiotic_pulse")
def run_antibiotic_pulse(config, df_init, timepoints,
                          comm_meta=None, **kwargs):
    """Stable gLV + species-specific antibiotic kill."""
    rng = np.random.default_rng(config.seed)
    n = config.n_species
    r_base, A_base, _ = construct_stable_regime(
        n, structure=getattr(config, "A_structure", "sparse"),
        stability_margin=getattr(config, "stability_margin", 0.03), rng=rng,
        structure_kwargs=getattr(config, "A_structure_kwargs",
                                  dict(p=0.15, offdiag_std=0.03,
                                       diag_range=(-0.6, -0.2))))
    hier = getattr(config, "enable_hierarchical", True)
    sig_r = getattr(config, "sigma_r", 0.15)
    sig_A = getattr(config, "sigma_A", 0.05)
    ab_fn = _build_antibiotic_fn(config, n)
    out = []
    for comm in sorted(df_init["Comm_name"].unique()):
        df_c = df_init[df_init["Comm_name"] == comm].copy()
        r_c, A_c = (
            sample_hierarchical_community_params(
                r_base, A_base, sig_r, sig_A, rng)
            if hier else (r_base, A_base)
        )
        m = _build_immigration_vec(config, r_c)
        df_s = simulate_gLV_dataframe(
            df_c, r_c, A_c, timepoints,
            antibiotic_fn=ab_fn, m=m, config=config
        )
        t0_ab = float(getattr(config, "antibiotic_start", 8.))
        dur_ab = float(getattr(config, "antibiotic_duration", 3.))
        df_s["antibiotic_on"] = (
            (df_s["Time"] >= t0_ab) & (df_s["Time"] <= t0_ab + dur_ab)
        ).astype(int)
        out.append(df_s)
    return pd.concat(out, ignore_index=True)


@register_scenario("environmental_drift")
def run_environmental_drift(config, df_init, timepoints,
                             comm_meta=None, **kwargs):
    """
    gLV with slowly drifting parameters (non-stationary environment).
    r(t) = r_base + drift_r * W_r(t)
    A(t) = A_base + drift_A * W_A(t)  (Wiener process, pre-sampled)
    """
    rng = np.random.default_rng(config.seed)
    n = config.n_species
    r_base, A_base, _ = construct_stable_regime(
        n, structure=getattr(config, "A_structure", "sparse"),
        stability_margin=getattr(config, "stability_margin", 0.03), rng=rng,
        structure_kwargs=getattr(config, "A_structure_kwargs",
                                  dict(p=0.15, offdiag_std=0.03,
                                       diag_range=(-0.6, -0.2))))
    drift_r = float(getattr(config, "drift_rate_r", 0.01))
    drift_A = float(getattr(config, "drift_rate_A", 0.005))
    m = _build_immigration_vec(config, r_base)
    sp = _get_sp_cols(df_init)

    # Pre-sample Wiener increments for r and A
    n_t = len(timepoints)
    dt_arr = np.diff(timepoints)
    W_r = np.zeros((n_t, n))
    W_A = np.zeros((n_t, n, n))
    for k in range(1, n_t):
        sqrt_dt = np.sqrt(dt_arr[k-1])
        W_r[k] = W_r[k-1] + drift_r * sqrt_dt * rng.normal(0, 1, n)
        W_A[k] = W_A[k-1] + drift_A * sqrt_dt * rng.normal(0, 1, (n, n))

    out = []
    for comm, sub in df_init.groupby("Comm_name"):
        x0 = np.maximum(sub[sp].iloc[0].values.astype(float), 0.)
        # Step-by-step integration with drifting params
        X = np.zeros((n_t, n))
        X[0] = x0
        for k in range(1, n_t):
            r_t = r_base + W_r[k]
            A_t = A_base + W_A[k]
            # Ensure diagonal remains negative
            A_t[np.diag_indices(n)] = -np.abs(np.diag(A_t)) - 1e-6
            dt = timepoints[k] - timepoints[k-1]
            x = np.maximum(X[k-1], 0.)
            dxdt = x * (r_t + A_t @ x) + m
            dxdt = _apply_carrying_cap(dxdt, x, config)
            dxdt = _apply_allee(dxdt, x, config)
            X[k] = np.maximum(x + dt * dxdt, 0.)

        df_c = pd.DataFrame(X, columns=sp)
        df_c["Time"] = timepoints
        df_c["Comm_name"] = comm
        out.append(df_c)

    return pd.concat(out, ignore_index=True)


print("bistable, dormancy, antibiotic_pulse, environmental_drift registered.")
list_scenarios()

## 11. Observation Model

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §11  OBSERVATION MODEL

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def apply_observation_model(df_latent: pd.DataFrame, config) -> pd.DataFrame:
    """
    Convert latent abundances → sequencing observations.
    Modes: continuous | multinomial | dirichlet_multinomial
    """
    mode = getattr(config, "observation_mode", "continuous")
    n_reps = max(int(getattr(config, "n_replicates", 1)), 1)

    if mode == "continuous":
        if n_reps == 1:
            return df_latent.copy()
        out = []
        for rep in range(n_reps):
            tmp = df_latent.copy()
            tmp["ReplicateID"] = rep
            out.append(tmp)
        return pd.concat(out, ignore_index=True)

    rng = np.random.default_rng(getattr(config, "seed", 42))
    sp = _get_sp_cols(df_latent)
    records = []

    for _, row in df_latent.iterrows():
        abund = row[sp].values.astype(float)
        total = abund.sum()
        if total <= 0:
            continue
        p = abund / total

        for rep in range(n_reps):
            depth = int(np.exp(rng.normal(
                np.log(max(getattr(config, "library_size_mean", 10000), 1)),
                max(getattr(config, "library_size_sigma", 0.6), 1e-6)
            )))

            if mode == "multinomial":
                counts = rng.multinomial(depth, p)
            elif mode == "dirichlet_multinomial":
                alpha = np.maximum(
                    p * float(getattr(config, "dm_alpha_scale", 100.)), 1e-12
                )
                counts = rng.multinomial(depth, rng.dirichlet(alpha))
            else:
                raise ValueError(f"Unknown observation_mode: {mode}")

            if getattr(config, "enable_dropout", False):
                counts[counts < int(
                    getattr(config, "detection_limit", 2)
                )] = 0

            nr = row.copy()
            for i, c in enumerate(sp):
                nr[c] = float(counts[i])
            nr["LibrarySize"] = depth
            nr["ReplicateID"] = rep
            nr["ObservationType"] = "counts"
            records.append(nr)

    df_counts = pd.DataFrame(records)
    df_rel = df_counts.copy()
    tot = df_rel[sp].sum(axis=1).replace(0, np.nan)
    for c in sp:
        df_rel[c] = df_rel[c] / tot
    df_rel["ObservationType"] = "relative_abundance"
    return pd.concat([df_counts, df_rel], ignore_index=True)


def apply_sampling_design(df_latent: pd.DataFrame, config) -> pd.DataFrame:
    """Apply irregular sampling and/or random missingness."""
    irregular = getattr(config, "irregular_sampling", False)
    miss_rate = float(getattr(config, "missing_rate", 0.))
    if not irregular and miss_rate <= 0.:
        return df_latent.copy()
    rng = np.random.default_rng(getattr(config, "seed", 42))
    df = df_latent.copy()

    if irregular:
        kf = min(max(float(
            getattr(config, "irregular_keep_frac", 0.6)
        ), 0.05), 1.)
        kept = []
        for _, g in df.groupby("Comm_name", sort=False):
            g = g.sort_values("Time")
            m = len(g)
            if m <= 2:
                kept.append(g)
                continue
            k = max(2, int(round(kf * m)))
            mid = np.arange(1, m - 1)
            ch = (rng.choice(mid, min(k - 2, len(mid)), replace=False)
                  if len(mid) and k > 2
                  else np.array([], int))
            kept.append(g.iloc[np.sort(np.concatenate([[0, m - 1], ch]))])
        df = pd.concat(kept, ignore_index=True)

    if miss_rate > 0.:
        df = df.loc[
            rng.random(len(df)) >= min(max(miss_rate, 0.), 0.95)
        ].reset_index(drop=True)

    return df


print("Observation model and sampling design defined.")

## 12. Auxiliary Metabolite Target

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §12  AUXILIARY METABOLITE TARGET

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def compute_metabolite_trajectories(df_latent, C, sigma_obs=0.05,
                                     rng=None, n_metabolites=None):
    """
    Integrate metabolite ODE: y_m(t) = integral of (C @ x) dt, with noise.
    C shape: (n_metabolites, n_species).
    """
    rng = np.random.default_rng() if rng is None else rng
    sp = _get_sp_cols(df_latent)
    n_s = len(sp)
    C = np.asarray(C, float)
    if C.ndim == 1:
        C = C.reshape(1, n_s)
    n_m = C.shape[0]
    df_out = df_latent.copy()

    for comm, g in df_out.groupby("Comm_name"):
        g = g.sort_values("Time")
        X = g[sp].values
        t = g["Time"].values
        Yr = (C @ X.T).T  # (T, M)
        Y = np.zeros_like(Yr)
        for k in range(1, len(t)):
            dt = t[k] - t[k-1]
            Y[k] = Y[k-1] + 0.5 * dt * (Yr[k-1] + Yr[k])
        Y += rng.normal(0, sigma_obs, Y.shape)
        for m_idx in range(n_m):
            df_out.loc[g.index, f"met{m_idx+1}"] = Y[:, m_idx]

    return df_out


def compute_shannon_diversity(df, species_prefix="sp"):
    """Compute Shannon diversity index per row."""
    sp = [c for c in df.columns if c.startswith(species_prefix)]
    X = df[sp].values.astype(float)
    X = np.maximum(X, 0.)
    totals = X.sum(axis=1, keepdims=True)
    totals = np.where(totals == 0, 1, totals)
    p = X / totals
    p = np.where(p > 0, p, 1)  # avoid log(0)
    H = -np.sum(p * np.log(p) * (X > 0), axis=1)
    return H


print("Metabolite trajectories and Shannon diversity defined.")

## 13. Benchmark Splits

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §13  BENCHMARK SPLITS

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def generate_benchmark_splits(df, config, train_frac=0.7, val_frac=0.15,
                               seed=None):
    """Community-level train/val/test splits + regime_ood split."""
    rng_s = np.random.default_rng(
        seed or getattr(config, "seed", 42) + 999
    )
    comms = sorted(df["Comm_name"].unique())
    idx = rng_s.permutation(len(comms))
    n_tr = int(np.floor(train_frac * len(comms)))
    n_va = int(np.floor(val_frac * len(comms)))
    tr = [comms[i] for i in idx[:n_tr]]
    va = [comms[i] for i in idx[n_tr:n_tr + n_va]]
    te = [comms[i] for i in idx[n_tr + n_va:]]
    splits = {
        "train": df[df["Comm_name"].isin(tr)].copy(),
        "val":   df[df["Comm_name"].isin(va)].copy(),
        "test":  df[df["Comm_name"].isin(te)].copy(),
    }
    if "regime" in df.columns:
        splits["regime_ood"] = df[df["regime"] == 1].copy()
    manifest = {
        "train_communities": tr,
        "val_communities": va,
        "test_communities": te,
        "n_train": len(tr),
        "n_val": len(va),
        "n_test": len(te),
    }
    return splits, manifest


print("Benchmark splits defined.")

## 14. Single gLV Fit (NLS)

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §14  SINGLE gLV FIT (FIX #11: proper NLS instead of finite diffs)

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def fit_single_glv(df, n_species, method="nls", lam=1e-3):
    """
    Fit single gLV to (potentially switching) data.

    method='nls': Nonlinear least squares — fits r,A by minimising
        || x_data(t) - x_sim(t; r, A) ||^2
        Eliminates finite-difference noise artefact.

    method='ridge': Original ridge regression on dx/dt (kept for speed).
    """
    sp = [f"sp{i+1}" for i in range(n_species)]

    if method == "ridge":
        return _fit_single_glv_ridge(df, n_species, lam)

    # ── NLS method ───────────────────────────────────────────────────
    # Collect all trajectories
    trajs = []
    for _, g in df.groupby("Comm_name"):
        g = g.sort_values("Time")
        X = g[sp].values.astype(float)
        t = g["Time"].values.astype(float)
        if len(t) < 3:
            continue
        trajs.append((t, X))

    if not trajs:
        return np.zeros(n_species), np.zeros((n_species, n_species))

    # Pack parameters: r (n) + A (n*n) = n + n^2
    n = n_species

    def residual(params):
        r_hat = params[:n]
        A_hat = params[n:].reshape(n, n)
        resids = []
        for t_arr, X_data in trajs:
            x0 = np.maximum(X_data[0], 1e-8)
            try:
                X_pred = _glv_integrate(r_hat, A_hat, x0, t_arr)
                resids.append((X_pred - X_data).ravel())
            except RuntimeError:
                # If ODE blows up, penalise heavily
                resids.append(np.full(X_data.size, 1e3))
        return np.concatenate(resids)

    # Initial guess via ridge
    r0, A0 = _fit_single_glv_ridge(df, n_species, lam)
    p0 = np.concatenate([r0, A0.ravel()])

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        result = least_squares(residual, p0, method="lm",
                               max_nfev=500, ftol=1e-6, xtol=1e-6)

    r_hat = result.x[:n]
    A_hat = result.x[n:].reshape(n, n)
    return r_hat, A_hat


def _fit_single_glv_ridge(df, n_species, lam=1e-3):
    """Ridge regression baseline (finite differences)."""
    sp = [f"sp{i+1}" for i in range(n_species)]
    r_hat = np.zeros(n_species)
    A_hat = np.zeros((n_species, n_species))

    for s in range(n_species):
        rows_X, rows_y = [], []
        for _, g in df.groupby("Comm_name"):
            g = g.sort_values("Time")
            X_ = g[sp].values.astype(float)
            t_ = g["Time"].values.astype(float)
            for k in range(1, len(t_) - 1):
                dt = t_[k+1] - t_[k-1]
                if dt < 1e-10 or X_[k, s] < 1e-9:
                    continue
                rows_y.append((X_[k+1, s] - X_[k-1, s]) / dt / X_[k, s])
                rows_X.append(np.concatenate([[1.], X_[k]]))
        if not rows_X:
            continue
        Xs = np.array(rows_X)
        ys = np.array(rows_y)
        th = np.linalg.solve(
            Xs.T @ Xs + lam * np.eye(Xs.shape[1]), Xs.T @ ys
        )
        r_hat[s] = th[0]
        A_hat[s, :] = th[1:]

    return r_hat, A_hat


print("Single-gLV fit (NLS + ridge) defined.")

## 15. Visualization

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §15  VISUALIZATION

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def _colors(n):
    cyc = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["C0"] * n)
    return (cyc * ((n // len(cyc)) + 1))[:n]


def plot_regime_switching(df, n_communities=3, trigger_col=None,
                           trigger_threshold=None, regime_col="regime",
                           species_prefix="sp", figsize_per_comm=(12, 4),
                           title=""):
    """
    Plot community trajectories with regime background shading.
    Optional bottom panel shows hidden trigger variable.
    """
    sp_cols = [c for c in df.columns if c.startswith(species_prefix)]
    n_sp = len(sp_cols)
    has_trig = trigger_col and trigger_col in df.columns
    has_reg = regime_col in df.columns
    colors = _colors(n_sp)
    comms = sorted(df["Comm_name"].unique())[:n_communities]
    n_panels = 2 if has_trig else 1

    fig, axes = plt.subplots(
        len(comms) * n_panels, 1,
        figsize=(figsize_per_comm[0],
                 figsize_per_comm[1] * len(comms)),
        squeeze=False
    )

    for ci, comm in enumerate(comms):
        g = df[df["Comm_name"] == comm].sort_values("Time")
        t = g["Time"].values
        ax = axes[ci * n_panels, 0]

        for si, col in enumerate(sp_cols):
            ax.plot(t, g[col].values, color=colors[si], lw=1.5, label=col)

        if has_reg:
            reg = g[regime_col].values
            ts = t[0]
            cr = reg[0]
            for k in range(1, len(t)):
                if reg[k] != cr or k == len(t) - 1:
                    color = "#FFD580" if cr == 1 else "#E8F4E8"
                    ax.axvspan(ts, t[k], alpha=0.22, color=color, zorder=0)
                    ts = t[k]
                    cr = reg[k]
        if "t_switch_true" in g.columns:
            ax.axvline(g["t_switch_true"].iloc[0], color="k",
                       lw=1, ls="--", alpha=0.6)

        ax.set_ylabel("Abundance")
        ax.grid(True, lw=0.4, alpha=0.5)
        ax.set_title(
            f"{comm}" + (f" — {title}" if title and ci == 0 else ""),
            loc="left"
        )
        if ci == 0:
            ax.legend(fontsize=7, ncol=min(3, n_sp), loc="upper right")

        if has_trig:
            ax2 = axes[ci * n_panels + 1, 0]
            ax2.plot(t, g[trigger_col].values, color="steelblue", lw=1.5)
            if trigger_threshold is not None:
                ax2.axhline(trigger_threshold, color="crimson", ls="--",
                             lw=1., label=f"threshold={trigger_threshold}")
                ax2.legend(fontsize=7)
            ax2.set_ylabel(trigger_col)
            ax2.set_xlabel("Time")
            ax2.grid(True, lw=0.4, alpha=0.5)
        else:
            ax.set_xlabel("Time")

    plt.tight_layout()
    return fig


def plot_single_vs_switching(df_true, n_species, timepoints,
                              community=None, method="nls", lam=1e-3):
    """
    Fit single gLV to switching data and compare trajectories + RMSE.
    """
    sp = [f"sp{i+1}" for i in range(n_species)]
    r_hat, A_hat = fit_single_glv(df_true, n_species, method=method, lam=lam)
    if community is None:
        community = sorted(df_true["Comm_name"].unique())[0]
    g = df_true[df_true["Comm_name"] == community].sort_values("Time")
    x0 = np.maximum(g[sp].iloc[0].values.astype(float), 1e-8)
    X_pred = _glv_integrate(r_hat, A_hat, x0, timepoints)
    colors = _colors(n_species)

    fig, axes = plt.subplots(1, 2, figsize=(14, 4), sharey=True)

    # Left: true switching
    ax = axes[0]
    for si, col in enumerate(sp):
        ax.plot(g["Time"], g[col], color=colors[si], lw=1.5, label=col)
    if "regime" in g.columns:
        reg = g["regime"].values
        t_ = g["Time"].values
        ts = t_[0]
        cr = reg[0]
        for k in range(1, len(t_)):
            if reg[k] != cr or k == len(t_) - 1:
                ax.axvspan(ts, t_[k], alpha=0.2,
                           color="#FFD580" if cr == 1 else "#E8F4E8",
                           zorder=0)
                ts = t_[k]
                cr = reg[k]
    ax.set_title(f"True 2-Regime ({community})", loc="left")
    ax.set_xlabel("Time")
    ax.set_ylabel("Abundance")
    ax.legend(fontsize=7, ncol=2)
    ax.grid(True, lw=0.4, alpha=0.5)

    # Right: single gLV prediction
    ax = axes[1]
    for si, col in enumerate(sp):
        ax.plot(timepoints, X_pred[:, si], color=colors[si],
                lw=1.5, ls="--", label=col)
        ax.scatter(g["Time"], g[col], color=colors[si],
                   s=20, alpha=0.6, zorder=4)
    g_interp = np.array([
        np.interp(timepoints, g["Time"].values, g[col].values)
        for col in sp
    ]).T
    rmse = np.sqrt(np.mean((X_pred - g_interp)**2, axis=0))
    ax.set_title(f"Single gLV fit — mean RMSE={rmse.mean():.4f}", loc="left")
    ax.set_xlabel("Time")
    ax.grid(True, lw=0.4, alpha=0.5)
    plt.tight_layout()
    return fig, rmse


print("Visualization functions defined.")

## 16. Export System

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §16  EXPORT (FIX #12: ground-truth params; FIX #13: per-comm switch)

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def export_config_snapshot(config_obj, output_dir=".", verbose=True):
    """Save GeneratorConfig as JSON."""
    p = Path(output_dir)
    p.mkdir(parents=True, exist_ok=True)
    snap = asdict(config_obj)
    snap["timestamp"] = datetime.now().isoformat()
    fp = p / "config.json"
    with open(fp, "w") as f:
        json.dump(snap, f, indent=2, default=str)
    if verbose:
        print(f"Config saved → {fp}")
    return str(fp)


def export_dataset(config, df_init, timepoints, output_dir="test_data",
                   comm_meta=None, compute_metabolites=False,
                   metabolite_C=None):
    """
    Run scenario and export full dataset bundle:
      latent_truth.csv, observed_counts.csv, config.json,
      dataset_metadata.json, split_manifest.json,
      ground_truth_params.json (FIX #12),
      split_train/val/test/regime_ood.csv,
      [optional] metabolites_truth.csv
    """
    outdir = Path(output_dir)
    outdir.mkdir(parents=True, exist_ok=True)
    sc = config.scenario
    if sc not in SCENARIO_REGISTRY:
        raise ValueError(f"Scenario '{sc}' not registered.")

    _GROUND_TRUTH_PARAMS.clear()

    df_latent_raw = SCENARIO_REGISTRY[sc](
        config, df_init=df_init, timepoints=timepoints,
        comm_meta=comm_meta
    )
    df_latent = apply_sampling_design(df_latent_raw, config)
    df_observed = apply_observation_model(df_latent, config)

    paths = {}

    def _save(df, name):
        p = outdir / name
        df.to_csv(p, index=False)
        paths[name.replace(".csv", "")] = str(p)

    _save(df_latent, "latent_truth.csv")
    _save(df_observed, "observed_counts.csv")

    # Config
    snap = asdict(config)
    snap["timestamp"] = datetime.now().isoformat()
    with open(outdir / "config.json", "w") as f:
        json.dump(snap, f, indent=2, default=str)
    paths["config"] = str(outdir / "config.json")

    # FIX #12: ground-truth regime parameters
    if getattr(config, "export_ground_truth_params", True) and _GROUND_TRUTH_PARAMS:
        with open(outdir / "ground_truth_params.json", "w") as f:
            json.dump(_GROUND_TRUTH_PARAMS, f, indent=2)
        paths["ground_truth_params"] = str(outdir / "ground_truth_params.json")

    # Metabolites
    if compute_metabolites:
        n_sp = config.n_species
        C = (metabolite_C if metabolite_C is not None
             else np.abs(np.random.default_rng(config.seed).normal(
                 0, .3, (1, n_sp))))
        df_met = compute_metabolite_trajectories(df_latent, C=C)
        _save(df_met, "metabolites_truth.csv")

    # Splits
    splits, manifest = generate_benchmark_splits(df_observed, config)
    for sname, df_s in splits.items():
        _save(df_s, f"split_{sname}.csv")
    with open(outdir / "split_manifest.json", "w") as f:
        json.dump(manifest, f, indent=2)
    paths["split_manifest"] = str(outdir / "split_manifest.json")

    # FIX #13: per-community switch-time estimation
    latent_vars = [c for c in ["u", "A_M", "w", "regime", "t_switch_true"]
                   if c in df_latent.columns]
    meta = {
        "scenario": sc,
        "seed": config.seed,
        "regime_distance": getattr(config, "regime_distance", None),
        "n_rows_observed": len(df_observed),
        "n_communities": df_observed["Comm_name"].nunique(),
        "time_min": float(df_observed["Time"].min()),
        "time_max": float(df_observed["Time"].max()),
        "n_timepoints": int(df_observed["Time"].nunique()),
        "has_replicates": "ReplicateID" in df_observed.columns,
        "latent_variables": latent_vars,
    }

    # Per-community switch times (FIX #13)
    if "t_switch_true" in df_latent.columns:
        per_comm_switches = {}
        for comm, g in df_latent.groupby("Comm_name"):
            t_sw = g["t_switch_true"].dropna().unique()
            if len(t_sw) > 0:
                per_comm_switches[comm] = float(t_sw[0])
        meta["per_community_switch_times"] = per_comm_switches
    elif "w" in df_latent.columns:
        per_comm_switches = {}
        for comm, g in df_latent.groupby("Comm_name"):
            g = g.sort_values("Time")
            idx_w = (g["w"] - 0.5).abs().idxmin()
            per_comm_switches[comm] = float(g.loc[idx_w, "Time"])
        meta["per_community_switch_times"] = per_comm_switches

    with open(outdir / "dataset_metadata.json", "w") as f:
        json.dump(meta, f, indent=2)
    paths["metadata"] = str(outdir / "dataset_metadata.json")

    print(f"✅ Export complete → {outdir.resolve()}")
    for k, v in paths.items():
        print(f"  {k:30s}: {v}")
    return paths


print("Export system defined (with ground-truth params & per-comm switches).")

## 17. Dataset Sweep Factory

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §17  DATASET SWEEP FACTORY

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

DIFFICULTY_PRESETS = dict(
    easy   = dict(regime_distance=1.5, seed_offset=0),
    medium = dict(regime_distance=1.0, seed_offset=1000),
    hard   = dict(regime_distance=0.3, seed_offset=2000),
)

SWITCHING_SCENARIOS = [
    "hidden_trigger", "cumulative_trigger", "time_switch", "soft_switch"
]


def run_dataset_sweep(df_init, timepoints, base_config, root_dir="datasets",
                      comm_meta=None, scenarios=None, difficulties=None,
                      verbose=True):
    """Generate all scenario × difficulty datasets."""
    ROOT = Path(root_dir)
    ROOT.mkdir(exist_ok=True)
    scenarios = scenarios or SWITCHING_SCENARIOS
    difficulties = difficulties or DIFFICULTY_PRESETS

    for scenario in scenarios:
        for difficulty, params in difficulties.items():
            cfg = deepcopy(base_config)
            cfg.scenario = scenario
            cfg.regime_distance = params["regime_distance"]
            cfg.seed = base_config.seed + params["seed_offset"]
            cfg.observation_mode = "continuous"
            cfg.n_replicates = 1
            for k, v in dict(
                t_switch=10., theta=.5, k_u=.4, u0=0., epsilon=.1,
                idx_M=None, c1=1., M_init=.05, a0=0.
            ).items():
                setattr(cfg, k, v)
            if verbose:
                print(f"\n=== {scenario}/{difficulty} "
                      f"(rd={params['regime_distance']}) ===")
            try:
                export_dataset(cfg, df_init, timepoints,
                               str(ROOT / scenario / difficulty), comm_meta)
            except Exception as e:
                print(f"  ⚠️ Failed: {e}")

    print(f"\n✅ Sweep complete → {ROOT.resolve()}")


print("Dataset sweep factory defined.")

## 18. Validation Suite

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §18  VALIDATION SUITE

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

def validate_dataset(df, config, verbose=True):
    """
    Comprehensive validation of a generated dataset.
    Returns dict of {check_name: passed_bool}.
    """
    sp = _get_sp_cols(df)
    results = {}

    # 1. No NaN / Inf in species columns
    has_nan = df[sp].isna().any().any()
    has_inf = np.isinf(df[sp].values.astype(float)).any()
    results["no_nan"] = not has_nan
    results["no_inf"] = not has_inf

    # 2. Non-negative abundances
    results["nonnegative"] = (df[sp].values.astype(float) >= -1e-12).all()

    # 3. Community count
    results["communities_present"] = df["Comm_name"].nunique() > 0

    # 4. Time monotonicity per community
    time_ok = True
    for _, g in df.groupby("Comm_name"):
        t = g["Time"].values.astype(float)
        if not np.all(np.diff(t) >= -1e-10):
            time_ok = False
            break
    results["time_monotonic"] = time_ok

    # 5. No blow-up (species < 100 as sanity bound)
    max_val = df[sp].values.astype(float).max()
    results["no_blowup"] = max_val < 100.

    # 6. Regime labels valid (if present)
    if "regime" in df.columns:
        results["valid_regimes"] = set(df["regime"].unique()).issubset({0, 1})

    # 7. Observation model checks (if counts present)
    if "ObservationType" in df.columns:
        counts_df = df[df["ObservationType"] == "counts"]
        if len(counts_df) > 0:
            counts_arr = counts_df[sp].values.astype(float)
            results["counts_nonneg"] = (counts_arr >= 0).all()
            results["counts_integer"] = np.allclose(
                counts_arr, np.round(counts_arr)
            )

    if verbose:
        for check, passed in results.items():
            status = "✅" if passed else "❌"
            print(f"  {status} {check}")

    return results


print("Validation suite defined.")

## 19. Tests

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §19  TESTS

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

print("\n" + "="*60)
print("RUNNING TESTS")
print("="*60)

# ── Setup shared test objects ────────────────────────────────────────
n_species   = 5
size_counts = [(1, 3), (2, 4), (3, 2)]
timepoints  = np.linspace(0., 20., 101)

df_init, COMM_META = generate_community_dataframe(
    n_species=n_species, size_counts=size_counts,
    mean_abundance=0.01, bounds=(0.001, 0.1), sigma_log=1.0, seed=42
)

sp_cols = [f"sp{i+1}" for i in range(n_species)]

CONFIG.n_species           = n_species
CONFIG.A_structure         = "sparse"
CONFIG.A_structure_kwargs  = dict(p=0.15, offdiag_std=0.03,
                                   diag_range=(-0.6, -0.2))
CONFIG.stability_margin    = 0.03
CONFIG.enable_hierarchical = True
CONFIG.sigma_r = 0.15
CONFIG.sigma_A = 0.05

# Test 0: stable regime construction
r_t, A_t, x_t = construct_stable_regime(
    n_species, rng=np.random.default_rng(2026)
)
assert np.max(np.abs(r_t + A_t @ x_t)) < 1e-10
J_t = np.diag(x_t) @ A_t
assert float(np.max(np.real(np.linalg.eigvals(J_t)))) < -0.03
print("✅ Test 0: construct_stable_regime passed.")

# Test 1: community dataframe
for nm, row in df_init.iterrows():
    cn = df_init.loc[nm, "Comm_name"]
    absent = [i for i in range(n_species) if i not in COMM_META[cn]]
    for a in absent:
        assert df_init.loc[nm, f"sp{a+1}"] == 0.
covered = {sp for pr in COMM_META.values() for sp in pr}
assert covered == set(range(n_species))
print(f"✅ Test 1: community df: {len(df_init)} rows, all species covered.")

# Test 2: baseline_glv_stable
CONFIG.scenario = "baseline_glv_stable"
CONFIG.observation_mode = "continuous"
CONFIG.n_replicates = 1

df_stable = simulate_community(
    CONFIG, df_init=df_init, timepoints=timepoints, comm_meta=COMM_META
)
assert (df_stable[_get_sp_cols(df_stable)] >= -1e-12).all().all()
print(f"✅ Test 2: baseline_glv_stable: {df_stable.shape}")

# Test 3: All switching scenarios
for scenario, params in [
    ("time_switch",        {"t_switch": 10., "regime_distance": 1.}),
    ("soft_switch",        {"t_switch": 10., "epsilon": 1.5,
                            "regime_distance": 1.}),
    ("hidden_trigger",     {"theta": .5, "k_u": .4, "u0": 0.,
                            "regime_distance": 1.}),
    ("cumulative_trigger", {"idx_M": [0, 2], "c1": 1., "M_init": .05,
                            "a0": 0., "regime_distance": 1.}),
]:
    CONFIG.scenario = scenario
    for k, v in params.items():
        setattr(CONFIG, k, v)
    CONFIG.observation_mode = "continuous"
    df_s = simulate_community(
        CONFIG, df_init=df_init, timepoints=timepoints, comm_meta=COMM_META
    )
    sp_c = _get_sp_cols(df_s)
    assert (df_s[sp_c] >= -1e-12).all().all()
    assert "regime" in df_s.columns or "w" in df_s.columns
    print(f"  ✅ {scenario}: {df_s.shape}")
print("✅ Test 3: All switching scenarios passed.")

# Test 4: observation model
CONFIG.scenario = "baseline_glv_stable"
CONFIG.observation_mode = "dirichlet_multinomial"
CONFIG.library_size_mean = 10000
CONFIG.library_size_sigma = 0.6
CONFIG.dm_alpha_scale = 100.
CONFIG.enable_dropout = True
CONFIG.detection_limit = 2
CONFIG.n_replicates = 2

df_obs = simulate_community(
    CONFIG, df_init=df_init, timepoints=timepoints, comm_meta=COMM_META
)
assert set(df_obs["ObservationType"].unique()) == {"counts", "relative_abundance"}
print(f"✅ Test 4: observation model: {df_obs.shape}")

# Test 5: immigration — absent species can colonise
CONFIG.scenario = "baseline_glv_stable"
CONFIG.observation_mode = "continuous"
CONFIG.n_replicates = 1
CONFIG.immigration_rate = 1e-3
CONFIG.immigration_scale = 1.0

df_imm_init = pd.DataFrame(
    [["single_sp1", 0.0] + [0.05, 0.0, 0.0, 0.0, 0.0]],
    columns=["Comm_name", "Time"] + [f"sp{i+1}" for i in range(5)]
)
df_imm = simulate_community(
    CONFIG, df_init=df_imm_init, timepoints=timepoints,
    comm_meta={"single_sp1": [0]}
)
sp_c = _get_sp_cols(df_imm)
final = df_imm[np.isclose(df_imm["Time"], timepoints[-1])][sp_c].values[0]
n_colonised = int(np.sum(final > 1e-6))
assert n_colonised >= 1
print(f"✅ Test 5: immigration — {n_colonised}/{len(sp_c)} species at t_final")
CONFIG.immigration_rate = 1e-4  # reset

# Test 6: hidden trigger heterogeneity
CONFIG.scenario = "hidden_trigger"
CONFIG.trigger_sigma_theta = 0.1
CONFIG.trigger_sigma_k_u = 0.05
CONFIG.regime_distance = 1.0
CONFIG.observation_mode = "continuous"
CONFIG.n_replicates = 1

df_het = simulate_community(
    CONFIG, df_init=df_init, timepoints=timepoints, comm_meta=COMM_META
)
if "t_switch_true" in df_het.columns:
    switch_times = df_het.groupby("Comm_name")["t_switch_true"].first().dropna()
    if len(switch_times) > 1:
        assert switch_times.std() > 0, "Trigger heterogeneity should spread switch times"
        print(f"  Switch time spread: {switch_times.std():.3f}")
print("✅ Test 6: hidden trigger heterogeneity passed.")
CONFIG.trigger_sigma_theta = 0.0  # reset
CONFIG.trigger_sigma_k_u = 0.0

# Test 7: validation suite
CONFIG.scenario = "time_switch"
CONFIG.t_switch = 10.
CONFIG.regime_distance = 1.
CONFIG.observation_mode = "continuous"
CONFIG.n_replicates = 1
df_val = simulate_community(
    CONFIG, df_init=df_init, timepoints=timepoints, comm_meta=COMM_META
)
print("Validation results:")
vr = validate_dataset(df_val, CONFIG)
assert all(vr.values()), "Validation failed!"
print("✅ Test 7: validation suite passed.")

# Test 8: environmental drift
CONFIG.scenario = "environmental_drift"
CONFIG.enable_drift = True
CONFIG.drift_rate_r = 0.01
CONFIG.drift_rate_A = 0.005
df_drift = simulate_community(
    CONFIG, df_init=df_init, timepoints=timepoints, comm_meta=COMM_META
)
assert (df_drift[_get_sp_cols(df_drift)] >= -1e-12).all().all()
print(f"✅ Test 8: environmental_drift: {df_drift.shape}")

print("\n" + "="*60)
print("ALL TESTS PASSED")
print("="*60)

## 20. Demo: Ground Truth vs Observed

In [None]:
# ══════════════════════════════════════════════════════════════════════
# §20  COMPREHENSIVE DEMO FIGURE: Ground Truth vs Observed

## Section 

In [None]:
# ══════════════════════════════════════════════════════════════════════

print("\n" + "="*60)
print("GENERATING COMPREHENSIVE DEMO FIGURE")
print("="*60)

# Use more communities with richer composition
n_species_demo = 5
size_counts_demo = [(2, 5), (3, 5), (4, 3), (5, 2)]
timepoints_demo = np.linspace(0., 20., 101)

df_init_demo, COMM_META_DEMO = generate_community_dataframe(
    n_species=n_species_demo, size_counts=size_counts_demo,
    mean_abundance=0.01, bounds=(0.001, 0.1), sigma_log=1.0, seed=99)

CONFIG.n_species = n_species_demo
CONFIG.scenario = "hidden_trigger"
CONFIG.regime_distance = 1.0
CONFIG.theta = 0.5; CONFIG.k_u = 0.4; CONFIG.u0 = 0.0; CONFIG.epsilon = 0.1
CONFIG.trigger_sigma_theta = 0.1; CONFIG.trigger_sigma_k_u = 0.05
CONFIG.immigration_rate = 1e-4; CONFIG.immigration_scale = 1.0
CONFIG.enable_carrying_cap = False; CONFIG.enable_allee = False
CONFIG.observation_mode = "continuous"; CONFIG.n_replicates = 1

df_latent = SCENARIO_REGISTRY["hidden_trigger"](
    CONFIG, df_init=df_init_demo, timepoints=timepoints_demo,
    comm_meta=COMM_META_DEMO)

# Observed
obs_config = deepcopy(CONFIG)
obs_config.observation_mode = "dirichlet_multinomial"
obs_config.library_size_mean = 5000; obs_config.library_size_sigma = 0.5
obs_config.dm_alpha_scale = 80.0; obs_config.enable_dropout = True
obs_config.detection_limit = 3; obs_config.n_replicates = 3
df_observed = apply_observation_model(df_latent, obs_config)

# Metabolite + diversity
C_met = np.abs(np.random.default_rng(77).normal(0, .3, (1, n_species_demo)))
df_latent_met = compute_metabolite_trajectories(df_latent, C=C_met, sigma_obs=0.01)
df_latent["Shannon"] = compute_shannon_diversity(df_latent)

# Pick 3 communities with >= 3 species
multi_sp_comms = [c for c, idxs in COMM_META_DEMO.items() if len(idxs) >= 3]
comms_to_plot = sorted(multi_sp_comms)[:3]
sp = [f"sp{i+1}" for i in range(n_species_demo)]
colors = _colors(n_species_demo)
n_comms = len(comms_to_plot)

fig = plt.figure(figsize=(18, 4.2 * n_comms))
outer_gs = gridspec.GridSpec(n_comms, 1, hspace=0.38)

for ci, comm in enumerate(comms_to_plot):
    inner_gs = gridspec.GridSpecFromSubplotSpec(
        4, 1, subplot_spec=outer_gs[ci],
        height_ratios=[3, 3, 1.5, 1.5], hspace=0.3)

    g_lat = df_latent[df_latent["Comm_name"] == comm].sort_values("Time")
    g_met = df_latent_met[df_latent_met["Comm_name"] == comm].sort_values("Time")
    t_lat = g_lat["Time"].values.astype(float)

    def _shade(ax, g, t_arr, alpha=0.22):
        if "regime" not in g.columns: return
        reg = g["regime"].values; ts = t_arr[0]; cr = reg[0]
        for k in range(1, len(t_arr)):
            if reg[k] != cr or k == len(t_arr) - 1:
                ax.axvspan(ts, t_arr[k], alpha=alpha,
                           color="#FFD580" if cr == 1 else "#E8F4E8", zorder=0)
                ts = t_arr[k]; cr = reg[k]

    # Panel A: Latent truth
    ax_a = fig.add_subplot(inner_gs[0])
    for si, col in enumerate(sp):
        ax_a.plot(t_lat, g_lat[col].values, color=colors[si], lw=1.8,
                  label=col, zorder=3)
    _shade(ax_a, g_lat, t_lat)
    if "t_switch_true" in g_lat.columns:
        t_sw = g_lat["t_switch_true"].iloc[0]
        if pd.notna(t_sw):
            ax_a.axvline(t_sw, color="k", lw=1.2, ls="--", alpha=0.6,
                         label=f"switch t={t_sw:.1f}")
    ax_a.set_ylabel("Abundance\n(latent truth)", fontsize=9)
    ax_a.set_title(f"{comm} — Ground Truth vs Observed  "
                   f"(species present: {COMM_META_DEMO[comm]})",
                   loc="left", fontweight="bold", fontsize=11)
    ax_a.legend(fontsize=7, ncol=min(4, n_species_demo + 1), loc="upper right")
    ax_a.grid(True, lw=0.3, alpha=0.4); ax_a.tick_params(labelbottom=False)

    # Panel B: Observed
    ax_b = fig.add_subplot(inner_gs[1])
    df_obs_comm = df_observed[
        (df_observed["Comm_name"] == comm) &
        (df_observed["ObservationType"] == "relative_abundance")]
    markers = ["o", "s", "^"]
    for rep_id in sorted(df_obs_comm["ReplicateID"].unique()):
        df_rep = df_obs_comm[df_obs_comm["ReplicateID"] == rep_id].sort_values("Time")
        t_obs = df_rep["Time"].values.astype(float)
        for si, col in enumerate(sp):
            ax_b.scatter(t_obs, df_rep[col].values.astype(float),
                         color=colors[si], marker=markers[int(rep_id) % 3],
                         s=14, alpha=0.45, edgecolors="none", zorder=3)
    # latent compositional reference
    lat_total = np.maximum(g_lat[sp].values.astype(float).sum(axis=1), 1e-15)
    for si, col in enumerate(sp):
        ax_b.plot(t_lat, g_lat[col].values.astype(float) / lat_total,
                  color=colors[si], lw=0.8, ls="--", alpha=0.5)
    _shade(ax_b, g_lat, t_lat, alpha=0.15)
    ax_b.set_ylabel("Rel. Abundance\n(DM observed)", fontsize=9)
    ax_b.set_ylim(-0.05, 1.05)
    ax_b.grid(True, lw=0.3, alpha=0.4); ax_b.tick_params(labelbottom=False)

    # Panel C: Hidden u(t)
    ax_c = fig.add_subplot(inner_gs[2])
    if "u" in g_lat.columns:
        ax_c.plot(t_lat, g_lat["u"].values, color="steelblue", lw=1.5)
        ax_c.axhline(CONFIG.theta, color="crimson", ls="--", lw=1,
                      label=f"θ={CONFIG.theta}")
        ax_c.legend(fontsize=7, loc="lower right")
    ax_c.set_ylabel("u(t)", fontsize=9)
    ax_c.grid(True, lw=0.3, alpha=0.4); ax_c.tick_params(labelbottom=False)

    # Panel D: Metabolite + diversity
    ax_d = fig.add_subplot(inner_gs[3])
    if "met1" in g_met.columns:
        ax_d.plot(g_met["Time"].values, g_met["met1"].values,
                  color="darkorange", lw=1.5, label="Metabolite")
    ax_d2 = ax_d.twinx()
    ax_d2.plot(t_lat, g_lat["Shannon"].values, color="purple", lw=1.2,
               ls=":", label="Shannon H")
    ax_d.set_ylabel("Metabolite", fontsize=9, color="darkorange")
    ax_d2.set_ylabel("Shannon H", fontsize=9, color="purple")
    ax_d.set_xlabel("Time", fontsize=10)
    ax_d.grid(True, lw=0.3, alpha=0.4)
    h1, l1 = ax_d.get_legend_handles_labels()
    h2, l2 = ax_d2.get_legend_handles_labels()
    ax_d.legend(h1 + h2, l1 + l2, fontsize=7, loc="upper left")

fig.suptitle(
    "Synthetic Microbiome Generator v3\n"
    "Ground Truth (latent gLV) vs Sequencing Observations (Dirichlet-Multinomial)",
    fontsize=13, fontweight="bold", y=1.02)

plt.savefig("/home/claude/demo_figure.png", dpi=150, bbox_inches="tight",
            facecolor="white")
plt.close()
print("✅ Demo figure saved.")