In [None]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal

import collections
import random
import time
from tqdm import tqdm


In [None]:
DEVICE = torch.device("cpu")
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)


In [None]:
import numpy as np


class EpilepticNeuralMassJR:
    """
    Jansen–Rit-style neural mass model with an explicit DBS input.

    Populations:
        - P: pyramidal cells
        - E: excitatory interneurons
        - I: inhibitory interneurons

    State vector y = [y1, y2, y3, y4, y5, y6]
        y1: P average membrane potential
        y2: E average membrane potential
        y3: I average membrane potential
        y4: d/dt y1
        y5: d/dt y2
        y6: d/dt y3

    Dynamics:
        Second-order PSP operators for each population, driven by sigmoidal
        firing from the others and external input p(t) + DBS(t).

    DBS coupling:
        - DBS current enters as an additive term to the pyramidal input
          (you can change this if you prefer to target E or I).

    All parameters are exposed so you can push the model into
    more “epileptogenic” regimes if needed.
    """

    def __init__(
        self,
        A=3.25,      # mV, excitatory gain
        B=22.0,      # mV, inhibitory gain
        a=100.0,     # s^-1, excitatory inverse time constant
        b=50.0,      # s^-1, inhibitory inverse time constant
        C=135.0,     # average connectivity
        e0=2.5,      # max firing rate (s^-1)
        v0=6.0,      # sigmoidal threshold (mV)
        r=0.56,      # sigmoid slope (mV^-1)
        p_mean=120.0,  # baseline external input to P (s^-1)
        p_std=30.0,    # noise on external input
        epileptogenic_boost=1.3,  # scale on excitatory connectivity
        rng=None,
    ):
        # --- store base connectivity so we can re-apply boost later (plasticity) ---
        self._C_base = float(C)

        # Default endpoint values for disease mapping (tune if you want)
        self.p_mean_normal = 90.0
        self.p_mean_ictal  = 130.0
        self.boost_normal  = 1.0
        self.boost_ictal   = 1.3

        # initialise current parameters
        self.p_mean = float(p_mean)
        self.p_std  = float(p_std)

        # set connectivity using the initial boost
        self._set_connectivity(float(epileptogenic_boost))


        self.A = A
        self.B = B
        self.a = a
        self.b = b

        self.e0 = e0
        self.v0 = v0
        self.r = r

        self.p_mean = p_mean
        self.p_std = p_std

        self.rng = np.random.default_rng() if rng is None else rng

                # --- Slow disease / excitability state --------------------------------
        # disease_level in [0,1]: 0 = normal, 1 = strongly epileptic
        self.disease_level = 0.0

        # how fast the system tends to get worse (per RL step) if untreated
        self.disease_drift = 0.01     # tune: 0.005–0.02

        # how strongly seizure suppression pushes back toward normal
        self.disease_control_gain = 0.1  # tune: 0.05–0.2

        # JR parameter values for the two extremes.
        # Replace "p_drive" with whatever your JR uses (e.g. p, I_ext, etc.).
        self.jr_params_normal = {"p_drive": 90.0}
        self.jr_params_ictal  = {"p_drive": 130.0}


    def _set_connectivity(self, boost: float) -> None:
        """Reset connectivity constants from base C and a boost factor (no accumulation)."""
        C = self._C_base
        boost = float(boost)

        self.C1 = C * boost
        self.C2 = 0.8 * C * boost
        self.C3 = 0.25 * C
        self.C4 = 0.25 * C

    def set_disease_level(self, d: float) -> None:
        """
        d in [0,1]. 0 = healthy, 1 = epileptogenic.
        Maps disease to background drive + excitatory connectivity boost.
        """
        d = float(np.clip(d, 0.0, 1.0))

        self.p_mean = (1.0 - d) * self.p_mean_normal + d * self.p_mean_ictal
        boost = (1.0 - d) * self.boost_normal + d * self.boost_ictal
        self._set_connectivity(boost)


    import numpy as np

    def _normalize_seizure_raw_safe(self, raw: float) -> float:
        # allow reset/calibration to run before baseline/scale are set
        if getattr(self, "seizure_baseline", None) is None or getattr(self, "seizure_scale", None) is None:
            return 0.0

        base = float(self.seizure_baseline)
        s = max(float(self.seizure_scale), 1e-3)

        z = (float(raw) - base) / s
        z = max(0.0, z)
        idx = z / (1.0 + z)
        return float(np.clip(idx, 0.0, 1.0))


    def S(self, v):
        """
        Sigmoidal firing-rate function.
        """
        return 2.0 * self.e0 / (1.0 + np.exp(self.r * (self.v0 - v)))

    def _deriv(self, y, t, dbs_val):
        """
        Compute time derivative dy/dt for state y at time t given DBS amplitude.
        """
        y1, y2, y3, y4, y5, y6 = y

        # External input p(t): noisy drive + can be modulated if needed
        p_t = self.p_mean + self.p_std * self.rng.standard_normal()

        # Firing rates
        S_p = self.S(y2 - y3)          # pyramidal output driven by E and I
        S_e = self.S(self.C1 * y1)     # excitatory input from pyramidal
        S_i = self.S(self.C3 * y1)     # inhibitory input from pyramidal

        # Add DBS as extra input to pyramidal (can interpret as current)
        # DBS is added to the excitatory input term here:
        # you can scale it or change how it couples.
        I_dbs = dbs_val

        # Differential equations (Jansen–Rit)
        dy1_dt = y4
        dy4_dt = self.A * self.a * (S_e + p_t + I_dbs) - 2.0 * self.a * y4 - (self.a ** 2) * y1

        dy2_dt = y5
        dy5_dt = self.A * self.a * (self.C2 * S_p) - 2.0 * self.a * y5 - (self.a ** 2) * y2

        dy3_dt = y6
        dy6_dt = self.B * self.b * (self.C4 * S_i) - 2.0 * self.b * y6 - (self.b ** 2) * y3

        return np.array([dy1_dt, dy2_dt, dy3_dt, dy4_dt, dy5_dt, dy6_dt])

        # ------------------------------------------------------------------
    # Map disease_level ∈ [0,1] into JR parameters
    # ------------------------------------------------------------------
    def _apply_jr_params_from_disease(self):
        """
        Map self.disease_level in [0,1] to a JR background drive between
        'normal' and 'ictal'.
        """
        p_norm = self.jr_regimes["normal"]["p_drive"]
        p_ict  = self.jr_regimes["ictal"]["p_drive"]

        # linear interpolation: 0 -> normal, 1 -> ictal
        p_current = (1.0 - self.disease_level) * p_norm + self.disease_level * p_ict

        # write this into the actual JR model parameter:
        self.jr.p_drive = p_current      # <<< or whatever param name you use



    def simulate(
        self,
        T=10.0,
        dt=0.0005,
        dbs_fun=None,
        y0=None,
        record_downsample=10,
    ):
        """
        Simulate the neural mass for duration T with step dt.

        Parameters
        ----------
        T : float
            Total duration in seconds.
        dt : float
            Integration timestep in seconds.
        dbs_fun : callable or None
            Function u = dbs_fun(t) giving DBS drive at time t (in arbitrary units).
            If None, DBS = 0 at all times.
        y0 : array-like or None
            Initial state. If None, start at zeros.
        record_downsample : int
            Store every Nth sample to reduce output size.

        Returns
        -------
        t_rec : (N_rec,) array
            Recorded time points.
        y_rec : (N_rec, 6) array
            State time series.
        v_out : (N_rec,) array
            “EEG-like” output: pyramidal potential y1.
        """
        n_steps = int(T / dt)
        if y0 is None:
            y = np.zeros(6, dtype=float)
        else:
            y = np.array(y0, dtype=float)

        if dbs_fun is None:
            def dbs_fun(t):
                return 0.0

        # Preallocate output lists
        t_list = []
        y_list = []
        v_list = []

        for k in range(n_steps):
            t = k * dt
            u = float(dbs_fun(t))

            # simple Euler–Maruyama integration
            dy = self._deriv(y, t, u)
            y = y + dt * dy

            if k % record_downsample == 0:
                t_list.append(t)
                y_list.append(y.copy())
                v_list.append(y[0])  # pyramidal potential

        t_rec = np.asarray(t_list)
        y_rec = np.vstack(y_list)
        v_out = np.asarray(v_list)

        return t_rec, y_rec, v_out


