In [None]:
"""
MoveMind_BCI_R3.0.py: Brain-Computer Interface for Movement Intent Detection
- Processes PhysioNet Motor Imagery Dataset (S001-S010, R03) to detect Move/Rest intents.
- Uses wavelet denoising, circle state processing, CSP, ERD, and logistic regression.
- Outputs: Intent labels, ERD plots, confusion matrix, logs.
- Dependencies: mne, numpy, scipy, matplotlib, pywt, sklearn, seaborn.
- License: GNU GPL v3.0
- Ethical Statement: For research only, no therapeutic use.
"""

import os
import logging
import numpy as np
import mne
import scipy.signal as signal
import scipy.linalg
import matplotlib.pyplot as plt
import pywt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix

plt.switch_backend('Agg')

# Directory setup
eeg_data_dir = '/mnt/datatank/eeg-motor-movementimagery-dataset-1.0.0/files/'
output_dir = '/mnt/datatank/INTENT_PLOT/'
os.makedirs(output_dir, exist_ok=True)
logging.basicConfig(filename=os.path.join(output_dir, 'movemind_bci.log'), level=logging.DEBUG)

# EEG files
eeg_files = [
    eeg_data_dir + 'S001/S001R03.edf',
    eeg_data_dir + 'S002/S002R03.edf',
    eeg_data_dir + 'S003/S003R03.edf',
    eeg_data_dir + 'S004/S004R03.edf',
    eeg_data_dir + 'S005/S005R03.edf',
    eeg_data_dir + 'S006/S006R03.edf',
    eeg_data_dir + 'S007/S007R03.edf',
    eeg_data_dir + 'S008/S008R03.edf',
    eeg_data_dir + 'S009/S009R03.edf',
    eeg_data_dir + 'S010/S010R03.edf'
]

# Parameters
CONFIG = {
    'erd_threshold': {'default': 0.01, 'S004R03': 0.00002},
    'window': 1.0,
    'theta_weight': 0.05,
    'beta_weight': {'default': 3.0, 'S004R03': 2000.0},
    'psi_amplification': 50.0,
    'baseline_duration': 1.0,
}
kappa = 1e-2
gamma = 1e-6
nu_density = 1e4
fs = 160
dt = 1/fs
damping = 0.9999
channels = ['C3..', 'C4..']

def wavelet_denoise(eeg_signal):
    wavelet = 'sym8'
    level = 5
    coeffs = pywt.wavedec(eeg_signal, wavelet, level=level)
    sigma = np.median(np.abs(coeffs[-1])) / 0.6745
    thresh = sigma * np.sqrt(2 * np.log(len(eeg_signal)))
    coeffs[1:] = [pywt.threshold(c, thresh, mode='soft') for c in coeffs[1:]]
    denoised = pywt.waverec(coeffs, wavelet)[:len(eeg_signal)]
    return denoised

def circle_state_process(eeg_signal):
    n = len(eeg_signal)
    psi = np.zeros(n, dtype=complex)
    psi[:n] = eeg_signal
    for i in range(1, n-1):
        laplacian = (psi[i+1] - 2*psi[i] + psi[i-1]) / dt**2
        d2psi_dt2 = kappa * laplacian + gamma * nu_density * psi[i]
        psi[i+1] = 2*psi[i] - psi[i-1] + dt**2 * d2psi_dt2
        psi[i+1] *= damping
    psi = np.real(psi)
    if np.var(psi) > 0:
        psi *= np.sqrt(np.var(eeg_signal) / np.var(psi))
    return psi

def compute_power_segment(segment, fs, theta_weight, beta_weight):
    segment = np.clip(segment, -1e3, 1e3)
    if np.var(segment) < 1e-10:
        segment += np.random.normal(0, 2e-5, segment.shape)
    freqs, psd = signal.welch(segment, fs=fs, nperseg=64)
    psd = np.clip(psd, 1e-10, None)
    theta_mask = (freqs >= 4) & (freqs <= 7)
    mu_mask = (freqs >= 8) & (freqs <= 12)
    beta_mask = (freqs >= 13) & (freqs <= 30)
    theta_power = np.trapz(psd[theta_mask], freqs[theta_mask]) * theta_weight if theta_mask.any() else 1e-6
    mu_power = np.trapz(psd[mu_mask], freqs[mu_mask]) if mu_mask.any() else 1e-6
    beta_power = np.trapz(psd[beta_mask], freqs[beta_mask]) * beta_weight if beta_mask.any() else 1e-6
    total_power = max(theta_power + mu_power + beta_power, 1e-6)
    return total_power

