In [None]:
import numpy as np

import os
import math
import random
import numpy as np

# If you use torch/gymnasium, include them here too:
import torch
import gymnasium as gym

os.makedirs("runs", exist_ok=True)

def set_seed(seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(0)


In [None]:
from dataclasses import dataclass, field

@dataclass
class DBSBounds:
    amp_mA_min: float = 0.1
    amp_mA_max: float = 5.0
    freq_Hz_min: float = 5.0
    freq_Hz_max: float = 200.0
    pw_ms_min: float = 0.05
    pw_ms_max: float = 0.5

@dataclass
class PlantConfig:
    dt_ms: float = 0.05
    window_ms: float = 250.0
    sigma_S_per_m: float = 0.2

    onset_weight_mult: float = 50.0

    # safety bounds (also used by action mapping)
    max_amp_mA: float = 10.0
    max_pw_ms: float = 1.0
    max_freq_Hz: float = 250.0

@dataclass
class NetworkConfig:
    n_cells: int = 80
    frac_gm: float = 0.65

    gm_frac_E: float = 0.75
    gm_frac_I_fast: float = 0.15
    # remainder is I_slow

    p_conn: float = 0.08
    w_exc: float = 0.002
    w_inh: float = -0.004

    # “spatial box” around focus center (mm)
    spatial_extent_mm: float = 1.5

    # WM axon geometry (um)
    wm_axon_L_um: float = 800.0
    wm_axon_diam_um: float = 2.0

    # tissue mode: "GM", "WM", "BOUNDARY"
    tissue_mode: str = "BOUNDARY"

@dataclass
class EnvConfig:
    episode_steps: int = 40
    obs_clip: float = 1e6
    reward_scale: float = 1.0
    # in EnvConfig
    min_amp_mA: float = 0.0   # default allow OFF


    # how much baseline to run at reset (optional)
    baseline_windows: int = 1

    # action mapping
    dbs_bounds: DBSBounds = field(default_factory=DBSBounds)

    # observation composition toggles (keep simple)
    include_last_action_in_obs: bool = True




In [None]:
def map_action_to_dbs(action: np.ndarray, amp_min, amp_max, freq_min, freq_max, pw_min, pw_max):
    """
    action in [-1, 1]^3 -> (amp_mA, freq_Hz, pw_ms) within bounds
    """
    a = np.clip(action.astype(float), -1.0, 1.0)
    # affine map: [-1,1] -> [0,1] -> [min,max]
    u = 0.5 * (a + 1.0)
    amp = amp_min + u[0] * (amp_max - amp_min)
    freq = freq_min + u[1] * (freq_max - freq_min)
    pw = pw_min + u[2] * (pw_max - pw_min)
    return float(amp), float(freq), float(pw)

def safe_clip(x: np.ndarray, bound: float) -> np.ndarray:
    return np.clip(x, -bound, bound)


In [None]:
import numpy as np
from typing import List

def burst_fraction_from_spikes(
    spike_times_ms: np.ndarray,
    isi_thresh_ms: float = 10.0,
    min_spikes_in_burst: int = 3
) -> float:
    """Fraction of spikes that belong to an ISI-defined burst."""
    st = np.asarray(spike_times_ms, dtype=float)
    if st.size < min_spikes_in_burst:
        return 0.0
    st = np.sort(st)
    isi = np.diff(st)
    fast = isi < float(isi_thresh_ms)

    count_in_bursts = 0
    run_len = 0
    for f in fast:
        if f:
            run_len += 1
        else:
            if run_len >= (min_spikes_in_burst - 1):
                count_in_bursts += (run_len + 1)
            run_len = 0
    if run_len >= (min_spikes_in_burst - 1):
        count_in_bursts += (run_len + 1)

    return float(count_in_bursts) / float(st.size)

def sync_from_spike_trains(
    spike_lists_ms: List[np.ndarray],
    window_ms: float,
    bin_ms: float = 5.0
) -> float:
    """
    Synchrony proxy:
    - Bin population spikes in bin_ms windows
    - Compute CV = std/mean of population bin counts
    - Return tanh(CV) to keep in [0,1)
    """
    if len(spike_lists_ms) == 0:
        return 0.0

    n_bins = int(np.ceil(float(window_ms) / float(bin_ms)))
    if n_bins <= 1:
        return 0.0

    pop = np.zeros(n_bins, dtype=float)
    for st in spike_lists_ms:
        if st.size == 0:
            continue
        idx = np.floor(st / float(bin_ms)).astype(int)
        idx = idx[(idx >= 0) & (idx < n_bins)]
        if idx.size:
            pop += np.bincount(idx, minlength=n_bins).astype(float)

    mu = float(np.mean(pop))
    if mu < 1e-9:
        return 0.0

    cv = float(np.std(pop)) / mu
    return float(np.tanh(cv))

In [None]:
# logging.py
from __future__ import annotations
import os, json
import numpy as np

class EpisodeLogger:
    def __init__(self, out_dir: str):
        self.out_dir = out_dir
        os.makedirs(self.out_dir, exist_ok=True)
        self.reset()

    def reset(self):
        self.rows = []
        self.ep_info = {}

    def log_step(self, row: dict):
        # row should only contain JSON-serializable scalars / strings
        self.rows.append(row)

    def end_episode(self, ep_index: int, summary: dict):
        # Save stepwise data
        step_path = os.path.join(self.out_dir, f"episode_{ep_index:06d}.jsonl")
        with open(step_path, "w", encoding="utf-8") as f:
            for r in self.rows:
                f.write(json.dumps(r) + "\n")

        # Save episode summary
        sum_path = os.path.join(self.out_dir, f"episode_{ep_index:06d}_summary.json")
        with open(sum_path, "w", encoding="utf-8") as f:
            json.dump(summary, f, indent=2)

        self.reset()


In [None]:
"""
Case specification + generators for the DBS-epilepsy RL environment.

This module is designed to be:
- Complete and reproducible (seeded generation, deterministic suites)
- Scalable (sample_n for large case batches)
- Explicit about anatomy/tissue modes, electrode geometry/placement, onset topology,
  severity/excitability, and per-case network/connectivity variants.

Intended usage:
- Training: CaseGenerator.sample() / sample_n()
- Evaluation: CaseSuite.grid(...) to produce deterministic grids/sweeps
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Sequence, Tuple
import hashlib
import math

import numpy as np

Vec3 = Tuple[float, float, float]
TissueMode = Literal["GM", "WM", "BOUNDARY"]
ConnectivityRegime = Literal["none", "sparse", "dense", "small_world", "clustered"]


# -----------------------------------------------------------------------------
# Core specs
# -----------------------------------------------------------------------------

@dataclass(frozen=True)
class ElectrodeSpec:
    """Electrode placement and coupling model selection."""
    xyz_mm: Vec3
    orientation_unit: Vec3 = (0.0, 0.0, 1.0)
    model: str = "point_source"  # e.g., point_source, monopolar_contact, directional
    contact_radius_mm: float = 0.0
    reference: str = "monopolar"  # monopolar / bipolar, etc.

    @property
    def xyz(self) -> np.ndarray:
        return np.asarray(self.xyz_mm, dtype=float)


@dataclass(frozen=True)
class BoundaryPlane:
    """WM–GM boundary represented by a plane: (x - p)·n = 0."""
    point_mm: Vec3
    normal_unit: Vec3

    def signed_distance_mm(self, xyz_mm: Vec3) -> float:
        p = np.asarray(self.point_mm, dtype=float)
        n = np.asarray(self.normal_unit, dtype=float)
        n = n / (np.linalg.norm(n) + 1e-12)
        x = np.asarray(xyz_mm, dtype=float)
        return float(np.dot(x - p, n))


@dataclass(frozen=True)
class AnisotropySpec:
    """Optional future extension. Kept for forward compatibility."""
    enabled: bool = False
    principal_dir_unit: Vec3 = (1.0, 0.0, 0.0)
    ratio_parallel_over_perp: float = 1.0  # 1.0 means isotropic


@dataclass(frozen=True)
class TissueSpec:
    mode: TissueMode
    sigma_gm_S_per_m: float = 0.2
    sigma_wm_S_per_m: float = 0.14
    boundary: Optional[BoundaryPlane] = None
    anisotropy: Optional[AnisotropySpec] = None

    def __post_init__(self) -> None:
        if self.mode == "BOUNDARY" and self.boundary is None:
            raise ValueError("TissueSpec.mode='BOUNDARY' requires a BoundaryPlane.")
        if self.mode != "BOUNDARY" and self.boundary is not None:
            raise ValueError("BoundaryPlane provided but TissueSpec.mode is not 'BOUNDARY'.")


@dataclass(frozen=True)
class RegionSpec:
    """Defines the spatial extent of the simulated region."""
    center_mm: Vec3
    bounds_mm: Vec3 = (3.0, 3.0, 3.0)  # half-extent box
    density_mode: str = "uniform"      # uniform / clustered / boundary_biased etc.

    def sample_point_mm(self, rng: np.random.Generator) -> Vec3:
        c = np.asarray(self.center_mm, dtype=float)
        b = np.asarray(self.bounds_mm, dtype=float)
        x = c + rng.uniform(-b, b)
        return (float(x[0]), float(x[1]), float(x[2]))


@dataclass(frozen=True)
class FocusSiteSpec:
    """One candidate onset site (primary or secondary)."""
    xyz_mm: Vec3
    baseline_strength: float
    baseline_drive_hz: float
    baseline_weight: float
    site_type: str = "secondary"  # secondary / latent / primary / control

    @property
    def xyz(self) -> np.ndarray:
        return np.asarray(self.xyz_mm, dtype=float)


@dataclass(frozen=True)
class FocusClusterSpec:
    """Primary focus described by a cluster plus (optional) multiple sites within it."""
    center_mm: Vec3
    radius_mm: float
    n_sites: int = 1


@dataclass(frozen=True)
class OnsetSpec:
    """Primary focus cluster + explicit secondary sites."""
    primary_cluster: FocusClusterSpec
    primary_sites: Tuple[FocusSiteSpec, ...]
    secondary_sites: Tuple[FocusSiteSpec, ...]


@dataclass(frozen=True)
class NetworkVariantSpec:
    """Per-case overrides for neuron composition / counts."""
    n_total: int = 80
    frac_gm: float = 0.7
    frac_exc_in_gm: float = 0.75
    frac_inh_fast_in_gm: float = 0.15
    frac_inh_slow_in_gm: float = 0.10
    wm_axon_params: Dict[str, float] = field(default_factory=lambda: {"L_um": 800.0, "diam_um": 1.0, "nseg": 9})

    def validate(self) -> None:
        if not (0.0 <= self.frac_gm <= 1.0):
            raise ValueError("frac_gm must be in [0,1].")
        s = self.frac_exc_in_gm + self.frac_inh_fast_in_gm + self.frac_inh_slow_in_gm
        if abs(s - 1.0) > 1e-6:
            raise ValueError("GM fractions must sum to 1.0 (exc + inh_fast + inh_slow).")
        if self.n_total <= 0:
            raise ValueError("n_total must be > 0.")


@dataclass(frozen=True)
class ConnectivityVariantSpec:
    """Per-case connectivity regime knobs (may be ignored by a minimal plant)."""
    regime: ConnectivityRegime = "none"
    p_conn: float = 0.0
    w_exc: float = 0.01
    w_inh: float = 0.01
    delay_ms_range: Tuple[float, float] = (1.0, 5.0)
    e_i_balance_mode: str = "fixed"

    def validate(self) -> None:
        if not (0.0 <= self.p_conn <= 1.0):
            raise ValueError("p_conn must be in [0,1].")
        lo, hi = self.delay_ms_range
        if lo < 0.0 or hi < lo:
            raise ValueError("delay_ms_range must satisfy 0 <= lo <= hi.")


@dataclass(frozen=True)
class CaseSpec:
    """A complete scenario instance."""
    case_id: str
    rng_seed: int

    electrode: ElectrodeSpec
    tissue: TissueSpec
    region: RegionSpec
    onset: OnsetSpec

    severity: float
    baseline_burden: float
    excitability: Dict[str, float] = field(default_factory=dict)

    network_variant: NetworkVariantSpec = field(default_factory=NetworkVariantSpec)
    connectivity_variant: ConnectivityVariantSpec = field(default_factory=ConnectivityVariantSpec)

    descriptors: Dict[str, float] = field(default_factory=dict)
    tags: Dict[str, object] = field(default_factory=dict)

    primary_onsets_xyz_mm: List[List[float]] = field(default_factory=list)
    secondary_onsets_xyz_mm: List[List[float]] = field(default_factory=list)

    def validate(self) -> None:
        if not (0.0 <= self.severity <= 2.0):
            raise ValueError("severity expected in [0,2] (recommended [0,1]).")
        if self.baseline_burden < 0.0:
            raise ValueError("baseline_burden must be >= 0.")
        self.network_variant.validate()
        self.connectivity_variant.validate()

    @property
    def primary_center_mm(self) -> Vec3:
        return self.onset.primary_cluster.center_mm

    def dist_electrode_to_primary_mm(self) -> float:
        e = self.electrode.xyz
        c = np.asarray(self.primary_center_mm, dtype=float)
        return float(np.linalg.norm(e - c))

    def boundary_signed_dist_primary_mm(self) -> Optional[float]:
        if self.tissue.mode != "BOUNDARY" or self.tissue.boundary is None:
            return None
        return self.tissue.boundary.signed_distance_mm(self.primary_center_mm)


# -----------------------------------------------------------------------------
# Helper functions
# -----------------------------------------------------------------------------

def _unit(v: np.ndarray) -> np.ndarray:
    n = float(np.linalg.norm(v))
    if n < 1e-12:
        return np.array([1.0, 0.0, 0.0], dtype=float)
    return v / n


def _stable_hash_to_int(s: str, mod: int = 2**31 - 1) -> int:
    h = hashlib.sha256(s.encode("utf-8")).hexdigest()
    return int(h[:16], 16) % mod


def _case_id_from_fields(prefix: str, fields: Dict[str, object]) -> str:
    # Deterministic and human-readable-ish
    parts = [prefix]
    for k in sorted(fields.keys()):
        v = fields[k]
        if isinstance(v, float):
            parts.append(f"{k}={v:.4g}")
        else:
            parts.append(f"{k}={v}")
    base = "|".join(parts)
    short = hashlib.md5(base.encode("utf-8")).hexdigest()[:10]
    return f"{prefix}-{short}"


def place_electrode_at_distance(
    primary_center_mm: Vec3,
    dist_mm: float,
    rng: np.random.Generator,
    ray_mode: Literal["random", "canonical"] = "random",
    canonical_axis: Optional[Literal["+x", "-x", "+y", "-y", "+z", "-z"]] = None,
) -> Tuple[Vec3, Vec3]:
    """
    Place electrode exactly dist_mm from primary center along a ray direction.
    Returns (electrode_xyz_mm, ray_dir_unit).
    """
    c = np.asarray(primary_center_mm, dtype=float)

    if ray_mode == "canonical":
        axis = canonical_axis or "+x"
        mapping = {
            "+x": np.array([1.0, 0.0, 0.0]),
            "-x": np.array([-1.0, 0.0, 0.0]),
            "+y": np.array([0.0, 1.0, 0.0]),
            "-y": np.array([0.0, -1.0, 0.0]),
            "+z": np.array([0.0, 0.0, 1.0]),
            "-z": np.array([0.0, 0.0, -1.0]),
        }
        u = mapping[axis]
    else:
        # Uniform on sphere using normal then normalize
        u = _unit(rng.normal(size=3))

    e = c + float(dist_mm) * u
    return (float(e[0]), float(e[1]), float(e[2])), (float(u[0]), float(u[1]), float(u[2]))


def sample_boundary_plane(
    rng: np.random.Generator,
    point_mm: Vec3,
    normal_mode: Literal["random", "canonical"] = "random",
    canonical_axis: Optional[Literal["+x", "-x", "+y", "-y", "+z", "-z"]] = None,
) -> BoundaryPlane:
    p = point_mm
    if normal_mode == "canonical":
        axis = canonical_axis or "+z"
        mapping = {
            "+x": (1.0, 0.0, 0.0),
            "-x": (-1.0, 0.0, 0.0),
            "+y": (0.0, 1.0, 0.0),
            "-y": (0.0, -1.0, 0.0),
            "+z": (0.0, 0.0, 1.0),
            "-z": (0.0, 0.0, -1.0),
        }
        n = mapping[axis]
        return BoundaryPlane(point_mm=p, normal_unit=n)
    n = _unit(rng.normal(size=3))
    return BoundaryPlane(point_mm=p, normal_unit=(float(n[0]), float(n[1]), float(n[2])))


def _beta_strength(rng: np.random.Generator, a: float, b: float, lo: float = 0.0, hi: float = 1.0) -> float:
    x = float(rng.beta(a, b))
    return lo + (hi - lo) * x


# -----------------------------------------------------------------------------
# Generators
# -----------------------------------------------------------------------------

@dataclass
class CaseGeneratorConfig:

    forced_tissue_mode: Optional[str] = None

    """Ranges and options for training-time case sampling."""
    # Tissue mode proportions
    p_gm: float = 0.34
    p_wm: float = 0.33
    p_boundary: float = 0.33

    # Region geometry
    region_bounds_mm: Vec3 = (3.0, 3.0, 3.0)

    # Electrode distance sweep range
    electrode_dist_mm_range: Tuple[float, float] = (0.5, 5.0)
    electrode_ray_mode: Literal["random", "canonical"] = "random"

    # Boundary options
    boundary_normal_mode: Literal["random", "canonical"] = "random"

    # Severity range
    severity_range: Tuple[float, float] = (0.1, 0.95)

    # Primary cluster
    primary_cluster_radius_mm_range: Tuple[float, float] = (0.4, 1.0)
    n_primary_sites_range: Tuple[int, int] = (1, 5)

    # Secondary foci
    n_secondary_range: Tuple[int, int] = (3, 12)
    secondary_mix: Tuple[float, float, float] = (0.5, 0.35, 0.15)  # near, mid, far fractions

    # Baseline drive parameters
    # (In minimal HH+Exp2Syn plants, 0.01–0.03 is often subthreshold; widen for spiking regimes.)
    drive_rate_hz_range: Tuple[float, float] = (20.0, 80.0)
    drive_weight_range: Tuple[float, float] = (0.05, 0.40)


    # Secondary site baseline strength distribution (latent-leaning)
    secondary_strength_beta: Tuple[float, float] = (2.0, 6.0)

    # Network composition variability
    n_total_range: Tuple[int, int] = (60, 120)
    frac_gm_range: Tuple[float, float] = (0.55, 0.85)
    frac_exc_in_gm_range: Tuple[float, float] = (0.65, 0.85)
    frac_inh_fast_in_gm_range: Tuple[float, float] = (0.10, 0.25)
    # inh_slow is 1 - exc - inh_fast (clamped)

    # Connectivity regime sampling
    connectivity_regimes: Tuple[ConnectivityRegime, ...] = ("none", "sparse", "dense")
    p_conn_by_regime: Dict[ConnectivityRegime, Tuple[float, float]] = field(
        default_factory=lambda: {
            "none": (0.0, 0.0),
            "sparse": (0.01, 0.05),
            "dense": (0.05, 0.15),
            "small_world": (0.03, 0.10),
            "clustered": (0.03, 0.10),
        }
    )

    # Conductivities (defaults, can be varied externally)
    sigma_gm_S_per_m: float = 0.2
    sigma_wm_S_per_m: float = 0.14


class CaseGenerator:
    """Stochastic training-time case sampler with reproducibility."""
    def __init__(self, cfg: Optional[CaseGeneratorConfig] = None, rng_seed: int = 0):
        self.cfg = cfg or CaseGeneratorConfig()
        self._root_seed = int(rng_seed)
        self._rng = np.random.default_rng(self._root_seed)

        # Normalize tissue probabilities
        s = self.cfg.p_gm + self.cfg.p_wm + self.cfg.p_boundary
        if s <= 0:
            raise ValueError("Invalid tissue probabilities: sum must be > 0.")
        self._p = np.array([self.cfg.p_gm, self.cfg.p_wm, self.cfg.p_boundary], dtype=float) / s
        self._modes: Tuple[TissueMode, ...] = ("GM", "WM", "BOUNDARY")

    def _spawn_case_rng(self) -> Tuple[int, np.random.Generator]:
        # Each case gets its own seed derived from the root RNG
        case_seed = int(self._rng.integers(1, 2**31 - 1))
        return case_seed, np.random.default_rng(case_seed)

    def sample(self, *, tags: Optional[Dict[str, object]] = None) -> CaseSpec:
        """Sample one randomized case."""
        tags = dict(tags or {})

        # 1) Decide forced tissue mode (tags override config)
        forced_mode = None

        # Config-level forcing (if provided)
        cfg_forced = getattr(self.cfg, "forced_tissue_mode", None)
        if cfg_forced is not None:
            cfg_forced = str(cfg_forced).upper().strip()
            if cfg_forced != "":
                forced_mode = cfg_forced

        # Per-sample forcing overrides config
        tag_forced = tags.get("forced_tissue_mode", None)
        if tag_forced is not None:
            tag_forced = str(tag_forced).upper().strip()
            if tag_forced != "":
                forced_mode = tag_forced

        # 2) Validate only if forcing is requested
        if forced_mode is not None and forced_mode not in ("GM", "WM", "BOUNDARY"):
            raise ValueError(f"Invalid forced_tissue_mode: {forced_mode}")

        # 3) Spawn RNG for this case
        case_seed, rng = self._spawn_case_rng()
        cfg = self.cfg

        # 4) Select tissue mode
        if forced_mode is not None:
            tissue_mode: TissueMode = forced_mode  # type: ignore[assignment]
            tags["forced_tissue_mode"] = forced_mode  # persist into CaseSpec.tags
        else:
            tissue_mode: TissueMode = self._modes[int(rng.choice(len(self._modes), p=self._p))]

        # Severity and baseline burden
        severity = float(rng.uniform(*cfg.severity_range))
        baseline_burden = float(max(0.0, severity))  # explicit; can be a nonlinear map later

        # Region & primary center
        region_center = (0.0, 0.0, 0.0)
        region = RegionSpec(center_mm=region_center, bounds_mm=cfg.region_bounds_mm, density_mode="uniform")
        primary_center = region.sample_point_mm(rng)

        # Primary cluster
        r_primary = float(rng.uniform(*cfg.primary_cluster_radius_mm_range))
        n_primary_sites = int(rng.integers(cfg.n_primary_sites_range[0], cfg.n_primary_sites_range[1] + 1))
        primary_cluster = FocusClusterSpec(center_mm=primary_center, radius_mm=r_primary, n_sites=n_primary_sites)

        # Electrode at a controlled distance
        dist_mm = float(rng.uniform(*cfg.electrode_dist_mm_range))
        electrode_xyz, ray_dir = place_electrode_at_distance(
            primary_center_mm=primary_center,
            dist_mm=dist_mm,
            rng=rng,
            ray_mode=cfg.electrode_ray_mode,
            canonical_axis=None,
        )
        electrode = ElectrodeSpec(xyz_mm=electrode_xyz, orientation_unit=(0.0, 0.0, 1.0), model="point_source")

        # Tissue spec (+ boundary plane if needed)
        boundary = None
        if tissue_mode == "BOUNDARY":
            # Place boundary plane through primary center for now (can shift later)
            boundary = sample_boundary_plane(
                rng=rng,
                point_mm=primary_center,
                normal_mode=cfg.boundary_normal_mode,
            )
        tissue = TissueSpec(
            mode=tissue_mode,
            sigma_gm_S_per_m=cfg.sigma_gm_S_per_m,
            sigma_wm_S_per_m=cfg.sigma_wm_S_per_m,
            boundary=boundary,
            anisotropy=None,
        )

       # Baseline drive params (primary)
        primary_drive_hz = float(rng.uniform(*cfg.drive_rate_hz_range))
        primary_weight = float(rng.uniform(*cfg.drive_weight_range))

        # Couple severity into the generated onset drive so higher severity cases are genuinely more excitable.
        # Keep names the same; just scale the sampled values.
        drive_scale = 0.8 + 0.6 * severity
        syn_weight_scale = 0.9 + 0.4 * severity

        primary_drive_hz *= drive_scale
        primary_weight *= syn_weight_scale


        # Primary sites (within cluster)
        primary_sites: List[FocusSiteSpec] = []
        for _ in range(n_primary_sites):
            # Uniform in ball: sample direction * radius * u^(1/3)
            d = _unit(rng.normal(size=3))
            u = float(rng.uniform(0.0, 1.0)) ** (1.0 / 3.0)
            offset = d * (u * r_primary)
            xyz = tuple((np.asarray(primary_center) + offset).tolist())
            primary_sites.append(
                FocusSiteSpec(
                    xyz_mm=(float(xyz[0]), float(xyz[1]), float(xyz[2])),
                    baseline_strength=float(rng.uniform(0.9, 1.0)),
                    baseline_drive_hz=primary_drive_hz,
                    baseline_weight=primary_weight,
                    site_type="primary",
                )
            )

        # Secondary sites mixture: near/mid/far shells around primary
        n_secondary = int(rng.integers(cfg.n_secondary_range[0], cfg.n_secondary_range[1] + 1))
        mix = np.asarray(cfg.secondary_mix, dtype=float)
        mix = mix / (mix.sum() + 1e-12)
        n_near = int(round(n_secondary * mix[0]))
        n_mid = int(round(n_secondary * mix[1]))
        n_far = max(0, n_secondary - n_near - n_mid)

        # Define distance shells (relative to primary radius)
        # near: [0.8r, 2r], mid: [2r, 4r], far: [4r, 7r] (clipped by region bounds implicitly)
        shells = (
            (0.8 * r_primary, 2.0 * r_primary, n_near),
            (2.0 * r_primary, 4.0 * r_primary, n_mid),
            (4.0 * r_primary, 7.0 * r_primary, n_far),
        )

        a_beta, b_beta = cfg.secondary_strength_beta
        secondary_sites: List[FocusSiteSpec] = []
        for r_lo, r_hi, count in shells:
            for _ in range(count):
                d = _unit(rng.normal(size=3))
                rad = float(rng.uniform(r_lo, r_hi))
                xyz = np.asarray(primary_center) + d * rad
                strength = _beta_strength(rng, a_beta, b_beta, lo=0.0, hi=1.0)

                sec_drive_hz = float(primary_drive_hz * (0.4 + 0.6 * strength) * rng.uniform(0.7, 1.1))
                sec_weight   = float(primary_weight   * (0.4 + 0.6 * strength) * rng.uniform(0.7, 1.1))

                # Floors to avoid secondaries becoming completely inert
                sec_drive_hz = max(5.0, sec_drive_hz)
                sec_weight   = max(0.02, sec_weight)


                site_type = "latent" if strength < 0.35 else "secondary"
                secondary_sites.append(
                    FocusSiteSpec(
                        xyz_mm=(float(xyz[0]), float(xyz[1]), float(xyz[2])),
                        baseline_strength=float(strength),
                        baseline_drive_hz=sec_drive_hz,
                        baseline_weight=sec_weight,
                        site_type=site_type,
                    )
                )

        onset = OnsetSpec(
            primary_cluster=primary_cluster,
            primary_sites=tuple(primary_sites),
            secondary_sites=tuple(secondary_sites),
        )

        # Per-case network variant
        n_total = int(rng.integers(cfg.n_total_range[0], cfg.n_total_range[1] + 1))
        frac_gm = float(rng.uniform(*cfg.frac_gm_range))
        frac_exc = float(rng.uniform(*cfg.frac_exc_in_gm_range))
        frac_inh_fast = float(rng.uniform(*cfg.frac_inh_fast_in_gm_range))

        # Ensure remaining fraction is valid
        frac_inh_slow = max(0.0, 1.0 - frac_exc - frac_inh_fast)
        # If we overshot, renormalize exc/inh_fast to leave at least 0.05 slow
        if frac_inh_slow < 0.05:
            target_slow = 0.05
            remaining = 1.0 - target_slow
            # scale exc and inh_fast to sum to remaining
            s2 = frac_exc + frac_inh_fast
            if s2 < 1e-9:
                frac_exc, frac_inh_fast = remaining * 0.75, remaining * 0.25
            else:
                frac_exc = frac_exc / s2 * remaining
                frac_inh_fast = frac_inh_fast / s2 * remaining
            frac_inh_slow = target_slow

        network_variant = NetworkVariantSpec(
            n_total=n_total,
            frac_gm=frac_gm,
            frac_exc_in_gm=frac_exc,
            frac_inh_fast_in_gm=frac_inh_fast,
            frac_inh_slow_in_gm=frac_inh_slow,
            wm_axon_params={"L_um": 800.0, "diam_um": 1.0, "nseg": 9},
        )

        # Connectivity variant
        regime = str(rng.choice(cfg.connectivity_regimes))
        p_lo, p_hi = cfg.p_conn_by_regime.get(regime, (0.0, 0.0))
        p_conn = float(rng.uniform(p_lo, p_hi))
        connectivity_variant = ConnectivityVariantSpec(
            regime=regime,  # type: ignore[arg-type]
            p_conn=p_conn,
            w_exc=float(rng.uniform(0.005, 0.02)),
            w_inh=float(rng.uniform(0.005, 0.02)),
            delay_ms_range=(1.0, 5.0),
            e_i_balance_mode="fixed",
        )

        # Excitability mapping (minimal but explicit)
        excitability = {
            "drive_scale": 0.8 + 0.6 * severity,
            "noise_scale": 0.8 + 0.8 * severity,
            "syn_weight_scale": 0.9 + 0.4 * severity,
        }

        # Descriptors for optional policy conditioning
        descriptors: Dict[str, float] = {
            "severity": float(severity),
            "dist_electrode_to_primary_mm": float(dist_mm),
            "n_secondary": float(n_secondary),
            "frac_gm": float(frac_gm),
        }
        if tissue_mode == "BOUNDARY" and boundary is not None:
            descriptors["boundary_signed_dist_primary_mm"] = float(boundary.signed_distance_mm(primary_center))
            # also store boundary normal components for analysis if desired
            descriptors["boundary_nx"] = float(boundary.normal_unit[0])
            descriptors["boundary_ny"] = float(boundary.normal_unit[1])
            descriptors["boundary_nz"] = float(boundary.normal_unit[2])

        # Stable deterministic ID for this sampled case (seed included)
        id_fields = {
            "seed": case_seed,
            "mode": tissue_mode,
            "sev": severity,
            "dist": dist_mm,
            "ns": n_secondary,
            "ntot": n_total,
            "reg": regime,
        }
        case_id = _case_id_from_fields("case", id_fields)

        primary_onsets_xyz_mm = [list(s.xyz_mm) for s in onset.primary_sites]
        secondary_onsets_xyz_mm = [list(s.xyz_mm) for s in onset.secondary_sites]


        case = CaseSpec(
            case_id=case_id,
            rng_seed=case_seed,
            electrode=electrode,
            tissue=tissue,
            region=RegionSpec(center_mm=region_center, bounds_mm=cfg.region_bounds_mm, density_mode="uniform"),
            onset=onset,
            severity=severity,
            baseline_burden=baseline_burden,
            excitability=excitability,
            network_variant=network_variant,
            connectivity_variant=connectivity_variant,
            descriptors=descriptors,
            tags=tags or {},
            primary_onsets_xyz_mm=primary_onsets_xyz_mm,
            secondary_onsets_xyz_mm=secondary_onsets_xyz_mm,
        )
        case.validate()
        return case

    def sample_n(self, n: int, *, tags: Optional[Dict[str, object]] = None) -> List[CaseSpec]:
        if n <= 0:
            return []
        return [self.sample(tags=tags) for _ in range(n)]


# -----------------------------------------------------------------------------
# Deterministic case suites
# -----------------------------------------------------------------------------

@dataclass(frozen=True)
class SuiteAxis:
    name: str
    values: Tuple[object, ...]


class CaseSuite:
    """
    Deterministic evaluation suite builder.

    Primary use: grid over tissue_mode × distance × severity, with optional
    additional axes for boundary normals, connectivity regimes, and composition presets.
    """
    @staticmethod
    def grid(
        *,
        suite_seed: int,
        tissue_modes: Sequence[TissueMode] = ("GM", "WM", "BOUNDARY"),
        distances_mm: Sequence[float] = (0.5, 1.0, 2.0, 3.5, 5.0),
        severities: Sequence[float] = (0.1, 0.3, 0.6, 0.9),
        # Optional axes:
        boundary_normals: Optional[Sequence[Vec3]] = None,
        connectivity_regimes: Sequence[ConnectivityRegime] = ("none", "sparse", "dense"),
        # Secondary / cluster settings (fixed or simple sweeps):
        primary_cluster_radius_mm: float = 0.7,
        n_primary_sites: int = 3,
        n_secondary: int = 8,
        # Network presets (optional; if None, use a default)
        network_presets: Optional[Sequence[NetworkVariantSpec]] = None,
        connectivity_presets: Optional[Sequence[ConnectivityVariantSpec]] = None,
        # Conductivities:
        sigma_gm_S_per_m: float = 0.2,
        sigma_wm_S_per_m: float = 0.14,
        # Ray mode for electrode placement:
        ray_mode: Literal["canonical", "random"] = "canonical",
        canonical_rays: Sequence[Literal["+x", "-x", "+y", "-y", "+z", "-z"]] = ("+x", "+y", "+z"),
        # Secondary strength distribution:
        secondary_strength_beta: Tuple[float, float] = (2.0, 6.0),
        # Drive defaults:
        drive_rate_hz: float = 40,
        drive_weight: float = 0.15,
        tag_overrides: Optional[Dict[str, object]] = None,
    ) -> List[CaseSpec]:
        rng = np.random.default_rng(int(suite_seed))

        # Default boundary normals if boundary included and none provided
        if boundary_normals is None:
            boundary_normals = [(0.0, 0.0, 1.0), (1.0, 0.0, 0.0), (0.0, 1.0, 0.0)]

        if network_presets is None:
            network_presets = [
                NetworkVariantSpec(n_total=80, frac_gm=0.7, frac_exc_in_gm=0.75, frac_inh_fast_in_gm=0.15, frac_inh_slow_in_gm=0.10),
            ]

        # If explicit connectivity presets are not provided, synthesize from regimes
        if connectivity_presets is None:
            presets: List[ConnectivityVariantSpec] = []
            for reg in connectivity_regimes:
                if reg == "none":
                    presets.append(ConnectivityVariantSpec(regime="none", p_conn=0.0))
                elif reg == "sparse":
                    presets.append(ConnectivityVariantSpec(regime="sparse", p_conn=0.03))
                elif reg == "dense":
                    presets.append(ConnectivityVariantSpec(regime="dense", p_conn=0.10))
                else:
                    presets.append(ConnectivityVariantSpec(regime=reg, p_conn=0.06))
            connectivity_presets = presets

        cases: List[CaseSpec] = []
        ray_cycle = list(canonical_rays)
        ray_idx = 0

        # Fixed region for suite (keep comparable)
        region_center = (0.0, 0.0, 0.0)
        region_bounds = (3.0, 3.0, 3.0)
        region = RegionSpec(center_mm=region_center, bounds_mm=region_bounds, density_mode="uniform")

        # Fixed primary center (can also be swept if desired)
        primary_center = (0.0, 0.0, 0.0)

        # Prepare primary sites inside cluster deterministically from suite_seed
        def make_primary_sites(local_rng: np.random.Generator) -> Tuple[FocusSiteSpec, ...]:
            sites: List[FocusSiteSpec] = []
            for _ in range(n_primary_sites):
                d = _unit(local_rng.normal(size=3))
                u = float(local_rng.uniform(0.0, 1.0)) ** (1.0 / 3.0)
                offset = d * (u * primary_cluster_radius_mm)
                xyz = tuple((np.asarray(primary_center) + offset).tolist())
                sites.append(
                    FocusSiteSpec(
                        xyz_mm=(float(xyz[0]), float(xyz[1]), float(xyz[2])),
                        baseline_strength=1.0,
                        baseline_drive_hz=float(drive_rate_hz),
                        baseline_weight=float(drive_weight),
                        site_type="primary",
                    )
                )
            return tuple(sites)

        # Prepare secondary sites deterministically per case (seeded by case_id)
        def make_secondary_sites(local_rng: np.random.Generator) -> Tuple[FocusSiteSpec, ...]:
            a_beta, b_beta = secondary_strength_beta
            sites: List[FocusSiteSpec] = []
            # distribute secondary points in shells around the primary center
            for i in range(n_secondary):
                d = _unit(local_rng.normal(size=3))
                # shell radius cycles from near to far
                shell = i % 3
                if shell == 0:
                    rad = float(local_rng.uniform(0.8 * primary_cluster_radius_mm, 2.0 * primary_cluster_radius_mm))
                elif shell == 1:
                    rad = float(local_rng.uniform(2.0 * primary_cluster_radius_mm, 4.0 * primary_cluster_radius_mm))
                else:
                    rad = float(local_rng.uniform(4.0 * primary_cluster_radius_mm, 7.0 * primary_cluster_radius_mm))
                xyz = np.asarray(primary_center) + d * rad
                strength = _beta_strength(local_rng, a_beta, b_beta, lo=0.0, hi=1.0)
                sec_drive_hz = float(drive_rate_hz * (0.4 + 0.6 * strength))
                sec_weight = float(drive_weight * (0.4 + 0.6 * strength))
                site_type = "latent" if strength < 0.35 else "secondary"
                sites.append(
                    FocusSiteSpec(
                        xyz_mm=(float(xyz[0]), float(xyz[1]), float(xyz[2])),
                        baseline_strength=float(strength),
                        baseline_drive_hz=sec_drive_hz,
                        baseline_weight=sec_weight,
                        site_type=site_type,
                    )
                )
            return tuple(sites)

        for mode in tissue_modes:
            for dist in distances_mm:
                for sev in severities:
                    for net in network_presets:
                        for conn in connectivity_presets:
                            # Deterministic ray direction selection
                            if ray_mode == "canonical":
                                axis = ray_cycle[ray_idx % len(ray_cycle)]
                                ray_idx += 1
                                electrode_xyz, ray_dir = place_electrode_at_distance(
                                    primary_center_mm=primary_center,
                                    dist_mm=float(dist),
                                    rng=rng,  # rng not used in canonical path except fallback
                                    ray_mode="canonical",
                                    canonical_axis=axis,
                                )
                            else:
                                electrode_xyz, ray_dir = place_electrode_at_distance(
                                    primary_center_mm=primary_center,
                                    dist_mm=float(dist),
                                    rng=rng,
                                    ray_mode="random",
                                )

                            electrode = ElectrodeSpec(xyz_mm=electrode_xyz, orientation_unit=(0.0, 0.0, 1.0), model="point_source")

                            boundary = None
                            if mode == "BOUNDARY":
                                # Choose a deterministic boundary normal cycling through provided normals
                                n = boundary_normals[int((ray_idx + int(dist * 10) + int(sev * 100)) % len(boundary_normals))]
                                boundary = BoundaryPlane(point_mm=primary_center, normal_unit=n)

                            tissue = TissueSpec(
                                mode=mode,
                                sigma_gm_S_per_m=float(sigma_gm_S_per_m),
                                sigma_wm_S_per_m=float(sigma_wm_S_per_m),
                                boundary=boundary,
                                anisotropy=None,
                            )

                            primary_cluster = FocusClusterSpec(center_mm=primary_center, radius_mm=float(primary_cluster_radius_mm), n_sites=int(n_primary_sites))

                            # Deterministic per-case seed derived from suite_seed + grid coords
                            id_fields = {
                                "suite": int(suite_seed),
                                "mode": mode,
                                "dist": float(dist),
                                "sev": float(sev),
                                "ntot": int(net.n_total),
                                "reg": conn.regime,
                                "p": float(conn.p_conn),
                            }
                            case_id = _case_id_from_fields("suitecase", id_fields)
                            case_seed = _stable_hash_to_int(case_id, mod=2**31 - 1)
                            local_rng = np.random.default_rng(case_seed)

                            onset = OnsetSpec(
                                primary_cluster=primary_cluster,
                                primary_sites=make_primary_sites(local_rng),
                                secondary_sites=make_secondary_sites(local_rng),
                            )

                            baseline_burden = float(max(0.0, float(sev)))

                            descriptors: Dict[str, float] = {
                                "severity": float(sev),
                                "dist_electrode_to_primary_mm": float(dist),
                                "n_secondary": float(n_secondary),
                                "frac_gm": float(net.frac_gm),
                            }
                            if mode == "BOUNDARY" and boundary is not None:
                                descriptors["boundary_signed_dist_primary_mm"] = float(boundary.signed_distance_mm(primary_center))
                                descriptors["boundary_nx"] = float(boundary.normal_unit[0])
                                descriptors["boundary_ny"] = float(boundary.normal_unit[1])
                                descriptors["boundary_nz"] = float(boundary.normal_unit[2])

                            excitability = {
                                "drive_scale": 0.8 + 0.6 * float(sev),
                                "noise_scale": 0.8 + 0.8 * float(sev),
                                "syn_weight_scale": 0.9 + 0.4 * float(sev),
                            }

                            tags = dict(tag_overrides or {})
                            tags.update({"suite_seed": int(suite_seed), "mode": mode})

                            case = CaseSpec(
                                case_id=case_id,
                                rng_seed=int(case_seed),
                                electrode=electrode,
                                tissue=tissue,
                                region=region,
                                onset=onset,
                                severity=float(sev),
                                baseline_burden=baseline_burden,
                                excitability=excitability,
                                network_variant=net,
                                connectivity_variant=conn,
                                descriptors=descriptors,
                                tags=tags,
                            )
                            case.validate()
                            cases.append(case)

        return cases


__all__ = [
    "Vec3",
    "TissueMode",
    "ConnectivityRegime",
    "ElectrodeSpec",
    "BoundaryPlane",
    "AnisotropySpec",
    "TissueSpec",
    "RegionSpec",
    "FocusSiteSpec",
    "FocusClusterSpec",
    "OnsetSpec",
    "NetworkVariantSpec",
    "ConnectivityVariantSpec",
    "CaseSpec",
    "CaseGeneratorConfig",
    "CaseGenerator",
    "CaseSuite",
    "SuiteAxis",
    "place_electrode_at_distance",
]


In [None]:
import numpy as np

def burst_fraction_from_spikes(
    spike_times_ms: np.ndarray,
    isi_thresh_ms: float = 10.0,
    min_spikes_in_burst: int = 3
) -> float:
    if spike_times_ms.size < min_spikes_in_burst:
        return 0.0
    st = np.sort(spike_times_ms)
    isi = np.diff(st)
    fast = isi < float(isi_thresh_ms)

    count_in_bursts = 0
    run_len = 0
    for f in fast:
        if f:
            run_len += 1
        else:
            if run_len >= (min_spikes_in_burst - 1):
                count_in_bursts += (run_len + 1)
            run_len = 0
    if run_len >= (min_spikes_in_burst - 1):
        count_in_bursts += (run_len + 1)

    return float(count_in_bursts) / float(st.size)

def sync_from_spike_trains(
    spike_lists_ms: List[np.ndarray],
    window_ms: float,
    bin_ms: float = 5.0
) -> float:
    """
    Synchrony proxy:
    - Bin population spikes in bin_ms windows
    - Compute CV = std/mean of population bin counts
    - Return tanh(CV) to keep in [0,1)
    """
    if len(spike_lists_ms) == 0:
        return 0.0

    n_bins = int(np.ceil(float(window_ms) / float(bin_ms)))
    if n_bins <= 1:
        return 0.0

    pop = np.zeros(n_bins, dtype=float)
    for st in spike_lists_ms:
        if st.size == 0:
            continue
        idx = np.floor(st / float(bin_ms)).astype(int)
        idx = idx[(idx >= 0) & (idx < n_bins)]
        if idx.size:
            pop += np.bincount(idx, minlength=n_bins).astype(float)

    mu = float(np.mean(pop))
    if mu < 1e-9:
        return 0.0

    cv = float(np.std(pop)) / mu
    return float(np.tanh(cv))


In [None]:
"""
Level-1 NEURON plant (biophysical layer) for DBS parameter optimization in epilepsy.

Design goals (project-aligned):
- Spatially embedded GM (HH soma) + WM (passive cable axons) populations
- Explicit CaseSpec-driven build (electrode position, tissue mode, boundary plane, onset sites)
- Distance-based extracellular coupling (baseline point-source 1/r), extensible to sigma differences and boundary effects
- Waveform generation per RL step (rectangular + optional biphasic/burst/duty cycle)
- Primary + secondary seizure-onset drive via NetStim -> Exp2Syn, dynamically adjustable each step
- Feature extraction per window: population rate, synchrony proxy, burst proxy, crude LFP proxy
- Clean interface:
    build_from_case(case, net_cfg=None)
    precompute_coupling(stim_xyz_mm)
    update_focus_drive(feedback)
    run_window(stim_params) -> features dict

Notes:
- This file assumes NEURON is installed (pip install neuron).
- If your environment is Colab, ensure you install NEURON and have a working compiler runtime.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import math
import numpy as np


# -----------------------------------------------------------------------------
# NEURON import (hard requirement to run)
# -----------------------------------------------------------------------------

try:
    from neuron import h  # type: ignore
except Exception as e:  # pragma: no cover
    h = None  # type: ignore
    _NEURON_IMPORT_ERROR = e
else:
    _NEURON_IMPORT_ERROR = None

h.load_file("stdrun.hoc")



# -----------------------------------------------------------------------------
# Types / Params
# -----------------------------------------------------------------------------

Vec3 = Tuple[float, float, float]


@dataclass
class StimParams:
    """
    Stimulation parameters for one decision window.

    Minimum required:
    - amp_mA: amplitude (mA)
    - freq_Hz: frequency (Hz)
    - pw_ms: pulse width (ms)

    Optional:
    - waveform: "rect" | "biphasic" | "burst"
    - duty_cycle: fraction [0..1] of time ON (simple on/off gating within window)
    - biphasic:
        * phase2_amp_ratio: amplitude ratio for phase 2 relative to phase 1 (usually -1 for charge balance)
        * interphase_gap_ms: gap between phases
    - burst:
        * burst_Hz: bursts per second
        * pulses_per_burst: integer
        * intra_burst_freq_Hz: pulses per second inside burst
    """
    amp_mA: float
    freq_Hz: float
    pw_ms: float

    waveform: str = "rect"
    duty_cycle: float = 1.0

    phase2_amp_ratio: float = -1.0
    interphase_gap_ms: float = 0.0

    burst_Hz: float = 5.0
    pulses_per_burst: int = 5
    intra_burst_freq_Hz: float = 100.0


def _as_stim_params(x: Union[StimParams, Dict[str, Any], Sequence[float]]) -> StimParams:
    if isinstance(x, StimParams):
        return x
    if isinstance(x, dict):
        return StimParams(**x)  # type: ignore[arg-type]
    if isinstance(x, (list, tuple, np.ndarray)):
        if len(x) < 3:
            raise ValueError("stim_params sequence must have at least (amp_mA, freq_Hz, pw_ms).")
        return StimParams(float(x[0]), float(x[1]), float(x[2]))
    raise TypeError("stim_params must be StimParams, dict, or sequence.")


def _unit(v: np.ndarray) -> np.ndarray:
    n = float(np.linalg.norm(v))
    if n < 1e-12:
        return np.array([1.0, 0.0, 0.0], dtype=float)
    return v / n


# -----------------------------------------------------------------------------
# Base Plant
# -----------------------------------------------------------------------------

class NeuronPlant:
    def __init__(self, cfg: PlantConfig, rng_seed: int = 0):
        if _NEURON_IMPORT_ERROR is not None:
            raise RuntimeError(
                "NEURON is not available. Install with `pip install neuron` and ensure a working runtime."
            ) from _NEURON_IMPORT_ERROR

        self.cfg = cfg
        self.rng_seed = int(rng_seed)
        self.rng = np.random.default_rng(self.rng_seed)

    # Interface expected by env
    def build_from_case(self, case: CaseSpec, net_cfg: Optional[Any] = None) -> None:
        raise NotImplementedError

    def precompute_coupling(self, stim_xyz_mm: Vec3) -> None:
        raise NotImplementedError

    def update_focus_drive(self, feedback: Dict[str, Any]) -> None:
        raise NotImplementedError

    def run_window(self, stim_params: Union[StimParams, Dict[str, Any], Sequence[float]]) -> Dict[str, float]:
        raise NotImplementedError


# -----------------------------------------------------------------------------
# Full NEURON plant
# -----------------------------------------------------------------------------

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import math
import numpy as np
from neuron import h

Vec3 = Tuple[float, float, float]


class FullNeuronPlant(NeuronPlant):
    """
    Level-1 NEURON plant for DBS parameter optimization.

    Implements:
    - build_from_case(case, net_cfg=None)
    - precompute_coupling(stim_xyz_mm)
    - update_focus_drive(feedback)
    - run_window(stim_params) -> Dict[str,float]
    """

    def __init__(self, cfg: PlantConfig, rng_seed: int = 0):
        super().__init__(cfg, rng_seed=rng_seed)

        # NEURON global settings
        h.load_file("stdrun.hoc")
        h.dt = float(self.cfg.dt_ms)

        # State
        self.case: Optional[CaseSpec] = None
        self.net_cfg: Optional[Any] = None
        self.rng = np.random.default_rng(int(rng_seed))

        # Populations
        self.gm_secs: List[Any] = []
        self.wm_secs: List[Any] = []
        self.gm_xyz_mm: np.ndarray = np.zeros((0, 3), dtype=float)
        self.wm_xyz_mm: np.ndarray = np.zeros((0, 3), dtype=float)
        self._gm_types: List[str] = []

        # Stimulation targets
        self._stim_segs: List[Any] = []
        self._stim_seg_xyz_mm: np.ndarray = np.zeros((0, 3), dtype=float)
        self._stim_coupling_mV_per_mA: np.ndarray = np.zeros((0,), dtype=float)
        self._stim_seg_types: List[str] = []
        self._stim_play_vecs: List[Any] = []  # keep (vvec, tvec) alive

        # Onset drives
        self._primary_syns: List[Any] = []
        self._primary_netstims: List[Any] = []
        self._primary_netcons: List[Any] = []
        self._primary_base_hz: List[float] = []
        self._primary_base_w: List[float] = []

        self._secondary_syns: List[Any] = []
        self._secondary_netstims: List[Any] = []
        self._secondary_netcons: List[Any] = []
        self._secondary_base_hz: List[float] = []
        self._secondary_base_w: List[float] = []

        # Feedback scales (stored so they apply at build time too)
        self._drive_rate_scale: float = 1.0
        self._primary_rate_scale: float = 1.0
        self._secondary_rate_scale: float = 1.0
        self._primary_weight_scale: float = 1.0
        self._secondary_weight_scale: float = 1.0
        self._primary_site_scales = None
        self._secondary_site_scales = None

        # Recording
        self._record_t: Any = None
        self._record_vs: List[Any] = []
        self._spike_vecs: List[Any] = []
        self._spike_netcons: List[Any] = []
        self._spike_cell_types: List[str] = []

        # Last step
        self._last_stim: Optional[StimParams] = None
        self._last_features: Dict[str, float] = {}

    # -------------------------------------------------------------------------
    # Build / reset
    # -------------------------------------------------------------------------

    def build_from_case(self, case: CaseSpec, net_cfg: Optional[Any] = None) -> None:
        """
        Rebuild the NEURON model from scratch for a new CaseSpec.
        """
        self.case = case
        self.net_cfg = net_cfg

        # Hard reset all NEURON sections
        h("forall delete_section()")

        # Clear python-side containers
        self.gm_secs.clear()
        self.wm_secs.clear()
        self._stim_segs.clear()
        self._stim_play_vecs.clear()
        self._stim_seg_types.clear()
        self._stim_seg_xyz_mm = np.zeros((0, 3), dtype=float)
        self._stim_coupling_mV_per_mA = np.zeros((0,), dtype=float)

        self._primary_syns.clear()
        self._primary_netstims.clear()
        self._primary_netcons.clear()
        self._primary_base_hz.clear()
        self._primary_base_w.clear()

        self._secondary_syns.clear()
        self._secondary_netstims.clear()
        self._secondary_netcons.clear()
        self._secondary_base_hz.clear()
        self._secondary_base_w.clear()

        self._record_vs.clear()
        self._spike_vecs.clear()
        self._spike_netcons.clear()
        self._spike_cell_types.clear()
        self._record_t = None

        # Resolve network parameters
        n_total, frac_gm, frac_exc, frac_inh_fast, frac_inh_slow, ax_L_um, ax_diam_um, ax_nseg = \
            self._resolve_network_params(case, net_cfg)

        n_gm = max(1, int(round(n_total * frac_gm)))
        n_wm = max(1, n_total - n_gm)

        # Sample positions in region
        region_center = np.array(case.region.center_mm, dtype=float)
        bounds = np.array(case.region.bounds_mm, dtype=float)

        self.gm_xyz_mm = region_center + self.rng.uniform(-bounds, bounds, size=(n_gm, 3))
        self.wm_xyz_mm = region_center + self.rng.uniform(-bounds, bounds, size=(n_wm, 3))

        # Assign GM types
        self._gm_types = self._assign_gm_types(n_gm, frac_exc, frac_inh_fast, frac_inh_slow)

        # Build GM: HH soma + extracellular
        for i in range(n_gm):
            sec = h.Section(name=f"gm_soma_{i}")
            sec.L = 20.0
            sec.diam = 20.0
            sec.nseg = 1
            sec.insert("hh")
            sec.insert("extracellular")
            self._apply_excitability_to_hh(sec, getattr(case, "excitability", {}) or {}, self._gm_types[i])
            self.gm_secs.append(sec)

        # Build WM: passive axon + extracellular
        for i in range(n_wm):
            sec = h.Section(name=f"wm_axon_{i}")
            sec.L = float(ax_L_um)
            sec.diam = float(ax_diam_um)
            sec.nseg = int(max(1, ax_nseg))
            sec.insert("pas")
            sec.g_pas = 1e-4
            sec.e_pas = -65.0
            sec.insert("extracellular")
            self.wm_secs.append(sec)

        # Collect segments to stimulate
        self._stim_segs, self._stim_seg_xyz_mm = self._collect_stim_segments()

        # Build onset drives (uses case.primary_onsets_xyz_mm / secondary_onsets_xyz_mm)
        self._build_onset_drives(case)

        # Setup recording (GM somas only)
        self._setup_recording()

        # Init
        h.finitialize(-65.0)
        h.t = 0.0

    def _resolve_network_params(self, case: CaseSpec, net_cfg: Optional[Any]):
        cfg = NetworkConfig() if net_cfg is None else net_cfg

        nv = getattr(case, "network_variant", None)
        if nv is not None:
            n_total = int(getattr(nv, "n_total", getattr(cfg, "n_total", 80)))
            frac_gm = float(getattr(nv, "frac_gm", getattr(cfg, "frac_gm", 0.7)))
            frac_exc = float(getattr(nv, "frac_exc_in_gm", getattr(cfg, "frac_exc_in_gm", 0.75)))
            frac_inh_fast = float(getattr(nv, "frac_inh_fast_in_gm", getattr(cfg, "frac_inh_fast_in_gm", 0.15)))
            frac_inh_slow = float(getattr(nv, "frac_inh_slow_in_gm", getattr(cfg, "frac_inh_slow_in_gm", 0.10)))
            wm_params = getattr(nv, "wm_axon_params", None) or {}
            ax_L_um = float(wm_params.get("L_um", getattr(cfg, "axon_L_um", 800.0)))
            ax_diam_um = float(wm_params.get("diam_um", getattr(cfg, "axon_diam_um", 1.0)))
            ax_nseg = int(wm_params.get("nseg", getattr(cfg, "axon_nseg", 9)))
        else:
            n_total = int(getattr(cfg, "n_total", 80))
            frac_gm = float(getattr(cfg, "frac_gm", 0.7))
            frac_exc = float(getattr(cfg, "frac_exc_in_gm", 0.75))
            frac_inh_fast = float(getattr(cfg, "frac_inh_fast_in_gm", 0.15))
            frac_inh_slow = float(getattr(cfg, "frac_inh_slow_in_gm", 0.10))
            ax_L_um = float(getattr(cfg, "axon_L_um", 800.0))
            ax_diam_um = float(getattr(cfg, "axon_diam_um", 1.0))
            ax_nseg = int(getattr(cfg, "axon_nseg", 9))

        frac_gm = float(np.clip(frac_gm, 0.05, 0.95))
        s = frac_exc + frac_inh_fast + frac_inh_slow
        if abs(s - 1.0) > 1e-6:
            frac_exc, frac_inh_fast, frac_inh_slow = [x / s for x in (frac_exc, frac_inh_fast, frac_inh_slow)]
        n_total = max(2, n_total)
        ax_nseg = max(1, ax_nseg)

        return n_total, frac_gm, frac_exc, frac_inh_fast, frac_inh_slow, ax_L_um, ax_diam_um, ax_nseg

    def _assign_gm_types(self, n_gm: int, frac_exc: float, frac_inh_fast: float, frac_inh_slow: float) -> List[str]:
        types = ["E"] * n_gm
        n_exc = int(round(n_gm * frac_exc))
        n_inh_fast = int(round(n_gm * frac_inh_fast))
        idx = np.arange(n_gm)
        self.rng.shuffle(idx)
        for k in idx[n_exc:n_exc + n_inh_fast]:
            types[int(k)] = "I_fast"
        for k in idx[n_exc + n_inh_fast:]:
            types[int(k)] = "I_slow"
        return types

    def _apply_excitability_to_hh(self, sec: Any, excitability: Dict[str, float], cell_type: str) -> None:
        if not excitability:
            return
        drive_scale = float(excitability.get("drive_scale", 1.0))
        syn_scale = float(excitability.get("syn_weight_scale", 1.0))
        scale = float(np.clip(0.5 * drive_scale + 0.5 * syn_scale, 0.7, 1.5))
        if str(cell_type).startswith("I"):
            scale *= 0.95
        for seg in sec:
            try:
                seg.hh.gnabar *= scale
                seg.hh.gkbar *= (0.9 + 0.1 * scale)
            except Exception:
                pass

    # -------------------------------------------------------------------------
    # Segment collection + coupling
    # -------------------------------------------------------------------------

    def _collect_stim_segments(self):
        segs: List[Any] = []
        xyz: List[List[float]] = []
        types: List[str] = []

        # GM soma centers
        for i, sec in enumerate(self.gm_secs):
            segs.append(sec(0.5))
            xyz.append(self.gm_xyz_mm[i].tolist())
            types.append(self._gm_types[i] if i < len(self._gm_types) else "GM")

        # WM all segments (use same xyz per axon for simplicity)
        for i, sec in enumerate(self.wm_secs):
            for seg in sec:
                segs.append(seg)
                xyz.append(self.wm_xyz_mm[i].tolist())
                types.append("WM")

        self._stim_seg_types = types
        return segs, np.asarray(xyz, dtype=float)

    def precompute_coupling(self, stim_xyz_mm: Vec3) -> None:
        """
        Compute coupling (mV per mA) for each stimulated segment based on 1/(4*pi*sigma*r).
        """
        if self.case is None:
            raise RuntimeError("Call build_from_case() before precompute_coupling().")

        stim_xyz = np.asarray(stim_xyz_mm, dtype=float).reshape(1, 3)
        seg_xyz = self._stim_seg_xyz_mm
        if seg_xyz.size == 0:
            self._stim_coupling_mV_per_mA = np.zeros((0,), dtype=float)
            return

        r_mm = np.linalg.norm(seg_xyz - stim_xyz, axis=1)
        r_mm = np.maximum(r_mm, float(getattr(self.cfg, "coupling_min_r_mm", 0.15)))
        r_m = r_mm * 1e-3

        sigma = self._sigma_for_segments(self.case.tissue, seg_xyz)
        sigma = np.maximum(sigma, 1e-6)

        coeff_V_per_A = 1.0 / (4.0 * math.pi * sigma * r_m)

        # In this convention, V/A * (mA) = mV (exact scaling), so store V/A as "mV per mA".
        self._stim_coupling_mV_per_mA = coeff_V_per_A.astype(float)

        # 8.2 Option A: type-dependent selectivity (optional)
        kE = float(getattr(self.cfg, "stim_selectivity_kE", 1.0))
        kI = float(getattr(self.cfg, "stim_selectivity_kI", 1.0))
        if len(self._stim_seg_types) == len(self._stim_coupling_mV_per_mA):
            mult = np.ones_like(self._stim_coupling_mV_per_mA, dtype=float)
            for i, tp in enumerate(self._stim_seg_types):
                if tp == "E":
                    mult[i] = kE
                elif str(tp).startswith("I"):
                    mult[i] = kI
            self._stim_coupling_mV_per_mA *= mult

    def _sigma_for_segments(self, tissue: Any, seg_xyz_mm: np.ndarray) -> np.ndarray:
        mode = getattr(tissue, "mode", "GM")
        sigma_gm = float(getattr(tissue, "sigma_gm_S_per_m", getattr(self.cfg, "sigma_S_per_m", 0.2)))
        sigma_wm = float(getattr(tissue, "sigma_wm_S_per_m", getattr(self.cfg, "sigma_S_per_m", 0.14)))

        if mode == "GM":
            return np.full((seg_xyz_mm.shape[0],), sigma_gm, dtype=float)
        if mode == "WM":
            return np.full((seg_xyz_mm.shape[0],), sigma_wm, dtype=float)

        boundary = getattr(tissue, "boundary", None)
        if boundary is None:
            return np.full((seg_xyz_mm.shape[0],), 0.5 * (sigma_gm + sigma_wm), dtype=float)

        p = np.asarray(boundary.point_mm, dtype=float)
        n = np.asarray(boundary.normal_unit, dtype=float)
        n = _unit(n)
        sd = (seg_xyz_mm - p.reshape(1, 3)) @ n.reshape(3, 1)
        sd = sd.reshape(-1)
        return np.where(sd >= 0.0, sigma_gm, sigma_wm).astype(float)

    # -------------------------------------------------------------------------
    # Onset drives (NetStim -> Exp2Syn)
    # -------------------------------------------------------------------------

    def _build_onset_drives(self, case: CaseSpec) -> None:
        if len(self.gm_secs) == 0:
            return

        # Clear existing (avoid doubling)
        self._primary_syns.clear()
        self._primary_netstims.clear()
        self._primary_netcons.clear()
        self._primary_base_hz.clear()
        self._primary_base_w.clear()

        self._secondary_syns.clear()
        self._secondary_netstims.clear()
        self._secondary_netcons.clear()
        self._secondary_base_hz.clear()
        self._secondary_base_w.clear()

        gm_xyz = np.asarray(self.gm_xyz_mm, dtype=float)

        def nearest_gm_index(site_xyz_mm: np.ndarray) -> int:
            d2 = np.sum((gm_xyz - site_xyz_mm[None, :]) ** 2, axis=1)
            return int(np.argmin(d2))

        # Build-time scales (use stored feedback values if env already stepped)
        drive_rate_scale = float(np.clip(getattr(self, "_drive_rate_scale", 1.0), 0.05, 5.0))
        primary_rate_scale = float(np.clip(getattr(self, "_primary_rate_scale", 1.0), 0.05, 5.0))
        secondary_rate_scale = float(np.clip(getattr(self, "_secondary_rate_scale", 1.0), 0.05, 5.0))
        primary_weight_scale = float(np.clip(getattr(self, "_primary_weight_scale", 1.0), 0.0, 10.0))
        secondary_weight_scale = float(np.clip(getattr(self, "_secondary_weight_scale", 1.0), 0.0, 10.0))

        p_site_scales = getattr(self, "_primary_site_scales", None)
        s_site_scales = getattr(self, "_secondary_site_scales", None)

        # Synapse params
        tau1 = float(getattr(self.cfg, "onset_syn_tau1_ms", 0.5))
        tau2 = float(getattr(self.cfg, "onset_syn_tau2_ms", 3.0))
        e_rev = float(getattr(self.cfg, "onset_syn_e_rev_mV", 0.0))

        rng = np.random.default_rng(int(getattr(case, "seed", 0)))

        def sample_primary_base():
            base_hz = float(getattr(case, "primary_base_hz", 25.0))
            base_w = float(getattr(case, "primary_base_w", 5.0))
            hz_jit = float(getattr(case, "primary_hz_jitter", 0.25))
            w_jit = float(getattr(case, "primary_w_jitter", 0.25))
            base_hz = max(0.1, base_hz * (1.0 + hz_jit * rng.normal()))
            base_w = max(0.0, base_w * (1.0 + w_jit * rng.normal()))
            return base_hz, base_w

        def sample_secondary_base():
            base_hz = float(getattr(case, "secondary_base_hz", 10.0))
            base_w = float(getattr(case, "secondary_base_w", 0.25))
            hz_jit = float(getattr(case, "secondary_hz_jitter", 0.30))
            w_jit = float(getattr(case, "secondary_w_jitter", 0.30))
            base_hz = max(0.1, base_hz * (1.0 + hz_jit * rng.normal()))
            base_w = max(0.0, base_w * (1.0 + w_jit * rng.normal()))
            return base_hz, base_w

        # Robust onset list read
        raw_primary = getattr(case, "primary_onsets_xyz_mm", []) or []
        raw_secondary = getattr(case, "secondary_onsets_xyz_mm", []) or []

        primary_sites = np.asarray(raw_primary, dtype=float).reshape(-1, 3) if len(raw_primary) else np.zeros((0, 3))
        secondary_sites = np.asarray(raw_secondary, dtype=float).reshape(-1, 3) if len(raw_secondary) else np.zeros((0, 3))

        print("[onset] primary_sites:", primary_sites.shape, "secondary_sites:", secondary_sites.shape)

        # Primary
        for i_site, site in enumerate(primary_sites):
            gi = nearest_gm_index(site)
            sec = self.gm_secs[gi]

            base_hz, base_w = sample_primary_base()
            hz = base_hz * drive_rate_scale * primary_rate_scale
            interval = 1000.0 / max(1e-6, hz)

            w_scale = primary_weight_scale
            if p_site_scales is not None and i_site < len(p_site_scales):
                w_scale *= float(np.clip(p_site_scales[i_site], 0.0, 10.0))
            w_eff = base_w * w_scale

            syn = h.Exp2Syn(sec(0.5))
            syn.tau1 = tau1
            syn.tau2 = tau2
            syn.e = e_rev

            ns = h.NetStim()
            ns.number = 1e9
            ns.start = float(h.t)
            ns.noise = float(getattr(self.cfg, "onset_noise", 1.0))
            ns.interval = float(interval)

            nc = h.NetCon(ns, syn)
            nc.delay = 0.0
            nc.weight[0] = float(w_eff)

            self._primary_syns.append(syn)
            self._primary_netstims.append(ns)
            self._primary_netcons.append(nc)
            self._primary_base_hz.append(float(base_hz))
            self._primary_base_w.append(float(base_w))

        # Secondary
        for i_site, site in enumerate(secondary_sites):
            gi = nearest_gm_index(site)
            sec = self.gm_secs[gi]

            base_hz, base_w = sample_secondary_base()
            hz = base_hz * drive_rate_scale * secondary_rate_scale
            interval = 1000.0 / max(1e-6, hz)

            w_scale = secondary_weight_scale
            if s_site_scales is not None and i_site < len(s_site_scales):
                w_scale *= float(np.clip(s_site_scales[i_site], 0.0, 10.0))
            w_eff = base_w * w_scale

            syn = h.Exp2Syn(sec(0.5))
            syn.tau1 = tau1
            syn.tau2 = tau2
            syn.e = e_rev

            ns = h.NetStim()
            ns.number = 1e9
            ns.start = float(h.t)
            ns.noise = float(getattr(self.cfg, "onset_noise", 1.0))
            ns.interval = float(interval)

            nc = h.NetCon(ns, syn)
            nc.delay = 0.0
            nc.weight[0] = float(w_eff)

            self._secondary_syns.append(syn)
            self._secondary_netstims.append(ns)
            self._secondary_netcons.append(nc)
            self._secondary_base_hz.append(float(base_hz))
            self._secondary_base_w.append(float(base_w))

        print("[onset] primary drives:", len(self._primary_netstims), "secondary drives:", len(self._secondary_netstims))

    def update_focus_drive(self, feedback: Dict[str, Any]) -> None:
        """
        Live update onset drives using baselines; store for build-time too.
        """
        drive_rate_scale = float(feedback.get("drive_rate_scale", 1.0))
        primary_rate_scale = float(feedback.get("primary_rate_scale", 1.0))
        secondary_rate_scale = float(feedback.get("secondary_rate_scale", 1.0))
        primary_weight_scale = float(feedback.get("primary_weight_scale", 1.0))
        secondary_weight_scale = float(feedback.get("secondary_weight_scale", 1.0))
        p_site = feedback.get("primary_site_scales", None)
        s_site = feedback.get("secondary_site_scales", None)

        self._drive_rate_scale = float(np.clip(drive_rate_scale, 0.05, 5.0))
        self._primary_rate_scale = float(np.clip(primary_rate_scale, 0.05, 5.0))
        self._secondary_rate_scale = float(np.clip(secondary_rate_scale, 0.05, 5.0))
        self._primary_weight_scale = float(np.clip(primary_weight_scale, 0.0, 10.0))
        self._secondary_weight_scale = float(np.clip(secondary_weight_scale, 0.0, 10.0))
        self._primary_site_scales = p_site
        self._secondary_site_scales = s_site

        # Primary updates
        for i, (ns, nc) in enumerate(zip(self._primary_netstims, self._primary_netcons)):
            base_hz = float(self._primary_base_hz[i])
            hz = base_hz * self._drive_rate_scale * self._primary_rate_scale
            ns.interval = 1000.0 / max(1e-6, hz)

            w = self._primary_weight_scale
            if p_site is not None and i < len(p_site):
                w *= float(np.clip(p_site[i], 0.0, 10.0))
            nc.weight[0] = float(self._primary_base_w[i]) * float(w)

        # Secondary updates
        for i, (ns, nc) in enumerate(zip(self._secondary_netstims, self._secondary_netcons)):
            base_hz = float(self._secondary_base_hz[i])
            hz = base_hz * self._drive_rate_scale * self._secondary_rate_scale
            ns.interval = 1000.0 / max(1e-6, hz)

            w = self._secondary_weight_scale
            if s_site is not None and i < len(s_site):
                w *= float(np.clip(s_site[i], 0.0, 10.0))
            nc.weight[0] = float(self._secondary_base_w[i]) * float(w)

    # -------------------------------------------------------------------------
    # Recording + spike feature helpers
    # -------------------------------------------------------------------------

    def _setup_recording(self) -> None:
        self._record_t = h.Vector()
        self._record_t.record(h._ref_t)

        self._record_vs = []
        self._spike_vecs = []
        self._spike_netcons = []
        self._spike_cell_types = []

        thr_mV = float(getattr(self.cfg, "spike_threshold_mV", -20.0))
        max_record = int(getattr(self.cfg, "n_record_cells", 8))

        # --- robust labels ---
        gm_types = getattr(self, "_gm_types", ["E"] * len(self.gm_secs))
        E_idx = [i for i, t in enumerate(gm_types) if str(t) == "E"]
        I_idx = [i for i, t in enumerate(gm_types) if str(t).startswith("I")]

        # --- balanced selection ---
        # target half inhibitory if possible
        nI = min(len(I_idx), max_record // 2)
        nE = min(len(E_idx), max_record - nI)

        # if still under-filled (e.g., not enough E or I), fill from remaining GM cells
        pick = []
        if nE > 0:
            pick += self.rng.choice(E_idx, size=nE, replace=False).tolist()
        if nI > 0:
            pick += self.rng.choice(I_idx, size=nI, replace=False).tolist()

        if len(pick) < min(max_record, len(self.gm_secs)):
            remaining = [i for i in range(len(self.gm_secs)) if i not in set(pick)]
            n_fill = min(len(remaining), min(max_record, len(self.gm_secs)) - len(pick))
            if n_fill > 0:
                pick += self.rng.choice(remaining, size=n_fill, replace=False).tolist()

        self.rng.shuffle(pick)

        # convenience for your external tests
        self._record_gm_indices = pick
        self._record_secs = [self.gm_secs[i] for i in pick]

        print(
            f"[record] n={len(pick)} thr={thr_mV} "
            f"E={sum(str(gm_types[i])=='E' for i in pick)} "
            f"I={sum(str(gm_types[i]).startswith('I') for i in pick)}"
        )

        # --- create recording objects ---
        for i in pick:
            sec = self.gm_secs[i]

            vvec = h.Vector()
            vvec.record(sec(0.5)._ref_v)
            self._record_vs.append(vvec)

            spk = h.Vector()
            nc = h.NetCon(sec(0.5)._ref_v, None, sec=sec)
            nc.threshold = thr_mV
            nc.record(spk)

            self._spike_vecs.append(spk)
            self._spike_netcons.append(nc)

            # label aligned to the spike vector
            self._spike_cell_types.append(str(gm_types[i]))

    def _split_spikes_by_type(self, t0: float, t1: float):
        E_lists, I_lists, all_lists = [], [], []
        for i, sp in enumerate(self._spike_vecs):
            st_all = np.asarray(list(sp), dtype=float)
            st_win = st_all[(st_all >= t0) & (st_all < t1)] - float(t0)
            all_lists.append(st_win)

            ctype = self._spike_cell_types[i]
            if ctype == "E":
                E_lists.append(st_win)
            elif str(ctype).startswith("I"):
                I_lists.append(st_win)
            else:
                E_lists.append(st_win)
        return E_lists, I_lists, all_lists

    @staticmethod
    def _rate_from_lists(spike_lists, window_ms: float) -> float:
        n = int(sum(len(x) for x in spike_lists))
        return float(n) / max(1e-9, float(window_ms) / 1000.0)

    # -------------------------------------------------------------------------
    # Waveform + stimulation application
    # -------------------------------------------------------------------------

    def _make_waveform_mA(self, stim: StimParams, t0_ms: float, t1_ms: float) -> np.ndarray:
        dt = float(self.cfg.dt_ms)
        n = int(np.round((t1_ms - t0_ms) / dt))
        n = max(1, n)

        t = np.arange(n, dtype=float) * dt
        I = np.zeros(n, dtype=float)

        amp = float(stim.amp_mA)
        freq = float(stim.freq_Hz)
        pw = float(stim.pw_ms)
        duty = float(np.clip(getattr(stim, "duty_cycle", 1.0), 0.0, 1.0))

        if duty <= 0.0 or amp == 0.0 or freq <= 0.0 or pw <= 0.0:
            return I

        gate = 1.0
        if duty < 1.0:
            on_ms = duty * float(self.cfg.window_ms)
            gate = (t < on_ms).astype(float)

        waveform = (getattr(stim, "waveform", "rect") or "rect").lower()

        if waveform == "rect":
            period_ms = 1000.0 / freq
            phase = np.mod(t, period_ms)
            I = np.where(phase < pw, amp, 0.0) * gate

        elif waveform == "biphasic":
            period_ms = 1000.0 / freq
            phase = np.mod(t, period_ms)
            p1 = (phase < pw)
            gap = float(max(0.0, getattr(stim, "interphase_gap_ms", 0.0)))
            p2_start = pw + gap
            p2 = (phase >= p2_start) & (phase < p2_start + pw)
            ratio = float(getattr(stim, "phase2_amp_ratio", -1.0))
            I = (amp * p1 + amp * ratio * p2) * gate

        elif waveform == "burst":
            burst_Hz = float(max(1e-6, getattr(stim, "burst_Hz", 5.0)))
            pulses_per_burst = int(max(1, getattr(stim, "pulses_per_burst", 5)))
            intra_Hz = float(max(1e-6, getattr(stim, "intra_burst_freq_Hz", 100.0)))

            burst_period_ms = 1000.0 / burst_Hz
            intra_period_ms = 1000.0 / intra_Hz

            t_in_burst = np.mod(t, burst_period_ms)
            pulse_idx = np.floor(t_in_burst / intra_period_ms).astype(int)
            in_train = pulse_idx < pulses_per_burst
            phase_in_pulse = t_in_burst - pulse_idx * intra_period_ms
            in_pw = phase_in_pulse < pw
            I = np.where(in_train & in_pw, amp, 0.0) * gate

        return I

    def _apply_extracellular_waveform(self, I_mA: np.ndarray, t0_ms: float) -> None:
        self._stim_play_vecs.clear()

        if len(self._stim_segs) == 0:
            return
        if self._stim_coupling_mV_per_mA.size != len(self._stim_segs):
            raise RuntimeError("Call precompute_coupling() before applying stimulation (coupling mismatch).")

        dt = float(self.cfg.dt_ms)
        times = (t0_ms + np.arange(len(I_mA)) * dt).astype(float)
        tvec = h.Vector(times.tolist())

        for seg, k in zip(self._stim_segs, self._stim_coupling_mV_per_mA):
            vext = (float(k) * I_mA).astype(float)
            vvec = h.Vector(vext.tolist())
            vvec.play(seg._ref_e_extracellular, tvec, 1)
            self._stim_play_vecs.append((vvec, tvec))

    # -------------------------------------------------------------------------
    # Main simulation step
    # -------------------------------------------------------------------------

    def run_window(self, stim_params: Union[StimParams, Dict[str, Any], Sequence[float]]) -> Dict[str, float]:
        if self.case is None:
            raise RuntimeError("Call build_from_case() before run_window().")

        stim = _as_stim_params(stim_params)
        self._last_stim = stim

        t0 = float(h.t)
        t1 = t0 + float(self.cfg.window_ms)

        I_mA = self._make_waveform_mA(stim, t0_ms=t0, t1_ms=t1)
        self._apply_extracellular_waveform(I_mA, t0_ms=t0)

        h.continuerun(t1)

        E_lists, I_lists, all_lists = ([], [], [])
        if hasattr(self, "_spike_vecs") and len(self._spike_vecs) > 0:
            E_lists, I_lists, all_lists = self._split_spikes_by_type(t0, t1)

        n_spikes = int(sum(len(x) for x in all_lists))
        rate_hz = float(n_spikes) / max(1e-9, float(self.cfg.window_ms) / 1000.0)

        E_rate = self._rate_from_lists(E_lists, self.cfg.window_ms)
        I_rate = self._rate_from_lists(I_lists, self.cfg.window_ms)
        eps = 1e-3
        logEI = float(np.log((E_rate + eps) / (I_rate + eps)))   # ~0 when E≈I, >0 when E>I
        logEI = float(np.clip(logEI, -3.0, 3.0))

        E_all = np.concatenate(E_lists) if len(E_lists) else np.array([], dtype=float)
        burst_E = burst_fraction_from_spikes(E_all, isi_thresh_ms=10.0, min_spikes_in_burst=3)
        sync_E = sync_from_spike_trains(E_lists, window_ms=float(self.cfg.window_ms), bin_ms=5.0)

        feats = {
            "n_spikes": float(n_spikes),
            "rate_hz": float(rate_hz),
            "E_rate_hz": float(E_rate),
            "I_rate_hz": float(I_rate),
            "logEI": float(logEI),
            "burst_E": float(burst_E),
            "sync_E": float(sync_E),
        }
        self._last_features = feats
        return feats



__all__ = [
    "StimParams",
    "NeuronPlant",
    "FullNeuronPlant",
]


In [None]:
import inspect
print(inspect.signature(sync_from_spike_trains))


(spike_lists_ms: 'List[np.ndarray]', window_ms: 'float', bin_ms: 'float' = 5.0) -> 'float'


In [None]:
print("Has FullNeuronPlant.run_window?", "run_window" in FullNeuronPlant.__dict__)
print("FullNeuronPlant.run_window:", FullNeuronPlant.run_window.__qualname__)


Has FullNeuronPlant.run_window? True
FullNeuronPlant.run_window: FullNeuronPlant.run_window


In [None]:
print("FullNeuronPlant.run_window defined on:", FullNeuronPlant.run_window.__qualname__)


FullNeuronPlant.run_window defined on: FullNeuronPlant.run_window


In [None]:
# model_l2.py
"""
Level-2 epilepsy state + plasticity model (lower-dimensional state layer).

Project-aligned objectives:
- Maintain a stable, differentiable-ish (but hand-coded) state update capturing:
    * seizure burden / propensity (scalar)
    * onset strengths (primary + secondary) with option for per-site vectors
    * long-term plasticity / drift term
- Consume Level-1 features + stimulation parameters each RL step
- Emit feedback that modulates Level-1 onset drive:
    * drive_rate_scale (global)
    * primary_weight_scale / secondary_weight_scale
    * optional per-site scales (primary_site_scales, secondary_site_scales)

This file is intentionally self-contained and conservative (bounded updates, smooth dynamics),
so training remains numerically stable.

Expected feature inputs (from Level-1 plant):
- rate_hz: float
- sync: float
- burst: float
- lfp: float

Expected stim inputs:
- amp_mA, freq_Hz, pw_ms
Optionally additional waveform descriptors (ignored safely if present).

Integration keys returned in feedback dict:
- drive_rate_scale
- primary_rate_scale
- secondary_rate_scale
- primary_weight_scale
- secondary_weight_scale
- primary_site_scales (optional vector)
- secondary_site_scales (optional vector)
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import numpy as np


# -----------------------------------------------------------------------------
# Config + State
# -----------------------------------------------------------------------------

from dataclasses import dataclass

from dataclasses import dataclass

@dataclass
class L2Config:
    # -------- time step (discrete-time L2) --------
    # Your EpilepsyStateModel uses cfg.dt as a multiplier on burden updates.
    # For “one update per env step”, dt=1.0 is the correct default.
    dt: float = 1.0

    # -------- state clamps --------
    burden_min: float = 0.0
    burden_max: float = 10.0

    strength_min: float = 0.0
    strength_max: float = 5.0

    plasticity_min: float = -2.0
    plasticity_max: float = 2.0

    # -------- feature references (normalization) --------
    ref_rate_hz: float = 20.0
    ref_sync: float = 0.5
    ref_burst: float = 0.2
    ref_lfp: float = 1.0

    # Keep BOTH spellings because your notebook currently uses cfg.ref_EI_ratio
    ref_ei_ratio: float = 1.0
    ref_EI_ratio: float = 1.0

    # Your notebook uses cfg.ref_I_rate (capital I)
    ref_I_rate: float = 10.0

    baseline_gain: float = 0.02      # start small; tune
    burden_decay: float = 0.10       # in "per second" units if dt is seconds; adjust if dt differs

    rebound_gain: float = 0.05
    rebound_threshold: float = 0.2


    # -------- weights (burden pathology mixture) --------
    w_rate: float = 0.60
    w_sync: float = 0.25
    w_burst: float = 0.15
    w_ei: float = 0.10

    w_lfp: float = 0.10
    w_plasticity: float = 0.10

    # stimulation suppression weight
    w_stim: float = 0.25

    # -------- burden dynamics --------
    # Your model uses cfg.burden_decay * burden
    burden_decay: float = 0.02

    # plasticity contributes to burden_dot via cfg.plasticity_burden_gain * plasticity
    plasticity_burden_gain: float = 0.05

    # -------- strength dynamics --------
    # Your code uses these in both scalar and per-site updates
    primary_strength_rate: float = 0.02
    secondary_strength_rate: float = 0.02

    strength_burden_gain: float = 0.10
    strength_stim_gain: float = 0.10

    # -------- plasticity dynamics --------
    plasticity_rate: float = 0.05
    plasticity_decay: float = 0.01

    # -------- acute suppression controls --------
    acute_stim_suppress_k: float = 1.0
    acute_min_scale: float = 0.2

    # -------- per-site modelling --------
    use_per_site_secondary: bool = True
    per_site_noise_std: float = 0.02

    # -------- feedback gains (L2 -> L1) --------
    # Your EpilepsyStateModel reads feedback_drive_gain/feedback_weight_gain.
    feedback_drive_gain: float = 0.50
    feedback_weight_gain: float = 0.30

    # Backward-compatible aliases (your older config used fb_drive_gain/fb_weight_gain)
    fb_drive_gain: float = 0.50
    fb_weight_gain: float = 0.30

    ref_logEI: float = 1.0
    w_ei: float = 0.0  # or small, e.g. 0.05–0.2

@dataclass
class L2State:
    """Level-2 internal state."""
    burden: float = 0.0
    plasticity: float = 0.0

    primary_strength: float = 1.0
    secondary_strength: float = 0.5

    # Optional per-site vectors (secondary and primary)
    primary_site_strengths: np.ndarray = field(default_factory=lambda: np.zeros((0,), dtype=float))
    secondary_site_strengths: np.ndarray = field(default_factory=lambda: np.zeros((0,), dtype=float))

    # History summaries (for trend effects)
    stim_ema: float = 0.0
    stim_ema_alpha: float = 0.05  # fixed smoothing for simplicity


# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------

def _clip(x: float, lo: float, hi: float) -> float:
    return float(min(max(float(x), float(lo)), float(hi)))


def _safe_float(d: Dict[str, Any], key: str, default: float = 0.0) -> float:
    try:
        return float(d.get(key, default))
    except Exception:
        return float(default)


def _normalize_feature(x: float, ref: float) -> float:
    # Simple bounded normalization. Keeps magnitudes manageable.
    ref = max(1e-6, float(ref))
    return float(np.tanh(x / ref))

def _stim_intensity(cfg, amp_mA: float, freq_Hz: float, pw_ms: float) -> float:
    """
    Scalar stimulation 'dose' proxy for cost / adaptation dynamics.

    Uses charge-per-second (mA * ms * Hz) as the base, then normalizes by
    configurable reference values.

    Returns a non-negative float, typically O(0..a few).
    """
    amp = float(max(0.0, amp_mA))
    freq = float(max(0.0, freq_Hz))
    pw = float(max(0.0, pw_ms))

    # charge proxy (mA * ms per pulse * pulses/s) = mA*ms/s
    q = amp * pw * freq

    # normalization refs (safe fallbacks)
    ref_amp = float(getattr(cfg, "ref_amp_mA", 1.0))
    ref_freq = float(getattr(cfg, "ref_freq_Hz", 100.0))
    ref_pw = float(getattr(cfg, "ref_pw_ms", 0.2))

    q_ref = max(1e-9, ref_amp * ref_pw * ref_freq)
    return float(q / q_ref)


def _normalize_stim(amp_mA: float, freq_Hz: float, pw_ms: float, cfg: L2Config) -> float:
    """
    Normalize stimulation intensity into ~[0,1] range using a conservative product-like metric:
      stim_intensity ~ (amp/amax) * (freq/fmax) * (pw/pwmax)
    This roughly correlates with charge-per-second-like effects.
    """
    a = float(np.clip(amp_mA / max(1e-6, cfg.amp_max_mA), 0.0, 2.0))
    f = float(np.clip(freq_Hz / max(1e-6, cfg.freq_max_Hz), 0.0, 2.0))
    p = float(np.clip(pw_ms / max(1e-6, cfg.pw_max_ms), 0.0, 2.0))
    # Keep it smooth and bounded
    return float(np.clip(a * f * p, 0.0, 2.0))


# -----------------------------------------------------------------------------
# Model
# -----------------------------------------------------------------------------

class EpilepsyStateModel:
    """
    Level-2 state + plasticity model with optional per-secondary-site vector dynamics.
    """

    def __init__(self, cfg: Optional[L2Config] = None, rng_seed: int = 0):
        self.cfg = cfg or L2Config()
        self.rng_seed = int(rng_seed)
        self.rng = np.random.default_rng(self.rng_seed)
        self.state = L2State()
        # Latent seizure drive (0..1 recommended)
        self.z = float(getattr(self.cfg, "z0", 0.6))

        # Cache last drive dict (optional but convenient)
        self.last_drive = None


    # -------------------------
    # Reset / init
    # -------------------------

    def reset(self, case: Optional[CaseSpec] = None, *, severity: Optional[float] = None) -> L2State:
        """
        Reset Level-2 state.

        Initialization rules:
        - burden initialized from case.baseline_burden if present, else from severity, else 0.
        - onset strengths initialized from case.onset site baseline_strengths if present.
        """
        cfg = self.cfg
        s = L2State()

        sev = None
        if severity is not None:
            sev = float(severity)
        elif case is not None and hasattr(case, "severity"):
            sev = float(getattr(case, "severity"))
        else:
            sev = 0.0

        # Baseline burden: prefer explicit case.baseline_burden; fallback to severity
        if case is not None and hasattr(case, "baseline_burden"):
            s.burden = float(getattr(case, "baseline_burden"))
        else:
            s.burden = float(max(0.0, sev))

        s.burden = _clip(s.burden, cfg.burden_min, cfg.burden_max)

        # Initialize strengths from case if possible
        primary_strength_init = 1.0
        secondary_strength_init = 0.5

        primary_site_strengths = np.zeros((0,), dtype=float)
        secondary_site_strengths = np.zeros((0,), dtype=float)

        if case is not None and hasattr(case, "onset"):
            onset = getattr(case, "onset")
            primary_sites = list(getattr(onset, "primary_sites", ()))
            secondary_sites = list(getattr(onset, "secondary_sites", ()))

            if len(primary_sites) > 0:
                primary_site_strengths = np.array(
                    [float(getattr(ps, "baseline_strength", 1.0)) for ps in primary_sites],
                    dtype=float,
                )
                primary_strength_init = float(np.mean(primary_site_strengths))

            if len(secondary_sites) > 0:
                secondary_site_strengths = np.array(
                    [float(getattr(ss, "baseline_strength", 0.3)) for ss in secondary_sites],
                    dtype=float,
                )
                secondary_strength_init = float(np.mean(secondary_site_strengths))

        # If no sites provided, fallback conservative init
        s.primary_strength = _clip(primary_strength_init, cfg.strength_min, cfg.strength_max)
        s.secondary_strength = _clip(secondary_strength_init, cfg.strength_min, cfg.strength_max)

        if cfg.use_per_site_secondary:
            s.primary_site_strengths = primary_site_strengths.copy()
            s.secondary_site_strengths = secondary_site_strengths.copy()
        else:
            s.primary_site_strengths = np.zeros((0,), dtype=float)
            s.secondary_site_strengths = np.zeros((0,), dtype=float)

        # Plasticity starts near 0, slightly biased by severity
        s.plasticity = _clip(0.2 * (sev - 0.5), cfg.plasticity_min, cfg.plasticity_max)

        # History
        s.stim_ema = 0.0

        self.state = s
        return self.state


    def get_drive(self) -> Dict[str, Any]:
        st = self.state
        cfg = self.cfg

        burden_level = st.burden / max(1e-6, cfg.burden_max)
        strength_level = (0.5 * st.primary_strength + 0.5 * st.secondary_strength) / max(1e-6, cfg.strength_max)

        drive_rate_scale = 1.0 + cfg.feedback_drive_gain * (2.0 * burden_level - 1.0) + 0.05 * st.plasticity
        drive_rate_scale = float(np.clip(drive_rate_scale, 0.25, 3.0))

        primary_rate_scale = float(np.clip(1.0 + 0.10 * (2.0 * burden_level - 1.0), 0.5, 2.0))
        secondary_rate_scale = float(np.clip(1.0 + 0.15 * (2.0 * burden_level - 1.0), 0.5, 2.5))

        primary_weight_scale = 1.0 + cfg.feedback_weight_gain * (2.0 * strength_level - 1.0)
        secondary_weight_scale = 1.0 + cfg.feedback_weight_gain * (2.0 * strength_level - 1.0)

        primary_weight_scale += 0.05 * (2.0 * burden_level - 1.0)
        secondary_weight_scale += 0.08 * (2.0 * burden_level - 1.0)

        primary_weight_scale = float(np.clip(primary_weight_scale, 0.25, 3.0))
        secondary_weight_scale = float(np.clip(secondary_weight_scale, 0.25, 3.0))

        # --- ACUTE DBS SUPPRESSION (key addition) ---
        # st.stim_ema is already in ~[0,2] given your normalize
        acute = 1.0 / (1.0 + cfg.acute_stim_suppress_k * float(st.stim_ema))
        acute = float(np.clip(acute, cfg.acute_min_scale, 1.0))

        drive_rate_scale *= acute
        primary_rate_scale *= acute
        secondary_rate_scale *= acute
        primary_weight_scale *= acute
        secondary_weight_scale *= acute

        feedback: Dict[str, Any] = {
            "drive_rate_scale": float(np.clip(drive_rate_scale, 0.25, 3.0)),
            "primary_rate_scale": float(np.clip(primary_rate_scale, 0.5, 2.0)),
            "secondary_rate_scale": float(np.clip(secondary_rate_scale, 0.5, 2.5)),
            "primary_weight_scale": float(np.clip(primary_weight_scale, 0.25, 3.0)),
            "secondary_weight_scale": float(np.clip(secondary_weight_scale, 0.25, 3.0)),
        }

        # Per-site scales as before...
        if cfg.use_per_site_secondary:
            if st.primary_site_strengths is not None and st.primary_site_strengths.size > 0:
                p = st.primary_site_strengths
                feedback["primary_site_scales"] = np.clip(
                    0.5 + 1.5 * (p / max(1e-6, cfg.strength_max)), 0.25, 3.0
                ).astype(float)
            if st.secondary_site_strengths is not None and st.secondary_site_strengths.size > 0:
                s = st.secondary_site_strengths
                feedback["secondary_site_scales"] = np.clip(
                    0.5 + 1.5 * (s / max(1e-6, cfg.strength_max)), 0.25, 3.0
                ).astype(float)

        return feedback

    def _stim_intensity_norm(stim: dict, amp_max=3.0, freq_max=150.0, pw_max=0.4) -> float:
        amp = float(stim.get("amp_mA", 0.0))
        freq = float(stim.get("freq_Hz", 0.0))
        pw = float(stim.get("pw_ms", 0.0))
        duty = float(stim.get("duty_cycle", 1.0))

        x = (amp / max(1e-9, amp_max)) * (freq / max(1e-9, freq_max)) * (pw / max(1e-9, pw_max)) * duty
        return float(np.clip(x, 0.0, 2.0))


    # -------------------------
    # Update step
    # -------------------------

    def step(
        self,
        features: Dict[str, float],
        stim: Union[Dict[str, Any], Sequence[float], Tuple[float, float, float]],
        case: Optional[CaseSpec] = None,
    ) -> Tuple[L2State, Dict[str, Any], Dict[str, float]]:

        cfg = self.cfg
        st = self.state

        # Parse stim
        if isinstance(stim, dict):
            amp = _safe_float(stim, "amp_mA", 0.0)
            freq = _safe_float(stim, "freq_Hz", 0.0)
            pw = _safe_float(stim, "pw_ms", 0.0)
        else:
            amp, freq, pw = float(stim[0]), float(stim[1]), float(stim[2])

        stim_int = _stim_intensity(cfg, amp_mA=amp, freq_Hz=freq, pw_ms=pw)
        st.stim_ema = (1.0 - st.stim_ema_alpha) * st.stim_ema + st.stim_ema_alpha * stim_int

        # 8.3: use excitatory features preferentially
        rate_raw = _safe_float(features, "E_rate_hz", _safe_float(features, "rate_hz", 0.0))
        sync_raw = _safe_float(features, "sync_E", _safe_float(features, "sync", 0.0))
        burst_raw = _safe_float(features, "burst_E", _safe_float(features, "burst", 0.0))
        lfp_raw = _safe_float(features, "lfp", 0.0)
        ei_raw = _safe_float(features, "EI_ratio", 1.0)

        rate_n = _normalize_feature(rate_raw, cfg.ref_rate_hz)
        sync_n = _normalize_feature(sync_raw, cfg.ref_sync)
        burst_n = _normalize_feature(burst_raw, cfg.ref_burst)
        lfp_n = _normalize_feature(lfp_raw, cfg.ref_lfp)

        # Optional EI penalty (only if you add cfg.ref_ei_ratio and cfg.w_ei)
        w_ei = float(getattr(cfg, "w_ei", 0.0))
        ref_logei = float(getattr(cfg, "ref_logEI", 1.0))  # new cfg parameter; see below

        rate_raw  = float(_safe_float(features, "E_rate_hz", 0.0))
        sync_raw  = float(_safe_float(features, "sync_E", 0.0))
        burst_raw = float(_safe_float(features, "burst_E", 0.0))
        lfp_raw   = float(_safe_float(features, "lfp_power", 0.0))
        logEI_raw = float(_safe_float(features, "logEI", 0.0))


        # Only penalize excitatory dominance: logEI > 0
        ei_drive = max(0.0, logEI_raw)

        # Normalize relative to a reference (e.g. 1.0 ~ "meaningfully E-dominant")
        ei_n = _normalize_feature(ei_drive, ref_logei) if w_ei > 0 else 0.0

        pathology = float(
            cfg.w_rate * rate_n
            + cfg.w_sync * sync_n
            + cfg.w_burst * burst_n
            + cfg.w_lfp * lfp_n
            + w_ei * ei_n
        )

        # Suppression increases with stim intensity but drops as foci strengthen
        efficacy = float(np.clip(1.0 / (0.7 + 0.6 * st.primary_strength + 0.4 * st.secondary_strength), 0.2, 1.2))
        w_stim = float(getattr(cfg, "w_stim", 0.25))
        suppress = float(w_stim * stim_int * efficacy)

        # --- Baseline pathology (C1) ---
        sev = float(getattr(self.case, "severity", 1.0)) if getattr(self, "case", None) is not None else 1.0
        baseline = float(getattr(cfg, "baseline_gain", 0.0)) * sev

        # --- Optional rebound (C2) ---
        rebound_gain   = float(getattr(cfg, "rebound_gain", 0.0))
        rebound_thresh = float(getattr(cfg, "rebound_threshold", 0.2))
        rebound = rebound_gain * max(0.0, rebound_thresh - st.burden)

        # --- Decay ---
        burden_decay = float(getattr(cfg, "burden_decay", 0.0))

        # --- Final burden dynamics (ONE assignment only) ---
        burden_dot = (
            baseline
            + pathology
            - suppress
            + rebound                    # safe even if rebound_gain = 0
            + float(cfg.w_plasticity * st.plasticity)
            - burden_decay * st.burden
        )

        st.burden = _clip(
            st.burden + float(cfg.dt) * burden_dot,
            cfg.burden_min,
            cfg.burden_max
        )


        strength_drive = (
            cfg.strength_burden_gain * (st.burden - 1.0)   # burden above baseline worsens
            - cfg.strength_stim_gain * st.stim_ema         # sustained stim improves (anti-kindling)
        )

        # Strength updates (slow)
        st.primary_strength = _clip(
            st.primary_strength + cfg.primary_strength_rate * strength_drive,
            cfg.strength_min, cfg.strength_max
        )
        st.secondary_strength = _clip(
            st.secondary_strength + cfg.secondary_strength_rate * strength_drive,
            cfg.strength_min, cfg.strength_max
        )

        # Plasticity (slow drift with burden)
        st.plasticity = _clip(
            st.plasticity + cfg.plasticity_rate * (st.burden - 0.5),
            cfg.plasticity_min, cfg.plasticity_max
        )

        # Feedback mapping (8.4 output)
        drive_scale = float(np.clip(1.0 + cfg.feedback_drive_gain * (st.burden - 0.5), 0.2, 2.0))
        weight_scale = float(np.clip(1.0 - cfg.feedback_weight_gain * (st.burden - 0.5), 0.2, 2.0))

        feedback = self.get_drive()
        feedback["burden"] = float(st.burden)


        info = {
            "pathology": pathology,
            "suppress": suppress,
            "stim_intensity": stim_int,
            "rate_raw": rate_raw,
            "sync_raw": sync_raw,
            "burst_raw": burst_raw,
            "logEI_raw": logEI_raw,
        }

        info["burden"] = float(st.burden)
        print(info["feedback"])



        return st, feedback, info

    # -------------------------
    # Observation helper
    # -------------------------

    def observation_vector(
        self,
        features: Optional[Dict[str, float]] = None,
        *,
        include_site_strengths: bool = False,
        max_sites: int = 16,
    ) -> np.ndarray:
        """
        Produce a numeric observation vector from current state and optional recent features.

        Base vector (always):
          [burden, plasticity, primary_strength, secondary_strength]

        If features provided, append:
          [rate_hz, sync, burst, lfp]

        If include_site_strengths:
          append primary_site_strengths (padded/truncated)
          append secondary_site_strengths (padded/truncated)
        """
        st = self.state
        x = [float(st.burden), float(st.plasticity), float(st.primary_strength), float(st.secondary_strength)]

        if features is not None:
            x.extend([
                float(features.get("rate_hz", 0.0)),
                float(features.get("sync", 0.0)),
                float(features.get("burst", 0.0)),
                float(features.get("lfp", 0.0)),
            ])

        if include_site_strengths:
            p = st.primary_site_strengths if st.primary_site_strengths.size else np.zeros((0,), dtype=float)
            s = st.secondary_site_strengths if st.secondary_site_strengths.size else np.zeros((0,), dtype=float)

            def pad(v: np.ndarray) -> np.ndarray:
                v = v.astype(float).reshape(-1)
                if v.size >= max_sites:
                    return v[:max_sites]
                out = np.zeros((max_sites,), dtype=float)
                out[:v.size] = v
                return out

            x.extend(pad(p).tolist())
            x.extend(pad(s).tolist())

        return np.asarray(x, dtype=np.float32)


__all__ = [
    "L2Config",
    "L2State",
    "EpilepsyStateModel",
]


In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import numpy as np

import gymnasium as gym
from gymnasium import spaces


def _lin_map(x: float, lo: float, hi: float) -> float:
    x = float(np.clip(x, -1.0, 1.0))
    return lo + (x + 1.0) * 0.5 * (hi - lo)


def _waveform_from_code(x: float) -> str:
    x = float(x)
    if x < -0.333:
        return "rect"
    if x < 0.333:
        return "biphasic"
    return "burst"


def _int_from_code(x: float, lo: int, hi: int) -> int:
    v = _lin_map(float(x), lo - 0.49, hi + 0.49)
    return int(np.clip(int(round(v)), lo, hi))


@dataclass
class RewardConfig:
    # Reward = delta_weight*(prev_burden - burden) - burden_weight*burden - stim_weight*stim_cost
    burden_weight: float = 1.0
    delta_weight: float = 1.0
    stim_weight: float = 0.02


class DBSGymEnv(gym.Env):
    metadata = {"render_modes": []}

    def __init__(
        self,
        *,
        case_gen,
        plant,
        l2_model,
        env_cfg: Optional[Any] = None,
        seed: int = 0,
        reward_cfg: Optional[RewardConfig] = None,
        forced_tissue_mode: Optional[str] = None,
    ):
        super().__init__()
        self.case_gen = case_gen
        self.plant = plant
        self.l2 = l2_model
        self.cfg = env_cfg,  # alias for step() and other helpers

        self.episode_len = int(getattr(env_cfg, "episode_len", getattr(env_cfg, "episode_steps", 40))) if env_cfg is not None else 40
        self.horizon_steps = self.episode_len

        self.seed_value = int(seed)
        self.rng = np.random.default_rng(self.seed_value)

        self.env_cfg = env_cfg
        self.reward_cfg = reward_cfg or RewardConfig()
        self.forced_tissue_mode = forced_tissue_mode

        # ---- episode length (single source of truth) ----
        self.episode_len = int(getattr(env_cfg, "episode_len", getattr(env_cfg, "episode_steps", 40))) if env_cfg is not None else 40
        # keep horizon_steps as an alias if other code uses it
        self.horizon_steps = self.episode_len

        self.baseline_windows = int(getattr(env_cfg, "baseline_windows", 1)) if env_cfg is not None else 1

        # Action bounds from configs.EnvConfig.dbs_bounds if available
        if env_cfg is not None and hasattr(env_cfg, "dbs_bounds"):
            b = env_cfg.dbs_bounds
            self.amp_min_mA = float(getattr(b, "amp_mA_min", 0.0))
            self.amp_max_mA = float(getattr(b, "amp_mA_max", 5.0))
            self.freq_min_Hz = float(getattr(b, "freq_Hz_min", 5.0))
            self.freq_max_Hz = float(getattr(b, "freq_Hz_max", 200.0))
            self.pw_min_ms = float(getattr(b, "pw_ms_min", 0.05))
            self.pw_max_ms = float(getattr(b, "pw_ms_max", 0.5))
        else:
            self.amp_min_mA, self.amp_max_mA = 0.0, 5.0
            self.freq_min_Hz, self.freq_max_Hz = 5.0, 200.0
            self.pw_min_ms, self.pw_max_ms = 0.05, 0.5

        # 8D continuous action (your existing design)
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)

        # Observation vector length 17 (your existing design)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float32)

        self._step_i = 0
        self._case = None
        self._last_stim: Dict[str, Any] = {"amp_mA": 0.0, "freq_Hz": 0.0, "pw_ms": 0.0, "waveform": "rect"}
        self._last_features: Dict[str, float] = {"rate_hz": 0.0, "sync": 0.0, "burst": 0.0, "lfp": 0.0}
        self._prev_burden: float = 0.0

        # Optional: env.net_cfg set externally (as you already do)
        self.net_cfg = None


    # ----------------------------
    # Action/obs helpers (yours)
    # ----------------------------

    def _stim_to_wave_onehot(self, stim: Dict[str, Any]) -> np.ndarray:
        w = str(stim.get("waveform", "rect"))
        if w == "rect":
            return np.array([1.0, 0.0, 0.0], dtype=np.float32)
        if w == "biphasic":
            return np.array([0.0, 1.0, 0.0], dtype=np.float32)
        return np.array([0.0, 0.0, 1.0], dtype=np.float32)

    def _action_to_stim(self, action: np.ndarray) -> Dict[str, Any]:
        a = np.asarray(action, dtype=float).reshape(-1)
        a = np.clip(a, -1.0, 1.0)

        stim: Dict[str, Any] = {
            "amp_mA": float(_lin_map(a[0], self.amp_min_mA, self.amp_max_mA)),
            "freq_Hz": float(_lin_map(a[1], self.freq_min_Hz, self.freq_max_Hz)),
            "pw_ms": float(_lin_map(a[2], self.pw_min_ms, self.pw_max_ms)),
            "waveform": _waveform_from_code(a[3]),
            "duty_cycle": float(_lin_map(a[4], 0.2, 1.0)),
        }

        # waveform-specific fields
        if stim["waveform"] == "biphasic":
            stim["phase2_amp_ratio"] = float(_lin_map(a[5], -1.2, -0.8))
            stim["interphase_gap_ms"] = float(_lin_map(a[6], 0.0, 0.2))
        elif stim["waveform"] == "burst":
            stim["pulses_per_burst"] = int(_int_from_code(a[7], 2, 10))
            stim["intra_burst_freq_Hz"] = float(_lin_map(a[7], 50.0, 250.0))
            stim["burst_Hz"] = float(_lin_map(a[5], 2.0, 20.0))

        # hard safety clamps
        stim["amp_mA"] = float(np.clip(stim["amp_mA"], self.amp_min_mA, self.amp_max_mA))
        stim["freq_Hz"] = float(np.clip(stim["freq_Hz"], self.freq_min_Hz, self.freq_max_Hz))
        stim["pw_ms"]   = float(np.clip(stim["pw_ms"],   self.pw_min_ms,  self.pw_max_ms))

        # enforce min amp if desired
        min_amp = float(getattr(self.env_cfg, "min_amp_mA", 0.0)) if self.env_cfg is not None else 0.0
        stim["amp_mA"] = max(float(stim["amp_mA"]), min_amp)

        return stim

    def _case_desc(self, case) -> np.ndarray:
        sev = float(getattr(case, "severity", 0.0))
        base = float(getattr(case, "baseline_burden", 0.0))
        dist = 0.0
        try:
            dist = float(case.descriptors.get("primary_dist_mm", 0.0))
        except Exception:
            dist = 0.0
        return np.array([sev, base, dist], dtype=np.float32)

    def _obs(self) -> np.ndarray:
        # l2.observation_vector(features) must return length 8
        l2v = self.l2.observation_vector(self._last_features, include_site_strengths=False).astype(np.float32)

        last = np.array(
            [
                float(self._last_stim.get("amp_mA", 0.0)),
                float(self._last_stim.get("freq_Hz", 0.0)),
                float(self._last_stim.get("pw_ms", 0.0)),
            ],
            dtype=np.float32,
        )
        w = self._stim_to_wave_onehot(self._last_stim)
        cdesc = self._case_desc(self._case) if self._case is not None else np.zeros((3,), dtype=np.float32)

        obs = np.concatenate([l2v, last, w, cdesc], dtype=np.float32)
        return obs

    # ----------------------------
    # Gymnasium API
    # ----------------------------

    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None):
        super().reset(seed=seed)
        if seed is not None:
            self.seed_value = int(seed)
            self.rng = np.random.default_rng(self.seed_value)

        self._step_i = 0

        # Sample (or reuse) case
        reuse = bool(getattr(self, "freeze_case", False)) and (getattr(self, "_case", None) is not None)

        if reuse:
            case = self._case
        else:
            if self.forced_tissue_mode:
                case = self.case_gen.sample(tags={"forced_tissue_mode": self.forced_tissue_mode})
            else:
                case = self.case_gen.sample()
            self._case = case

        # Rebuild plant
        self.plant.build_from_case(case, net_cfg=getattr(self, "net_cfg", None))

        # Coupling: pass Vec3 (shape (3,)), not (1,3)
        stim_xyz_mm = np.asarray(case.electrode.xyz_mm, dtype=float).reshape(3,)
        self.plant.precompute_coupling(tuple(stim_xyz_mm.tolist()))

        # Reset L2
        self.l2.reset(case=case)

        # Default zero stimulation for baselining
        self._last_stim = {"amp_mA": 0.0, "freq_Hz": 0.0, "pw_ms": 0.0, "waveform": "rect", "duty_cycle": 1.0}

        # Initialize features with the keys your plant/L2 actually use (safe defaults)
        self._last_features = {
            "n_spikes": 0.0,
            "rate_hz": 0.0,
            "E_rate_hz": 0.0,
            "I_rate_hz": 0.0,
            "EI_ratio": 1.0,
            "burst_E": 0.0,
            "sync_E": 0.0,
            "lfp": 0.0,
        }

        # Baseline windows: run plant, update L2, apply feedback to plant
        for _ in range(max(0, int(self.baseline_windows))):
            feats = self.plant.run_window(self._last_stim)
            self._last_features = feats

            st, feedback, _info = self.l2.step(features=feats, stim=self._last_stim, case=case)
            if isinstance(feedback, dict):
                self.plant.update_focus_drive(feedback)
            else:
                self.plant.update_focus_drive({})

        # Set prev burden defensively
        burden = 0.0
        if hasattr(self.l2, "state") and self.l2.state is not None:
            burden = float(getattr(self.l2.state, "burden", 0.0))
        self._prev_burden = burden

        return self._obs(), {"case_id": getattr(case, "case_id", None)}


    def step(self, action: np.ndarray):
        self._step_i += 1

        # 1) action -> stim (ensure dict with required keys)
        stim = self._action_to_stim(action)
        if not isinstance(stim, dict):
            # fallback: convert to dict if user returned StimParams or array
            try:
                stim = dict(stim)  # type: ignore[arg-type]
            except Exception:
                stim = {
                    "amp_mA": float(stim[0]),
                    "freq_Hz": float(stim[1]),
                    "pw_ms": float(stim[2]),
                }
        stim.setdefault("waveform", "rect")
        stim.setdefault("duty_cycle", 1.0)

        self._last_stim = stim

        # 2) run Level-1 plant
        feats = self.plant.run_window(stim)
        self._last_features = feats

        # 3) Level-2 update
        st, feedback, l2_info = self.l2.step(features=feats, stim=stim, case=self._case)

        # 4) 8.4: apply feedback back into plant every step
        if isinstance(feedback, dict):
            self.plant.update_focus_drive(feedback)
        else:
            # be defensive
            self.plant.update_focus_drive({})

        # 5) reward: improvement-based + stimulation cost
        bur = float(getattr(st, "burden", 0.0))

        stim_cost = float(getattr(self.reward_cfg, "stim_cost", 0.05)) if self.reward_cfg is not None else 0.05
        stim_int = 0.0
        if isinstance(l2_info, dict):
            stim_int = float(l2_info.get("stim_intensity", 0.0))

        reward = (self._prev_burden - bur) - stim_cost * stim_int
        self._prev_burden = bur

        # 6) termination / truncation
        terminated = False
        truncated = bool(self._step_i >= int(self.episode_len))

        # 7) info
        info: Dict[str, Any] = {
            "case_id": getattr(self._case, "case_id", None),
            "stim": stim,
            "features": feats,
            "l2_state": st,
            "feedback": feedback,
        }
        if isinstance(l2_info, dict):
            info.update(l2_info)

        return self._obs(), float(reward), terminated, truncated, info



In [None]:
def build_env(
    seed: int = 0,
    episode_len: int = 50,
    forced_tissue_mode: str | None = None,
    freeze_case: bool = False,
):
    """
    Top-level factory function for creating DBSGymEnv.
    Put this in a notebook cell BELOW DBSGymEnv and all config classes.
    """
    mode = forced_tissue_mode or "BOUNDARY"

    # ---- Env config ----
    env_cfg = EnvConfig(
        episode_steps=int(episode_len),
        baseline_windows=3,
        dbs_bounds=DBSBounds(
            amp_mA_min=0.1, amp_mA_max=3.0,
            freq_Hz_min=5.0, freq_Hz_max=150.0,
            pw_ms_min=0.05, pw_ms_max=0.4,
        ),
        include_last_action_in_obs=True,
    )

    # ---- Plant config ----
    plant_cfg = PlantConfig(dt_ms=0.05, window_ms=250.0, sigma_S_per_m=0.2)

    # ---- Network config (passed into plant via env.net_cfg -> reset -> build_from_case) ----
    net_cfg = NetworkConfig(n_cells=60, frac_gm=0.7, tissue_mode=mode)

    # ---- Components ----
    case_gen = CaseGenerator(cfg=CaseGeneratorConfig(), rng_seed=seed)
    plant = FullNeuronPlant(plant_cfg, rng_seed=seed)   # FIX: use plant_cfg, not cfg.plant
    l2_model = EpilepsyStateModel(rng_seed=seed)

    # ---- Sanity checks ----
    assert type(plant).__name__ == "FullNeuronPlant", f"Using {type(plant)} not FullNeuronPlant"
    assert callable(getattr(plant, "run_window", None)), "plant.run_window missing"

    # ---- Env ----
    env = DBSGymEnv(
        case_gen=case_gen,
        plant=plant,
        l2_model=l2_model,
        env_cfg=env_cfg,
        seed=seed,
        forced_tissue_mode=mode,
    )

    # Your reset() already does: build_from_case(case, net_cfg=getattr(self, "net_cfg", None))
    env.net_cfg = net_cfg

    # If you want a single authoritative episode length in env:
    env.episode_len = int(episode_len)

    env.freeze_case = bool(freeze_case)


    return env



In [None]:
# sac_agent.py
"""
SAC agent wrapper for the DBS-epilepsy project.

This module prioritizes Stable-Baselines3 (SB3) SAC for correctness and speed of iteration.
If SB3 is not installed, it raises a clear error with installation instructions.

Assumptions:
- Your Gymnasium environment is implemented in env.py as DBSGymEnv (or similar).
- Action space is continuous Box, observation space is Box.

Install (recommended):
  pip install "stable-baselines3[extra]" gymnasium torch

Notes:
- SAC-MAML/meta-learning is NOT implemented here; this is a clean SAC baseline.
- This wrapper standardizes configuration, seeding, saving, evaluation rollouts, etc.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import os
import time

import numpy as np

try:
    from stable_baselines3 import SAC
    from stable_baselines3.common.callbacks import CheckpointCallback
    from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
    from stable_baselines3.common.monitor import Monitor
except Exception as e:  # pragma: no cover
    SAC = None  # type: ignore
    CheckpointCallback = None  # type: ignore
    DummyVecEnv = None  # type: ignore
    VecMonitor = None  # type: ignore
    Monitor = None  # type: ignore
    _SB3_IMPORT_ERROR = e
else:
    _SB3_IMPORT_ERROR = None


@dataclass(frozen=True)
class SACConfig:
    # Training
    total_timesteps: int = 200_000
    learning_rate: float = 3e-4
    buffer_size: int = 200_000
    batch_size: int = 256
    gamma: float = 0.99
    tau: float = 0.005
    train_freq: int = 1
    gradient_steps: int = 1
    learning_starts: int = 5_000

    # Exploration / entropy
    ent_coef: str = "auto"  # or float

    # Network architecture
    net_arch: Tuple[int, int] = (256, 256)

    # Evaluation / logging
    seed: int = 0
    log_dir: str = "runs/sac"
    save_every_steps: int = 50_000

    # Device
    device: str = "auto"


class SACAgent:
    def __init__(self, env, cfg: Optional[SACConfig] = None):
        if _SB3_IMPORT_ERROR is not None:
            raise RuntimeError(
                "Stable-Baselines3 is not installed or failed to import.\n\n"
                "Install with:\n"
                "  pip install \"stable-baselines3[extra]\" gymnasium torch\n\n"
                f"Import error: {_SB3_IMPORT_ERROR}"
            )

        self.env = env
        self.cfg = cfg or SACConfig()

        os.makedirs(self.cfg.log_dir, exist_ok=True)

        # Wrap env for logging
        # Use Monitor for episode returns/lengths; VecMonitor for vec env stats
        def _make():
            return Monitor(self.env)

        venv = DummyVecEnv([_make])
        venv = VecMonitor(venv, filename=os.path.join(self.cfg.log_dir, "monitor.csv"))
        self.venv = venv

        policy_kwargs = dict(net_arch=list(self.cfg.net_arch))

        self.model = SAC(
            policy="MlpPolicy",
            env=self.venv,
            learning_rate=self.cfg.learning_rate,
            buffer_size=self.cfg.buffer_size,
            batch_size=self.cfg.batch_size,
            tau=self.cfg.tau,
            gamma=self.cfg.gamma,
            train_freq=self.cfg.train_freq,
            gradient_steps=self.cfg.gradient_steps,
            learning_starts=self.cfg.learning_starts,
            ent_coef=self.cfg.ent_coef,
            policy_kwargs=policy_kwargs,
            verbose=1,
            device=self.cfg.device,
            seed=self.cfg.seed,
        )

    def train(self) -> str:
        """Train SAC and return the final model path."""
        ts = int(time.time())
        run_name = f"sac_{ts}"
        run_dir = os.path.join(self.cfg.log_dir, run_name)
        os.makedirs(run_dir, exist_ok=True)

        cb = CheckpointCallback(
            save_freq=self.cfg.save_every_steps,
            save_path=run_dir,
            name_prefix="checkpoint",
            save_replay_buffer=True,
            save_vecnormalize=False,
        )

        self.model.learn(total_timesteps=self.cfg.total_timesteps, callback=cb)

        final_path = os.path.join(run_dir, "final_model.zip")
        self.model.save(final_path)
        return final_path

    def save(self, path: str) -> None:
        self.model.save(path)

    @staticmethod
    def load(path: str, env) -> "SACAgent":
        """Load a trained SAC model and attach to env."""
        if _SB3_IMPORT_ERROR is not None:
            raise RuntimeError("SB3 not available.") from _SB3_IMPORT_ERROR

        agent = SACAgent.__new__(SACAgent)
        agent.env = env
        agent.cfg = SACConfig()
        agent.model = SAC.load(path, env=env)
        agent.venv = None
        return agent

    def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
        """Single-step action for a raw (non-vec) observation."""
        action, _ = self.model.predict(obs, deterministic=deterministic)
        return action

    def rollout(self, n_episodes: int = 5, deterministic: bool = True) -> Dict[str, float]:
        """Evaluate on the underlying raw env (not vectorized)."""
        returns = []
        lengths = []
        for _ in range(n_episodes):
            obs, _ = self.env.reset()
            done = False
            ep_ret = 0.0
            ep_len = 0
            while not done:
                action = self.act(obs, deterministic=deterministic)
                obs, reward, terminated, truncated, _info = self.env.step(action)
                done = bool(terminated or truncated)
                ep_ret += float(reward)
                ep_len += 1
            returns.append(ep_ret)
            lengths.append(ep_len)
        return {
            "return_mean": float(np.mean(returns)) if returns else 0.0,
            "return_std": float(np.std(returns)) if returns else 0.0,
            "len_mean": float(np.mean(lengths)) if lengths else 0.0,
        }


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


In [None]:
def main():
    # ---- configs ----
    plant_cfg = PlantConfig(dt_ms=0.05, window_ms=250.0, sigma_S_per_m=0.2)
    net_cfg = NetworkConfig(n_cells=60, frac_gm=0.7, tissue_mode="BOUNDARY")

    env_cfg = EnvConfig(
        episode_steps=20,
        baseline_windows=1,
        dbs_bounds=DBSBounds(
            amp_mA_min=0.0, amp_mA_max=3.0,
            freq_Hz_min=5.0, freq_Hz_max=150.0,
            pw_ms_min=0.05, pw_ms_max=0.4,
        ),
        include_last_action_in_obs=True,
    )

    # ---- runtime objects ----
    case_gen = CaseGenerator(
        cfg=CaseGeneratorConfig(),
        rng_seed=0,
    )

    plant = FullNeuronPlant(
        cfg=plant_cfg,
        rng_seed=0,
    )

    l2_model = EpilepsyStateModel(rng_seed=0)

    # ---- environment ----
    env = DBSGymEnv(
        case_gen=case_gen,
        plant=plant,
        l2_model=l2_model,
        env_cfg=env_cfg,
        seed=0,
        forced_tissue_mode="BOUNDARY",
    )

    # Make net_cfg available to env / plant
    env.net_cfg = net_cfg

    # ---- smoke test ----
    obs, info = env.reset()
    print("Reset obs shape:", obs.shape)

    for t in range(env_cfg.episode_steps):
        a = env.action_space.sample()
        obs, r, term, trunc, info = env.step(a)
        print(
            f"t={t:02d} r={r:.3f} "
            f"burden={info['l2_state'].burden:.3f} "
            f"stim={info['stim']} "
            f"feats={info['features']}"
        )
        if term or trunc:
            break


import inspect
print(inspect.signature(EpilepsyStateModel.__init__))


if __name__ == "__main__":
    main()


(self, cfg: 'Optional[L2Config]' = None, rng_seed: 'int' = 0)
[onset] primary_sites: (5, 3) secondary_sites: (8, 3)
[onset] primary drives: 5 secondary drives: 8
[record] n=8 thr=-20.0 E=4 I=4


KeyError: 'feedback'

In [None]:
import numpy as np
import copy

def _deepcopy_l2_state(l2):
    """
    Robust snapshot/restore for your L2 model state.
    L2State contains numpy arrays; copy.deepcopy is fine here.
    """
    return copy.deepcopy(l2.state)

def _restore_l2_state(l2, state_snapshot):
    l2.state = copy.deepcopy(state_snapshot)

def rollout_fixed_stim(env, stim_dict, steps=40, *, reset_case=False, seed=0):
    """
    Rollout for a fixed stimulation dictionary.
    If reset_case=True, env.reset(seed=seed) will resample a new case.
    For Fix A, you generally want reset_case=False to keep the same case.
    Returns arrays of burden, E_rate, I_rate, sync_E, n_spikes.
    """
    if reset_case:
        env.reset(seed=seed)

    burdens = []
    E_rates = []
    I_rates = []
    syncEs  = []
    spikes  = []

    for _ in range(steps):
        feats = env.plant.run_window(stim_dict)
        env._last_features = feats
        env._last_stim = stim_dict

        st, feedback, l2_info = env.l2.step(features=feats, stim=stim_dict, case=env._case)

        # Keep the intended L2->L1 coupling consistent with your env.step()
        env.plant.update_focus_drive(feedback)

        burdens.append(float(getattr(st, "burden", np.nan)))
        E_rates.append(float(feats.get("E_rate_hz", feats.get("rate_hz", 0.0))))
        I_rates.append(float(feats.get("I_rate_hz", 0.0)))
        syncEs.append(float(feats.get("sync_E", feats.get("sync", 0.0))))
        spikes.append(float(feats.get("n_spikes", 0.0)))

    return {
        "burden": np.asarray(burdens, dtype=float),
        "E_rate_hz": np.asarray(E_rates, dtype=float),
        "I_rate_hz": np.asarray(I_rates, dtype=float),
        "sync_E": np.asarray(syncEs, dtype=float),
        "n_spikes": np.asarray(spikes, dtype=float),
    }

def compare_off_on_same_case(
    seed=0,
    episode_len=50,
    forced_tissue_mode="BOUNDARY",
    steps=40,
    stim_on=None,
):
    """
    Fix A: OFF vs ON comparison with the SAME case.
    - Builds env
    - Resets once (samples one case)
    - Snapshots L2 state
    - Runs OFF rollout (amp=0) from snapshot
    - Restores snapshot
    - Runs ON rollout from snapshot
    - Prints summary deltas
    """
    env = build_env(seed=seed, episode_len=episode_len, forced_tissue_mode=forced_tissue_mode)
    obs, info = env.reset(seed=seed)

    # Default ON stimulation (reasonable DBS-like starting point)
    if stim_on is None:
        stim_on = {
            "amp_mA": 1.5,
            "freq_Hz": 130.0,
            "pw_ms": 0.15,
            "waveform": "rect",
            "duty_cycle": 1.0,
        }

    stim_off = {
        "amp_mA": 0.0,
        "freq_Hz": float(stim_on.get("freq_Hz", 130.0)),
        "pw_ms": float(stim_on.get("pw_ms", 0.15)),
        "waveform": str(stim_on.get("waveform", "rect")),
        "duty_cycle": float(stim_on.get("duty_cycle", 1.0)),
    }

    # Snapshot L2 state after reset (and after baseline windows your env.reset() ran)
    snap = _deepcopy_l2_state(env.l2)

    # OFF rollout
    _restore_l2_state(env.l2, snap)
    off = rollout_fixed_stim(env, stim_off, steps=steps, reset_case=False, seed=seed)

    # ON rollout (restore same initial L2 state)
    _restore_l2_state(env.l2, snap)
    on = rollout_fixed_stim(env, stim_on, steps=steps, reset_case=False, seed=seed)

    # Summaries
    def summarize(tag, d):
        return {
            "tag": tag,
            "burden_end": float(d["burden"][-1]),
            "burden_mean": float(np.mean(d["burden"])),
            "E_rate_mean": float(np.mean(d["E_rate_hz"])),
            "I_rate_mean": float(np.mean(d["I_rate_hz"])),
            "syncE_mean": float(np.mean(d["sync_E"])),
            "spikes_mean": float(np.mean(d["n_spikes"])),
        }

    s_off = summarize("OFF", off)
    s_on  = summarize("ON", on)

    print("Case:", getattr(env._case, "case_id", None), "mode:", forced_tissue_mode)
    print("OFF summary:", s_off)
    print("ON  summary:", s_on)

    # Deltas (ON - OFF): negative is suppressive for burden/E_rate/sync/spikes
    print("\nDeltas (ON - OFF):")
    print("  burden_end:", s_on["burden_end"] - s_off["burden_end"])
    print("  burden_mean:", s_on["burden_mean"] - s_off["burden_mean"])
    print("  E_rate_mean:", s_on["E_rate_mean"] - s_off["E_rate_mean"])
    print("  I_rate_mean:", s_on["I_rate_mean"] - s_off["I_rate_mean"])
    print("  syncE_mean:", s_on["syncE_mean"] - s_off["syncE_mean"])
    print("  spikes_mean:", s_on["spikes_mean"] - s_off["spikes_mean"])

    return env, off, on


In [None]:
env, off, on = compare_off_on_same_case(
    seed=0,
    forced_tissue_mode="BOUNDARY",
    steps=40,
    stim_on={"amp_mA": 1.5, "freq_Hz": 130.0, "pw_ms": 0.15, "waveform": "rect", "duty_cycle": 1.0},
)


In [None]:
print("Has FullNeuronPlant.run_window?", "run_window" in FullNeuronPlant.__dict__)
print("FullNeuronPlant.run_window is:", FullNeuronPlant.__dict__.get("run_window", None))
print("MRO:", [c.__name__ for c in FullNeuronPlant.mro()])


In [None]:
print("Has FullNeuronPlant.run_window?", "run_window" in FullNeuronPlant.__dict__)
print("Bound method comes from:", FullNeuronPlant.run_window.__qualname__)


In [None]:
# Pre-flight: confirm build_env exists (and if not, tell us why)
print("build_env defined?", "build_env" in globals())

# If not defined, run all cells above this one, or re-run the big definitions cell.
# Optional: list near-matches that may exist
print([k for k in globals().keys() if "build" in k.lower() or "env" in k.lower()])

print("build_env defined?", "build_env" in globals())



In [None]:
import numpy as np

def rollout_burden(env, steps=50, seed=0):
    obs, info = env.reset(seed=seed)
    burdens, spike_counts = [], []

    # infer window length in seconds (fallback: 250 ms)
    window_ms = 400.0
    for obj in [env, getattr(env, "unwrapped", None)]:
        if obj is None:
            continue
        plant = getattr(obj, "plant", None)
        if plant is not None and hasattr(plant, "cfg"):
            window_ms = float(getattr(plant.cfg, "window_ms", window_ms))
            break
    window_s = window_ms / 1000.0


    for t in range(steps):
        action = env.action_space.sample()
        obs, r, term, trunc, info = env.step(action)

        b_val = None

        # 1) Preferred: L2 state if your env exposes it
        if hasattr(env, "l2") and hasattr(env.l2, "state") and hasattr(env.l2.state, "burden"):
            b_val = float(env.l2.state.burden)

        # 2) If you store it on env directly
        elif hasattr(env, "l2_state") and hasattr(env.l2_state, "burden"):
            b_val = float(env.l2_state.burden)

        # 3) Info dict (works if you set info["burden"] = st.burden in L2)
        elif isinstance(info, dict) and ("burden" in info):
            try:
                b_val = float(info["burden"])
            except Exception:
                b_val = None

        if b_val is None or not np.isfinite(b_val):
            raise RuntimeError(f"Burden missing/invalid at step {t}. info keys={list(info.keys())}")

        burdens.append(b_val)

        feats = info.get("features", {}) or {}
        if "n_spikes" in feats:
            spike_counts.append(float(feats["n_spikes"]))
        else:
            rate_hz = float(feats.get("rate_hz", 0.0))
            spike_counts.append(rate_hz * window_s)


        if term or trunc:
            break

    return np.array(burdens), np.array(spike_counts)

env = build_env(seed=0, episode_len=50, forced_tissue_mode="BOUNDARY")
b, s = rollout_burden(env, steps=50, seed=0)

print("Burden min/max:", np.nanmin(b), np.nanmax(b))
print("SpikeCount(win) min/max:", np.nanmin(s), np.nanmax(s))



In [None]:
def rollout_fixed_stim(env_builder, steps=100, seed=0, mode="BOUNDARY", case_id=None, stim_on=None):
    """
    Builds two envs with identical case/seed and runs OFF vs ON.
    env_builder: a callable like build_env(...) that returns a fresh env.
    stim_on: dict of stim params for ON. OFF uses amp_mA=0 with same rest.
    """

    # Build OFF env and freeze its case after reset
    env_off = env_builder(seed=seed, episode_len=steps, forced_tissue_mode=mode)
    obs_off, info_off = env_off.reset(seed=seed)

    # If your env exposes case_id after reset, capture it
    # Otherwise provide case_id explicitly to your builder if supported.
    if case_id is None:
        case_id = getattr(getattr(env_off, "case", None), "case_id", None) or info_off.get("case_id", None)

    # Build ON env with the SAME case_id if your builder supports it.
    # If it doesn't, you must modify build_env/build_case to accept a fixed case_id.
    env_on = env_builder(seed=seed, episode_len=steps, forced_tissue_mode=mode, forced_case_id=case_id) \
             if "forced_case_id" in env_builder.__code__.co_varnames else env_builder(seed=seed, episode_len=steps, forced_tissue_mode=mode)
    obs_on, info_on = env_on.reset(seed=seed)

    if stim_on is None:
        stim_on = {"amp_mA": 1.0, "freq_Hz": 130.0, "pw_ms": 0.1, "waveform": "rect", "duty_cycle": 1.0}

    stim_off = dict(stim_on)
    stim_off["amp_mA"] = 0.0

    def run(env, stim):
        burdens, spikes = [], []
        for t in range(steps):
            obs, r, term, trunc, info = env.step(stim)
            # prefer info burden; ensure L2 sets it
            burdens.append(float(info["burden"]))
            feats = info.get("features", {}) or {}
            spikes.append(float(feats.get("n_spikes", 0.0)))
            if term or trunc:
                break
        return np.asarray(burdens), np.asarray(spikes)

    b_off, s_off = run(env_off, stim_off)
    b_on,  s_on  = run(env_on,  stim_on)

    return {
        "case_id": case_id,
        "OFF": {"burden_end": float(b_off[-1]), "burden_mean": float(np.mean(b_off)), "spikes_mean": float(np.mean(s_off))},
        "ON":  {"burden_end": float(b_on[-1]),  "burden_mean": float(np.mean(b_on)),  "spikes_mean": float(np.mean(s_on))},
        "b_off": b_off, "b_on": b_on, "s_off": s_off, "s_on": s_on
    }


In [None]:
import numpy as np
from neuron import h

def quick_spike_test(
    sec,
    amp_nA=0.2,
    delay_ms=5.0,
    dur_ms=5.0,
    tstop_ms=30.0,
    v_init=-65.0,
    spike_threshold_mV=0.0,
):
    """
    Inject a brief current pulse into a NEURON Section and report Vm excursion.
    Also reports whether a spike-like crossing occurred (simple threshold crossing).
    """
    # Ensure active excitability (HH); ignore if already present or not allowed
    try:
        sec.insert("hh")
    except Exception:
        pass

    stim = h.IClamp(sec(0.5))
    stim.delay = float(delay_ms)
    stim.dur   = float(dur_ms)
    stim.amp   = float(amp_nA)   # nA

    v = h.Vector(); v.record(sec(0.5)._ref_v)
    t = h.Vector(); t.record(h._ref_t)

    h.finitialize(float(v_init))
    h.continuerun(float(tstop_ms))

    t_np = np.asarray(t, dtype=float)
    v_np = np.asarray(v, dtype=float)

    vmax = float(np.max(v_np)) if v_np.size else float("nan")
    vmin = float(np.min(v_np)) if v_np.size else float("nan")
    spiked = bool(np.any(v_np >= float(spike_threshold_mV)))

    print(
        f"Vm min/max: {vmin:.3f} / {vmax:.3f} mV "
        f"(IClamp amp={amp_nA} nA) | spike>= {spike_threshold_mV} mV: {spiked}"
    )
    return t_np, v_np


def pick_test_section(env):
    """
    Robustly find a reasonable GM soma section to test, across your evolving plant code.
    Priority:
      1) A recorded GM section (if the plant exposes one)
      2) First GM soma in env.plant.gm_secs
      3) Any section whose name starts with 'gm_soma_'
      4) Fallback: first NEURON section in the model
    """
    p = env.plant

    # 1) If you store recorded sections explicitly (optional)
    for attr in ("_record_secs", "record_secs", "_recorded_secs"):
        if hasattr(p, attr):
            secs = getattr(p, attr)
            if isinstance(secs, (list, tuple)) and len(secs) > 0:
                return secs[0], f"env.plant.{attr}[0]"

    # 2) Your current plant uses gm_secs (list of Sections)
    if hasattr(p, "gm_secs") and isinstance(p.gm_secs, (list, tuple)) and len(p.gm_secs) > 0:
        return p.gm_secs[0], "env.plant.gm_secs[0]"

    # 3) Try to find a gm_soma_* section among all sections
    all_secs = list(h.allsec())
    for sec in all_secs:
        try:
            if str(sec.name()).startswith("gm_soma_"):
                return sec, "first h.allsec() with name gm_soma_*"
        except Exception:
            pass

    # 4) Final fallback: first section in the model
    if len(all_secs) == 0:
        raise RuntimeError("No NEURON sections exist. Build the plant first (call env.reset()).")
    return all_secs[0], "h.allsec()[0]"


# --------- usage ----------
# Ensure the model is built first
# obs, info = env.reset()

sec0, how = pick_test_section(env)
print(f"Testing section: {sec0.name()} (picked via {how})")

t_ms, v_mV = quick_spike_test(sec0, amp_nA=0.2, spike_threshold_mV=0.0)



In [None]:
obs, info = env.reset(seed=0)
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())


In [None]:
assert np.isfinite(reward)


In [None]:
import time

obs, info = env.reset(seed=0)

t0 = time.time()
for _ in range(20):
    obs, r, terminated, truncated, info = env.step(env.action_space.sample())
t1 = time.time()

step_time = (t1 - t0) / 20
print(f"env.step() ≈ {step_time*1000:.1f} ms")

print("Estimated time for 100k steps:",
      f"{step_time * 5_000 / 60:.1f} minutes")


In [None]:
# --- SAC training cell with:
# 1) TensorBoard logging (reward + SAC losses)
# 2) Custom logging of burden/spikes/E_rate/sync_E via SB3 logger
# 3) monitor.csv episode reward logging (already handled by Monitor/VecMonitor)
# 4) Optional: quick post-train evaluation rollout

# DO NOT import sac_agent – it's already defined in this notebook

from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback

# ---------- Custom callback to log burden + neural features ----------
class DBSMetricsCallback(BaseCallback):
    def __init__(self, verbose: int = 0):
        super().__init__(verbose=verbose)

    def _on_step(self) -> bool:
        infos = self.locals.get("infos", None)
        if not infos:
            return True

        info = infos[0] if isinstance(infos, (list, tuple)) else infos
        if not isinstance(info, dict):
            return True

        # Level-2 metric
        if "burden" in info and info["burden"] is not None:
            self.logger.record("custom/burden", float(info["burden"]))

        # Neural features
        feats = info.get("features", {}) or {}
        for k in ["n_spikes", "rate_hz", "E_rate_hz", "I_rate_hz", "sync_E", "burst_E", "logEI"]:
            if k in feats and feats[k] is not None:
                self.logger.record(f"custom/{k}", float(feats[k]))

        # Stimulation parameters (optional)
        stim = info.get("stim", None)
        if isinstance(stim, dict):
            for k in ["amp_mA", "freq_Hz", "pw_ms", "duty_cycle"]:
                if k in stim and stim[k] is not None:
                    self.logger.record(f"custom/stim_{k}", float(stim[k]))

        return True


# ---------- Build environment ----------
env = build_env(
    seed=0,
    episode_len=100,
    forced_tissue_mode="BOUNDARY",
    freeze_case=True   # ⬅️ add this
)

# ---------- SAC configuration ----------
cfg = SACConfig(
    total_timesteps=5_000,
    learning_rate=3e-4,
    buffer_size=200_000,
    batch_size=256,
    gamma=0.99,
    tau=0.005,
    learning_starts=250,
    train_freq=1,
    gradient_steps=1,
    seed=0,
    log_dir="runs/sac_boundary",
    save_every_steps=50_000,
)

# ---------- Create agent ----------
agent = SACAgent(env, cfg)

# ---------- Train with callbacks ----------
checkpoint_cb = CheckpointCallback(
    save_freq=cfg.save_every_steps,
    save_path=cfg.log_dir,
    name_prefix="checkpoint",
    save_replay_buffer=True,
    save_vecnormalize=False,
)

metrics_cb = DBSMetricsCallback()

agent.model.learn(
    total_timesteps=cfg.total_timesteps,
    callback=[checkpoint_cb, metrics_cb],
)

# ---------- Save final model ----------
final_model_path = f"{cfg.log_dir}/final_model.zip"
agent.model.save(final_model_path)
print("Saved model to:", final_model_path)

# ---------- Quick evaluation ----------
print("Eval:", agent.rollout(n_episodes=3, deterministic=True))


In [None]:
obs, info = env.reset(seed=0)
print("reset info:", info)

action = env.action_space.sample()
obs, r, terminated, truncated, info = env.step(action)

print("reward:", r, "done:", terminated or truncated)
print("info keys:", list(info.keys()))
print("info['features']:", info.get("features", None))
print("burden:", info.get("burden", None))
print("burden:", info["l2_state"]["burden"])



In [None]:
%load_ext tensorboard
%tensorboard --logdir runs
