<a href="https://colab.research.google.com/github/abelowska/mlNeuro/blob/main/2025/MLN_p3_classification_solutions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simple P300 Speller - solutions

Detect when participant see target, and when non-target stimuli using [BI2015a dataset](https://neurotechx.github.io/moabb/generated/moabb.datasets.BI2015a.html#moabb.datasets.BI2015a).

In [None]:
!pip install moabb
!pip install mne

Now, **restart your session** and then run next cells.

Imports

In [None]:
import moabb
import mne
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from imblearn.under_sampling import RandomUnderSampler
from mne.decoding import LinearModel, Vectorizer, get_coef

## Prepare data

### 1. Fetch data

In [None]:
# Get data fro one subject. It might take a while
dataset = moabb.datasets.BI2015a()
data = dataset.get_data(subjects=[2])

In [None]:
data

Extract `MNE` `Raw` from the downloaded data

In [None]:
subject = 2
session = '0'
run = '0'

raw = data[subject][session][run]
raw

### 2. Simple Raw pre-processing

In [None]:
fig = raw.plot()
fig = raw.compute_psd().plot()
# 1. re-reference: to almost-mastoids
raw.set_eeg_reference(ref_channels=['T7', 'T8'])

# 2. band-pass filter
raw_filtered = raw.copy().filter(
    picks=['eeg'],
    l_freq=.1,
    h_freq=30.0,
    n_jobs=10,
    method='iir',
    iir_params=None
    )

# 3. Notch filter
power_freq = 50
nyquist_freq = raw_filtered.info['sfreq'] / 2

raw_filtered = raw_filtered.notch_filter(
    picks=['eeg', 'eog'],
    freqs=np.arange(power_freq, nyquist_freq, power_freq),
    n_jobs=10,
)

fig = raw_filtered.plot()
fig = raw_filtered.compute_psd().plot()

### 3. Create segments around stimuli

In [None]:
# fing events on the STIM channel
events = mne.find_events(raw_filtered)

# create events dict
event_ids = {'Target': 2, 'Non-Target': 1}

# create segments
tmin = -0.2
tmax = 0.6
baseline = (-0.2,0)
epochs = mne.Epochs(
    raw_filtered,
    events,
    event_id=event_ids,
    picks="eeg",
    tmin=tmin,
    tmax=tmax,
    baseline=baseline,
    preload=True
)

epochs

### 4. Look into EEG signal for target and non-target stimuli

In [None]:
# create ERPs
target_erp = epochs['Target'].average()
nontarget_erp = epochs['Non-Target'].average()

# compare target and non-target ERPs
picks = ['Cz']

fig = mne.viz.plot_compare_evokeds(
    evokeds = {'target': target_erp, 'non-target': nontarget_erp},
    picks=picks,
    invert_y=True
)

## ML model

Now you can use your `epochs` to create a model

### MODEL 1: Problem with imbalanced classes

**Feature**: mean amplitude at Cz channel in time window 0.25 - 0.45s after stimuli presentation

In [None]:
###### Creates data for fitting #######################
tmin = 0.25
tmax = 0.45
epochs_data = epochs.get_data(picks='Cz', tmin=tmin, tmax=tmax) # get epochs data in desired time-window and at desired channel
epochs_data = epochs_data.mean(axis=-1) # average signal within time-window to get mean amplitude

X = epochs_data.reshape(epochs_data.shape[0], -1) # reshape to (n_samples, n_features)
y = epochs.events[:,-1] - 1 # create labels: you can use events code for this!

print(f"Shape of X data: {X.shape}, shape of y data: {y.shape}\n")

###### Fit the simplest classification model ##########
clf = LogisticRegression(random_state=0).fit(X, y)
y_predicted = clf.predict(X)

###### Print classification results ##################
print(classification_report(y_true=y, y_pred=y_predicted))

You can see, that your model preedict always class 0. It is the majority class, and it's due to huge imbalance between the classes - **390:78**

### MODEL 2: Balanced Classes but Still Performs Poorly
**Feature**: mean amplitude at Cz channel in time window 0.25 - 0.45s after stimuli presentation

In [None]:
###### Creates data for fitting #######################
tmin = 0.25
tmax = 0.45
epochs_data = epochs.get_data(picks='Cz', tmin=tmin, tmax=tmax) # get epochs data in desired time-window and at desired channel
epochs_data = epochs_data.mean(axis=-1) # average signal within time-window to get mean amplitude

X = epochs_data.reshape(epochs_data.shape[0], -1) # reshape to (n_samples, n_features)
y = epochs.events[:,-1] - 1 # create labels: you can use events code for this!

print(f"Shape of X data: {X.shape}, shape of y data: {y.shape}\n")

###### Balance classes #################################
### You can implement class balancing manually—for example, by iterating
### simultaneously over X and y and appending observations of each class
### to separate lists. Alternatively, you can use an off-the-shelf
### implementation, like the one shown below:###########

undersampler = RandomUnderSampler(random_state=42)
X_resampled, y_resampled = undersampler.fit_resample(X, y)

print(f"After resampling:\nShape of X data: {X_resampled.shape}, shape of y data: {y_resampled.shape}\n")

###### Fit the simplest classification model ##########
clf = LogisticRegression(random_state=0).fit(X_resampled, y_resampled)
y_predicted = clf.predict(X_resampled)

###### Print classification results ##################
print(classification_report(y_true=y_resampled, y_pred=y_predicted))

Now it's working a bit better, but overall performance is still basically random.

### MODEL 3: Balanced Classes with Standard Scaler
**Feature**: mean amplitude at Cz channel in time window 0.25 - 0.45s after stimuli presentation, scaled

In [None]:
###### Creates data for fitting #######################
tmin = 0.25
tmax = 0.45
epochs_data = epochs.get_data(picks='Cz', tmin=tmin, tmax=tmax) # get epochs data in desired time-window and at desired channel
epochs_data = epochs_data.mean(axis=-1) # average signal within time-window to get mean amplitude

X = epochs_data.reshape(epochs_data.shape[0], -1) # reshape to (n_samples, n_features)
y = epochs.events[:,-1] - 1 # create labels: you can use events code for this!

print(f"Shape of X data: {X.shape}, shape of y data: {y.shape}\n")

###### Balance classes #################################
### You can implement class balancing manually—for example, by iterating
### simultaneously over X and y and appending observations of each class
### to separate lists. Alternatively, you can use an off-the-shelf
### implementation, like the one shown below:###########

undersampler = RandomUnderSampler(random_state=42)
X_resampled, y_resampled = undersampler.fit_resample(X, y)

print(f"After resampling:\nShape of X data: {X_resampled.shape}, shape of y data: {y_resampled.shape}\n")

###### Fit the simplest classification model ##########
##$ But now use Pipelines and Standard Scaler #########
clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0)).fit(X_resampled, y_resampled)
y_predicted = clf.predict(X_resampled)