In [None]:
class EpilepsyDBSCombinedEnv(gym.Env):
    def __init__(
        self,
        jr_model=None,
        pac_model=None,
        step_T=0.1,
        dt=0.001,
        max_steps=100,
        amp_min=0.0,
        amp_max=5.0,
        freq_min=10.0,
        freq_max=200.0,
        pw_min=60.0,
        pw_max=450.0,
        amp_delta_max=0.25,
        freq_delta_max=5.0,
        pw_delta_max=10.0,
        w_seizure=3.0,
        w_energy=0.2,
        w_slew=0.005,
        w_disease = 0.0, # just initially, change later
        log_best_episodes=False,
        n_best_episodes=10,
        rng=None,
        default_regime="normal",   # <--- NEW: which regime to use by default
        # --- seizure metric config (NEW) ---
        seizure_metric="bandpower_ratio",   # or "line_length"
        seizure_band=(8.0, 30.0),           # adjust to your JR seizure oscillation
        total_band=(1.0, 80.0),             # total band for normalization
        seizure_norm="tanh",                # "tanh" or "logistic"
        seizure_scale=0.15,                 # will be auto-calibrated if you run calibrate()
        seizure_target=0.2,                # optional: for shaping / logging


    ):
        super().__init__()

        self.prev_seizure_index = None
        self.w_delta_seizure = 0.5

        self.seiz_mu = 0.0
        self.seiz_var = 1.0
        self.seiz_n = 0
        self.seiz_eps = 1e-6
        self.seiz_sigmoid_k = 2.0   # slope; 1–3 is typical
        self.seiz_update_stats = True  # during training; maybe False in eval

        # ---- Option B: latent burden dynamics ----
        self.use_burden_state = True

        # burden thresholds for regime labels
        self.burden_thr_normal = 0.20
        self.burden_thr_ictal  = 0.50

        # burden dynamics parameters (tune later)
        self.burden_drift_normal   = 0.002   # per step
        self.burden_drift_preictal = 0.006
        self.burden_drift_ictal    = 0.012

        self.burden_relief_gain = 0.020      # how strongly DBS reduces burden
        self.burden_noise_std   = 0.010      # stochasticity per step

        self.disease_up_rate = 0.010    # increase when burden high
        self.disease_down_rate = 0.002  # recovery when burden low
        self.disease_floor = 0.0


        # map stimulation params to "effective control"
        self.stim_k_amp  = 1.0
        self.stim_k_freq = 1.0
        self.stim_k_pw   = 1.0

        # optional: prevent unrealistic "infinite DBS = instant cure"
        self.burden_relief_cap = 0.05        # max reduction per step

        # store parameters
        self.w_seizure = float(w_seizure)
        self.w_energy = float(w_energy)
        self.w_slew = float(w_slew)
        self.w_disease = float(w_disease)   # <--- NEW

        self.seizure_metric = seizure_metric
        self.seizure_band = seizure_band
        self.total_band = total_band
        self.seizure_norm = seizure_norm
        self.seizure_scale = float(seizure_scale)
        self.seizure_target = float(seizure_target)

        self.dt = float(dt)
        self.step_T = float(step_T)
        self.fs = 1.0 / self.dt

        # ---- coupling of burden to physiology metric ----
        self.burden_phys_gain = 0.030   # tune later
        self.burden_phys_tau = 10.0
        self._phys_ema = 0.0

        self.phys_raw_lo = 0.70
        self.phys_raw_hi = 1.10

        self.w_seiz = 1.0
        self.w_energy = 0.0
        self.w_slew = 0.0
        # optional: reward scaling
        self.reward_scale = 1.0




        # --- Seizure baseline (normal activity) -----------------------
        self.seizure_baseline = None  # calibrated in reset()
        self.baseline_margin = 0.02

        self.jr = jr_model if jr_model is not None else EpilepticNeuralMassJR()

                # --- disease / regime handling --------------------------------
        # "normal", "preictal", "ictal"
        self.default_regime = default_regime
        self.disease_level = 0.0   # 0 = normal, 0.5 = preictal, 1 = ictal
        # ---------------------------------------------------------------


        self.step_T = float(step_T)
        self.dt = float(dt)
        self.steps_per_step = int(self.step_T / self.dt)
        self.max_steps = int(max_steps)

        # DBS ranges
        self.amp_min = float(amp_min)
        self.amp_max = float(amp_max)
        self.freq_min = float(freq_min)
        self.freq_max = float(freq_max)
        self.pw_min = float(pw_min)
        self.pw_max = float(pw_max)

        self.amp_delta_max = float(amp_delta_max)
        self.freq_delta_max = float(freq_delta_max)
        self.pw_delta_max = float(pw_delta_max)

        # logging
        self.log_best_episodes = bool(log_best_episodes)
        self.n_best_episodes = int(n_best_episodes)
        self.best_episodes = []
        self.current_episode_log = None
        self.episode_reward = 0.0

        self.rng = np.random.default_rng() if rng is None else rng

        # internal state
        self.t = 0.0
        self.jr_state = None
        self.pac_state = None
        self.current_params = None
        self.prev_params = None
        self.step_count = 0

                # -------------------------------
        # Disease state & progression
        # -------------------------------
        self.default_regime = default_regime  # "normal", "preictal", or "ictal"
        self.disease_level = 0.0             # scalar in [0, 1]

        # how fast it drifts toward epilepsy if you do nothing (per step)
        self.disease_drift = 0.01          # tune as you like

        # how strongly seizure reduction pushes disease_level back down
        self.disease_gain = 0.05              # weight for delta_seiz effect

        # hard bounds
        self.disease_min = 0.0
        self.disease_max = 1.0


        # reference scales
        self._seizure_scale = 0.25
        self._energy_ref = (self.amp_max ** 2) * self.freq_max * (self.pw_max / self.pw_max)

        # --- NEW: JR regimes ---
        # Replace 'p_drive' with the actual attribute / parameter name in your JR model.
        self.jr_regimes = {
            "normal":   {"p_drive": 80.0},   # low background input
            "preictal": {"p_drive": 100.0},  # near bifurcation
            "ictal":    {"p_drive": 130.0},  # clearly epileptic
        }
        if default_regime not in self.jr_regimes:
            raise ValueError(f"Unknown default_regime '{default_regime}'")
        self.current_regime = default_regime
        # ------------------------

        self.observation_space = spaces.Box(
            low=np.zeros(5, dtype=np.float32),
            high=np.ones(5, dtype=np.float32),
            dtype=np.float32,
        )

        self.action_space = spaces.Box(
            low=-1.0,
            high=1.0,
            shape=(3,),
            dtype=np.float32,
        )


        # ------------------------------------------------------------------
    # Disease state / regime helpers
    # ------------------------------------------------------------------
    def _set_disease_from_regime(self, regime: str):
        """
        Map a discrete regime label to a scalar disease_level in [0,1].

        Currently:
          normal   -> 0.0
          preictal -> 0.5
          ictal    -> 1.0
        """
        if regime not in ("normal", "preictal", "ictal"):
            raise ValueError(f"Unknown regime '{regime}'. Use 'normal', 'preictal', or 'ictal'.")

        if regime == "normal":
            self.disease_level = 0.0
        elif regime == "preictal":
            self.disease_level = 0.5
        elif regime == "ictal":
            self.disease_level = 1.0

        self._apply_jr_params_from_disease()


    def _update_seizure_stats(self, x: float):
        self.seiz_n += 1
        if self.seiz_n == 1:
            self.seiz_mu = x
            self.seiz_var = 1.0
            return
        delta = x - self.seiz_mu
        self.seiz_mu += delta / self.seiz_n
        delta2 = x - self.seiz_mu
        # population variance estimate (good enough here)
        self.seiz_var = ((self.seiz_n - 1) * self.seiz_var + delta * delta2) / self.seiz_n


    def _seizure_severity_from_raw(self, raw: float) -> float:
        if self.seiz_update_stats:
            self._update_seizure_stats(raw)

        sigma = float(np.sqrt(max(self.seiz_var, self.seiz_eps)))
        z = (raw - float(self.seiz_mu)) / sigma

        # sigmoid; stable
        sev = 1.0 / (1.0 + np.exp(-self.seiz_sigmoid_k * z))
        return float(np.clip(sev, 0.0, 1.0))



    def _apply_jr_params_from_disease(self):
        """
        Push disease_level into JR parameters (excitability / connectivity).
        Requires EpilepticNeuralMassJR.set_disease_level().
        """
        if self.jr is None:
            return
        if hasattr(self.jr, "set_disease_level"):
            self.jr.set_disease_level(self.disease_level)
            return


    def _bandpower(self, freqs, psd, f_lo, f_hi):
        idx = (freqs >= f_lo) & (freqs <= f_hi)
        if not np.any(idx):
            return 0.0
        return float(np.trapz(psd[idx], freqs[idx]))

    def _seizure_metric_bandpower_ratio(self, x: np.ndarray) -> float:
        """
        Returns a scalar >= 0, where higher means more oscillatory power in seizure band.
        """
        x = np.asarray(x, dtype=np.float64)
        if x.size < 8:
            return 0.0

        # Demean + window to reduce leakage
        x = x - np.mean(x)
        n = x.size
        w = np.hanning(n)
        xw = x * w

        # One-sided PSD estimate (periodogram)
        X = np.fft.rfft(xw)
        freqs = np.fft.rfftfreq(n, d=self.dt)

        # Scale: not critical since we use a ratio; still keep it consistent
        psd = (np.abs(X) ** 2) / (np.sum(w ** 2) * self.fs + 1e-12)

        p_band = self._bandpower(freqs, psd, self.seizure_band[0], self.seizure_band[1])
        p_total = self._bandpower(freqs, psd, self.total_band[0], self.total_band[1])

        ratio = p_band / (p_total + 1e-12)  # 0..1-ish
        return float(ratio)

    def _seizure_metric_line_length(self, x: np.ndarray) -> float:
        x = np.asarray(x, dtype=np.float64)
        if x.size < 2:
            return 0.0
        dx = np.diff(x)
        ll = np.sum(np.abs(dx)) / (x.size * self.dt + 1e-12)
        return float(ll)

    def _normalize_seizure_raw(self, raw: float) -> float:
        if self.seizure_baseline is None or self.seizure_scale is None:
            return 0.0

        base = float(self.seizure_baseline)
        s = max(float(self.seizure_scale), 1e-3)

        z = (float(raw) - base) / s
        z = max(0.0, z)
        idx = z / (1.0 + z)
        return float(np.clip(idx, 0.0, 1.0))



    # ------------------------------------------------------------------
    # Parameter handling
    # ------------------------------------------------------------------
    def _clip_params(self, amp, freq, pw):
        amp_clipped = np.clip(amp, self.amp_min, self.amp_max)
        freq_clipped = np.clip(freq, self.freq_min, self.freq_max)
        pw_clipped = np.clip(pw, self.pw_min, self.pw_max)
        return amp_clipped, freq_clipped, pw_clipped

    def _update_params_from_action(self, action):
        """
        Continuous action mapping for SAC.

        action: np.array shape (3,), each component in [-1, 1]
          action[0] controls amplitude change (mA)
          action[1] controls frequency change (Hz)
          action[2] controls pulse-width change (µs)

        Per-step deltas are continuous:
          d_amp  = action[0] * amp_delta_max
          d_freq = action[1] * freq_delta_max
          d_pw   = action[2] * pw_delta_max
        """
        if self.current_params is None:
            raise RuntimeError("Environment not reset. Call reset() first.")

        a = np.asarray(action, dtype=float).flatten()
        if a.shape[0] != 3:
            raise ValueError("Action must have shape (3,)")

        # current params
        amp, freq, pw = self.current_params

        # continuous deltas
        d_amp  = float(a[0]) * float(self.amp_delta_max)
        d_freq = float(a[1]) * float(self.freq_delta_max)
        d_pw   = float(a[2]) * float(self.pw_delta_max)

        # update
        new_amp  = amp  + d_amp
        new_freq = freq + d_freq
        new_pw   = pw   + d_pw

        # clip to device limits
        new_amp, new_freq, new_pw = self._clip_params(new_amp, new_freq, new_pw)

        return new_amp, new_freq, new_pw, d_amp, d_freq, d_pw


    # ------------------------------------------------------------------
    # DBS waveform
    # ------------------------------------------------------------------
    def _make_dbs_fun(self, amp, freq, pw_us, duty=1.0, duty_period_s=1.0):
        pw_s = float(pw_us) * 1e-6
        f = float(freq)
        amp = float(amp)

        if amp == 0.0 or f <= 0.0 or pw_s <= 0.0:
            return lambda t: 0.0

        T = 1.0 / f

        duty = float(np.clip(duty, 0.0, 1.0))
        on_s = duty * float(duty_period_s)

        def u(t):
            t = float(t)
            # duty gate
            if duty < 1.0:
                if (t % duty_period_s) > on_s:
                    return 0.0
            # pulse train
            return amp if (t % T) <= pw_s else 0.0

        return u


    def _simulate_window(self, amp, freq, pw):
        """
        Simulate one RL decision window.

        Uses an effective (time-averaged) DBS drive so that pw_us < dt does not cause
        pulses to be numerically missed. PAC is disabled (not computed).
        """
        # ---- Effective DBS drive (average of pulse train) ----
        # pw is in microseconds
        amp = float(amp)
        freq = float(freq)
        pw_us = float(pw)

        pw_s = pw_us * 1e-6
        # Average of rectangular pulse train: amp * (pw/T) = amp * pw_s * freq
        u_avg = amp * pw_s * freq
        u_avg *= 50.0


        # logging lists (PAC removed)
        t_list, v_list, u_list = [], [], []      # downsampled, for logging/plots
        v_metric = []                             # full-rate, for seizure metric

        if self.jr_state is None:
            self.jr_state = np.zeros(6, dtype=float)

        MAX_STATE = 50.0      # hard clamp for all states
        MAX_INPUT = 10.0      # clamp on DBS drive

        for k in range(self.steps_per_step):
            t = self.t + k * self.dt

            # DBS drive (effective)
            u_t = float(np.clip(u_avg, -MAX_INPUT, MAX_INPUT))

            # JR update
            dy = self.jr._deriv(self.jr_state, t, u_t)
            self.jr_state = self.jr_state + self.dt * dy
            self.jr_state = np.clip(self.jr_state, -MAX_STATE, MAX_STATE)

            v = float(self.jr_state[0])
            v_metric.append(v)


            # downsample logging
            if k % 10 == 0:
                t_list.append(t)
                v_list.append(v)
                u_list.append(u_t)

        self.t += self.step_T

        t_window = np.asarray(t_list)
        v_window = np.asarray(v_list)
        u_window = np.asarray(u_list)
        v_metric = np.asarray(v_metric, dtype=np.float64)




        # ------------------------------
        # Seizure metric (RAW + normalized index) computed on FULL-RATE signal
        # ------------------------------
        if not np.all(np.isfinite(v_metric)):
            seizure_index = 1.0
            energy_norm = 1.0
            self._last_seizure_metric_raw = float("nan")
            return (
                float(seiz_raw),          # <-- RAW FIRST
                float(energy_norm),
                t_window,
                v_window,
                u_window,
            )


        if self.seizure_metric == "bandpower_ratio":
            seiz_raw = self._seizure_metric_bandpower_ratio(v_metric)
        elif self.seizure_metric == "line_length":
            seiz_raw = self._seizure_metric_line_length(v_metric)
        else:
            raise ValueError(f"Unknown seizure_metric: {self.seizure_metric}")

        # keep raw metric for logging
        self._last_seizure_metric_raw = float(seiz_raw)


        # store for info/debug/calibration
        self._last_seizure_metric_raw = float(seiz_raw)


        # stimulation "energy" proxy (keep your original, but ensure types)
        energy = (amp ** 2) * freq * (pw_us / max(self.pw_max, 1e-6))
        energy_norm = float(np.clip(energy / max(self._energy_ref, 1e-6), 0.0, 1.0))

        return (float(seiz_raw), float(energy_norm), t_window, v_window, u_window)


    # ------------------------------------------------------------------
    # Episode logging helpers
    # ------------------------------------------------------------------
    def _start_new_episode_log(self):
        if not self.log_best_episodes:
            self.current_episode_log = None
            self.episode_reward = 0.0
            return

        self.current_episode_log = {
            "t": [],
            "v": [],
            "u": [],
            "theta": [],
            "amp_fast": [],
            "seizure_index": [],
            "pac_index": [],
            "energy_norm": [],
            "amp": [],
            "freq": [],
            "pw": [],
            "rewards": [],
            "total_reward": 0.0,
        }
        self.episode_reward = 0.0

    def _append_window_to_log(
        self,
        t_window,
        v_window,
        u_window,
        theta_window,
        amp_fast_window,
        seizure_index,
        pac_index,
        energy_norm,
        amp,
        freq,
        pw,
        reward,
    ):
        if self.current_episode_log is None:
            return

        log = self.current_episode_log
        log["t"].append(t_window)
        log["v"].append(v_window)
        log["u"].append(u_window)
        log["theta"].append(theta_window)
        log["amp_fast"].append(amp_fast_window)

        log["seizure_index"].append(float(seizure_index))
        log["pac_index"].append(float(pac_index))
        log["energy_norm"].append(float(energy_norm))

        log["amp"].append(float(amp))
        log["freq"].append(float(freq))
        log["pw"].append(float(pw))

        log["rewards"].append(float(reward))

    def _finalise_episode_log(self):
        if self.current_episode_log is None or not self.log_best_episodes:
            return

        log = self.current_episode_log

        # concatenate per-window time series
        for key in ["t", "v", "u", "theta", "amp_fast"]:
            if len(log[key]) > 0:
                log[key] = np.concatenate(log[key])
            else:
                log[key] = np.array([], dtype=float)

        # convert per-step scalars to arrays
        for key in ["seizure_index", "pac_index", "energy_norm",
                    "amp", "freq", "pw", "rewards"]:
            log[key] = np.asarray(log[key], dtype=float)

        log["total_reward"] = float(self.episode_reward)

        # insert into best episode list
        self.best_episodes.append(log)
        self.best_episodes.sort(key=lambda d: d["total_reward"], reverse=True)
        self.best_episodes = self.best_episodes[: self.n_best_episodes]

        # clear current log
        self.current_episode_log = None
        self.episode_reward = 0.0

    # ------------------------------------------------------------------
    # Gymnasium API: reset / step
    # ------------------------------------------------------------------
    def _build_obs(self, seizure_index, params):
        amp, freq, pw = params
        amp_norm = (amp - self.amp_min) / (self.amp_max - self.amp_min)
        freq_norm = (freq - self.freq_min) / (self.freq_max - self.freq_min)
        pw_norm = (pw - self.pw_min) / (self.pw_max - self.pw_min)
        return np.array(
            [
                seizure_index,
                amp_norm,
                freq_norm,
                pw_norm,
                float(self.disease_level),  # NEW: make it explicit & normalised
            ],
            dtype=np.float32,
        )

    def _regime_from_burden(self, burden: float) -> str:
        if burden < self.burden_thr_normal:
            return "normal"
        elif burden < self.burden_thr_ictal:
            return "preictal"
        else:
            return "ictal"

    def _stim_effect(self, amp: float, freq: float, pw: float) -> float:
        """
        Map DBS params to a scalar control signal in [0,1] via a saturating nonlinearity.
        """
        # Scale each param into ~[0,1] using action bounds you already have
        amp_n  = (amp  - self.amp_min)  / max(self.amp_max  - self.amp_min, 1e-6)
        freq_n = (freq - self.freq_min) / max(self.freq_max - self.freq_min, 1e-6)
        pw_n   = (pw   - self.pw_min)   / max(self.pw_max   - self.pw_min, 1e-6)

        amp_n  = float(np.clip(amp_n,  0.0, 1.0))
        freq_n = float(np.clip(freq_n, 0.0, 1.0))
        pw_n   = float(np.clip(pw_n,   0.0, 1.0))

        # Weighted sum then sigmoid saturation
        x = self.stim_k_amp * amp_n + self.stim_k_freq * freq_n + self.stim_k_pw * pw_n
        # squash into (0,1)
        eff = 1.0 / (1.0 + np.exp(-(x - 1.0) / 0.5))
        return float(np.clip(eff, 0.0, 1.0))


    def _update_burden(self, amp: float, freq: float, pw: float) -> dict:
        """
        Update latent burden dynamics and return diagnostics.
        """
        # Determine drift based on current regime (derived from current burden)
        reg = self._regime_from_burden(self.burden)
        if reg == "normal":
            drift = self.burden_drift_normal
        elif reg == "preictal":
            drift = self.burden_drift_preictal
        else:
            drift = self.burden_drift_ictal

        eff = self._stim_effect(amp, freq, pw)

        # Use last window’s raw metric (set in _simulate_window) to couple physiology -> burden
        raw = float(getattr(self, "_last_seizure_metric_raw", np.nan))
        if np.isfinite(raw):
            # map raw to [0,1]
            phys = (raw - self.phys_raw_lo) / max(self.phys_raw_hi - self.phys_raw_lo, 1e-6)
            phys = float(np.clip(phys, 0.0, 1.0))

            # smooth it (EMA)
            alpha = 1.0 / max(float(self.burden_phys_tau), 1.0)
            self._phys_ema = (1.0 - alpha) * float(getattr(self, "_phys_ema", 0.0)) + alpha * phys

            # coupling term (push burden upward when phys_ema high, downward when low)
            # centered at 0.5 so it can go +/-.
            phys_drive = self.burden_phys_gain * (self._phys_ema - 0.5)
        else:
            phys = float("nan")
            phys_drive = 0.0


        relief = self.burden_relief_gain * eff
        relief = float(min(relief, self.burden_relief_cap))

        noise = float(self.np_random.normal(0.0, self.burden_noise_std))

        b_next = self.burden + drift + phys_drive - relief + noise
        b_next = float(np.clip(b_next, 0.0, 1.0))

        diag = {
            "burden_prev": float(self.burden),
            "burden_next": float(b_next),
            "burden_drift": float(drift),
            "burden_relief": float(relief),
            "stim_eff": float(eff),
            "burden_noise": float(noise),
            "regime_prev": reg,
            "regime_next": self._regime_from_burden(b_next),
            "phys_raw": raw,
            "phys_mapped": phys,
            "phys_ema": float(getattr(self, "_phys_ema", 0.0)),
            "phys_drive": float(phys_drive),
        }

        self.burden = b_next
        return diag


    def reset(self, *, seed=None, options=None):
        """
        Reset environment and start a new episode.

        Returns
        -------
        obs : np.array, shape (5,)
        info : dict
        """
        super().reset(seed=seed)

        self.t = 0.0
        self.step_count = 0

        # --- choose regime / disease state ----------------------------
        if options is not None and "regime" in options:
            regime = options["regime"]
        else:
            regime = self.default_regime

        self._set_disease_from_regime(regime)
        # ---------------------------------------------------------------

        # choose initial burden consistent with regime
        if regime == "normal":
            self.burden = float(np.clip(self.np_random.normal(0.12, 0.04), 0.0, 1.0))
        elif regime == "preictal":
            self.burden = float(np.clip(self.np_random.normal(0.33, 0.06), 0.0, 1.0))
        else:  # ictal
            self.burden = float(np.clip(self.np_random.normal(0.70, 0.08), 0.0, 1.0))


        # reset underlying models
        self.jr_state = np.zeros(6, dtype=float)
        self.pac_state = np.array(
            [0.0,
            0.01 * self.rng.standard_normal(),
            0.01 * self.rng.standard_normal()],
            dtype=float,
        )

        # start from typical DBS settings
        init_amp = 2.0     # mA
        init_freq = 130.0  # Hz
        init_pw = 120.0    # µs

        init_amp, init_freq, init_pw = self._clip_params(init_amp, init_freq, init_pw)
        self.current_params = (init_amp, init_freq, init_pw)
        self.prev_params = self.current_params

        # new episode log
        self._start_new_episode_log()

        # optional warm-up (no DBS) + baseline calibration
        warmup_T = 1.0
        warmup_steps = int(warmup_T / self.step_T)

        if self.seizure_baseline is None:
            # leave None; normalization will fallback conservatively
            pass

        # baseline = typical no-stim seizure index for THIS episode
        # If user has externally calibrated a baseline, do not overwrite it
        #if self.seizure_baseline is None:
         #   self.seizure_baseline = float(np.mean(warm_seiz)) if len(warm_seiz) else 0.7


        # ---------------------------------------------------------
        # initial observation at current params
        # ---------------------------------------------------------
        seizure_index, energy_norm, t_w, v_w, u_w = self._simulate_window(*self.current_params)
        pac_index = 0.0
        theta_w = np.zeros_like(t_w, dtype=float)
        amp_fast_w = np.zeros_like(t_w, dtype=float)

        # PAC disabled: provide zero arrays for logging compatibility
        theta_w = np.zeros_like(t_w, dtype=float)
        amp_fast_w = np.zeros_like(t_w, dtype=float)
        pac_index = 0.0

        self.prev_seizure_index = float(seizure_index)

        # log this initial window with zero reward
        self._append_window_to_log(
            t_w,
            v_w,
            u_w,
            theta_w,
            amp_fast_w,
            seizure_index,
            pac_index,
            energy_norm,
            amp=self.current_params[0],
            freq=self.current_params[1],
            pw=self.current_params[2],
            reward=0.0,
        )


        obs = self._build_obs(seizure_index, self.current_params)
        info = {
            "regime": regime,
            "disease_level": self.disease_level,
            "burden": self.burden,
            "regime_labe;": self._regime_from_burden(self.burden),
        }
        return obs, info



    def step(self, action):
      """
      Single RL step.

      Returns
      -------
      obs : np.array, shape (5,)
           [seizure_index, amp_norm, freq_norm, pw_norm, disease_level]

      reward : float
      terminated : bool
      truncated : bool
      info : dict
      """
      action = np.asarray(action, dtype=float).flatten()
      if action.shape[0] != 3:
          raise ValueError("Action must have shape (3,)")

      self.step_count += 1

      # ---------------------------------------------------------
      # 1) Update DBS parameters from action
      # ---------------------------------------------------------
      new_amp, new_freq, new_pw, d_amp, d_freq, d_pw = self._update_params_from_action(action)
      self.prev_params = self.current_params
      self.current_params = (new_amp, new_freq, new_pw)

      burden_diag = {}
      if getattr(self, "use_burden_state", False):
          burden_diag = self._update_burden(new_amp, new_freq, new_pw)



      # ---------------------------------------------------------
      # 2) Simulate one window of JR + PAC dynamics
      # ---------------------------------------------------------
      # 2) Simulate one window of JR + PAC dynamics
      raw_seiz, energy_norm, t_w, v_w, u_w = self._simulate_window(new_amp, new_freq, new_pw)

      # ---- Option A: z-score + sigmoid mapping (INSERT HERE) ----
      seizure_index = self._seizure_severity_from_raw(raw_seiz)

      # PAC disabled: provide zero arrays for logging compatibility
      theta_w = np.zeros_like(t_w, dtype=float)
      amp_fast_w = np.zeros_like(t_w, dtype=float)
      pac_index = 0.0

      # ---------------------------------------------------------
      # 3) Build observation (includes disease_level as last entry)
      # ---------------------------------------------------------
      obs = self._build_obs(seizure_index, self.current_params)

      # ---------------------------------------------------------
      # 4) Parameter slew penalty
      # ---------------------------------------------------------
      slew_amp = abs(d_amp) / max(self.amp_delta_max, 1e-6)
      slew_freq = abs(d_freq) / max(self.freq_delta_max, 1e-6)
      slew_pw = abs(d_pw) / max(self.pw_delta_max, 1e-6)
      slew_penalty = (slew_amp + slew_freq + slew_pw) / 3.0

      # ---------------------------------------------------------
      # 5) Seizure change (delta_seiz > 0 means improvement)
      # ---------------------------------------------------------
      prev_seiz = self.prev_seizure_index
      if prev_seiz is None:
          delta_seiz = 0.0
      else:
          delta_seiz = max(0.0, prev_seiz - seizure_index)

      # ---------------------------------------------------------
      # 6) Disease dynamics (long-term plasticity)
      #     - Disease increases if seizure_index stays above baseline
      #     - Disease decreases if seizures improve
      # ---------------------------------------------------------
      # Use the same reference as your reward gating (target), not baseline


      # Apply JR parameters based on updated disease_level
      self._apply_jr_params_from_disease()

      # now update prev for next step
      self.prev_seizure_index = float(seizure_index)

      # disease increases when seizures are high, decreases slowly when low
      idx = float(seizure_index)  # burden-based seizure_index
      target = float(getattr(self, "seizure_target", 0.2))

      up = self.disease_up_rate * max(0.0, idx - target)
      down = self.disease_down_rate * max(0.0, target - idx)

      self.disease_level = float(np.clip(
          self.disease_level + up - down,
          self.disease_floor,
          1.0
))

      # ---------------------------------------------------------
      # 7) Reward components
      #     - Primary: suppress seizures
      #     - Secondary: energy, slew, disease
      # ---------------------------------------------------------
      target = float(getattr(self, "seizure_target", 0.05))

      # seizure_index is ALREADY returned by _simulate_window()
      seiz_cost = float(seizure_index)  # dense, always informative

      cost = (
          self.w_seiz   * seiz_cost +
          self.w_energy * float(energy_norm) +
          self.w_slew   * float(slew_penalty) +
          self.w_disease * float(self.disease_level)
      )
      reward = -cost + self.w_delta_seizure * float(delta_seiz)



      # ---------------------------------------------------------
      # 8) Logging
      # ---------------------------------------------------------
      self.episode_reward += reward
      self._append_window_to_log(
          t_w,
          v_w,
          u_w,
          theta_w,
          amp_fast_w,
          seizure_index,
          pac_index,      # <-- must be present (0.0 if PAC disabled)
          energy_norm,
          new_amp,
          new_freq,
          new_pw,
          reward,         # <-- must be last
      )


      # ---------------------------------------------------------
      # 9) Termination / truncation
      # ---------------------------------------------------------
      terminated = False
      truncated = self.step_count >= self.max_steps

      if truncated or terminated:
          self._finalise_episode_log()

      # Add these just before building info
      base = float(self.seizure_baseline) if self.seizure_baseline is not None else np.nan
      sc   = float(self.seizure_scale) if self.seizure_scale is not None else np.nan
      raw  = float(getattr(self, "_last_seizure_metric_raw", np.nan))
      z    = (raw - base) / sc if np.isfinite(raw) and np.isfinite(base) and np.isfinite(sc) and sc > 0 else np.nan
      target = float(getattr(self, "seizure_target", 0.05))


      # ---------------------------------------------------------
      # 10) info dict (no assignments inside the literal)
      # ---------------------------------------------------------
      info = {
          "seizure_index": float(seizure_index),
          "seizure_metric_raw": float(getattr(self, "_last_seizure_metric_raw", np.nan)),
          "seizure_error_above_baseline": float(seiz_error),
          "seizure_target": target,
          "seizure_baseline": base,
          "seizure_scale": sc,
          "burden": self.burden,
          "regime_label": self._regime_from_burden(self.burden),
          "seizure_z": float(z),
          "pac_index": float(pac_index),
          "energy_norm": float(energy_norm),
          "amp": float(new_amp),
          "freq": float(new_freq),
          "pw": float(new_pw),
          "slew_penalty": float(slew_penalty),
          "disease_level": float(self.disease_level),
          "reward": float(reward),
          "cost_components": {
              "seizure": float(self.w_seiz * seiz_cost),
              "energy": float(self.w_energy * energy_norm),
              "slew": float(self.w_slew * slew_penalty),
              "disease": float(self.w_disease * self.disease_level),
              "delta_seiz_bonus": float(self.w_delta_seizure * delta_seiz),
          },
      }

      return obs, reward, terminated, truncated, info

    def calibrate_seizure_scale(self, n_windows: int = 300, seed: int = 0, regime: str = "normal"):
        # Ensure we are in the requested regime
        self.reset(seed=seed, options={"regime": regime})

        raw_vals = []
        zero_action = np.array([0.0, 0.0, 0.0], dtype=np.float32)

        for _ in range(n_windows):
            _, _, terminated, truncated, info = self.step(zero_action)
            raw_vals.append(float(info["seizure_metric_raw"]))
            if terminated or truncated:
                self.reset(seed=seed, options={"regime": regime})

        raw = np.asarray(raw_vals, dtype=float)
        raw = raw[np.isfinite(raw)]
        if raw.size < 10:
            raise RuntimeError(f"Not enough finite raw seizure metric samples: {raw.size}")

        p50 = float(np.percentile(raw, 50))
        p95 = float(np.percentile(raw, 95))
        eps = 1e-3

        self.seizure_baseline = p50
        self.seizure_scale = float(max(p95 - p50, eps))

        return {
            "regime": regime,
            "n_windows": int(raw.size),
            "raw_p50": p50,
            "raw_p95": p95,
            "seizure_baseline_set_to": float(self.seizure_baseline),
            "seizure_scale_set_to": float(self.seizure_scale),
        }

    # ------------------------------------------------------------------
    # Convenience
    # ------------------------------------------------------------------
    def get_best_episodes(self):
        """
        Return a shallow copy of the list of best episodes.
        """
        return list(self.best_episodes)


