BCI Motor Imagery Classification Demo
====================================

This script demonstrates a complete pipeline for motor imagery classification:
1. Data loading and visualization
2. Preprocessing (filtering, epoching, artifact handling)
3. Feature extraction with Common Spatial Patterns (CSP)
4. Classification with multiple algorithms (LDA, SVM, kNN, Random Forest)
5. Evaluation with cross-validation and visualization
6. Real-time simulation with rest state detection

Based on the MNE motor imagery dataset and extending the example from:
https://mne.tools/stable/auto_examples/decoding/decoding_csp_eeg.html


# PhysioNet EEG Motor Movement/Imagery Dataset Guide

## Database Description

The PhysioNet EEG Motor Movement/Imagery Dataset contains EEG recordings from 109 subjects performing motor execution and motor imagery tasks. Each subject performed 14 experimental runs with different movement paradigms:

### Runs and Tasks
- **Runs 1-2**: Baseline tasks (eyes open/closed)
- **Runs 3-7**: Motor execution (actually performing movements)
- **Runs 8-14**: Motor imagery (imagining movements)

### Motor Paradigms
The dataset includes different movement types, coded by event markers:
- **T0**: Rest (baseline)
- **T1**: 
  - Left fist movement/imagery (runs 3, 4, 7, 8, 11, 12)
  - Both fists movement/imagery (runs 5, 6, 9, 10, 13, 14)
- **T2**: 
  - Right fist movement/imagery (runs 3, 4, 7, 8, 11, 12)
  - Both feet movement/imagery (runs 5, 6, 9, 10, 13, 14)

### Recording Setup
- 64-channel EEG using the international 10-10 system
- Sampling rate: 160 Hz
- Recordings stored in EDF+ format

## Modifying the Analysis

To experiment with different aspects of the dataset, you can modify the following parameters:

### 1. To change subjects:
```python
# Change the subject number (1-109)
subject = 88  # Subject 88 often shows clear motor imagery patterns
```

### 2. To switch between left/right hand imagery:
```python
# Use these runs for left vs. right hand motor imagery
runs = [4, 8, 12]  # Motor imagery: left hand vs right hand

# Event labels
raw.annotations.rename(dict(T1="left_hand", T2="right_hand"))
```

### 3. To switch to hands/feet imagery:
```python
# Use these runs for both hands vs both feet motor imagery
runs = [6, 10, 14]  # Motor imagery: both hands vs both feet

# Event labels
raw.annotations.rename(dict(T1="both_hands", T2="both_feet"))
```

### 4. To experiment with different frequency bands:
```python
# Change the filter bands
raw.filter(8.0, 12.0, fir_design="firwin")  # Mu rhythm only
# or
raw.filter(13.0, 30.0, fir_design="firwin")  # Beta rhythm only
```

### 5. To modify CSP components:
```python
# Adjust number of CSP components (more components = more features)
csp = CSP(n_components=6, reg=None, log=True, norm_trace=False)
```

## Subject Variability

There is significant variability in how clearly different subjects display motor imagery patterns. Some subjects (like 88, 8, 84, and 100) typically show more pronounced and classifiable patterns, while others may show minimal differences between conditions (the "BCI illiteracy" phenomenon).

Recommended experiment: Run the same analysis on 3-4 different subjects and compare:
1. The CSP patterns between subjects
2. The classification accuracy across subjects
3. The time course of classification performance

This will provide insight into the individual differences that impact BCI performance and the challenges of creating subject-independent BCI systems.

In [295]:
%matplotlib qt


import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.pipeline import Pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit, cross_val_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier

import mne
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import tfr_morlet


# Set random seed for reproducibility
np.random.seed(42)

# Set visualization style
#plt.style.use('ggplot')
#plt.rcParams.update({'font.size': 14})


In [296]:


# Define helper functions
def plot_confusion_matrix(y_true, y_pred, classes, title):
    """Plot confusion matrix."""
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    fig, ax = plt.subplots(figsize=(8, 8))
    disp.plot(ax=ax, values_format='.2f', xticks_rotation=45)
    plt.title(title)
    plt.tight_layout()
    return fig

def plot_feature_histograms(features, labels, class_names, feature_names=None):
    """Plot histograms of features by class."""
    if feature_names is None:
        feature_names = [f'CSP Component {i+1}' for i in range(features.shape[1])]
    
    n_features = min(6, features.shape[1])  # Show at most 6 features
    fig, axes = plt.subplots(n_features, 1, figsize=(10, 2*n_features), sharex=True)
    
    for i in range(n_features):
        ax = axes[i] if n_features > 1 else axes
        for label_idx, label in enumerate(np.unique(labels)):
            label_data = features[labels == label, i]
            sns.histplot(label_data, label=class_names[label_idx], 
                         kde=True, ax=ax, alpha=0.5)
        ax.set_title(feature_names[i])
        ax.legend()
    
    plt.tight_layout()
    return fig

