In [None]:
import time, math
import numpy as np
import matplotlib.pyplot as plt
import argparse  # <<< NEW IMPORT
from collections import deque
from scipy.signal import welch, butter, filtfilt
from pythonosc.udp_client import SimpleUDPClient

# Try to import Unicorn; we’ll fallback gracefully if unavailable.
try:
    import UnicornPy

    HAVE_UNICORN = True
except Exception:
    HAVE_UNICORN = False

# =========================
# FLAGS
# =========================
FORCE_MOCK = False  # force mock even if Unicorn is available
ORACLE_WHEN_MOCK = True  # when we're in mock mode, use the A/V oracle instead of mock EEG

# =========================
# CONFIG
# =========================
FS = 250
WIN_SEC = 2.0
HOP_SEC = 0.25
N_WIN = int(WIN_SEC * FS)
N_HOP = int(HOP_SEC * FS)

OSC_IP, OSC_PORT = "127.0.0.1", 9000
ADDR_A = "/arousal"
ADDR_V = "/valence"
ADDR_ART = "/artifact"
OSC_PARAMS = ["drums", "pad", "tempo", "grain"]

THETA = (4, 8)
ALPHA = (8, 12)
BETA = (13, 30)
CH_NAMES = ["Fz", "Cz", "F3", "F4", "FP1", "FP2", "PO7", "PO8"]

BASELINE_SEC = 60
FZ_SPIKE_UV = 1000.0

# Snappier EMA for live feel (was 2.5)
EMA_TAU = 3.0
EMA_RATE = 1.0 / HOP_SEC
EMA_ALPHA = 1.0 - math.exp(-EMA_RATE / EMA_TAU)

# Target emotional state
TARGETS = {
    "calm": {"A_star": 0.30, "V_star": 0.80},
    "focus": {"A_star": 0.45, "V_star": 0.55},
    "energize": {"A_star": 0.75, "V_star": 0.60},
}
ACTIVE_TARGET = "calm"
TARGET = TARGETS[ACTIVE_TARGET]

# Adaptive scaler for EEG mode
ADAPT_ALPHA = 0.01
K_SCALE = 2.0
MAD_MIN_LIMIT = 0.05  # <<< NEW: Minimum floor for A_mad/V_mad
# SPSA Config
STUCK_P_EPS = 1e-5
STUCK_P_N = 5


# =========================
# UTIL
# =========================
def ema(prev, x, alpha):
    if prev is None:
        return x
    return (1 - alpha) * prev + alpha * x


def distance(A, V):
    return math.sqrt((A - TARGET["A_star"]) ** 2 + (V - TARGET["V_star"]) ** 2)


# =========================
# FEATURE EXTRACTION (EEG mode)
# ... (bandpower_1d, HP, LP, prefilter_window, compute_features, squash_tanh remain unchanged)
# =========================

def bandpower_1d(x, fs, fmin, fmax):
    f, Pxx = welch(x, fs=fs, nperseg=min(len(x), 256))
    idx = (f >= fmin) & (f <= fmax)
    return float(Pxx[idx].mean()) if np.any(idx) else 0.0


HP = butter(2, 1.0 / (FS / 2), btype="highpass")
LP = butter(2, 40.0 / (FS / 2), btype="lowpass")


def prefilter_window(win):
    for i in range(win.shape[0]):
        x = filtfilt(*HP, win[i])
        x = filtfilt(*LP, x)
        win[i] = x
    return win


