<a href="https://colab.research.google.com/github/abelowska/mlNeuro/blob/main/2025/MLN_p3_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simple P300 Speller

Detect when participant see target, and when non-target stimuli using [BI2015a dataset](https://neurotechx.github.io/moabb/generated/moabb.datasets.BI2015a.html#moabb.datasets.BI2015a).

In [None]:
!pip install moabb
!pip install mne

Imports

In [None]:
import moabb
import mne
import numpy as np
import matplotlib.pyplot as plt
import time
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

## Prepare data

### 1. Fetch data

In [None]:
# Get data fro one subject. It might take a while
dataset = moabb.datasets.BI2015a()
data = dataset.get_data(subjects=[1])

In [None]:
data[]

Extract `MNE` `Raw` from the downloaded data

In [None]:
subject = 1
session = '0'
run = '0'

raw = data[subject][session][run]
raw

### 2. Simple Raw pre-processing

In [None]:
fig = raw.plot()
fig = raw.compute_psd().plot()
# 1. re-reference: to almost-mastoids
raw.set_eeg_reference(ref_channels=['T7', 'T8'])

# 2. band-pass filter
raw_filtered = raw.copy().filter(
    picks=['eeg'],
    l_freq=.1,
    h_freq=30.0,
    n_jobs=10,
    method='iir',
    iir_params=None
    )

# 3. Notch filter
power_freq = 50
nyquist_freq = raw_filtered.info['sfreq'] / 2

raw_filtered = raw_filtered.notch_filter(
    picks=['eeg', 'eog'],
    freqs=np.arange(power_freq, nyquist_freq, power_freq),
    n_jobs=10,
)

fig = raw_filtered.plot()
fig = raw_filtered.compute_psd().plot()

### 3. Create segments around stimuli

In [None]:
# fing events on the STIM channel
events = mne.find_events(raw_filtered)

# create events dict
event_ids = {'Target': 2, 'Non-Target': 1}

# create segments
tmin = -0.2
tmax = 0.6
baseline = (-0.2,0)
epochs = mne.Epochs(
    raw_filtered,
    events,
    event_id=event_ids,
    tmin=tmin,
    tmax=tmax,
    baseline=baseline,
)

epochs

### 4. Look into EEG signal for target and non-target stimuli

In [None]:
# create ERPs
target_erp = epochs['Target'].average()
nontarget_erp = epochs['Non-Target'].average()

# compare target and non-target ERPs
picks = ['Cz']

fig = mne.viz.plot_compare_evokeds(
    evokeds = {'target': target_erp, 'non-target': nontarget_erp},
    picks=picks,
    invert_y=True
)

## ML model

Now you can use your `epochs` to create a model

In [None]:
# your code here

## Test your model

In [None]:
def get_test_data(session='1'):
  subject = 1
  session = session
  run = '0'

  test_raw = data[subject][session][run]
  # 1. re-reference: to almost-mastoids
  test_raw.set_eeg_reference(ref_channels=['T7', 'T8'])

  # 2. band-pass filter
  test_raw_filtered = test_raw.copy().filter(
      picks=['eeg'],
      l_freq=.1,
      h_freq=30.0,
      n_jobs=10,
      method='iir',
      iir_params=None
      )

  # 3. Notch filter
  power_freq = 50
  nyquist_freq = test_raw_filtered.info['sfreq'] / 2

  test_raw_filtered = test_raw_filtered.notch_filter(
      picks=['eeg', 'eog'],
      freqs=np.arange(power_freq, nyquist_freq, power_freq),
      n_jobs=10,
  )

  # fing events on the STIM channel
  events = mne.find_events(test_raw_filtered)

  # create events dict
  event_ids = {'Target': 2, 'Non-Target': 1}

  # create segments
  tmin = -0.2
  tmax = 0.6
  baseline = (-0.2,0)
  test_epochs = mne.Epochs(
      test_raw_filtered,
      events,
      event_id=event_ids,
      tmin=tmin,
      tmax=tmax,
      baseline=baseline,
  )

  return test_epochs


def test_checker(X_test, y_test, model, n_samples=10):
  for i in range(len(X_test[:n_samples])):
      print(f"Checking test trial {i + 1}...\n")
      time.sleep(1.4)

      # Get the prediction for the current sample
      y_pred = model.predict(X_test[i].reshape(1, -1))

      # Check if the prediction is correct
      if y_pred[0] == y_test[i]:
          print("Correct! ❤️\n\n")
      else:
          print("Incorrect! 😢\n\n")

      time.sleep(0.5)

In [None]:
test_epochs = get_test_data(session='1')

Transform your data in exaclty the same way as training data, to facilitate testing procedure:

In [None]:
# your code here

And run `test_checker()` !

In [None]:
test_checker(X_test=X_test, y_test=y_test, model=lda)