In [2]:
# @title Inner Speech BCI - Tutorial Notebook (Interactive)

# --- Imports ---
import mne
import warnings
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score, classification_report
from ipywidgets import interact, fixed, SelectMultiple, Dropdown, IntSlider
import ipywidgets as widgets

In [None]:
# --- Setup ---
np.random.seed(23)
mne.set_log_level(verbose="warning")
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
warnings.filterwarnings(action="ignore", category=FutureWarning)

# --- Data Processing Functions ---
from Python_Processing.Data_extractions import (
    extract_data_from_subject,
)
from Python_Processing.Data_processing import (
    select_time_window,
    transform_for_classificator,
)

In [4]:
# --- Feature Extraction Function ---
def extract_band_power_features(
    epoch_data, sfreq, bands={"alpha": [8, 12], "beta": [13, 30]}
):
    """Extracts band power features from EEG epoch data (manual band power calculation - no f_bands)."""
    features = []
    for trial_data in epoch_data:  # Iterate through trials
        trial_features = []
        for channel_data in trial_data:  # Iterate through channels
            psds, freqs = mne.time_frequency.psd_array_welch(
                channel_data, sfreq=sfreq, verbose=False
            )  # Calculate full PSD (no f_bands)

            for band_name, band_range in bands.items():  # Iterate through bands
                fmin, fmax = band_range
                band_indices = np.where((freqs >= fmin) & (freqs <= fmax))[
                    0
                ]  # Find indices within the band
                if band_indices.any():  # Check if any frequencies fall within the band
                    band_power = psds[
                        band_indices
                    ].mean()  # Average PSD within the band
                    trial_features.append(band_power)
                else:
                    trial_features.append(
                        0
                    )  # Or handle case where band is outside frequency range

        features.append(trial_features)
    return np.array(features)

In [5]:
# --- Main Interactive Function ---
def run_analysis(
    subject_number,
    selected_condition,  # Changed to Dropdown - single select
    class_choices,
    time_start,
    time_end,
    selected_bands,
    classifier_choice,
):
    """Runs EEG analysis interactively."""
    root_dir = "dataset"
    datatype = "EEG"
    fs = 256
    N_S = subject_number
    t_start = time_start
    t_end = time_end
    Conditions = [
        [selected_condition],
        [selected_condition],
    ]  # Duplicated condition for 2 groups
    Classes = [[c] for c in class_choices]
    bands_dict = selected_bands

    print(
        f"--- Analyzing Subject: {N_S}, Conditions: {Conditions}, Classes: {Classes} ---"
    )

    # --- Data Loading and Preprocessing ---
    X, Y = extract_data_from_subject(root_dir, N_S, datatype)
    X = select_time_window(X=X, t_start=t_start, t_end=t_end, fs=fs)

    print(
        f"DEBUG: Conditions being passed to transform_for_classificator: {Conditions}"
    )  # Debug print
    print(
        f"DEBUG: Classes being passed to transform_for_classificator: {Classes}"
    )  # Debug print

    X, Y = transform_for_classificator(X, Y, Classes, Conditions)

    print("Data shape:", X.shape)
    print("Labels shape:", Y.shape)

    # --- Feature Extraction ---
    features_X = extract_band_power_features(X, fs, bands=bands_dict)
    print("Feature matrix shape:", features_X.shape)

    # --- Data Splitting, Classifier, Evaluation ---
    X_train, X_test, y_train, y_test = train_test_split(
        features_X, Y, test_size=0.2, random_state=42, stratify=Y
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.15, random_state=42, stratify=y_train
    )

    classifier = LinearDiscriminantAnalysis()  # LDA Classifier
    classifier.fit(X_train, y_train)
    y_val_pred = classifier.predict(X_val)
    val_accuracy = accuracy_score(y_val, y_val_pred)

    print(f"\n--- Validation Results ({classifier_choice}) ---")
    print(f"Validation Accuracy: {val_accuracy:.4f}")
    print(
        "\nValidation Classification Report:\n",
        classification_report(y_val, y_val_pred, zero_division=0),
    )

In [6]:
# --- Interactive Widgets ---
subject_widget = IntSlider(min=1, max=10, step=1, value=1, description="Subject Number")
condition_widget = Dropdown(
    options=["Inner", "Pronounced", "Visualized"],
    value="Inner",
    description="Condition",
)
class_widget = SelectMultiple(
    options=["Up", "Down", "Left", "Right"], value=["Up", "Down"], description="Classes"
)
time_start_widget = widgets.FloatSlider(
    min=0, max=4, step=0.1, value=1.5, description="Time Start (s)"
)
time_end_widget = widgets.FloatSlider(
    min=0, max=4.5, step=0.1, value=3.5, description="Time End (s)"
)

band_options = {
    "Alpha (8-12Hz)": {"alpha": [8, 12]},
    "Beta (13-30Hz)": {"beta": [13, 30]},
    "Alpha + Beta": {"alpha": [8, 12], "beta": [13, 30]},
    "Theta (4-7Hz)": {"theta": [4, 7]},
    "Delta (1-3Hz)": {"delta": [1, 3]},
    "Gamma (30-45Hz)": {"gamma": [30, 45]},
    "All Bands": {
        "delta": [1, 3],
        "theta": [4, 7],
        "alpha": [8, 12],
        "beta": [13, 30],
        "gamma": [30, 45],
    },
}
bands_widget = Dropdown(
    options=band_options,
    value=band_options["Alpha + Beta"],
    description="Frequency Bands",
)
classifier_widget = Dropdown(options=["LDA"], value="LDA", description="Classifier")

In [7]:
# --- Interactive Output ---
interact(
    run_analysis,
    subject_number=subject_widget,
    selected_condition=condition_widget,  # Changed to selected_condition (single select)
    class_choices=class_widget,
    time_start=time_start_widget,
    time_end=time_end_widget,
    selected_bands=bands_widget,
    classifier_choice=classifier_widget,
)

print("\n--- Notebook Initialization Complete ---")
print(
    "Use the interactive widgets above to explore the dataset and classifier performance."
)

interactive(children=(IntSlider(value=1, description='Subject Number', max=10, min=1), Dropdown(description='C…


--- Notebook Initialization Complete ---
Use the interactive widgets above to explore the dataset and classifier performance.