def compute_features(window):
    """Return arousal_raw, valence_raw, artifact_flag."""
    window = prefilter_window(window.copy())

    # artifact detection still based on Fz
    fz = window[CH_NAMES.index("Fz")]
    artifact = int(np.max(np.abs(fz)) > FZ_SPIKE_UV)

    # compute bandpowers
    alpha = np.array([bandpower_1d(window[i], FS, *ALPHA) for i in range(len(CH_NAMES))])
    theta = np.array([bandpower_1d(window[i], FS, *THETA) for i in range(len(CH_NAMES))])
    beta = np.array([bandpower_1d(window[i], FS, *BETA) for i in range(len(CH_NAMES))])

    # --- new: use PO7, PO8, Cz for arousal
    pick_arousal = [CH_NAMES.index(n) for n in ["PO7", "PO8", "Cz"]]
    a_bar = alpha[pick_arousal].mean()
    t_bar = theta[pick_arousal].mean()
    b_bar = beta[pick_arousal].mean()

    eps = 1e-12
    arousal_raw = math.log(b_bar + eps) - math.log(a_bar + t_bar + eps)

    # --- new: valence from combined F3/F4 and FP1/FP2 alpha asymmetry
    f3, f4 = CH_NAMES.index("F3"), CH_NAMES.index("F4")
    fp1, fp2 = CH_NAMES.index("FP1"), CH_NAMES.index("FP2")

    val1 = math.log(alpha[f4] + eps) - math.log(alpha[f3] + eps)
    val2 = math.log(alpha[fp2] + eps) - math.log(alpha[fp1] + eps)
    valence_raw = 0.5 * (val1 + val2)

    return arousal_raw, valence_raw, artifact


def squash_tanh(x, m, s, k=K_SCALE):
    z = (x - m) / (k * s + 1e-9)
    return 0.5 + 0.5 * math.tanh(z)


# =========================
# EEG SOURCES (EEG mode)
# ... (UnicornSource, MockEEGSource, mock_av_response remain unchanged)
# =========================

class UnicornSource:
    def __init__(self):
        devs = UnicornPy.GetAvailableDevices(True)
        if not devs:
            raise RuntimeError("No Unicorn devices found.")
        self.name = devs[0]
        self.dev = UnicornPy.Unicorn(self.name)
        self.n_ch = self.dev.GetNumberOfAcquiredChannels()
        self.dev.StartAcquisition(False)
        print("Connected:", self.name, "with", self.n_ch, "channels.")

    def pull_chunk(self, n_samples, frame_len=8):
        out = np.zeros((self.n_ch, n_samples), dtype=np.float32)
        got = 0
        while got < n_samples:
            need = min(frame_len, n_samples - got)
            buf_bytes = bytearray(need * self.n_ch * 4)
            self.dev.GetData(need, buf_bytes, len(buf_bytes))
            arr = np.frombuffer(buf_bytes, dtype=np.float32).reshape((need, self.n_ch)).T
            out[:, got:got + need] = arr
            got += need
        return out[:len(CH_NAMES), :]

    def set_current_params(self, params_dict):
        pass

    def close(self):
        try:
            self.dev.StopAcquisition()
        finally:
            del self.dev
            print("Device closed.")


class MockEEGSource:
    """EEG-like signal; used only in mock EEG mode when not using the oracle."""

    def __init__(self, fs=FS, n_ch=len(CH_NAMES)):
        self.fs = fs
        self.n_ch = n_ch
        self.t = 0
        self._bias_alpha = 0.0
        self._bias_beta = 0.0
        rng = np.random.default_rng(42)
        self.phase = rng.uniform(0, 2 * np.pi, size=(n_ch, 3))
        self.alpha_base = rng.uniform(0.6, 1.0, size=n_ch)
        self.beta_base = rng.uniform(0.4, 0.9, size=n_ch)
        self.theta_base = rng.uniform(0.5, 1.0, size=n_ch)
        print("Mock EEG source active.")

    def set_current_params(self, params_dict):
        d = params_dict.get("drums", 0.5)
        p = params_dict.get("pad", 0.5)
        t = params_dict.get("tempo", 0.5)
        g = params_dict.get("grain", 0.5)
        # stronger, plausible coupling (so you can observe effects)
        self._bias_beta = 0.35 * (d - 0.5) + 0.35 * (t - 0.5) + 0.10 * (g - 0.5)
        self._bias_alpha = 0.35 * (p - 0.5) - 0.08 * (g - 0.5)

    def pull_chunk(self, n_samples, frame_len=8):
        ts = (self.t + np.arange(n_samples)) / self.fs
        self.t += n_samples
        ftheta, falpha, fbeta = 6.0, 10.0, 20.0
        theta = np.sin(2 * np.pi * ftheta * ts)[None, :]
        alpha = np.sin(2 * np.pi * falpha * ts)[None, :]
        beta = np.sin(2 * np.pi * fbeta * ts)[None, :]

        slow = np.sin(2 * np.pi * 0.05 * ts)[None, :]
        out = np.zeros((self.n_ch, n_samples), dtype=np.float32)
        for ch in range(self.n_ch):
            th = self.theta_base[ch] * (theta * np.cos(self.phase[ch, 0]))
            al = (self.alpha_base[ch] + self._bias_alpha) * (alpha * np.cos(self.phase[ch, 1]))
            be = (self.beta_base[ch] + self._bias_beta) * (beta * np.cos(self.phase[ch, 2]))
            signal = 15.0 * (th + 1.2 * al + 0.8 * be)
            signal += 5.0 * slow
            noise = np.random.normal(0, 3.0, size=(1, n_samples))
            out[ch] = (signal + noise).astype(np.float32)

        if np.random.rand() < 0.02:
            idx = CH_NAMES.index("Fz")
            pos = np.random.randint(0, n_samples)
            out[idx, pos:pos + 1] += 1500.0
        return out

    def close(self):
        print("Mock EEG closed.")