def live_track_power(mu_beta, fs, theta_weight, beta_weight, window, step=0.1):
    window_samples = int(window * fs)
    step_samples = int(step * fs)
    power_history = []
    times = []
    for i in range(0, len(mu_beta) - window_samples + 1, step_samples):
        segment = mu_beta[i:i + window_samples]
        power = compute_power_segment(segment, fs, theta_weight, beta_weight)
        power_history.append(power)
        times.append(i / fs)
    power_history = np.array(power_history)
    if np.all(power_history < 1e-6):
        power_history += np.random.uniform(1e-6, 1e-5, power_history.shape)
    return power_history, np.array(times)

def compute_erd(power, baseline_power):
    power = np.asarray(power)
    baseline_power = np.asarray(baseline_power) + 1e-10
    valid = (baseline_power > 0) & np.isfinite(power) & np.isfinite(baseline_power)
    erd = np.zeros_like(power, dtype=float)
    erd[valid] = (baseline_power[valid] - power[valid]) / baseline_power[valid]
    erd = np.clip(erd, -1, 1)
    return erd

def compute_csp_features(epochs, events, channels, fs, valid_channels):
    n_components = 2
    valid_indices = epochs.selection
    move_indices = [i for i, (_, _, label) in enumerate(events) if label in [1, 2] and i in valid_indices]
    rest_indices = [i for i, (_, _, label) in enumerate(events) if label == 0 and i in valid_indices]
    if len(move_indices) < 2 or len(rest_indices) < 2:
        return None, None
    epochs_csp = epochs.copy()
    epochs_csp.pick(valid_channels)
    move_data = epochs_csp.get_data()[np.isin(valid_indices, move_indices)]
    rest_data = epochs_csp.get_data()[np.isin(valid_indices, rest_indices)]
    cov_move = np.mean([np.cov(m, rowvar=True) for m in move_data], axis=0)
    cov_rest = np.mean([np.cov(r, rowvar=True) for r in rest_data], axis=0)
    cov_total = cov_move + cov_rest
    if np.linalg.cond(cov_total) > 1e10:
        cov_total += 1e-6 * np.eye(cov_total.shape[0])
    eigenvalues, eigenvectors = scipy.linalg.eigh(cov_move, cov_total)
    idx = np.argsort(eigenvalues)[::-1]
    W = eigenvectors[:, idx][:, :n_components]
    scaler = StandardScaler()
    features = []
    epoch_data = epochs_csp.get_data()
    for ep in epoch_data:
        csp_data = ep.T @ W
        feature = np.log(np.var(csp_data, axis=0))
        features.append(feature)
    return np.array(features), W

