# EEG Motor Imagery Classification & Neural State Prediction

**Author:** Alexy Louis  
**Project:** Brain-Computer Interface (BCI) Signal Analysis

---

## Overview

This notebook demonstrates a complete pipeline for analyzing EEG signals from motor imagery tasks. We will:

1. **Load and explore** the PhysioNet EEG Motor Movement/Imagery Dataset
2. **Preprocess** the signals (filtering, artifact removal, re-referencing)
3. **Extract features** (time-domain, frequency-domain, and spatial features)
4. **Classify** motor imagery tasks using classical ML and deep learning
5. **Predict** neural states using sequence models
6. **Visualize** brain activity patterns

### Dataset

- **Source:** PhysioNet EEG Motor Movement/Imagery Dataset
- **Subjects:** 109 volunteers (we use 10 for this demo)
- **Channels:** 64 EEG electrodes (10-20 system)
- **Tasks:** Motor imagery (left/right hand), Real movement, Rest
- **Sampling Rate:** 160 Hz

In [None]:
# Core imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# EEG analysis
import mne
from mne.datasets import eegbci
from mne.io import read_raw_edf, concatenate_raws
from mne.decoding import CSP

# ML imports
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# Set up paths
import sys
sys.path.append('..')

# Local modules
from src.preprocessing import (
    load_subject_data, set_montage, apply_filters,
    set_reference, create_epochs, preprocess_pipeline
)
from src.features import (
    extract_time_features, extract_band_power,
    extract_frequency_features, fit_csp,
    extract_features_from_epochs, scale_features, FREQ_BANDS
)

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
mne.set_log_level('WARNING')

print(f"MNE version: {mne.__version__}")
print(f"NumPy version: {np.__version__}")

---

## 1. Data Loading and Exploration

### 1.1 Load EEG Data

We'll load data from the PhysioNet EEG Motor Movement/Imagery Dataset. The dataset includes:
- **Runs 4, 8, 12:** Motor imagery (left/right hand)
- **Runs 3, 7, 11:** Real movement (left/right hand)

In [None]:
# Configuration
SUBJECT = 1  # Subject ID (1-109)
RUNS_IMAGERY = [4, 8, 12]  # Motor imagery runs
RUNS_MOVEMENT = [3, 7, 11]  # Real movement runs

# Load motor imagery data
print(f"Loading data for Subject {SUBJECT}...")
raw = load_subject_data(SUBJECT, RUNS_IMAGERY)

print(f"\nData loaded successfully!")
print(f"  Channels: {len(raw.ch_names)}")
print(f"  Sampling rate: {raw.info['sfreq']} Hz")
print(f"  Duration: {raw.times[-1]:.1f} seconds")
print(f"  Channel names: {raw.ch_names[:10]}...")

In [None]:
# Display raw data info
print(raw.info)

### 1.2 Visualize Raw Signals

In [None]:
# Plot a segment of raw data
fig, axes = plt.subplots(8, 1, figsize=(14, 10), sharex=True)

# Select channels of interest (sensorimotor cortex)
channels = ['Fc3', 'Fc4', 'C3', 'Cz', 'C4', 'Cp3', 'Cp4', 'Pz']
channels = [ch for ch in channels if ch in raw.ch_names][:8]

start_time = 10  # Start at 10 seconds
duration = 5  # 5 seconds
sfreq = raw.info['sfreq']

start_idx = int(start_time * sfreq)
end_idx = int((start_time + duration) * sfreq)
times = np.arange(start_idx, end_idx) / sfreq

for i, (ax, ch) in enumerate(zip(axes, channels)):
    ch_idx = raw.ch_names.index(ch)
    data = raw.get_data()[ch_idx, start_idx:end_idx] * 1e6  # Convert to microvolts
    ax.plot(times, data, 'b-', linewidth=0.5)
    ax.set_ylabel(ch, rotation=0, ha='right', fontsize=10)
    ax.set_ylim([-100, 100])
    if i < len(channels) - 1:
        ax.set_xticklabels([])

axes[-1].set_xlabel('Time (s)')
axes[0].set_title('Raw EEG Signals (Sensorimotor Cortex)', fontsize=12)
plt.tight_layout()
plt.savefig('../images/raw_signals.png', dpi=150, bbox_inches='tight')
plt.show()

### 1.3 Explore Events/Annotations

The dataset uses annotations to mark task events:
- **T0:** Rest
- **T1:** Left hand motor imagery/movement
- **T2:** Right hand motor imagery/movement

In [None]:
# Extract events from annotations
events, event_id = mne.events_from_annotations(raw)