# =========================
# ORACLE: causal mapping params->[A,V] for verification
# =========================
def mock_av_response(params, noise_sigma=0.05):
    """Causal, smooth mapping + small noise."""
    d = float(params["drums"])
    p = float(params["pad"])
    t = float(params["tempo"])
    g = float(params["grain"])
    arousal = 0.6 * d + 0.7 * t - 0.2 * p - 0.1 * g + np.random.normal(0, noise_sigma)
    valence = 0.5 * p + 0.2 * t - 0.3 * d + 0.1 * g + np.random.normal(0, 0.05)
    return float(np.clip(arousal, 0.0, 1.0)), float(np.clip(valence, 0.0, 1.0))


# =========================
# SPSA (continuous 4D in [0,1]) with exploration & “stuck” kick
# =========================
class SPSAOptimizer:
    def __init__(self, dim=4, low=0.0, high=1.0,
                 alpha0=0.15, c0=0.12, momentum=0.8,
                 sigma0=0.03, sigma_decay=0.995,
                 min_c=0.03, stuck_eps=1e-3, stuck_kick=0.08,
                 stuck_p_eps=STUCK_P_EPS, stuck_p_n=STUCK_P_N,  # <<< MODIFIED
                 seed=123):
        self.dim = dim
        self.low, self.high = low, high
        self.alpha0 = alpha0
        self.c0 = c0
        self.momentum = momentum
        self.sigma = sigma0
        self.sigma_decay = sigma_decay
        self.min_c = min_c
        self.stuck_eps = stuck_eps
        self.stuck_kick = stuck_kick
        self.stuck_p_eps = stuck_p_eps  # <<< NEW
        self.stuck_p_n = stuck_p_n  # <<< NEW
        self.rng = np.random.default_rng(seed)
        self.p = np.full(dim, 0.5)
        self.v = np.zeros(dim)
        self._Delta = None
        self._r_plus = None
        self._r_minus = None
        self._prev_distance = None
        self.t = 0
        self.stuck_count = 0  # <<< NEW

    def _proj(self, x):
        return np.clip(x, self.low, self.high)

    def _alpha(self):
        return self.alpha0 / (1.0 + 0.02 * self.t)

    def _c(self):
        return max(self.min_c, self.c0 / math.sqrt(1.0 + 0.01 * self.t))

    def begin_iteration(self, current_distance):
        self.t += 1
        self._prev_distance = current_distance
        self._Delta = self.rng.choice([-1.0, 1.0], size=self.dim)
        # exploration on the center
        self.p = self._proj(self.p + self.rng.normal(0.0, self.sigma, size=self.dim))
        self.sigma = max(0.005, self.sigma * self.sigma_decay)

    def action_plus(self):
        c = self._c()
        return self._proj(self.p + c * self._Delta)

    def action_minus(self):
        c = self._c()
        return self._proj(self.p - c * self._Delta)

    def observe_reward_plus(self, new_distance):
        if self._prev_distance is None:
            self._r_plus = 0.0
        else:
            self._r_plus = float(np.clip(self._prev_distance - new_distance, -0.4, 0.4))

    def observe_reward_minus(self, new_distance):
        if self._prev_distance is None:
            self._r_minus = 0.0
        else:
            self._r_minus = float(np.clip(self._prev_distance - new_distance, -0.4, 0.4))

    def update(self):
        if self._r_plus is None or self._r_minus is None or self._Delta is None:
            return
        c = self._c()
        g_hat = ((self._r_plus - self._r_minus) / (2.0 * c)) * self._Delta

        # SPSA update step
        old_p = self.p.copy()  # Store p before update
        self.v = self.momentum * self.v + (1.0 - self.momentum) * g_hat
        self.p = self._proj(self.p + self._alpha() * self.v)

        # --- MODIFIED STUCK DETECTION ---
        # 1. Check if the parameter update was too small (distance moved)
        p_change = np.linalg.norm(self.p - old_p)

        # 2. Check if the gradient was near zero (original logic)
        grad_small = abs(self._r_plus - self._r_minus) < self.stuck_eps

        # 3. Check for consecutive small updates
        if p_change < self.stuck_p_eps:
            self.stuck_count += 1
        else:
            self.stuck_count = 0

        # Kick if stuck count or gradient small
        if grad_small or self.stuck_count >= self.stuck_p_n:
            kick = self.rng.normal(0.0, self.stuck_kick, size=self.dim)
            self.p = self._proj(self.p + kick)
            self.stuck_count = 0  # Reset after kick

        self._Delta = None
        self._r_plus = None
        self._r_minus = None
        self._prev_distance = None

    def current_params_dict(self):
        return {"drums": float(self.p[0]), "pad": float(self.p[1]),
                "tempo": float(self.p[2]), "grain": float(self.p[3])}