###### Print classification results ##################
print(classification_report(y_true=y_resampled, y_pred=y_predicted))

It's a huge improvement! Now it's working much better, but the model is still biased toward the 0 class, and the recall for the 1 class is below chance level.

### MODEL 4: Balanced Classes with Standard Scaler and more features
**Feature**: amplitude at Cz channel in time window 0.25 - 0.45s after stimuli presentation, scaled

**Note, that now we don't have mean amplitude (one feature), but the whole signal from the time-window (in this calse it is 102 features).**

In [None]:
###### Creates data for fitting #######################
tmin = 0.25
tmax = 0.45
picks=['Cz']
epochs_data = epochs.get_data(picks=picks, tmin=tmin, tmax=tmax) # get epochs data in desired time-window and at desired channel

X = epochs_data.reshape(epochs_data.shape[0], -1) # reshape to (n_samples, n_features)
y = epochs.events[:,-1] - 1 # create labels: you can use events code for this!

print(f"Shape of X data: {X.shape}, shape of y data: {y.shape}\n")

###### Balance classes #################################
### You can implement class balancing manually—for example, by iterating
### simultaneously over X and y and appending observations of each class
### to separate lists. Alternatively, you can use an off-the-shelf
### implementation, like the one shown below:###########

undersampler = RandomUnderSampler(random_state=42)
X_resampled, y_resampled = undersampler.fit_resample(X, y)

print(f"After resampling:\nShape of X data: {X_resampled.shape}, shape of y data: {y_resampled.shape}\n")

###### Fit the simplest classification model ##########
##$ But now use Pipelines and Standard Scaler #########
clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0)).fit(X_resampled, y_resampled)
y_predicted = clf.predict(X_resampled)

###### Print classification results ##################
print(classification_report(y_true=y_resampled, y_pred=y_predicted))

## Looking into the model: visualization of model parameters (coefficients)

In [None]:
coefs = clf['logisticregression'].coef_.flatten() # assuming that model was a pipeline
x_ticks = np.linspace(tmin, tmax, len(coefs))

abs_max = np.max(np.abs(coefs))
coefs_2d = coefs[np.newaxis, :]

plt.figure(figsize=(10, 2))
plt.imshow(
    coefs_2d,
    aspect='auto',
    cmap='bwr',
    interpolation='none',
    extent=[tmin, tmax, 0, len(picks)],
    vmin=-abs_max,
    vmax=abs_max
    )
plt.colorbar(label="Coefficient Value")
plt.yticks(np.arange(len(picks)), labels=picks)
plt.xlabel("Time")
plt.title("Classifier Coefficients Over Time")
plt.tight_layout()
plt.show()

### MODEL 5: Balanced Classes with Standard Scaler and spatial features
**Feature**: amplitude at Fz, C, and Pz channel in time window 0.25 - 0.45s after stimuli presentation, scaled

In [None]:
###### Creates data for fitting #######################
tmin = 0.25
tmax = 0.45
picks = ['Cz', 'CP1', 'CP2', 'Pz']
epochs_data = epochs.get_data(picks, tmin=tmin, tmax=tmax) # get epochs data in desired time-window and at desired channel

X = epochs_data.reshape(epochs_data.shape[0], -1) # reshape to (n_samples, n_features)
y = epochs.events[:,-1] - 1 # create labels: you can use events code for this!