In [None]:
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_sizes=(256, 256), activation=nn.ReLU):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden_sizes:
            layers.append(nn.Linear(last, h))
            layers.append(activation())
            last = h
        layers.append(nn.Linear(last, out_dim))
        self.net = nn.Sequential(*layers)
        self.prev_seizure_index = None


    def forward(self, x):
        return self.net(x)


In [None]:
class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, act_low, act_high, hidden_sizes=(256, 256), log_std_min=-20, log_std_max=2):
        super().__init__()
        self.base = MLP(obs_dim, 2 * act_dim, hidden_sizes)
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        self.register_buffer("act_low", torch.tensor(act_low, dtype=torch.float32))
        self.register_buffer("act_high", torch.tensor(act_high, dtype=torch.float32))

    def forward(self, obs):
        mu_logstd = self.base(obs)
        mu, log_std = torch.chunk(mu_logstd, 2, dim=-1)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        std = log_std.exp()
        dist = Normal(mu, std)
        return dist

    def sample(self, obs):
        dist = self.forward(obs)
        x_t = dist.rsample()
        log_prob = dist.log_prob(x_t).sum(-1, keepdim=True)
        # squash via tanh
        y_t = torch.tanh(x_t)
        # correction term
        log_prob -= torch.sum(torch.log(1 - y_t.pow(2) + 1e-6), dim=-1, keepdim=True)
        # rescale to action bounds
        act_mid = (self.act_high + self.act_low) / 2.0
        act_half = (self.act_high - self.act_low) / 2.0
        action = act_mid + act_half * y_t
        return action, log_prob