# =========================
# LIVE PLOT (3s averages only)
# ... (LivePlot class remains unchanged)
# =========================
from collections import deque


class LivePlot:
    def __init__(self, A_star=None, V_star=None):
        self.times = []
        self.A_vals = deque(maxlen=int(15 / HOP_SEC))  # rolling window of 15 s
        self.V_vals = deque(maxlen=int(15 / HOP_SEC))
        self.param_means = {k: [] for k in ["drums", "pad", "tempo", "grain"]}
        self.t0 = time.time()

        plt.ion()
        self.fig, self.axs = plt.subplots(2, 1, figsize=(8, 6), sharex=True)

        # A/V subplot
        self.lA, = self.axs[0].plot([], [], "C0-", label="Arousal (15s avg)")
        self.lV, = self.axs[0].plot([], [], "C1-", label="Valence (15s avg)")
        if A_star is not None:
            self.axs[0].axhline(A_star, linestyle="--", linewidth=1, color="C2", label="A* target")
        if V_star is not None:
            self.axs[0].axhline(V_star, linestyle="--", linewidth=1, color="C3", label="V* target")
        self.axs[0].set_ylabel("Arousal / Valence")
        self.axs[0].legend(loc="upper right")

        # Params subplot
        self.lines_params = {k: self.axs[1].plot([], [], "-o", label=k)[0]
                             for k in self.param_means}
        self.axs[1].legend(loc="upper right")
        self.axs[1].set_ylabel("Music Params")
        self.axs[1].set_xlabel("Time (s)")

        plt.tight_layout()
        plt.show(block=False)

    def update(self, A_new, V_new, params_mean):
        t = time.time() - self.t0
        self.times.append(t)
        self.A_vals.append(A_new)
        self.V_vals.append(V_new)

        # rolling means
        A_mean = float(np.mean(self.A_vals))
        V_mean = float(np.mean(self.V_vals))

        self.lA.set_data(self.times, [A_mean] * len(self.times))
        self.lV.set_data(self.times, [V_mean] * len(self.times))

        for k in self.param_means:
            self.param_means[k].append(params_mean.get(k, np.nan))
            self.lines_params[k].set_data(self.times, self.param_means[k])

        for ax in self.axs:
            ax.relim()
            ax.autoscale_view()

        self.fig.canvas.draw_idle()
        self.fig.canvas.flush_events()