print(f"Shape of X data: {X.shape}, shape of y data: {y.shape}\n")

###### Balance classes #################################
### You can implement class balancing manually—for example, by iterating
### simultaneously over X and y and appending observations of each class
### to separate lists. Alternatively, you can use an off-the-shelf
### implementation, like the one shown below:###########

undersampler = RandomUnderSampler(random_state=42)
X_resampled, y_resampled = undersampler.fit_resample(X, y)

print(f"After resampling:\nShape of X data: {X_resampled.shape}, shape of y data: {y_resampled.shape}\n")

###### Fit the simplest classification model ##########
### But now use Pipelines and Standard Scaler #########
clf = make_pipeline(StandardScaler(), LogisticRegression(random_state=0)).fit(X_resampled, y_resampled)
y_predicted = clf.predict(X_resampled)

###### Print classification results ##################
print(classification_report(y_true=y_resampled, y_pred=y_predicted))

#### Looking into the model: visualization of model parameters (coefficients)

In [None]:
coefs = clf['logisticregression'].coef_.flatten()  # assuming that model was a pipeline
coefs_2d = coefs.reshape(len(picks), -1)

x_ticks = np.linspace(tmin, tmax, coefs_2d.shape[-1])

abs_max = np.max(np.abs(coefs))

plt.figure(figsize=(10, 2))
plt.imshow(
    coefs_2d,
    aspect='auto',
    cmap='bwr',
    interpolation='none',
    extent=[tmin, tmax, 0, len(picks)],  # Adjust the extent of y-axis to match number of picks
    vmin=-abs_max,
    vmax=abs_max
    )
plt.colorbar(label="Coefficient Value")
plt.yticks(np.arange(len(picks)), labels=picks)  # Set y-ticks with your channel names
plt.xlabel("Time")
plt.title("Classifier Coefficients Over Time")
plt.tight_layout()
plt.show()

## Test your model

In [None]:
def get_test_data(session='1'):
  subject = 2
  session = session
  run = '0'

  test_raw = data[subject][session][run]
  # 1. re-reference: to almost-mastoids
  test_raw.set_eeg_reference(ref_channels=['T7', 'T8'])

  # 2. band-pass filter
  test_raw_filtered = test_raw.copy().filter(
      picks=['eeg'],
      l_freq=.1,
      h_freq=30.0,
      n_jobs=10,
      method='iir',
      iir_params=None
      )

  # 3. Notch filter
  power_freq = 50
  nyquist_freq = test_raw_filtered.info['sfreq'] / 2

  test_raw_filtered = test_raw_filtered.notch_filter(
      picks=['eeg', 'eog'],
      freqs=np.arange(power_freq, nyquist_freq, power_freq),
      n_jobs=10,
  )

  # fing events on the STIM channel
  events = mne.find_events(test_raw_filtered)

  # create events dict
  event_ids = {'Target': 2, 'Non-Target': 1}

  # create segments
  tmin = -0.2
  tmax = 0.6
  baseline = (-0.2,0)
  test_epochs = mne.Epochs(
      test_raw_filtered,
      events,
      event_id=event_ids,
      tmin=tmin,
      tmax=tmax,
      baseline=baseline,
  )

  return test_epochs


def test_checker(X_test, y_test, model, n_samples=10):
  for i in range(len(X_test[:n_samples])):
      print(f"Checking test trial {i + 1}...\n")
      time.sleep(1.4)

      # Get the prediction for the current sample
      y_pred = model.predict(X_test[i].reshape(1, -1))

      # Check if the prediction is correct
      if y_pred[0] == y_test[i]:
          print("Correct! ❤️\n\n")
      else:
          print("Incorrect! 😢\n\n")

      time.sleep(0.5)
  y_test_predicted = model.predict(X_test)
  print(classification_report(y_test, y_test_predicted))

In [None]:
test_epochs = get_test_data(session='1')

Transform your data in exaclty the same way as training data, to facilitate testing procedure:

In [None]:
###### Creates data for fitting #######################
tmin = 0.25
tmax = 0.45
picks = ['Cz', 'CP1', 'CP2', 'Pz']
epochs_data = test_epochs.get_data(picks, tmin=tmin, tmax=tmax) # get epochs data in desired time-window and at desired channel

X_test = epochs_data.reshape(epochs_data.shape[0], -1) # reshape to (n_samples, n_features)
y_test = test_epochs.events[:,-1] - 1 # create labels: you can use events code for this!

print(f"Shape of X data: {X_test.shape}, shape of y data: {y_test.shape}\n")

###### Balance classes #################################
### You can implement class balancing manually—for example, by iterating
### simultaneously over X and y and appending observations of each class
### to separate lists. Alternatively, you can use an off-the-shelf
### implementation, like the one shown below:###########

undersampler = RandomUnderSampler(random_state=42)
X_test_resampled, y_test_resampled = undersampler.fit_resample(X_test, y_test)

print(f"After resampling:\nShape of X data: {X_test_resampled.shape}, shape of y data: {y_test_resampled.shape}\n")

And run `test_checker()` !

In [None]:
test_checker(X_test=X_test_resampled, y_test=y_test_resampled, model=clf)