def sliding_window_analysis(epochs_data, epochs_data_train, labels, cv_splits,
                           window_length=0.5, window_step=0.1):
    """
    Perform sliding window analysis to simulate real-time BCI performance.
    """
    sfreq = epochs.info['sfreq']
    w_length = int(sfreq * window_length)  # window length in samples
    w_step = int(sfreq * window_step)  # window step in samples
    w_start = np.arange(0, epochs_data.shape[2] - w_length, w_step)
    
    scores_windows = []
    
    # Create a fresh classifier and CSP for each CV split
    lda = LinearDiscriminantAnalysis()
    csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)
    
    for train_idx, test_idx in cv_splits:
        y_train, y_test = labels[train_idx], labels[test_idx]
        
        # Train on the specific training window (1-2s)
        X_train = csp.fit_transform(epochs_data_train[train_idx], y_train)
        
        # Fit classifier
        lda.fit(X_train, y_train)
        
        # Test on sliding windows
        score_this_window = []
        
        for n in w_start:
            X_test = csp.transform(epochs_data[test_idx][:, :, n:(n + w_length)])
            window_score = lda.score(X_test, y_test)
            score_this_window.append(window_score)
            
        scores_windows.append(score_this_window)
    
    # Calculate window time points relative to epoch start
    w_times = (w_start + w_length / 2.0) / sfreq + epochs.tmin
    
    return w_times, np.array(scores_windows)

### Part 1 - Loading the data from the dataset

In [297]:

print("BCI Motor Imagery Classification Demo")
print("=====================================")
print("\nPart 1: Data Loading and Preprocessing")
print("-----------------------------------")

# Define subjects and runs
subject = 1  # We'll use data from subject 1

# Motor imagery: left hand vs right hand (runs 4, 8, 12)
runs = [4, 8, 12]  

# Motor imagery: both hands vs both feet (runs 6, 10, 14)
#runs = [6, 10, 14]  

# Load the data
print(f"Loading data for subject {subject}, runs {runs}...")
raw_fnames = eegbci.load_data(subject, runs)
raw = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames])

# Standardize channel names
eegbci.standardize(raw)  # This ensures consistent channel naming

# Apply standard montage
print("Setting up electrode montage...")
montage = make_standard_montage("standard_1005")
raw.set_montage(montage)

# Rename events for better clarity
raw.annotations.rename(dict(T1="left_hand", T2="right_hand"))  # as documented on PhysioNet

# Set EEG reference
raw.set_eeg_reference(projection=True)  # Apply average reference

# Print information about the dataset
info = raw.info
print(f"\nDataset information:")
print(f"- {len(info['ch_names'])} channels")
print(f"- Sampling frequency: {info['sfreq']} Hz")
print(f"- Recording length: {raw.times[-1]:.1f} seconds")


BCI Motor Imagery Classification Demo

Part 1: Data Loading and Preprocessing
-----------------------------------
Loading data for subject 1, runs [4, 8, 12]...
Extracting EDF parameters from /home/koutras/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/koutras/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R08.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Extracting EDF parameters from /home/koutras/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R12.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Setting up electrode montage...
EEG channel type selected for re-referencing
Adding avera

### Data Visualization

In [298]:

print("\nPart 2: Data Visualization and Preprocessing")
print("------------------------------------------")

# Plot raw data (first 10 seconds)
fig_raw = raw.plot(duration=10, n_channels=20, title='Raw EEG Data (First 10s)')

# Plot power spectral density
fig_psd = raw.plot_psd(fmax=40, average=True)

# Apply bandpass filter
print("Applying bandpass filter (7-30 Hz)...")
raw.filter(8.0, 30.0, fir_design="firwin", skip_by_annotation="edge")

# Select EEG channels
picks = pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, exclude="bads")

# Create epochs
print("Creating epochs...")
tmin, tmax = -1.0, 4.0  # Start 1s before event, end 4s after event
epochs = Epochs(
    raw,
    event_id=["left_hand", "right_hand"],
    tmin=tmin,
    tmax=tmax,
    proj=True,
    picks=picks,
    baseline=None,
    preload=True,
)

# Create a cropped version for training (1-2s after cue)
epochs_train = epochs.copy().crop(tmin=1.0, tmax=2.0)

# Plot average evoked responses for each class
fig_evoked = mne.viz.plot_compare_evokeds(
    dict(left_hand=epochs['left_hand'].average(), right_hand=epochs['right_hand'].average()), 
    picks=['C3', 'Cz', 'C4'],  # Central electrodes over motor cortex
    title='Average ERP at Motor Cortex Electrodes'
)



Part 2: Data Visualization and Preprocessing
------------------------------------------