# =========================
# ARG PARSING (NEW)
# =========================
def parse_args():
    """Parses command-line arguments to configure the BCMI."""
    parser = argparse.ArgumentParser(description="Unicorn DJ: BCMI for Music Control.")

    parser.add_argument("--mock-force", action="store_true", default=FORCE_MOCK,
                        help="Force mock mode, even if Unicorn is available.")

    parser.add_argument("--oracle", action="store_true", default=ORACLE_WHEN_MOCK,
                        help="Use A/V oracle when in mock mode (Mock AV -> SPSA).")

    parser.add_argument("--target", type=str, choices=TARGETS.keys(), default=ACTIVE_TARGET,
                        help=f"Target emotional state. Choices: {list(TARGETS.keys())}")

    parser.add_argument("--osc-ip", type=str, default=OSC_IP,
                        help="OSC target IP address.")

    parser.add_argument("--osc-port", type=int, default=OSC_PORT,
                        help="OSC target port.")

    args = parser.parse_args()
    return args


# =========================
# MAIN (Modified)
# =========================
def main():
    args = parse_args()  # <<< NEW: Parse arguments

    # Apply parsed arguments
    force_mock = args.mock_force
    oracle_when_mock = args.oracle

    # Update global/target/OSC variables based on arguments
    global TARGET, ACTIVE_TARGET, OSC_IP, OSC_PORT
    ACTIVE_TARGET = args.target
    TARGET = TARGETS[ACTIVE_TARGET]
    OSC_IP = args.osc_ip
    OSC_PORT = args.osc_port

    print(f"--- ACTIVE TARGET: {ACTIVE_TARGET} (A*={TARGET['A_star']:.2f}, V*={TARGET['V_star']:.2f}) ---")

    client = SimpleUDPClient(OSC_IP, OSC_PORT)

    # Decide mode (using parsed args)
    mode = None
    src = None
    if not force_mock and HAVE_UNICORN:
        try:
            src = UnicornSource()
            mode = "EEG_REAL"
        except Exception as e:
            print(f"Unicorn init failed ({e}). Falling back to mock.")
            mode = "MOCK_ORACLE" if oracle_when_mock else "MOCK_EEG"
            src = None
    else:
        mode = "MOCK_ORACLE" if oracle_when_mock else "MOCK_EEG"

    # Params start at center
    current_params = {name: 0.5 for name in OSC_PARAMS}
    for k, v in current_params.items():
        client.send_message(f"/music/{k}", float(v))

    # If mock EEG pipeline, create source
    if mode == "MOCK_EEG":
        src = MockEEGSource()
        src.set_current_params(current_params)

    A_out = 0.5
    V_out = 0.5
    art = 0

    # If oracle, initialize outputs immediately
    if mode == "MOCK_ORACLE":
        A_out, V_out = mock_av_response(current_params)
        client.send_message(ADDR_A, float(A_out))
        client.send_message(ADDR_V, float(V_out))
        client.send_message(ADDR_ART, 0)

    # EEG baseline/scaler only when in real/mock EEG modes
    if mode in ("EEG_REAL", "MOCK_EEG"):
        print("Collecting baseline...")
        buf = np.zeros((len(CH_NAMES), N_WIN))
        write_idx = 0
        A_hist, V_hist = [], []
        baseline_samples = int(BASELINE_SEC * FS)
        samples_collected, hop_accum = 0, 0
        while samples_collected < baseline_samples:
            chunk = src.pull_chunk(min(N_HOP, baseline_samples - samples_collected))
            for i in range(chunk.shape[1]):
                buf[:, write_idx % N_WIN] = chunk[:, i]
                write_idx += 1
                samples_collected += 1
                hop_accum += 1
                if hop_accum == N_HOP:
                    idx = np.arange(write_idx - N_WIN, write_idx) % N_WIN
                    win = buf[:, idx]
                    a, v, _ = compute_features(win)
                    A_hist.append(a);
                    V_hist.append(v)
                    hop_accum = 0
        A_center = float(np.mean(A_hist))
        V_center = float(np.mean(V_hist))
        A_mad = float(np.median(np.abs(A_hist - np.median(A_hist)))) or MAD_MIN_LIMIT
        V_mad = float(np.median(np.abs(V_hist - np.median(V_hist)))) or MAD_MIN_LIMIT
        ema_A, ema_V = float(A_hist[-1]), float(V_hist[-1])
        print("Baseline ready. Center A =", round(A_center, 4), "V =", round(V_center, 4))
    else:
        # dummy values to satisfy references
        buf = None;
        write_idx = 0
        A_center = V_center = 0.0
        A_mad = V_mad = 1.0
        ema_A = ema_V = 0.0

    live = LivePlot(A_star=TARGET["A_star"], V_star=TARGET["V_star"])
    opt = SPSAOptimizer(dim=4, low=0.0, high=1.0)  # Uses new SPSA defaults

    # --- NEW: buffers for averaging A/V per phase (+ / -)
    plus_A, plus_V = [], []
    minus_A, minus_V = [], []
    settle_hops = 4  # ignore first ~0.5 s after param change (with HOP_SEC=0.25)
    phase_hop_index = 0  # hops since current phase started

    hop_counter = 0
    epoch_hops = 12  # ~3 s/epoch at 0.25 s hop
    phase = 0  # 0=idle, 1=plus, 2=minus

    try:
        while True:
            if mode == "MOCK_ORACLE":
                # no EEG I/O; time step
                pass
            else:
                # REAL EEG or MOCK EEG pipeline
                hop = src.pull_chunk(N_HOP)
                for i in range(hop.shape[1]):
                    buf[:, write_idx % N_WIN] = hop[:, i]
                    write_idx += 1

                idx = np.arange(write_idx - N_WIN, write_idx) % N_WIN
                win = buf[:, idx]
                a_raw, v_raw, art = compute_features(win)
                ema_A = ema(ema_A, a_raw, EMA_ALPHA)
                ema_V = ema(ema_V, v_raw, EMA_ALPHA)

                A_scaled = squash_tanh(ema_A, A_center, A_mad)
                V_scaled = squash_tanh(ema_V, V_center, V_mad)

                if art == 0:
                    A_out, V_out = A_scaled, V_scaled

                    # Update centers
                    A_center = (1 - ADAPT_ALPHA) * A_center + ADAPT_ALPHA * ema_A
                    V_center = (1 - ADAPT_ALPHA) * V_center + ADAPT_ALPHA * ema_V

                    # Update MADs with floor applied (MODIFIED)
                    A_mad_update = (1 - ADAPT_ALPHA) * A_mad + ADAPT_ALPHA * abs(ema_A - A_center)
                    V_mad_update = (1 - ADAPT_ALPHA) * V_mad + ADAPT_ALPHA * abs(ema_V - V_center)

                    A_mad = max(MAD_MIN_LIMIT, A_mad_update)
                    V_mad = max(MAD_MIN_LIMIT, V_mad_update)

                if A_out is not None and V_out is not None:
                    client.send_message(ADDR_A, float(A_out))
                    client.send_message(ADDR_V, float(V_out))
                client.send_message(ADDR_ART, int(art))

            # --- collect per-hop A,V samples for the current phase (after A/V are ready)
            if A_out is not None and V_out is not None and (art == 0 or mode == "MOCK_ORACLE"):
                if phase == 1 and phase_hop_index >= settle_hops:
                    plus_A.append(A_out);
                    plus_V.append(V_out)
                elif phase == 2 and phase_hop_index >= settle_hops:
                    minus_A.append(A_out);
                    minus_V.append(V_out)

            hop_counter += 1

            # ===== SPSA schedule =====
            if hop_counter % epoch_hops == 1:
                # ensure we have a current A,V
                if mode == "MOCK_ORACLE":
                    A_out, V_out = mock_av_response(current_params)
                    client.send_message(ADDR_A, float(A_out))
                    client.send_message(ADDR_V, float(V_out))

                if A_out is not None and V_out is not None:
                    d0 = distance(A_out, V_out)
                    opt.begin_iteration(d0)
                    p_plus = opt.action_plus()
                    current_params = {"drums": p_plus[0], "pad": p_plus[1], "tempo": p_plus[2], "grain": p_plus[3]}
                    for k, v in current_params.items():
                        client.send_message(f"/music/{k}", float(v))
                    if mode == "MOCK_EEG":
                        src.set_current_params(current_params)
                    if mode == "MOCK_ORACLE":
                        A_out, V_out = mock_av_response(current_params)
                        client.send_message(ADDR_A, float(A_out))
                        client.send_message(ADDR_V, float(V_out))

                # start '+' phase buffers
                plus_A.clear();
                plus_V.clear()
                phase_hop_index = 0
                phase = 1

            if hop_counter % epoch_hops == 0 and phase == 1:
                # reward for '+' phase from averaged distance over same window
                if plus_A and plus_V:
                    d_plus = float(np.mean([distance(a, v) for a, v in zip(plus_A, plus_V)]))
                else:
                    d_plus = opt._prev_distance if opt._prev_distance is not None else 0.0
                opt.observe_reward_plus(d_plus)

                # plot 3s average for '+' phase ONLY (no instantaneous points)
                A_mean_plus = float(np.mean(plus_A)) if plus_A else (A_out if A_out is not None else 0.5)
                V_mean_plus = float(np.mean(plus_V)) if plus_V else (V_out if V_out is not None else 0.5)
                live.update(A_mean_plus, V_mean_plus, current_params)

                # now set '-' params
                p_minus = opt.action_minus()
                current_params = {"drums": p_minus[0], "pad": p_minus[1], "tempo": p_minus[2], "grain": p_minus[3]}
                for k, v in current_params.items():
                    client.send_message(f"/music/{k}", float(v))
                if mode == "MOCK_EEG":
                    src.set_current_params(current_params)
                if mode == "MOCK_ORACLE":
                    A_out, V_out = mock_av_response(current_params)
                    client.send_message(ADDR_A, float(A_out))
                    client.send_message(ADDR_V, float(V_out))

                # prepare '-' phase buffers
                minus_A.clear();
                minus_V.clear()
                phase_hop_index = 0
                phase = 2

            if hop_counter % (2 * epoch_hops) == 0 and phase == 2:
                # reward for '-' phase from averaged distance
                if minus_A and minus_V:
                    d_minus = float(np.mean([distance(a, v) for a, v in zip(minus_A, minus_V)]))
                else:
                    d_minus = opt._prev_distance if opt._prev_distance is not None else 0.0
                opt.observe_reward_minus(d_minus)

                # plot 3s average for '-' phase
                A_mean_minus = float(np.mean(minus_A)) if minus_A else (A_out if A_out is not None else 0.5)
                V_mean_minus = float(np.mean(minus_V)) if minus_V else (V_out if V_out is not None else 0.5)
                live.update(A_mean_minus, V_mean_minus, current_params)

                # apply SPSA update and push new center params
                opt.update()
                current_params = opt.current_params_dict()
                for k, v in current_params.items():
                    client.send_message(f"/music/{k}", float(v))
                if mode == "MOCK_EEG":
                    src.set_current_params(current_params)
                if mode == "MOCK_ORACLE":
                    A_out, V_out = mock_av_response(current_params)
                    client.send_message(ADDR_A, float(A_out))
                    client.send_message(ADDR_V, float(V_out))
                phase = 0

            # status line each second (fix dict_items bug: no nested .items())
            if hop_counter % int(1.0 / HOP_SEC) == 0:
                ps = ", ".join(
                    [f"{k}={v:.2f}" for k, v in current_params.items()]) if 'current_params' in locals() else ""
                a_show = A_out if A_out is not None else float('nan')
                v_show = V_out if V_out is not None else float('nan')
                d_show = distance(0.5 if A_out is None else A_out, 0.5 if V_out is None else V_out)
                print(f"[{mode}] A={a_show:.3f} V={v_show:.3f} dist={d_show:.3f} | {ps}", end="\r")

            # bump per-phase hop counter at end of loop body
            if phase in (1, 2):
                phase_hop_index += 1

    except KeyboardInterrupt:
        print("\nStopped by user.")
        if src is not None and mode in ("EEG_REAL", "MOCK_EEG"):
            src.close()


if __name__ == "__main__":
    main()