In [2]:
import time
import sys
import os
import requests
import numpy as np
import torch
from scipy.signal import welch
import mne

# --- IMPORTS (STRICT) ---
try:
    from brainaccess.utils import acquisition
    from brainaccess.core.eeg_manager import EEGManager
    from emotion_model import AdvancedEmotionModel
except ImportError as e:
    print(f"CRITICAL ERROR: Missing dependency. {e}")
    sys.exit(1)

# --- CONFIGURATION ---
DEVICE_NAME = "BA HALO"
API_URL = "http://localhost:8000/update-emotion"
WINDOW_SECONDS = 4
SAMPLING_RATE = 250
MODEL_PATH = "advanced_eeg_model.pth"

# --- HELPER: DATA PROCESSOR ---
def process_window_for_nn(data, sampling_rate):
    """
    Splits 4s window into 8 chunks -> Returns Tensor for Model
    Returns None if signal is flat/zero.
    """
    # Check for dead signal (all zeros or NaNs)
    if np.all(data == 0) or np.isnan(data).any():
        return None

    # Split into 8 segments (0.5s each)
    n_splits = 8
    try:
        splits = np.array_split(data, n_splits, axis=1)
    except:
        return None

    sequence = []

    for segment in splits:
        avg_data = np.mean(segment, axis=0)
        # Handle empty segments
        if len(avg_data) == 0: return None

        freqs, psd = welch(avg_data, fs=sampling_rate, nperseg=len(avg_data))

        def get_power(low, high):
            mask = (freqs >= low) & (freqs <= high)
            if not np.any(mask): return 0.0
            return np.mean(psd[mask])

        alpha = get_power(8, 13)
        beta  = get_power(13, 30)
        theta = get_power(4, 8)

        sequence.append([alpha, beta, theta])

    # Return Shape: (1, 8, 3)
    return torch.tensor([sequence], dtype=torch.float32)

# --- HELPER: MAPPING ---
def map_to_label(valence, energy):
    if valence >= 0.5 and energy >= 0.5: return "Happy"
    elif valence < 0.5 and energy >= 0.5: return "Stressed"
    elif valence < 0.5 and energy < 0.5: return "Sad"
    else: return "Calm"

def send_to_backend(mood):
    try:
        requests.post(API_URL, json={"label": mood}, timeout=0.2)
    except:
        pass # Backend might be down, keep running

# --- MAIN EXECUTION ---
def main():
    # 1. LOAD MODEL
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Initializing Model on {device}...")

    model = AdvancedEmotionModel(input_features=3).to(device)

    if not os.path.exists(MODEL_PATH):
        print(f"ERROR: Model weights '{MODEL_PATH}' not found.")
        print("Please run the training/weight creation script first.")
        sys.exit(1)

    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
        model.eval()
    except Exception as e:
        print(f"Model Load Error: {e}")
        sys.exit(1)

    # 2. HARDWARE CONNECTION
    print(f"Connecting to {DEVICE_NAME}...")
    eeg = acquisition.EEG()

    with EEGManager() as mgr:
        try:
            # Map channels (Ensure this matches your physical cap setup)
            cap = {0: "Fp1", 1: "Fp2", 2: "O1", 3: "O2"}
            eeg.setup(mgr, device_name=DEVICE_NAME, cap=cap)
            eeg.start_acquisition()
            print(">>> EEG STREAM STARTED <<<")
        except Exception as e:
            print(f"CONNECTION FAILED: {e}")
            print("Check bluetooth/USB and try again.")
            sys.exit(1)

        # 3. REAL-TIME LOOP
        try:
            while True:
                time.sleep(1) # Wait for buffer to fill

                # Fetch MNE Raw Object
                raw = eeg.get_mne(tim=WINDOW_SECONDS)

                # --- SAFETY CHECKS ---
                if raw is None or len(raw) == 0:
                    print("No Data Received -> Sending Neutral")
                    send_to_backend("Neutral")
                    continue

                data = raw.get_data()

                # Check dimensions (Must have enough samples)
                if data.shape[1] < (SAMPLING_RATE * 0.5):
                    continue

                # --- PROCESSING ---
                input_tensor = process_window_for_nn(data, SAMPLING_RATE)

                if input_tensor is None:
                    # Signal is flat/bad
                    print("Bad Signal/Noise -> Sending Neutral")
                    send_to_backend("Neutral")
                    continue

                # --- INFERENCE ---
                input_tensor = input_tensor.to(device)

                with torch.no_grad():
                    v_score, e_score = model(input_tensor)
                    v = v_score.item()
                    e = e_score.item()

                mood = map_to_label(v, e)
                print(f"Signal OK | V:{v:.2f} E:{e:.2f} -> {mood}")

                send_to_backend(mood)

        except KeyboardInterrupt:
            print("\nStopping Stream...")
        finally:
            eeg.stop_acquisition()
            # mgr context manager handles disconnect automatically

if __name__ == "__main__":
    main()

Initializing Model on cpu...
Connecting to BA HALO...


BrainAccessException: Library already initialized