print("Event mapping:")
for name, code in event_id.items():
    count = np.sum(events[:, 2] == code)
    print(f"  {name}: Code {code}, Count: {count}")

In [None]:
# Visualize event timeline
fig, ax = plt.subplots(figsize=(14, 3))

colors = {'T0': 'gray', 'T1': 'green', 'T2': 'red'}
for name, code in event_id.items():
    event_times = events[events[:, 2] == code, 0] / sfreq
    ax.scatter(event_times, [code] * len(event_times), 
               c=colors.get(name, 'blue'), label=name, s=50, alpha=0.7)

ax.set_xlabel('Time (s)')
ax.set_ylabel('Event Code')
ax.set_title('Event Timeline')
ax.legend()
plt.tight_layout()
plt.show()

---

## 2. Preprocessing Pipeline

EEG signals require careful preprocessing to remove noise and artifacts:

1. **Set electrode montage** for spatial information
2. **Bandpass filter** (1-40 Hz) to remove drift and high-frequency noise
3. **Re-reference** to average reference
4. **Create epochs** around events

In [None]:
# Set up montage
montage = mne.channels.make_standard_montage('standard_1005')
raw.set_montage(montage)

# Plot sensor positions
fig = raw.plot_sensors(show_names=True, sphere='auto')
plt.title('EEG Electrode Positions (10-20 System)')
plt.savefig('../images/electrode_positions.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Apply bandpass filter
print("Applying bandpass filter (1-40 Hz)...")
raw_filtered = raw.copy()
raw_filtered.filter(l_freq=1.0, h_freq=40.0, fir_design='firwin')

# Compare PSD before and after filtering
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

# Before filtering
raw.compute_psd(fmax=80).plot(axes=axes[0], show=False)
axes[0].set_title('Before Filtering')

# After filtering
raw_filtered.compute_psd(fmax=80).plot(axes=axes[1], show=False)
axes[1].set_title('After Filtering (1-40 Hz)')

plt.tight_layout()
plt.savefig('../images/filtering_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Set average reference
print("Setting average reference...")
raw_filtered.set_eeg_reference('average', projection=False)

In [None]:
# Create epochs
print("Creating epochs...")

# Event IDs for motor imagery
event_id_motor = {'T1': 2, 'T2': 3}  # Left and Right hand

epochs = mne.Epochs(
    raw_filtered,
    events,
    event_id=event_id_motor,
    tmin=-0.5,  # 0.5s before event
    tmax=4.0,   # 4s after event
    baseline=(-0.5, 0),
    preload=True,
    picks='eeg'
)

print(f"\nCreated {len(epochs)} epochs")
print(f"  Left hand (T1): {len(epochs['T1'])} epochs")
print(f"  Right hand (T2): {len(epochs['T2'])} epochs")
print(f"  Time window: [{epochs.tmin}, {epochs.tmax}] seconds")
print(f"  Epoch shape: {epochs.get_data().shape}")

---

## 3. EEG Data Visualization

### 3.1 Power Spectral Density (PSD)

In [None]:
# Compare PSD between conditions
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

colors = {'T1': '#2ecc71', 'T2': '#e74c3c'}
labels = {'T1': 'Left Hand', 'T2': 'Right Hand'}

for event_name in ['T1', 'T2']:
    epochs_cond = epochs[event_name]
    spectrum = epochs_cond.compute_psd(method='welch', fmin=1, fmax=40)
    psds, freqs = spectrum.get_data(return_freqs=True)
    
    # Average over epochs and channels
    psds_mean = psds.mean(axis=(0, 1)) * 1e12
    psds_std = psds.std(axis=(0, 1)) * 1e12
    
    # Linear scale
    axes[0].plot(freqs, psds_mean, color=colors[event_name], 
                 label=labels[event_name], linewidth=2)
    axes[0].fill_between(freqs, psds_mean - psds_std, psds_mean + psds_std,
                         color=colors[event_name], alpha=0.2)
    
    # Log scale
    axes[1].semilogy(freqs, psds_mean, color=colors[event_name],
                     label=labels[event_name], linewidth=2)

# Add frequency band annotations
bands = {'Delta': (1, 4), 'Theta': (4, 8), 'Alpha': (8, 13), 'Beta': (13, 30)}
for ax in axes:
    for band_name, (fmin, fmax) in bands.items():
        ax.axvspan(fmin, fmax, alpha=0.1, color='gray')

axes[0].set_xlabel('Frequency (Hz)')
axes[0].set_ylabel('Power (uV^2/Hz)')
axes[0].set_title('PSD - Linear Scale')
axes[0].legend()

axes[1].set_xlabel('Frequency (Hz)')
axes[1].set_ylabel('Power (uV^2/Hz)')
axes[1].set_title('PSD - Log Scale')
axes[1].legend()

plt.tight_layout()
plt.savefig('../images/psd_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

### 3.2 Event-Related Potentials (ERPs)

In [None]:
# Plot ERPs for motor cortex channels
motor_channels = ['C3', 'Cz', 'C4']
motor_channels = [ch for ch in motor_channels if ch in epochs.ch_names]

fig, axes = plt.subplots(1, len(motor_channels), figsize=(14, 4))
times = epochs.times

for ax, ch_name in zip(axes, motor_channels):
    ch_idx = epochs.ch_names.index(ch_name)
    
    for event_name in ['T1', 'T2']:
        epochs_cond = epochs[event_name]
        data = epochs_cond.get_data()[:, ch_idx, :]
        
        mean = data.mean(axis=0) * 1e6
        sem = data.std(axis=0) / np.sqrt(len(data)) * 1e6
        
        ax.plot(times, mean, color=colors[event_name], 
                label=labels[event_name], linewidth=2)
        ax.fill_between(times, mean - sem, mean + sem,
                        color=colors[event_name], alpha=0.2)
    
    ax.axvline(0, color='k', linestyle='--', linewidth=1, alpha=0.5)
    ax.axhline(0, color='k', linestyle='-', linewidth=0.5, alpha=0.3)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (uV)')
    ax.set_title(f'Channel {ch_name}')
    ax.legend()

fig.suptitle('Event-Related Potentials (Motor Cortex)', fontsize=12, y=1.02)
plt.tight_layout()
plt.savefig('../images/erp_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

### 3.3 Topographic Maps

In [None]:
# Topographic maps at different time points
time_points = [0.5, 1.0, 2.0, 3.0]

fig, axes = plt.subplots(2, len(time_points), figsize=(14, 7))

for row, event_name in enumerate(['T1', 'T2']):
    epochs_cond = epochs[event_name]
    evoked = epochs_cond.average()
    
    for col, t in enumerate(time_points):
        evoked.plot_topomap(times=t, axes=axes[row, col], show=False,
                            colorbar=False, time_unit='s')
        if row == 0:
            axes[row, col].set_title(f't = {t}s')
    
    axes[row, 0].set_ylabel(labels[event_name], fontsize=12)

fig.suptitle('Topographic Maps Over Time', fontsize=12, y=1.02)
plt.tight_layout()
plt.savefig('../images/topomaps.png', dpi=150, bbox_inches='tight')
plt.show()

### 3.4 Time-Frequency Analysis

In [None]:
# Time-frequency analysis using Morlet wavelets
from mne.time_frequency import tfr_morlet

freqs = np.arange(4, 35, 1)
n_cycles = freqs / 2

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for row, event_name in enumerate(['T1', 'T2']):
    epochs_cond = epochs[event_name]
    
    for col, ch_name in enumerate(['C3', 'Cz', 'C4']):
        if ch_name not in epochs.ch_names:
            continue
            
        power = tfr_morlet(
            epochs_cond, freqs=freqs, n_cycles=n_cycles,
            return_itc=False, picks=ch_name, average=True
        )
        
        # Baseline normalize
        power.apply_baseline(baseline=(-0.5, 0), mode='percent')
        
        power.plot(
            picks=ch_name, axes=axes[row, col],
            show=False, colorbar=False,
            title=f'{labels[event_name]} - {ch_name}'
        )

fig.suptitle('Time-Frequency Representations (% change from baseline)', fontsize=12, y=1.02)
plt.tight_layout()
plt.savefig('../images/time_frequency.png', dpi=150, bbox_inches='tight')
plt.show()

---

## 4. Feature Extraction

We'll extract three types of features:
1. **Time-domain:** Statistical measures (mean, variance, etc.)
2. **Frequency-domain:** Band power (alpha, beta, etc.)
3. **Spatial:** Common Spatial Patterns (CSP)

In [None]:
# Get labels
labels_array = epochs.events[:, -1]
print(f"Labels: {np.unique(labels_array, return_counts=True)}")

In [None]:
# Extract features from all epochs
X, y, feature_names = extract_features_from_epochs(epochs)

print(f"Feature matrix shape: {X.shape}")
print(f"Labels shape: {y.shape}")
print(f"Number of features: {len(feature_names)}")
print(f"\nSample feature names: {feature_names[:10]}")

In [None]:
# Handle any NaN or Inf values
X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Scale features
X_train_scaled, X_test_scaled, scaler = scale_features(X_train, X_test)

print(f"Training set: {X_train_scaled.shape}")
print(f"Test set: {X_test_scaled.shape}")

### 4.1 Common Spatial Patterns (CSP)

CSP is a powerful spatial filtering technique for EEG that maximizes the variance for one class while minimizing it for another.

In [None]:
# Fit CSP
csp = CSP(n_components=6, reg='ledoit_wolf', log=True, norm_trace=False)
csp.fit(epochs.get_data(), y)

# Plot CSP patterns
fig, axes = plt.subplots(1, 6, figsize=(14, 3))

for idx, ax in enumerate(axes):
    mne.viz.plot_topomap(
        csp.patterns_[idx], epochs.info,
        axes=ax, show=False, cmap='RdBu_r'
    )
    ax.set_title(f'CSP {idx+1}')

fig.suptitle('Common Spatial Patterns', fontsize=12, y=1.05)
plt.tight_layout()
plt.savefig('../images/csp_patterns.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Extract CSP features
X_csp = csp.transform(epochs.get_data())

# Visualize CSP features
fig, ax = plt.subplots(figsize=(10, 6))

for label, color in zip([2, 3], ['green', 'red']):
    mask = y == label
    ax.scatter(X_csp[mask, 0], X_csp[mask, -1], c=color, alpha=0.7,
               label='Left Hand' if label == 2 else 'Right Hand')

ax.set_xlabel('CSP Component 1')
ax.set_ylabel('CSP Component 6')
ax.set_title('CSP Feature Space')
ax.legend()
plt.tight_layout()
plt.savefig('../images/csp_features.png', dpi=150, bbox_inches='tight')
plt.show()

---

## 5. Classification with Classical ML

We'll compare several classical machine learning algorithms:
- Linear Discriminant Analysis (LDA)
- Support Vector Machine (SVM)
- Random Forest

In [None]:
# Define models
models = {
    'LDA': LinearDiscriminantAnalysis(),
    'SVM-RBF': SVC(kernel='rbf', C=1.0, probability=True),
    'SVM-Linear': SVC(kernel='linear', C=1.0, probability=True),
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42)
}

# Cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
results = {}

print("Cross-Validation Results (5-fold):\n")
for name, model in models.items():
    scores = cross_val_score(model, X_train_scaled, y_train, cv=cv, scoring='accuracy')
    results[name] = {'mean': scores.mean(), 'std': scores.std(), 'scores': scores}
    print(f"{name:15} Accuracy: {scores.mean():.3f} (+/- {scores.std():.3f})")

In [None]:
# Train best model and evaluate on test set
best_model_name = max(results, key=lambda x: results[x]['mean'])
print(f"\nBest model: {best_model_name}")

best_model = models[best_model_name]
best_model.fit(X_train_scaled, y_train)
y_pred = best_model.predict(X_test_scaled)

print(f"\nTest Accuracy: {accuracy_score(y_test, y_pred):.3f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=['Left Hand', 'Right Hand']))

In [None]:
# Plot confusion matrix
fig, ax = plt.subplots(figsize=(6, 5))

cm = confusion_matrix(y_test, y_pred)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=['Left Hand', 'Right Hand'],
            yticklabels=['Left Hand', 'Right Hand'], ax=ax)

ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title(f'Confusion Matrix ({best_model_name})')

plt.tight_layout()
plt.savefig('../images/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Model comparison bar chart
fig, ax = plt.subplots(figsize=(10, 5))

model_names = list(results.keys())
means = [results[name]['mean'] for name in model_names]
stds = [results[name]['std'] for name in model_names]

colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(model_names)))
bars = ax.bar(model_names, means, yerr=stds, color=colors, edgecolor='black', capsize=5)

ax.axhline(0.5, color='red', linestyle='--', label='Chance level')
ax.set_ylabel('Accuracy')
ax.set_title('Model Comparison - Cross-Validation Accuracy')
ax.set_ylim([0, 1])
ax.legend()

# Add value labels
for bar, mean in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{mean:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig('../images/model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

---

## 6. Classification with CSP + LDA Pipeline

The CSP + LDA pipeline is a classic approach for motor imagery BCI.

In [None]:
from sklearn.pipeline import Pipeline

# Create CSP + LDA pipeline
csp_lda_pipeline = Pipeline([
    ('csp', CSP(n_components=6, reg='ledoit_wolf', log=True)),
    ('lda', LinearDiscriminantAnalysis())
])

# Get epoch data
X_epochs = epochs.get_data()

# Split
X_train_ep, X_test_ep, y_train_ep, y_test_ep = train_test_split(
    X_epochs, y, test_size=0.2, random_state=42, stratify=y
)

# Cross-validation
scores_csp_lda = cross_val_score(csp_lda_pipeline, X_train_ep, y_train_ep, cv=cv)
print(f"CSP + LDA Cross-Validation Accuracy: {scores_csp_lda.mean():.3f} (+/- {scores_csp_lda.std():.3f})")

# Train and evaluate
csp_lda_pipeline.fit(X_train_ep, y_train_ep)
y_pred_csp = csp_lda_pipeline.predict(X_test_ep)
print(f"CSP + LDA Test Accuracy: {accuracy_score(y_test_ep, y_pred_csp):.3f}")

---

## 7. Multi-Subject Analysis

Let's analyze multiple subjects to understand variability across individuals.

In [None]:
# Analyze multiple subjects
subjects = [1, 2, 3, 4, 5]
subject_results = []

print("Processing multiple subjects...\n")

for subj in subjects:
    try:
        # Load and preprocess
        raw_subj = load_subject_data(subj, RUNS_IMAGERY)
        raw_subj.set_montage(mne.channels.make_standard_montage('standard_1005'))
        raw_subj.filter(l_freq=1.0, h_freq=40.0, fir_design='firwin')
        raw_subj.set_eeg_reference('average', projection=False)
        
        # Create epochs
        events_subj, _ = mne.events_from_annotations(raw_subj)
        epochs_subj = mne.Epochs(
            raw_subj, events_subj, event_id=event_id_motor,
            tmin=-0.5, tmax=4.0, baseline=(-0.5, 0),
            preload=True, picks='eeg'
        )
        
        # Extract data
        X_subj = epochs_subj.get_data()
        y_subj = epochs_subj.events[:, -1]
        
        # CSP + LDA pipeline
        scores = cross_val_score(csp_lda_pipeline, X_subj, y_subj, cv=5)
        
        subject_results.append({
            'subject': subj,
            'n_epochs': len(epochs_subj),
            'accuracy_mean': scores.mean(),
            'accuracy_std': scores.std()
        })
        
        print(f"Subject {subj}: Accuracy = {scores.mean():.3f} (+/- {scores.std():.3f})")
        
    except Exception as e:
        print(f"Subject {subj}: Error - {e}")

# Convert to DataFrame
df_results = pd.DataFrame(subject_results)
print(f"\nMean accuracy across subjects: {df_results['accuracy_mean'].mean():.3f}")

In [None]:
# Plot subject variability
fig, ax = plt.subplots(figsize=(10, 5))

colors = plt.cm.Set2(np.linspace(0, 1, len(df_results)))
bars = ax.bar(df_results['subject'].astype(str), df_results['accuracy_mean'],
              yerr=df_results['accuracy_std'], color=colors, edgecolor='black', capsize=5)

ax.axhline(0.5, color='red', linestyle='--', label='Chance level')
ax.axhline(df_results['accuracy_mean'].mean(), color='blue', linestyle='-.',
           label=f'Mean = {df_results["accuracy_mean"].mean():.3f}')

ax.set_xlabel('Subject ID')
ax.set_ylabel('Classification Accuracy')
ax.set_title('Motor Imagery Classification - Subject Variability')
ax.set_ylim([0, 1])
ax.legend()

plt.tight_layout()
plt.savefig('../images/subject_variability.png', dpi=150, bbox_inches='tight')
plt.show()

---

## 8. Summary and Conclusions

### Key Findings

1. **EEG signals contain discriminative information** for distinguishing left vs right hand motor imagery
2. **Preprocessing is crucial:** Filtering and artifact removal significantly improve signal quality
3. **CSP is highly effective** for motor imagery classification, extracting spatial patterns that maximize class separability
4. **Inter-subject variability** is significant, highlighting the need for subject-specific calibration
5. **Classical ML achieves reasonable accuracy** for this binary classification task

### Next Steps

- Deep learning approaches (EEGNet, CNN)
- Neural state prediction using RNNs
- Real-time BCI implementation
- Transfer learning across subjects

In [None]:
# Save summary statistics
summary = {
    'dataset': 'PhysioNet EEG Motor Imagery',
    'n_subjects_analyzed': len(df_results),
    'n_channels': len(epochs.ch_names),
    'sampling_rate': epochs.info['sfreq'],
    'best_single_subject': df_results.loc[df_results['accuracy_mean'].idxmax()].to_dict(),
    'mean_accuracy': df_results['accuracy_mean'].mean(),
    'std_accuracy': df_results['accuracy_std'].mean(),
    'model': 'CSP + LDA'
}

print("Summary Statistics:")
for key, value in summary.items():
    print(f"  {key}: {value}")

In [None]:
print("\n" + "="*60)
print("EEG Motor Imagery Analysis Complete!")
print("="*60)