qt.core.qobject.connect: QObject::connect(QStyleHints, QStyleHints): unique connections require a pointer to member function of a QObject subclass


NOTE: plot_psd() is a legacy function. New code should use .compute_psd().plot().
Effective window size : 12.800 (s)
Plotting power spectral density (dB=True).
Applying bandpass filter (7-30 Hz)...
Filtering raw data in 3 contiguous segments
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 265 samples (1.656 s)



[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Creating epochs...
Used Annotations descriptions: [np.str_('T0'), np.str_('left_hand'), np.str_('right_hand')]
Ignoring annotation durations and creating fixed-duration epochs around annotation onsets.
Not setting metadata
45 matching events found
No baseline correction applied
Created an SSP operator (subspace dimension = 1)
1 projection items activated
Using data from preloaded Raw for 45 events and 801 original time points ...
0 bad epochs dropped
combining channels using GFP (eeg channels)
combining channels using GFP (eeg channels)


### Part 3 - Time-Frequency Analysis

In [299]:

print("\nPart 3: Time-Frequency Analysis (ERD/ERS)")
print("----------------------------------------")

# Calculate time-frequency representation to visualize ERD/ERS
print("Computing time-frequency representations...")
freqs = np.arange(8, 30, 2)  # Frequencies from 8 to 30 Hz
n_cycles = freqs / 2  # Different number of cycles per frequency

# Compute TFR for each class
tfr_left = tfr_morlet(epochs['left_hand'], freqs, n_cycles, return_itc=False, average=True)
tfr_right = tfr_morlet(epochs['right_hand'], freqs, n_cycles, return_itc=False, average=True)

# Plot TFR for key channels
fig_tfr, axes = plt.subplots(2, 3, figsize=(15, 8))
plt.suptitle('Time-Frequency Analysis of Motor Imagery')

# Plot left hand imagery
tfr_left.plot([epochs.ch_names.index(ch) for ch in ['C3', 'Cz', 'C4']], 
              baseline=(-1, 0), mode='percent', axes=axes[0], 
              colorbar=False, show=False)

# Plot right hand imagery
tfr_right.plot([epochs.ch_names.index(ch) for ch in ['C3', 'Cz', 'C4']], 
             baseline=(-1, 0), mode='percent', axes=axes[1],
             colorbar=False, show=False)

# Add correct titles manually
titles = [['Left Hand - C3 (Left Hem)', 'Left Hand - Cz', 'Left Hand - C4 (Right Hem)'],
          ['Right Hand - C3 (Left Hem)', 'Right Hand - Cz', 'Right Hand - C4 (Right Hem)']]

for i in range(2):
    for j in range(3):
        axes[i, j].set_title(titles[i][j])

plt.tight_layout()


Part 3: Time-Frequency Analysis (ERD/ERS)
----------------------------------------
Computing time-frequency representations...
NOTE: tfr_morlet() is a legacy function. New code should use .compute_tfr(method="morlet").


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.5s


NOTE: tfr_morlet() is a legacy function. New code should use .compute_tfr(method="morlet").


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.4s


Applying baseline correction (mode: percent)
Applying baseline correction (mode: percent)


### Part 4 - Feature extraction with Common Spatial Patterns

In [None]:

print("\nPart 4: Feature Extraction with CSP")
print("---------------------------------")

# Prepare data for classification
epochs_data = epochs.get_data()
epochs_data_train = epochs_train.get_data()
labels = epochs.events[:, -1] - 2  # Convert to 0/1 for left/right

# Create class labels for visualization
class_names = ["Left Hand", "Right Hand"]

# Set up CSP
csp = CSP(n_components=4, reg=None, log=True, norm_trace=False)

# Fit CSP on all data for pattern visualization
csp.fit_transform(epochs_data, labels)

# Plot CSP patterns
print("Plotting CSP patterns...")
fig_patterns = csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5)

# Extract CSP features from training window (1-2s)
X_csp = csp.transform(epochs_data_train)

# Plot feature distributions
print("Plotting feature distributions...")
fig_features = plot_feature_histograms(X_csp, labels, class_names)



Part 4: Feature Extraction with CSP
---------------------------------
Computing rank from data with rank=None


    Using tolerance 0.00016 (2.2e-16 eps * 64 dim * 1.1e+10  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=0 covariance using EMPIRICAL
Done.
Estimating class=1 covariance using EMPIRICAL
Done.
Plotting CSP patterns...
Plotting feature distributions...


### Part 5 - Machine Learning Classification using different classifiers

In [301]:

print("\nPart 5: Machine Learning Classification")
print("------------------------------------")

# Set up cross-validation
cv = ShuffleSplit(10, test_size=0.2, random_state=42)
cv_split = list(cv.split(epochs_data_train))

# Initialize classifiers
classifiers = {
    'LDA': LinearDiscriminantAnalysis(),
    'SVM': SVC(kernel='linear'),
    'k-NN': KNeighborsClassifier(n_neighbors=5),
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42)
}