In [None]:
class QNetwork(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes=(256, 256)):
        super().__init__()
        self.q = MLP(obs_dim + act_dim, 1, hidden_sizes)

    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=-1)
        return self.q(x)


In [None]:
Transition = collections.namedtuple("Transition", ["obs", "act", "rew", "next_obs", "done"])

class ReplayBuffer:
    def __init__(self, capacity, obs_dim, act_dim):
        self.capacity = capacity
        self.obs_buf = np.zeros((capacity, obs_dim), dtype=np.float32)
        self.next_obs_buf = np.zeros((capacity, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros((capacity, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros((capacity, 1), dtype=np.float32)
        self.done_buf = np.zeros((capacity, 1), dtype=np.float32)
        self.ptr = 0
        self.size = 0

    def add(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.ptr] = obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.next_obs_buf[self.ptr] = next_obs
        self.done_buf[self.ptr] = done

        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return (
            torch.tensor(self.obs_buf[idxs], dtype=torch.float32, device=DEVICE),
            torch.tensor(self.act_buf[idxs], dtype=torch.float32, device=DEVICE),
            torch.tensor(self.rew_buf[idxs], dtype=torch.float32, device=DEVICE),
            torch.tensor(self.next_obs_buf[idxs], dtype=torch.float32, device=DEVICE),
            torch.tensor(self.done_buf[idxs], dtype=torch.float32, device=DEVICE),
        )


In [None]:
class SACAgent:
    def __init__(self, obs_dim, act_dim, act_low, act_high,
                 gamma=0.99, tau=0.005, alpha=0.2,
                 actor_lr=1e-4 , critic_lr=3e-4, target_entropy=None):
        self.gamma = gamma
        self.tau = tau

        self.actor = GaussianPolicy(obs_dim, act_dim, act_low, act_high).to(DEVICE)
        self.q1 = QNetwork(obs_dim, act_dim).to(DEVICE)
        self.q2 = QNetwork(obs_dim, act_dim).to(DEVICE)
        self.q1_target = QNetwork(obs_dim, act_dim).to(DEVICE)
        self.q2_target = QNetwork(obs_dim, act_dim).to(DEVICE)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        self.actor_opt = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.q1_opt = optim.Adam(self.q1.parameters(), lr=critic_lr)
        self.q2_opt = optim.Adam(self.q2.parameters(), lr=critic_lr)

        self.log_alpha = torch.tensor(np.log(alpha), device=DEVICE, requires_grad=True)
        self.alpha_opt = optim.Adam([self.log_alpha], lr=actor_lr)
        self.target_entropy = target_entropy if target_entropy is not None else -act_dim

        self.critic_losses = []
        self.actor_losses = []
        self.alpha_hist = []
        self.entropy_hist = []


    @property
    def alpha(self):
        return self.log_alpha.exp()

    def select_action(self, obs, eval_mode=False):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        if eval_mode:
            dist = self.actor(obs_t)
            mu = dist.mean
            y_t = torch.tanh(mu)
            act_mid = (self.actor.act_high + self.actor.act_low) / 2.0
            act_half = (self.actor.act_high - self.actor.act_low) / 2.0
            action = act_mid + act_half * y_t
            return action.detach().cpu().numpy()[0]
        else:
            action, _ = self.actor.sample(obs_t)
            return action.detach().cpu().numpy()[0]

    def update(self, replay_buffer, batch_size):
        obs, act, rew, next_obs, done = replay_buffer.sample(batch_size)

        with torch.no_grad():
            next_action, next_log_prob = self.actor.sample(next_obs)
            q1_next = self.q1_target(next_obs, next_action)
            q2_next = self.q2_target(next_obs, next_action)
            q_next = torch.min(q1_next, q2_next) - self.alpha * next_log_prob
            target_q = rew + (1 - done) * self.gamma * q_next

        # Q1, Q2 losses
        q1_pred = self.q1(obs, act)
        q2_pred = self.q2(obs, act)
        q1_loss = ((q1_pred - target_q)**2).mean()
        q2_loss = ((q2_pred - target_q)**2).mean()

        self.q1_opt.zero_grad()
        q1_loss.backward()
        self.q1_opt.step()

        self.q2_opt.zero_grad()
        q2_loss.backward()
        self.q2_opt.step()

        # Actor + alpha
        new_actions, log_prob = self.actor.sample(obs)
        q1_new = self.q1(obs, new_actions)
        q2_new = self.q2(obs, new_actions)
        q_new = torch.min(q1_new, q2_new)

        actor_loss = (self.alpha * log_prob - q_new).mean()

        self.actor_opt.zero_grad()
        actor_loss.backward()
        self.actor_opt.step()

        alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
        self.alpha_opt.zero_grad()
        alpha_loss.backward()
        self.alpha_opt.step()

        # After computing q1_loss, q2_loss, actor_loss, alpha_loss, log_prob etc.
        self.critic_losses.append(float((q1_loss + q2_loss).item()))
        self.actor_losses.append(float(actor_loss.item()))
        self.alpha_hist.append(float(self.alpha.item()))

        # policy entropy estimate: -E[log pi(a|s)]
        self.entropy_hist.append(float((-log_prob).mean().item()))


        # Soft target updates
        with torch.no_grad():
            for param, target_param in zip(self.q1.parameters(), self.q1_target.parameters()):
                target_param.data.mul_(1 - self.tau)
                target_param.data.add_(self.tau * param.data)
            for param, target_param in zip(self.q2.parameters(), self.q2_target.parameters()):
                target_param.data.mul_(1 - self.tau)
                target_param.data.add_(self.tau * param.data)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import welch

def compute_psd(signal, fs, nperseg=None):
    freqs, psd = welch(
        signal,
        fs=fs,
        nperseg=nperseg or min(1024, len(signal)),
        scaling="density"
    )
    return freqs, psd


def collect_regime_psds(
    env,
    regime,
    n_windows=50,
    fs=1000.0,
    action=None
):
    psds = []
    freqs_ref = None

    env.reset(options={"regime": regime})

    for _ in range(n_windows):
        obs, reward, terminated, truncated, info = env.step(
            action if action is not None else np.zeros(env.action_space.shape)
        )

        # EXPECTED: info["lfp"]
        lfp = info.get("lfp", None)
        if lfp is None:
            raise RuntimeError("LFP not found in info; adjust extraction.")

        freqs, psd = compute_psd(lfp, fs)
        freqs_ref = freqs if freqs_ref is None else freqs_ref
        psds.append(psd)

        if terminated or truncated:
            env.reset(options={"regime": regime})

    return freqs_ref, np.vstack(psds)


In [None]:
env2.reset(options={"regime": "normal"})
obs, reward, terminated, truncated, info = env2.step(
    np.zeros(env2.action_space.shape)
)

print("info keys:", info.keys())
print("obs type:", type(obs))
if isinstance(obs, dict):
    print("obs keys:", obs.keys())


info keys: dict_keys(['seizure_index', 'seizure_metric_raw', 'seizure_error_above_baseline', 'seizure_target', 'seizure_baseline', 'seizure_scale', 'burden', 'regime_label', 'seizure_z', 'pac_index', 'energy_norm', 'amp', 'freq', 'pw', 'slew_penalty', 'disease_level', 'reward', 'cost_components'])
obs type: <class 'numpy.ndarray'>


  return float(np.trapz(psd[idx], freqs[idx]))


In [None]:
import inspect
print(inspect.getsource(EpilepsyDBSCombinedEnv._normalize_seizure_raw))


    def _normalize_seizure_raw(self, raw: float) -> float:
        if self.seizure_baseline is None or self.seizure_scale is None:
            return 0.0

        base = float(self.seizure_baseline)
        s = max(float(self.seizure_scale), 1e-3)

        z = (float(raw) - base) / s
        z = max(0.0, z)
        idx = z / (1.0 + z)
        return float(np.clip(idx, 0.0, 1.0))



In [None]:
import numpy as np

# 1) Create env
env = EpilepsyDBSCombinedEnv()

episode_summaries = []


# Seizure metric config
env.seizure_metric = "bandpower_ratio"   # or "line_length"
env.seizure_band = (8.0, 30.0)           # only used if bandpower_ratio
env.total_band = (1.0, 80.0)

cal = env.calibrate_seizure_scale(n_windows=500, regime="normal", seed=0)
print(cal)

# 2) Get dimensions from Gymnasium spaces
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
act_low = env.action_space.low
act_high = env.action_space.high

# 3) Create SAC agent
agent = SACAgent(obs_dim, act_dim, act_low, act_high)

# 4) Create replay buffer
buffer_capacity = 100_000
replay_buffer = ReplayBuffer(buffer_capacity, obs_dim, act_dim)

# 5) Training hyperparameters
num_episodes = 400
max_steps_per_episode = env.max_steps
batch_size = 128
warmup_steps = 5000

episode_returns = []
global_step = 0

for episode in range(num_episodes):
    # Choose regime per episode

    if episode < 30:
        env.w_seiz, env.w_energy, env.w_slew = 1.0, 0.0, 0.0
    elif episode < 70:
        env.w_seiz, env.w_energy, env.w_slew = 1.0, 0.1, 0.05
    else:
        env.w_seiz, env.w_energy, env.w_slew = 1.0, 0.3, 0.1
    regime = "ictal"

    obs, info = env.reset(options={"regime": regime})
    done = False
    ep_return = 0.0
    steps = 0

    # ----- accumulators for printed diagnostics -----
    seizure_cost_sum = 0.0
    pac_err_sum = 0.0  # stays 0 unless you add it
    energy_cost_sum = 0.0
    slew_cost_sum = 0.0

    amp_sum = 0.0
    freq_sum = 0.0
    pw_sum = 0.0

    # NEW: seizure index diagnostics (raw index, not cost)
    seiz_idx_sum = 0.0
    seiz_idx_max = 0.0

    while (not done) and (steps < max_steps_per_episode):
        # 1) Action selection
        if global_step < warmup_steps:
            action = env.action_space.sample()
        else:
            action = agent.select_action(obs)

        # 2) Step environment
        next_obs, reward, terminated, truncated, info = env.step(action)
        done = bool(terminated or truncated)

        # 3) Replay buffer guard + add
        if (
            np.all(np.isfinite(obs)) and
            np.all(np.isfinite(action)) and
            np.isfinite(reward) and
            np.all(np.isfinite(next_obs))
        ):
            replay_buffer.add(
                obs,
                action,
                np.array([reward], dtype=np.float32),
                next_obs,
                np.array([float(done)], dtype=np.float32),
            )

            if (global_step >= warmup_steps) and (replay_buffer.size >= batch_size):
                agent.update(replay_buffer, batch_size)
        else:
            print(f"[Warning] Non-finite transition at step {steps}, terminating episode.")
            done = True

        # 4) Accumulate metrics (per step)
        # Cost components in your env are already weighted
        seizure_cost_sum += float(info["cost_components"]["seizure"])
        energy_cost_sum += float(info["cost_components"]["energy"])
        slew_cost_sum += float(info["cost_components"]["slew"])
        # pac_err_sum += float(info.get("pac_error", 0.0))  # only if you ever add it

        amp_sum += float(info["amp"])
        freq_sum += float(info["freq"])
        pw_sum += float(info["pw"])

        seiz_idx = float(info["seizure_index"])
        seiz_idx_sum += seiz_idx
        if seiz_idx > seiz_idx_max:
            seiz_idx_max = seiz_idx

        # 5) Bookkeeping
        obs = next_obs
        ep_return += float(reward)
        steps += 1
        global_step += 1

    # ----- episode averages -----
    if steps > 0:
        avg_seizure_cost = seizure_cost_sum / steps
        avg_pac_err = pac_err_sum / steps
        avg_energy_cost = energy_cost_sum / steps
        avg_slew_cost = slew_cost_sum / steps

        avg_amp = amp_sum / steps
        avg_freq = freq_sum / steps
        avg_pw = pw_sum / steps

        mean_seiz_idx = seiz_idx_sum / steps
        max_seiz_idx = seiz_idx_max
    else:
        avg_seizure_cost = avg_pac_err = avg_energy_cost = avg_slew_cost = 0.0
        avg_amp = avg_freq = avg_pw = 0.0
        mean_seiz_idx = max_seiz_idx = 0.0

    # SECTION 5: If ictal episodes still show ~0 seizure penalty AND index is below target, scale is likely too large.
    if (regime == "ictal") and (steps > 0):
        target = float(getattr(env, "seizure_target", 0.05))
        if (avg_seizure_cost < 1e-6) and (max_seiz_idx < target + 1e-3):
            env.seizure_scale = max(env.seizure_scale * 0.25, 1e-12)
            print(f"[Adjust] ictal seizure penalty ~0; reducing env.seizure_scale to {env.seizure_scale:.3e}")

    episode_returns.append(ep_return)

    episode_summaries.append({
    "episode": episode + 1,
    "regime": regime,
    "return": float(ep_return),
    "steps": int(steps),
    "avg_seizure_cost": float(avg_seizure_cost),
    "avg_energy_cost": float(avg_energy_cost),
    "avg_slew_cost": float(avg_slew_cost),
    "mean_seiz_idx": float(mean_seiz_idx),
    "max_seiz_idx": float(max_seiz_idx),
    "avg_amp": float(avg_amp),
    "avg_freq": float(avg_freq),
    "avg_pw": float(avg_pw),
})


    # ----- prints -----
    print(f"\nEpisode {episode+1}/{num_episodes}  (regime={regime})")
    print(f"Return: {ep_return:.3f} over {steps} steps\n")

    print("=== Average Cost Components (weighted) ===")
    print(f"  Seizure term         : {avg_seizure_cost:.4f}")
    print(f"  PAC error term       : {avg_pac_err:.4f}")
    print(f"  Energy term          : {avg_energy_cost:.4f}")
    print(f"  Slew term            : {avg_slew_cost:.4f}")

    print("\n=== Seizure Index Diagnostics (raw index) ===")
    print(f"  Mean seizure_index   : {mean_seiz_idx:.4f}")
    print(f"  Max seizure_index    : {max_seiz_idx:.4f}")
    print(f"  seizure_target       : {float(getattr(env, 'seizure_target', 0.05)):.4f}")

    print("\n=== Average DBS Parameters ===")
    print(f"  Mean amplitude (mA)  : {avg_amp:.3f}")
    print(f"  Mean frequency (Hz)  : {avg_freq:.3f}")
    print(f"  Mean pulse width (µs): {avg_pw:.3f}")
    print("\n" + "-"*60 + "\n")


  return float(np.trapz(psd[idx], freqs[idx]))


NameError: name 'seiz_error' is not defined

In [None]:
env = EpilepsyDBSCombinedEnv()
env.seizure_metric = "bandpower_ratio"
env.seizure_band = (8.0, 30.0)
env.total_band = (1.0, 80.0)

# 1) Reset into the regime you want calibrated
env.reset(seed=0, options={"regime": "normal"})

# 2) Calibrate WITHOUT regime kwarg
cal = env.calibrate_seizure_scale(n_windows=50, seed=0)
print("cal:", cal)

a0 = np.array([0.0, 0.0, 0.0], dtype=np.float32)

raws = []
for _ in range(20):
    obs, r, term, trunc, info = env.step(a0)
    raws.append(info["seizure_metric_raw"])

print("raw min/max:", float(np.min(raws)), float(np.max(raws)))



In [None]:
print("seizure_metric:", getattr(env, "seizure_metric", None))
print("seizure_baseline:", getattr(env, "seizure_baseline", None))
print("seizure_scale:", getattr(env, "seizure_scale", None))

def debug_norm_once(regime="normal", steps=10):
    obs, info = env.reset(options={"regime": regime})
    for i in range(steps):
        obs, r, term, trunc, info = env.step(np.zeros(3, dtype=np.float32))
        raw = info.get("seizure_metric_raw", None)
        idx = info.get("seizure_index", None)
        print(f"{regime:8s} step {i:02d}: raw={raw:.6f}  idx={idx:.6f}")
        if term or trunc:
            obs, info = env.reset(options={"regime": regime})

debug_norm_once("normal", steps=10)
debug_norm_once("ictal", steps=10)


In [None]:
cal = env.calibrate_seizure_scale(n_windows=300, regime="normal", seed=0)
print("CAL:", cal)
print("baseline:", getattr(env, "seizure_baseline", None))
print("scale   :", getattr(env, "seizure_scale", None))

# quick probe of raw range under the same regime, zero action
obs, info = env.reset(seed=0, options={"regime": "normal"})
raws, idxs = [], []
for _ in range(50):
    obs, r, term, trunc, info = env.step(np.zeros(3, dtype=np.float32))
    raws.append(info["seizure_metric_raw"])
    idxs.append(info["seizure_index"])
    if term or trunc:
        obs, info = env.reset(seed=0, options={"regime": "normal"})

print("raw min/mean/max:", float(np.min(raws)), float(np.mean(raws)), float(np.max(raws)))
print("idx min/mean/max:", float(np.min(idxs)), float(np.mean(idxs)), float(np.max(idxs)))


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def run_diagnostic_episode(env, regime="normal", n_steps=150, action=None, seed=0, capture_step=10):
    """
    Runs one episode worth of steps under a given regime and returns logged arrays.
    action: if None -> zero action; else callable(step)->action or constant np.array(3,)
    capture_step: which step to capture a window trace from (v_window/u_window)
    """
    obs, info = env.reset(seed=seed, options={"regime": regime})

    raws = []
    idxs = []
    rews = []
    amps = []
    freqs = []
    pws = []

    # capture window traces (if available)
    cap = {"t": None, "v": None, "u": None, "step": None}

    for t in range(n_steps):
        if action is None:
            a = np.zeros(3, dtype=np.float32)
        elif callable(action):
            a = np.asarray(action(t), dtype=np.float32)
        else:
            a = np.asarray(action, dtype=np.float32)

        obs, reward, terminated, truncated, info = env.step(a)

        raws.append(float(info.get("seizure_metric_raw", np.nan)))
        idxs.append(float(info.get("seizure_index", np.nan)))
        rews.append(float(reward))

        # pull stim params if you store them
        # (adapt keys if your env uses different names)
        stim = info.get("stim_params", None)
        if stim is None:
            # fallback: if you store current_params on env
            if hasattr(env, "current_params") and env.current_params is not None:
                amp, freq, pw = env.current_params
            else:
                amp, freq, pw = (np.nan, np.nan, np.nan)
        else:
            amp = stim.get("amp_mA", np.nan)
            freq = stim.get("freq_Hz", np.nan)
            pw = stim.get("pw_us", stim.get("pw_ms", np.nan))  # support either

        amps.append(float(amp))
        freqs.append(float(freq))
        pws.append(float(pw))

        # try to capture an example window trace
        if t == capture_step:
            # Your env already has t_window/v_window/u_window inside _simulate_window.
            # If you store them (recommended), pull them here.
            if "t_window" in info and "v_window" in info and "u_window" in info:
                cap["t"] = np.asarray(info["t_window"])
                cap["v"] = np.asarray(info["v_window"])
                cap["u"] = np.asarray(info["u_window"])
                cap["step"] = t
            else:
                # If not in info, attempt to fetch from env fields if you stored them
                for k_t, k_v, k_u in [("_last_t_window", "_last_v_window", "_last_u_window"),
                                      ("_t_window", "_v_window", "_u_window")]:
                    if hasattr(env, k_t) and hasattr(env, k_v) and hasattr(env, k_u):
                        cap["t"] = np.asarray(getattr(env, k_t))
                        cap["v"] = np.asarray(getattr(env, k_v))
                        cap["u"] = np.asarray(getattr(env, k_u))
                        cap["step"] = t
                        break

        if terminated or truncated:
            break

    return {
        "regime": regime,
        "raw": np.asarray(raws),
        "idx": np.asarray(idxs),
        "rew": np.asarray(rews),
        "amp": np.asarray(amps),
        "freq": np.asarray(freqs),
        "pw": np.asarray(pws),
        "cap": cap,
    }


def plot_diagnostics(result, show_hist=True, psd=True):
    regime = result["regime"]
    raw = result["raw"]
    idx = result["idx"]
    amp = result["amp"]
    freq = result["freq"]
    pw = result["pw"]
    cap = result["cap"]

    steps = np.arange(len(raw))

    # 1) raw seizure metric
    plt.figure(figsize=(10, 4))
    plt.plot(steps, raw)
    plt.title(f"{regime}: raw seizure metric")
    plt.xlabel("step")
    plt.ylabel("raw")
    plt.show()

    # 2) seizure index
    plt.figure(figsize=(10, 4))
    plt.plot(steps, idx)
    plt.title(f"{regime}: seizure_index")
    plt.xlabel("step")
    plt.ylabel("index")
    plt.show()

    # 3) DBS params
    plt.figure(figsize=(10, 4))
    plt.plot(steps, amp, label="amp (mA)")
    plt.plot(steps, freq, label="freq (Hz)")
    plt.plot(steps, pw, label="pw (us or ms)")
    plt.title(f"{regime}: stimulation parameters")
    plt.xlabel("step")
    plt.legend()
    plt.show()

    # 4) example window traces
    if cap["t"] is not None and cap["v"] is not None:
        plt.figure(figsize=(10, 4))
        plt.plot(cap["t"], cap["v"], label="v(t)")
        if cap["u"] is not None:
            plt.plot(cap["t"], cap["u"], label="u(t)")
        plt.title(f"{regime}: example window trace (step {cap['step']})")
        plt.xlabel("time (s or internal units)")
        plt.legend()
        plt.show()

        # 5) PSD (optional)
        if psd:
            x = np.asarray(cap["v"], dtype=float)
            x = x - np.mean(x)
            if x.size >= 32:
                # estimate fs from time axis if possible
                t = np.asarray(cap["t"], dtype=float)
                if np.all(np.diff(t) > 0):
                    dt = float(np.median(np.diff(t)))
                    fs = 1.0 / dt
                else:
                    fs = 1000.0  # fallback

                from scipy.signal import welch
                f, Pxx = welch(x, fs=fs, nperseg=min(256, x.size))
                plt.figure(figsize=(10, 4))
                plt.semilogy(f, Pxx)
                plt.title(f"{regime}: PSD of v(t) (step {cap['step']})")
                plt.xlabel("Hz")
                plt.ylabel("PSD")
                plt.xlim(0, 80)
                plt.show()

    # 6) histograms
    if show_hist:
        plt.figure(figsize=(10, 4))
        plt.hist(raw[np.isfinite(raw)], bins=30)
        plt.title(f"{regime}: raw metric histogram")
        plt.xlabel("raw")
        plt.ylabel("count")
        plt.show()

        plt.figure(figsize=(10, 4))
        plt.hist(idx[np.isfinite(idx)], bins=30)
        plt.title(f"{regime}: index histogram")
        plt.xlabel("index")
        plt.ylabel("count")
        plt.show()

    # Print summary stats
    print(f"== {regime} summary ==")
    print("raw  min/mean/max:", float(np.nanmin(raw)), float(np.nanmean(raw)), float(np.nanmax(raw)))
    print("idx  min/mean/max:", float(np.nanmin(idx)), float(np.nanmean(idx)), float(np.nanmax(idx)))
    print("amp  min/mean/max:", float(np.nanmin(amp)), float(np.nanmean(amp)), float(np.nanmax(amp)))
    print("freq min/mean/max:", float(np.nanmin(freq)), float(np.nanmean(freq)), float(np.nanmax(freq)))
    print("pw   min/mean/max:", float(np.nanmin(pw)), float(np.nanmean(pw)), float(np.nanmax(pw)))


# ---- Usage ----
# 1) Calibrate
cal = env.calibrate_seizure_scale(n_windows=300, regime="normal", seed=0)
print("Calibration:", cal)

# 2) Reset after calibration
env.reset(seed=0, options={"regime": "normal"})

# 3) Diagnostics per regime
for reg in ["normal", "preictal", "ictal"]:
    res = run_diagnostic_episode(env, regime=reg, n_steps=150, action=None, seed=0, capture_step=10)
    plot_diagnostics(res, show_hist=True, psd=True)


In [None]:
# Sanity check: seizure_index under no stim in different regimes
def mean_idx(regime, n=200):
    env.reset(options={"regime": regime})
    xs = []
    for _ in range(n):
        _, _, term, trunc, info = env.step(np.zeros(3, dtype=np.float32))
        xs.append(info["seizure_index"])
        if term or trunc:
            env.reset(options={"regime": regime})
    return float(np.mean(xs)), float(np.std(xs))

print("normal :", mean_idx("normal"))
print("preictal:", mean_idx("preictal"))
print("ictal  :", mean_idx("ictal"))


In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(8,4))
plt.plot(episode_returns, label="Episode return")
# optional: moving average
window = 10
if len(episode_returns) >= window:
    ma = np.convolve(episode_returns,
                     np.ones(window)/window,
                     mode="valid")
    plt.plot(range(window-1, len(episode_returns)),
             ma, linestyle="--", label=f"{window}-ep moving avg")

plt.xlabel("Episode")
plt.ylabel("Return")
plt.title("DBS RL training: episode returns")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
obs, info = env.reset(options={"regime": "ictal"})  # start ictal
seiz_traj = []
dis_traj = []

done = False
while not done:
    action = np.zeros(3, dtype=float)  # or your trained policy(obs)
    obs, reward, terminated, truncated, info = env.step(action)
    seiz_traj.append(info["seizure_index"])
    dis_traj.append(info["disease_level"])
    done = terminated or truncated

plt.figure(figsize=(8, 4))
plt.plot(seiz_traj, label="seizure_index")
plt.plot(dis_traj, label="disease_level")
plt.axhline(env.seizure_baseline, linestyle="--", label="baseline")
plt.legend()
plt.xlabel("RL step")
plt.ylabel("Value")
plt.title("Seizure and disease dynamics in one episode")
plt.tight_layout()
plt.show()


In [None]:
# Start from an epileptic regime
obs, info = env.reset(options={"regime": "ictal"})

seiz_list = []
disease_list = []
reward_list = []
time_list = []

for step in range(env.max_steps):
    # Use the SAC agent's learned policy
    action = agent.select_action(obs)

    obs, reward, terminated, truncated, info = env.step(action)

    seiz_list.append(info["seizure_index"])
    disease_list.append(info["disease_level"])
    reward_list.append(reward)
    time_list.append(step * env.step_T)

    if terminated or truncated:
        break

import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 1, figsize=(8, 8), sharex=True)

axes[0].plot(time_list, seiz_list)
axes[0].set_ylabel("Seizure index")

axes[1].plot(time_list, disease_list)
axes[1].set_ylabel("Disease level")

axes[2].plot(time_list, reward_list)
axes[2].set_ylabel("Reward")
axes[2].set_xlabel("Time (s)")

plt.tight_layout()
plt.show()


In [None]:
env = EpilepsyDBSCombinedEnv(
    default_regime="ictal",
    log_best_episodes=True,   # <--- crucial
    n_best_episodes=1
)

obs, info = env.reset(options={"regime": "ictal"})

done = False
while not done:
    # for now, just keep DBS parameters fixed:
    action = np.zeros(3, dtype=float)
    obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated

episodes = env.get_best_episodes()
episode = episodes[0]   # only one


In [None]:
import matplotlib.pyplot as plt
import numpy as np

t = episode["t"]       # shape (N_samples,)
v = episode["v"]       # JR output
u = episode["u"]       # DBS drive signal

fig, (ax_stim, ax_seiz) = plt.subplots(
    2, 1, figsize=(10, 6), sharex=True,
    gridspec_kw={"height_ratios": [1, 2]}
)

# Top: stimulation signal (continuous DBS drive)
u_env = np.convolve(np.abs(u), np.ones(200)/200, mode="same")
ax_stim.plot(t, u_env)
ax_stim.set_ylabel("DBS |u(t)| envelope")
ax_stim.set_title("DBS stimulation")

# Bottom: 'seizure activity' – here using JR voltage
ax_seiz.plot(t, v)
ax_seiz.set_xlabel("Time (s)")
ax_seiz.set_ylabel("JR output v(t)")
ax_seiz.set_title("Network activity")

plt.tight_layout()
plt.show()


In [None]:
seiz_idx = episode["seizure_index"]   # shape (num_steps,)
rewards  = episode["rewards"]         # shape (num_steps,)

# Disease level isn't in the episode log yet; we read it during the episode.
# For now, do a fresh run with explicit logging outside the env:
env2 = EpilepsyDBSCombinedEnv(default_regime="ictal")
obs, info = env2.reset(options={"regime": "ictal"})

seiz_traj = []
dis_traj  = []
rew_traj  = []
t_steps   = []

done = False
step_idx = 0
while not done:
    action = np.zeros(3, dtype=float)
    obs, reward, terminated, truncated, info = env2.step(action)

    seiz_traj.append(info["seizure_index"])
    dis_traj.append(info["disease_level"])
    rew_traj.append(info["reward"])
    t_steps.append(step_idx * env2.step_T)

    step_idx += 1
    done = terminated or truncated

seiz_traj = np.array(seiz_traj)
dis_traj  = np.array(dis_traj)
rew_traj  = np.array(rew_traj)
t_steps   = np.array(t_steps)

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

ax1.plot(t_steps, seiz_traj, label="seizure_index")
ax1.plot(t_steps, dis_traj, label="disease_level")
ax1.axhline(getattr(env2, "seizure_baseline", 0.7),
            linestyle="--", label="baseline")
ax1.set_ylabel("Value")
ax1.set_title("RL-level seizure & disease dynamics")
ax1.legend()

ax2.plot(t_steps, rew_traj)
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Reward")
ax2.set_title("Reward per step")

plt.tight_layout()
plt.show()


In [None]:
best = env.get_best_episodes()

# Save each best episode as a separate .npz file
import numpy as np

for i, ep in enumerate(best):
    np.savez_compressed(
        f"best_episode_{i}.npz",
        t=ep["t"],
        v=ep["v"],
        u=ep["u"],
        theta=ep["theta"],
        amp_fast=ep["amp_fast"],
        seizure_index=ep["seizure_index"],
        pac_index=ep["pac_index"],
        energy_norm=ep["energy_norm"],
        amp=ep["amp"],
        freq=ep["freq"],
        pw=ep["pw"],
        rewards=ep["rewards"],
        total_reward=ep["total_reward"],
    )


In [None]:
def evaluate_fixed_dbs(
    env,
    regime="ictal",
    amp=5.0,
    freq=145.0,
    pw=90.0,
    n_steps=200,
):
    """
    Run an episode in `regime` with *fixed* DBS parameters and return
    time series and total reward.
    """
    obs, info = env.reset(options={"regime": regime})

    # Override current params and previous params to the desired fixed DBS
    amp, freq, pw = env._clip_params(amp, freq, pw)
    env.current_params = (amp, freq, pw)
    env.prev_params = env.current_params

    rewards = []
    seizure_ts = []
    disease_ts = []
    pac_ts = []
    energy_ts = []

    for _ in range(n_steps):
        # zero action = keep params unchanged
        action = np.zeros(3, dtype=float)
        obs, reward, terminated, truncated, info_step = env.step(action)

        rewards.append(reward)
        seizure_ts.append(info_step["seizure_index"])
        pac_ts.append(info_step["pac_index"])
        energy_ts.append(info_step["energy_norm"])
        disease_ts.append(env.disease_level)

        if terminated or truncated:
            break

    return {
        "rewards": np.array(rewards, dtype=float),
        "seizure": np.array(seizure_ts, dtype=float),
        "pac": np.array(pac_ts, dtype=float),
        "energy": np.array(energy_ts, dtype=float),
        "disease": np.array(disease_ts, dtype=float),
    }


In [None]:
param_sets = {
    "SANTE_like_high":  {"amp": 5.0, "freq": 145.0, "pw": 90.0},
    "SANTE_like_low":   {"amp": 2.5, "freq": 145.0, "pw": 90.0},
    "init_130Hz":       {"amp": 2.0, "freq": 130.0, "pw": 120.0},
    "no_stim":          {"amp": 0.0, "freq": 130.0, "pw": 120.0},
}

results = {}
for name, p in param_sets.items():
    out = evaluate_fixed_dbs(
        env,
        regime="ictal",     # start in ictal state
        amp=p["amp"],
        freq=p["freq"],
        pw=p["pw"],
        n_steps=env.max_steps,
    )
    results[name] = out
    total_R = out["rewards"].sum()
    mean_R = out["rewards"].mean()
    print(f"{name}: total reward = {total_R:.3f}, mean reward/step = {mean_R:.3f}")


In [None]:
import matplotlib.pyplot as plt

best = env.get_best_episodes()[-1]   # last run
t = best["t"]
v = best["v"]
u = best["u"]
seiz = best["seizure_index"]         # per-step values – you can upsample in time if you want
rewards = best["rewards"]

fig, axes = plt.subplots(3, 1, sharex=True, figsize=(10, 8))

axes[0].plot(t, u)
axes[0].set_ylabel("DBS drive u(t)")
axes[0].set_title("DBS waveform")

axes[1].plot(t, v)
axes[1].set_ylabel("JR output v(t)")

# Optionally interpolate seizure_index to same time grid as t
axes[2].step(range(len(seiz)), seiz, where="post")
axes[2].set_ylabel("Seizure index")
axes[2].set_xlabel("Step or time")

plt.tight_layout()
plt.show()


In [None]:
def mean_seiz(env, regime, n=200):
    try:
        env.reset(regime=regime)
    except TypeError:
        env.default_regime = regime
        env.reset()

    zs = []
    for _ in range(n):
        _, _, term, trunc, info = env.step(np.array([0.,0.,0.], dtype=np.float32))
        zs.append(info["seizure_index"])
        if term or trunc:
            try:
                env.reset(regime=regime)
            except TypeError:
                env.default_regime = regime
                env.reset()
    return float(np.mean(zs)), float(np.std(zs))

print("normal :", mean_seiz(env, "normal"))
print("preictal:", mean_seiz(env, "preictal"))
print("ictal  :", mean_seiz(env, "ictal"))


In [None]:
import pandas as pd

rows = []
for name, out in results.items():
    rows.append({
        "policy": name,
        "mean_reward": out["rewards"].mean(),
        "total_reward": out["rewards"].sum(),
        "mean_seiz": out["seizure"].mean(),
        "final_disease": out["disease"][-1],
        "mean_energy": out["energy"].mean(),
    })

df = pd.DataFrame(rows)
print(df)


In [None]:
import numpy as np
import pandas as pd

def _get_policy_action(policy, obs, deterministic=True):
    """
    Tries common policy interfaces.
    - If you have an SAC agent, rename `policy` accordingly or wrap it.
    """
    # 1) Stable-baselines3 style: policy.predict(obs, deterministic=True)
    if hasattr(policy, "predict"):
        a, _ = policy.predict(obs, deterministic=deterministic)
        return np.asarray(a, dtype=float).reshape(-1)

    # 2) Common custom: policy.act(obs, deterministic=True) or policy.select_action(obs, deterministic=True)
    for name in ["act", "select_action", "get_action"]:
        if hasattr(policy, name):
            fn = getattr(policy, name)
            try:
                a = fn(obs, deterministic=deterministic)
            except TypeError:
                a = fn(obs)
            return np.asarray(a, dtype=float).reshape(-1)

    raise AttributeError("Could not infer how to get actions from `policy`. Provide a wrapper with a .predict() or .act().")


def eval_policy(env, policy=None, regime="normal", n_episodes=5, max_steps=100, deterministic=True, kind="learned"):
    rows = []
    for ep in range(n_episodes):
        obs, info = env.reset(options={"regime": regime})
        ep_ret = 0.0
        seiz = []
        energy = []
        slew = []
        disease = []
        for t in range(max_steps):
            if kind == "no_action":
                action = np.zeros(3, dtype=float)
            elif kind == "random_action":
                action = np.random.uniform(-1.0, 1.0, size=(3,))
            else:
                action = _get_policy_action(policy, obs, deterministic=deterministic)

            obs, reward, terminated, truncated, info = env.step(action)
            ep_ret += float(reward)

            seiz.append(float(info.get("seizure_index", np.nan)))
            energy.append(float(info.get("energy_norm", np.nan)))
            slew.append(float(info.get("slew_penalty", np.nan)))
            disease.append(float(info.get("disease_level", np.nan)))

            if terminated or truncated:
                break

        rows.append({
            "regime": regime,
            "policy": kind if kind != "learned" else "learned",
            "episode_return": ep_ret,
            "mean_seizure": float(np.nanmean(seiz)),
            "mean_energy": float(np.nanmean(energy)),
            "mean_slew": float(np.nanmean(slew)),
            "final_disease": float(disease[-1]) if len(disease) else np.nan,
            "steps": len(seiz),
        })
    return pd.DataFrame(rows)


# ---- Run evaluation ----
# Set `policy` to your trained SAC policy/agent object.
# Example:
# policy = agent
policy = None  # <-- CHANGE THIS to your trained SAC agent/policy

regimes = ["normal", "preictal", "ictal"]
all_rows = []

for r in regimes:
    all_rows.append(eval_policy(env, policy=None, regime=r, n_episodes=5, max_steps=env.max_steps, kind="no_action"))
    all_rows.append(eval_policy(env, policy=None, regime=r, n_episodes=5, max_steps=env.max_steps, kind="random_action"))
    if policy is not None:
        all_rows.append(eval_policy(env, policy=policy, regime=r, n_episodes=5, max_steps=env.max_steps, kind="learned", deterministic=True))

df_eval = pd.concat(all_rows, ignore_index=True)
summary = df_eval.groupby(["regime", "policy"]).agg(
    mean_return=("episode_return", "mean"),
    std_return=("episode_return", "std"),
    mean_seizure=("mean_seizure", "mean"),
    mean_energy=("mean_energy", "mean"),
    mean_slew=("mean_slew", "mean"),
    final_disease=("final_disease", "mean"),
    mean_steps=("steps", "mean"),
).reset_index()

summary


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(episode_summaries)

# choose a regime to rank within (ictal is most meaningful)
df_ictal = df[df["regime"] == "ictal"].copy()

# best episodes by return
top = df_ictal.sort_values("return", ascending=False).head(10)

print(top[["episode","return","mean_seiz_idx","avg_amp","avg_freq","avg_pw"]])

# Scatter plots of param sets for top episodes
plt.figure()
plt.scatter(top["avg_amp"], top["avg_freq"])
plt.xlabel("Mean amplitude (mA)")
plt.ylabel("Mean frequency (Hz)")
plt.title("Top-10 ictal episodes: mean amp vs mean freq")
plt.show()

plt.figure()
plt.scatter(top["avg_amp"], top["avg_pw"])
plt.xlabel("Mean amplitude (mA)")
plt.ylabel("Mean pulse width (us)")
plt.title("Top-10 ictal episodes: mean amp vs mean pw")
plt.show()

plt.figure()
plt.scatter(top["avg_freq"], top["avg_pw"])
plt.xlabel("Mean frequency (Hz)")
plt.ylabel("Mean pulse width (us)")
plt.title("Top-10 ictal episodes: mean freq vs mean pw")
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Episode return curve
plt.figure()
plt.plot(episode_returns)
plt.xlabel("Episode")
plt.ylabel("Return")
plt.title("Episode returns")
plt.show()

# SAC internal curves (if available)
def safe_plot(arr, title, ylabel):
    if arr is None or len(arr) == 0:
        print(f"[Missing] {title} not logged")
        return
    plt.figure()
    plt.plot(arr)
    plt.xlabel("Update step")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.show()

safe_plot(getattr(agent, "critic_losses", None), "Critic loss", "loss")
safe_plot(getattr(agent, "actor_losses", None), "Actor loss", "loss")
safe_plot(getattr(agent, "alpha_hist", None), "Alpha (temperature)", "alpha")
safe_plot(getattr(agent, "entropy_hist", None), "Policy entropy estimate", "entropy")


In [None]:
BASELINES = {
    "SANTE_low":  {"amp": 1.0, "freq": 145.0, "pw": 90.0},
    "SANTE_high": {"amp": 3.0, "freq": 145.0, "pw": 90.0},
    "standard_130Hz": {"amp": 2.0, "freq": 130.0, "pw": 90.0},
    "low_freq": {"amp": 2.0, "freq": 50.0, "pw": 90.0},
}

import numpy as np
import pandas as pd

def rollout_fixed(env, params, regime="ictal", seed=0):
    obs, info = env.reset(seed=seed, options={"regime": regime})
    env.current_params = (params["amp"], params["freq"], params["pw"])
    env.prev_params = env.current_params

    rets = 0.0
    seiz = []
    energy = []
    slew = []

    for t in range(env.max_steps):
        # action is ignored if you directly set current_params; so instead
        # you should implement a helper env.step_params(...) or map params -> action
        # If your env uses actions only, do params->action mapping here.
        # For now, assume you have a function env.params_to_action(...)
        a = env.params_to_action(params["amp"], params["freq"], params["pw"])
        obs, r, term, trunc, info = env.step(a)
        rets += float(r)
        seiz.append(float(info["seizure_index"]))
        energy.append(float(info["cost_components"]["energy"]))
        slew.append(float(info["cost_components"]["slew"]))
        if term or trunc:
            break

    return {
        "return": rets,
        "mean_seizure_index": float(np.mean(seiz)),
        "max_seizure_index": float(np.max(seiz)),
        "mean_energy_cost": float(np.mean(energy)),
        "mean_slew_cost": float(np.mean(slew)),
        "steps": len(seiz),
    }


In [None]:
rows = []
for name, p in BASELINES.items():
    out = rollout_fixed(env, p, regime="ictal", seed=0)
    out.update({"policy": name})
    rows.append(out)

df_base = pd.DataFrame(rows)
print(df_base.sort_values("return", ascending=False))


In [None]:
env.seizure_metric = "line_length"  # use this if bandpower seems unstable at 100 ms
print(env.calibrate_seizure_scale(n_windows=300, regime="normal", seed=0))

def mean_index(regime, n=100):
    env.reset(regime=regime)
    idxs = []
    for _ in range(n):
        obs, r, done, info = env.step(np.array([0.0, 0.0, 0.0], dtype=np.float32))
        idxs.append(info["seizure_index"])
        if done:
            env.reset(regime=regime)
    return float(np.mean(idxs)), float(np.std(idxs))

print("normal mean/std:", mean_index("normal"))
print("ictal  mean/std:", mean_index("ictal"))
print("preictal mean/std:", mean_index("preictal"))


In [None]:
import matplotlib.pyplot as plt

def _maybe_get(obj, names):
    for n in names:
        if hasattr(obj, n):
            return getattr(obj, n)
    return None

# Replace `agent` with your SAC agent variable if it exists in your notebook.
agent = sac_agent  # or whatever your instance variable is called

if agent is None:
    print("Set `agent = <your SAC agent>` for this cell.")
else:
    # Common history field names
    critic_hist = _maybe_get(agent, ["critic_losses", "q_losses", "critic_loss_hist", "loss_q_hist"])
    actor_hist  = _maybe_get(agent, ["actor_losses", "actor_loss_hist", "loss_pi_hist"])
    alpha_hist  = _maybe_get(agent, ["alpha_hist", "temperature_hist", "log_alpha_hist"])
    ent_hist    = _maybe_get(agent, ["entropy_hist", "policy_entropy_hist"])

    found_any = any(x is not None for x in [critic_hist, actor_hist, alpha_hist, ent_hist])

    if not found_any:
        print(
            "I could not find stored loss/alpha/entropy histories on `agent`.\n\n"
            "To log learning, you need to append scalars during each SAC update, e.g.:\n"
            "  self.critic_losses.append(float(q_loss))\n"
            "  self.actor_losses.append(float(pi_loss))\n"
            "  self.alpha_hist.append(float(alpha))   # if autotuning\n"
            "  self.entropy_hist.append(float(entropy))\n\n"
            "Where to add:\n"
            "- right after critic backward/step\n"
            "- right after actor backward/step\n"
            "- right after alpha update (if any)\n"
        )
    else:
        # Plot whatever exists
        plt.figure()
        if critic_hist is not None:
            plt.plot(np.asarray(critic_hist, dtype=float))
            plt.title("Critic loss history")
            plt.xlabel("Update")
            plt.ylabel("Loss")
            plt.show()

        plt.figure()
        if actor_hist is not None:
            plt.plot(np.asarray(actor_hist, dtype=float))
            plt.title("Actor loss history")
            plt.xlabel("Update")
            plt.ylabel("Loss")
            plt.show()

        if alpha_hist is not None:
            plt.figure()
            plt.plot(np.asarray(alpha_hist, dtype=float))
            plt.title("Alpha / temperature history")
            plt.xlabel("Update")
            plt.ylabel("Alpha")
            plt.show()

        if ent_hist is not None:
            plt.figure()
            plt.plot(np.asarray(ent_hist, dtype=float))
            plt.title("Policy entropy history")
            plt.xlabel("Update")
            plt.ylabel("Entropy")
            plt.show()


In [None]:
import numpy as np
import pandas as pd

# Set your trained policy/agent object here
policy = SACAgent  # <-- set to your trained SAC policy/agent

def collect_actions(env, policy, regime="normal", max_steps=100, deterministic=True):
    obs, info = env.reset(options={"regime": regime})
    acts = []
    infos = []
    for _ in range(max_steps):
        a = _get_policy_action(policy, obs, deterministic=deterministic)
        obs, reward, terminated, truncated, info = env.step(a)
        acts.append(a)
        infos.append(info)
        if terminated or truncated:
            break
    return np.asarray(acts, dtype=float), infos

def summarize_actions(actions):
    # proportions relative to your quantisation thresholds
    thr = 1.0/3.0
    cols = []
    for i in range(actions.shape[1]):
        x = actions[:, i]
        cols.append({
            "dim": i,
            "mean": float(np.mean(x)),
            "std": float(np.std(x)),
            "p_neg": float(np.mean(x < -thr)),
            "p_zero": float(np.mean((x >= -thr) & (x <= thr))),
            "p_pos": float(np.mean(x > thr)),
        })
    return pd.DataFrame(cols)

if policy is None:
    print("Set `policy = <your trained SAC policy/agent>` first.")
else:
    for r in ["normal", "preictal", "ictal"]:
        acts, infos = collect_actions(env, policy, regime=r, max_steps=env.max_steps, deterministic=True)
        print(f"\nRegime: {r}, steps: {len(acts)}")
        display(summarize_actions(acts))


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def run_fixed_policy(env, action_vec, regime="ictal", n_steps=100, label="policy"):
    obs, info = env.reset(options={"regime": regime})
    seiz, energy, amp, freq, pw = [], [], [], [], []
    for _ in range(n_steps):
        obs, reward, terminated, truncated, info = env.step(np.array(action_vec, dtype=float))
        seiz.append(float(info["seizure_index"]))
        energy.append(float(info["energy_norm"]))
        amp.append(float(info["amp"]))
        freq.append(float(info["freq"]))
        pw.append(float(info["pw"]))
        if terminated or truncated:
            break
    return {
        "label": label,
        "seizure": np.asarray(seiz),
        "energy": np.asarray(energy),
        "amp": np.asarray(amp),
        "freq": np.asarray(freq),
        "pw": np.asarray(pw),
    }

regime = "ictal"
N = env.max_steps

runs = [
    run_fixed_policy(env, [0, 0, 0], regime=regime, n_steps=N, label="no_change"),
    run_fixed_policy(env, [+1, +1, +1], regime=regime, n_steps=N, label="always_increase"),
    run_fixed_policy(env, [-1, -1, -1], regime=regime, n_steps=N, label="always_decrease"),
]

plt.figure()
for r in runs:
    plt.plot(r["seizure"], label=r["label"])
plt.title(f"Seizure index under fixed policies ({regime})")
plt.xlabel("Step")
plt.ylabel("Seizure index")
plt.legend()
plt.show()

plt.figure()
for r in runs:
    plt.plot(r["energy"], label=r["label"])
plt.title(f"Energy norm under fixed policies ({regime})")
plt.xlabel("Step")
plt.ylabel("Energy norm")
plt.legend()
plt.show()

plt.figure()
for r in runs:
    plt.plot(r["amp"], label=r["label"])
plt.title(f"Amp (mA) under fixed policies ({regime})")
plt.xlabel("Step")
plt.ylabel("mA")
plt.legend()
plt.show()

plt.figure()
for r in runs:
    plt.plot(r["freq"], label=r["label"])
plt.title(f"Freq (Hz) under fixed policies ({regime})")
plt.xlabel("Step")
plt.ylabel("Hz")
plt.legend()
plt.show()

plt.figure()
for r in runs:
    plt.plot(r["pw"], label=r["label"])
plt.title(f"PW (us) under fixed policies ({regime})")
plt.xlabel("Step")
plt.ylabel("us")
plt.legend()
plt.show()