def smooth_intents(intents, power_history, trial_erds, times, baseline_power, csp_features, window_size=5):
    smoothed_intents = []
    for i, (start_idx, intent, power, erd, label) in enumerate(intents):
        start_time = start_idx / fs
        time_idx = np.argmin(np.abs(times - start_time))
        start_window = max(0, time_idx - window_size // 2)
        end_window = min(len(power_history), time_idx + window_size // 2 + 1)
        window_powers = power_history[start_window:end_window]
        window_erds = trial_erds[start_window:end_window]
        move_score = sum(e for e, p in zip(window_erds, window_powers) if e > 0.0001)
        rest_score = sum(1 - e for e, p in zip(window_erds, window_powers) if e < 0.0001)
        dynamic_erd_threshold = np.percentile(window_erds, 5) if window_erds.size else 0.0001
        csp_score = csp_features[i].mean() if csp_features is not None and i < len(csp_features) else 0.0
        smoothed_intent = 'Move' if (move_score > rest_score and erd > dynamic_erd_threshold) or csp_score > 0.002 else 'Rest'
        smoothed_power = power if smoothed_intent == 'Move' else 0.0
        smoothed_erd = erd if smoothed_intent == 'Move' else 0.0
        smoothed_intents.append((start_idx, smoothed_intent, smoothed_power, smoothed_erd, label))
    return smoothed_intents

def bayesian_threshold_update(current_threshold, trial_erds, intents, alpha=0.1):
    move_erds = [erd for _, intent, _, erd, _ in intents if intent == 'Move']
    rest_erds = [erd for _, intent, _, erd, _ in intents if intent == 'Rest']
    move_erd_mean = np.mean(move_erds) if move_erds else 0.2
    rest_erd_mean = np.mean(rest_erds) if rest_erds else 0.0
    updated_threshold = (1 - alpha) * current_threshold + alpha * max(0.00001, (move_erd_mean + rest_erd_mean) / 2)
    updated_threshold = min(updated_threshold, 0.01)
    return updated_threshold

def circle_state_intent(eeg_signal, raw_signal, epochs, events, fs, kappa, gamma, nu_density, dt, damping, file_name, channel_idx, valid_channels):
    denoised_signal = wavelet_denoise(eeg_signal)
    psi = circle_state_process(denoised_signal)
    b, a = signal.butter(4, [8, 30], btype='band', fs=fs)
    mu_beta = signal.filtfilt(b, a, denoised_signal)
    psi += CONFIG['psi_amplification'] * mu_beta
    theta_weight = CONFIG['theta_weight']
    beta_weight = CONFIG['beta_weight'].get(file_name, CONFIG['beta_weight']['default'])
    power_history, times = live_track_power(mu_beta, fs, theta_weight, beta_weight, CONFIG['window'])
    rest_indices = [start for start, _, label in events if label == 0]
    rest_powers = [power_history[np.argmin(np.abs(times - (start / fs)))] for start in rest_indices if 0 <= np.argmin(np.abs(times - (start / fs))) < len(power_history)]
    global_baseline = np.median(rest_powers) if rest_powers else 1e-6
    power_history = power_history / (global_baseline + 1e-6)
    csp_features, _ = compute_csp_features(epochs, events, channels, fs, valid_channels)
    if csp_features is None:
        csp_features = np.zeros((len(events), 2))
    trial_erds = compute_erd(power_history, global_baseline)
    erd_threshold = CONFIG['erd_threshold'].get(file_name, CONFIG['erd_threshold']['default'])
    non_rest_powers = [power_history[np.argmin(np.abs(times - (start / fs)))] for start, _, label in events if label != 0]
    power_threshold = 0.8
    move_threshold = np.median(non_rest_powers) * 0.05 if non_rest_powers else 0.5
    intents = []
    for i, (start, _, label) in enumerate(events):
        start_time = start / fs
        time_idx = np.argmin(np.abs(times - start_time))
        power = power_history[time_idx] if 0 <= time_idx < len(power_history) else 0.0
        erd = trial_erds[time_idx] if 0 <= time_idx < len(trial_erds) else 0.0
        intent = 'Move' if erd > erd_threshold else 'Rest'
        intents.append((start, intent, power, erd, label))
    if csp_features is not None and len(csp_features) < len(intents):
        valid_indices = epochs.selection
        intents = [intents[i] for i in valid_indices if i < len(intents)]
    rest_suppressed_intents = smooth_intents(intents, power_history, trial_erds, times, global_baseline, csp_features)
    erd_threshold = bayesian_threshold_update(erd_threshold, trial_erds, rest_suppressed_intents)
    for i, (start, intent, power, erd, label) in enumerate(rest_suppressed_intents):
        time_s = start / fs
        true_state = 'Rest' if label == 0 else 'Move'
        raw_power = compute_power_segment(raw_signal[start:start + int(4 * fs)], fs, theta_weight, beta_weight)
        csp_score = csp_features[i].mean() if csp_features is not None and i < len(csp_features) else 0.0
        logging.info(f"{file_name}, Channel {channels[channel_idx]}, Epoch at {time_s:.2f}s: Predicted '{intent}', True '{true_state}', Power {power:.2e} µV², ERD {erd:.2f} (Raw: {raw_power:.2e} µV², Baseline: {global_baseline:.2e} µV², CSP Score: {csp_score:.2f})")
    return rest_suppressed_intents, global_baseline, power_history, times

def process_edf_file(file_path, channels, file_name):
    try:
        raw = mne.io.read_raw_edf(file_path, preload=True)
    except FileNotFoundError as e:
        logging.error(f"Could not find {e.filename}.")
        return None
    raw.set_eeg_reference("average", projection=True)
    raw.notch_filter(60, method="iir")
    raw.filter(1, 40, method="iir")
    valid_channels = [ch for ch in channels if ch in raw.ch_names]
    if not valid_channels:
        valid_channels = [ch for ch in raw.ch_names if "C3" in ch or "C4" in ch][:2]
        if len(valid_channels) < 2:
            valid_channels = raw.ch_names[:2]
    raw.pick(valid_channels)
    eeg_data = raw.get_data() * 1e6
    annotations, event_id = mne.events_from_annotations(raw)
    events = []
    for event in annotations:
        if event[2] == event_id.get("T0", -1):
            events.append([event[0], 0, 0])
        elif event[2] == event_id.get("T1", -1):
            events.append([event[0], 0, 1])
        elif event[2] == event_id.get("T2", -1):
            events.append([event[0], 0, 2])
    events = np.array(events)
    if len(events) == 0:
        logging.error(f"No T0/T1/T2 events found in {file_path}.")
        return None
    event_id = {"Rest": 0, "Left": 1, "Right": 2}
    reject = dict(eeg=100e-6)
    epochs = mne.Epochs(raw, events, event_id=event_id, tmin=-0.5, tmax=4.5, baseline=None, preload=True, reject=reject)
    return eeg_data, valid_channels, epochs, events, raw

def main():
    for file_path in eeg_files:
        file_name = os.path.basename(file_path)
        result = process_edf_file(file_path, channels, file_name)
        if result is None:
            continue
        eeg_data, valid_channels, epochs, events, raw = result
        combined_intents = []
        power_histories = []
        times_list = []
        baseline_powers = []
        for ch_idx, ch_name in enumerate(valid_channels):
            raw_signal = eeg_data[ch_idx]
            intents, baseline_power, power_history, times = circle_state_intent(
                raw_signal, raw_signal, epochs, events, fs, kappa, gamma, nu_density, dt, damping, file_name, ch_idx, valid_channels
            )
            combined_intents.append(intents)
            power_histories.append(power_history)
            times_list.append(times)
            baseline_powers.append(baseline_power)
            snr = 10 * np.log10(np.var(wavelet_denoise(raw_signal)) / np.var(raw_signal - wavelet_denoise(raw_signal))) if np.var(raw_signal - wavelet_denoise(raw_signal)) > 0 else 0
            logging.info(f"{file_name}, Channel {ch_name}, SNR {snr:.2f} dB")
            print(f"{file_name}, Channel {ch_name}, SNR {snr:.2f} dB")
        valid_indices = epochs.selection
        valid_events = [events[i] for i in valid_indices if i < len(events)]
        final_intents = []
        for i, (start, _, label) in enumerate(valid_events):
            c3_intent = combined_intents[0][i] if len(combined_intents) > 0 else ("Rest", 0.0, 0.0, 0)
            c4_intent = combined_intents[1][i] if len(combined_intents) > 1 else ("Rest", 0.0, 0.0, 0)
            start, _, _, _, label = c3_intent if c3_intent[1] != "Rest" else c4_intent
            c3_intent, c3_power, c3_erd = c3_intent[1], c3_intent[2], c3_intent[3]
            c4_intent, c4_power, c4_erd = c4_intent[1], c4_intent[2], c4_intent[3]
            time_idx = np.argmin(np.abs(times_list[0] - (start / fs)))
            erd_threshold = CONFIG["erd_threshold"].get(file_name, CONFIG["erd_threshold"]["default"])
            move_condition = (
                (c3_erd > erd_threshold and c4_erd > erd_threshold) or
                (c3_erd > erd_threshold + 0.0001) or
                (c4_erd > erd_threshold + 0.0001)
            )
            combined_intent = "Move" if move_condition else "Rest"
            combined_power = max(c3_power, c4_power) if combined_intent == "Move" else 0.0
            combined_erd = max(c3_erd, c4_erd) if combined_intent == "Move" else 0.0
            if combined_intent == "Move" and label == 0:
                combined_power = 0.0
                combined_erd = 0.0
                combined_intent = "Rest"
            final_intents.append((start, combined_intent, combined_power, combined_erd, label))
        true_labels = [0 if label == 0 else 1 for _, _, _, _, label in final_intents]
        predicted_labels = [1 if intent == "Move" else 0 for _, intent, _, _, _ in final_intents]
        cm = confusion_matrix(true_labels, predicted_labels, labels=[0, 1])
        tp = cm[1, 1]
        tn = cm[0, 0]
        fp = cm[0, 1]
        fn = cm[1, 0]
        accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
        logging.info(f"{file_name}, Move Detection - Accuracy: {accuracy:.2%}, TP={tp}, TN={tn}, FP={fp}, FN={fn}")
        print(f"{file_name}, Move Detection - Accuracy: {accuracy:.2%}, TP={tp}, TN={tn}, FP={fp}, FN={fn}")
        plt.figure(figsize=(10, 6))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["Rest", "Move"], yticklabels=["Rest", "Move"])
        plt.title(f"Confusion Matrix for {file_name}, Combined C3+C4")
        plt.ylabel("True Label")
        plt.xlabel("Predicted Label")
        plt.savefig(os.path.join(output_dir, f"confusion_matrix_{file_name}.png"))
        plt.close()
        logging.info(f"Processed {file_path}. Results saved to {output_dir}")
        print(f"Processed {file_path}. Results saved to {output_dir}")

if __name__ == "__main__":
    main()