# Evaluate each classifier
print("Classifier performance (cross-validation):")
cv_results = {}

for name, clf in classifiers.items():
    # Create a pipeline with CSP and the classifier
    pipeline = Pipeline([("CSP", CSP(n_components=4, reg=None, log=True, norm_trace=False)), (name, clf)])
    
    # Evaluate with cross-validation
    scores = cross_val_score(pipeline, epochs_data_train, labels, cv=cv)
    cv_results[name] = scores
    
    print(f"- {name}: {scores.mean():.3f} ± {scores.std():.3f}")

# Plot CV results
cv_means = [np.mean(scores) for scores in cv_results.values()]
cv_stds = [np.std(scores) for scores in cv_results.values()]

fig_cv = plt.figure(figsize=(10, 6))
plt.bar(range(len(classifiers)), cv_means, yerr=cv_stds, tick_label=list(classifiers.keys()))
plt.title('Cross-Validation Accuracy by Classifier')
plt.ylabel('Accuracy')
plt.axhline(0.5, color='k', linestyle='--', label='Chance level')
plt.ylim(0.4, 1.0)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)





Part 5: Machine Learning Classification
------------------------------------
Classifier performance (cross-validation):
Computing rank from data with rank=None
    Using tolerance 7.2e-05 (2.2e-16 eps * 64 dim * 5e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=0 covariance using EMPIRICAL
Done.
Estimating class=1 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 7.2e-05 (2.2e-16 eps * 64 dim * 5.1e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=0 covariance using EMPIRICAL
Done.
Estimating class=1 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using toleran

### Real time Classifier Simulation

In [302]:


print("\nPart 6: Real-Time Classification Simulation")
print("----------------------------------------")

# Use the best classifier for real-time simulation
best_clf_name = max(cv_results.keys(), key=lambda k: np.mean(cv_results[k]))
#best_clf_name="LDA"
print(f"Using {best_clf_name} for real-time simulation")

best_clf = classifiers[best_clf_name]

# Perform sliding window analysis
print("Simulating real-time classification with sliding window...")
w_times, scores_windows = sliding_window_analysis(
    epochs_data, epochs_data_train, labels, cv_split, 
    window_length=0.5, window_step=0.1
)

# Plot results
fig_time = plt.figure(figsize=(12, 6))
plt.plot(w_times, np.mean(scores_windows, 0), label="Classification Score")
plt.axvline(0, linestyle="--", color="k", label="Cue Onset")
plt.axhline(0.5, linestyle="-", color="k", label="Chance Level")
plt.axvspan(1.0, 2.0, color='lightgray', alpha=0.3, label="Training Window")
plt.xlabel("Time (s)")
plt.ylabel("Classification Accuracy")
plt.title("Real-time Classification Performance Over Time")
plt.legend(loc="lower right")
plt.grid(True, linestyle='--', alpha=0.7)


Part 6: Real-Time Classification Simulation
----------------------------------------
Using k-NN for real-time simulation
Simulating real-time classification with sliding window...
Computing rank from data with rank=None
    Using tolerance 7.2e-05 (2.2e-16 eps * 64 dim * 5e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=0 covariance using EMPIRICAL
Done.
Estimating class=1 covariance using EMPIRICAL
Done.
Computing rank from data with rank=None
    Using tolerance 7.2e-05 (2.2e-16 eps * 64 dim * 5.1e+09  max singular value)
    Estimated rank (data): 63
    data: rank 63 computed from 64 data channels with 0 projectors
    Setting small data eigenvalues to zero (without PCA)
Reducing data rank from 64 -> 63
Estimating class=0 covariance using EMPIRICAL
Done.
Estimating class=1 covariance using EMPIRICAL
Don

### Summary of Results

In [303]:


print("\nPart 7: Summary of Results")
print("------------------------")

# Calculate the chance level
class_balance = np.mean(labels == labels[0])
class_balance = max(class_balance, 1.0 - class_balance)

print(f"Classification results:")
print(f"- Best classifier: {best_clf_name}")
print(f"- Peak accuracy: {np.max(np.mean(scores_windows, 0)):.3f}")
print(f"- Chance level: {class_balance:.3f}")
print(f"- Training window accuracy: {np.mean(cv_results[best_clf_name]):.3f}")

# Compute peak time
peak_time_idx = np.argmax(np.mean(scores_windows, 0))
peak_time = w_times[peak_time_idx]
print(f"- Peak performance time: {peak_time:.2f}s")



Part 7: Summary of Results
------------------------
Classification results:
- Best classifier: k-NN
- Peak accuracy: 0.656
- Chance level: 0.511
- Training window accuracy: 0.689
- Peak performance time: 1.65s


Channels marked as bad:
none


  sig.disconnect()
