In [1]:
! pip install mne
! pip install moabb
# Restart session afterwards!

Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0
Collecting moabb
  Downloading moabb-1.2.0-py3-none-any.whl.metadata (14 kB)
Collecting coverage<8.0.0,>=7.0.1 (from moabb)
  Downloading coverage-7.8.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.5 kB)
Collecting edfio<0.5.0,>=0.4.2 (from moabb)
  Downloading edfio-0.4.8-py3-none-any.whl.metadata (3.9 kB)
Collecting edflib-python<2.0.0,>=1.0.6 (from moabb)
  Downloading EDFlib_Python-1.0.8-py3-none-any.whl.metadata (1.3 kB)
Collecting memory-profiler<0.62.0,>=0.61.0 (from moabb)
  Downloading memory_profiler-0.61.0-py3-none-any.whl.metadata (20 kB)
Collecting mne-bids>=0.14 (from moabb)
  Downloading mne_bids-0.16

In [11]:
import moabb
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import random
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, precision_score, recall_score, balanced_accuracy_score
from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score, RepeatedStratifiedKFold, GridSearchCV, train_test_split, RepeatedKFold
from sklearn import linear_model
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


In [2]:
# load data from the subject
dataset = moabb.datasets.BI2015a()
data = dataset.get_data(subjects=[2])
subject = 2
session = '0'
run = '0'

raw = data[subject][session][run]

MNE_DATA is not already configured. It will be set to default location in the home directory - /root/mne_data
All datasets will be downloaded to this location, if anything is already downloaded, please move manually to this location
Attempting to create new mne-python configuration file:
/root/.mne/mne-python.json


  set_config(key, get_config("MNE_DATA"))
Downloading data from 'https://zenodo.org/record/3266930/files/subject_02_mat.zip' to file '/root/mne_data/MNE-braininvaders2015a-data/record/3266930/files/subject_02_mat.zip'.
100%|████████████████████████████████████████| 107M/107M [00:00<00:00, 140GB/s]
SHA256 hash of downloaded file: c13ab3a18dbd661f5c9bb630445bc2979476b922f84af91e38a6b4246ebeddb5
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.


In [17]:
# Re-reference
raw.set_eeg_reference(ref_channels=['T7', 'T8'])

# Band-pass filter
raw_filtered = raw.copy().filter(
      picks=['eeg'],
      l_freq=.1,
      h_freq=30.0,
      n_jobs=10,
      method='iir',
      iir_params=None
      )

# Notch filter
power_freq = 50
nyquist_freq = raw_filtered.info['sfreq'] / 2

test_raw_filtered = raw_filtered.notch_filter(
    picks=['eeg', 'eog'],
    freqs=np.arange(power_freq, nyquist_freq, power_freq),
    n_jobs=10,
)

# Create epochs
events = mne.find_events(raw_filtered)
event_ids = {'Target': 2, 'Non-Target': 1}
tmin = -0.2
tmax = 0.6
baseline = (-0.2,0)
epochs = mne.Epochs(
    test_raw_filtered,
    events,
    event_id=event_ids,
    tmin=tmin,
    tmax=tmax,
    baseline=baseline,
)



EEG channel type selected for re-referencing
Applying a custom ('EEG',) reference.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.1 - 30 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 0.10, 30.00 Hz: -6.02, -6.02 dB



[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done   8 tasks      | elapsed:    6.0s


Filtering raw data in 1 contiguous segment
Setting up band-stop filter

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower transition bandwidth: 0.50 Hz
- Upper transition bandwidth: 0.50 Hz
- Filter length: 3381 samples (6.604 s)



[Parallel(n_jobs=10)]: Done  30 out of  32 | elapsed:    9.0s remaining:    0.6s
[Parallel(n_jobs=10)]: Done  32 out of  32 | elapsed:    9.2s finished
[Parallel(n_jobs=10)]: Using backend LokyBackend with 10 concurrent workers.
[Parallel(n_jobs=10)]: Done   8 tasks      | elapsed:    0.4s


468 events found on stim channel STI 014
Event IDs: [1 2]
Not setting metadata
468 matching events found
Applying baseline correction (mode: mean)
0 projection items activated


[Parallel(n_jobs=10)]: Done  30 out of  32 | elapsed:    0.8s remaining:    0.1s
[Parallel(n_jobs=10)]: Done  32 out of  32 | elapsed:    0.8s finished


In [5]:
# Compare number of target vs non-target events
non_target_count = 0
target_count = 0
for event in epochs.events[:,-1]:
  if event == 1:
    non_target_count += 1
  else:
    target_count += 1
print(f"Target occurances: {target_count}")
print(f"Non-target occurances: {non_target_count}")

Target occurances: 78
Non-target occurances: 390


In [21]:
# Balance number of target/non-target events to avoid bias
nt_ids = np.where(np.isin(epochs.events[:, -1], (1)))[0]
ids_to_drop = []
for i in range(non_target_count - target_count):
  while True:
    random_id = random.choice(nt_ids)
    if random_id not in ids_to_drop:
      ids_to_drop.append(random_id)
      break
epochs.drop(indices=ids_to_drop, reason="balance dataset")

Dropped 312 epochs: 4, 5, 7, 8, 9, 11, 12, 13, 14, 16, 18, 19, 20, 21, 22, 23, 26, 28, 29, 31, 33, 34, 35, 36, 37, 38, 39, 42, 44, 45, 46, 49, 50, 51, 52, 53, 54, 56, 57, 58, 59, 60, 61, 63, 65, 67, 68, 69, 72, 73, 75, 76, 79, 81, 82, 83, 84, 85, 88, 89, 90, 91, 92, 94, 95, 96, 98, 100, 101, 104, 105, 106, 108, 109, 110, 111, 113, 114, 116, 117, 120, 121, 124, 125, 127, 128, 130, 131, 132, 134, 136, 137, 140, 141, 146, 147, 148, 149, 151, 152, 153, 154, 155, 156, 157, 158, 160, 161, 162, 163, 164, 165, 167, 170, 171, 173, 176, 177, 180, 182, 183, 185, 187, 188, 189, 190, 191, 193, 194, 195, 196, 197, 198, 200, 203, 204, 205, 206, 210, 211, 213, 216, 217, 219, 220, 221, 223, 225, 226, 227, 228, 229, 230, 233, 234, 235, 236, 238, 239, 240, 241, 243, 244, 245, 246, 247, 249, 250, 253, 254, 255, 256, 257, 258, 259, 260, 261, 263, 264, 265, 267, 268, 269, 270, 272, 273, 274, 279, 280, 281, 285, 286, 287, 288, 291, 292, 293, 295, 296, 297, 299, 301, 302, 303, 304, 305, 308, 309, 310, 311, 31

Unnamed: 0,General,General.1
,MNE object type,Epochs
,Measurement date,Unknown
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,156
,Events counts,Non-Target: 78  Target: 78
,Time range,-0.199 – 0.600 s
,Baseline,-0.200 – 0.000 s
,Sampling frequency,512.00 Hz


In [22]:
# Get X data
ml_data = epochs.get_data(picks="Cz", tmin=0.25, tmax=0.4)

# Assign X and y
X = ml_data
y = epochs.events[:, -1]
# Reshape X to two dimentions
X = X.reshape(X.shape[0], -1)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=5)

cs = [1, 100, 1000, 10000]
results = []
for c in cs:
  # Fit
  lm =linear_model.LogisticRegression(C=c)
  lm.fit(X_train, y_train)
  # Get train and test accuracy
  y_predicted = lm.predict(X_test)
  y_predicted_train = lm.predict(X_train)
  cv_acc = cross_val_score(lm, X_train, y_train)

  results_dict = {
      'model_name': f"Logistic regression c: {c}", # name your pipeline
      'model': lm,
      'train_acc': accuracy_score(y_train, y_predicted_train),
      'mean_cv_acc': np.mean(cv_acc),
      'test_acc': accuracy_score(y_test, y_predicted),
      'c' : c
    }
  results.append(results_dict)
display(pd.DataFrame(results))
#???????

Using data from preloaded Raw for 156 events and 410 original time points ...


Unnamed: 0,model_name,model,train_acc,mean_cv_acc,test_acc,c
0,Logistic regression c: 1,LogisticRegression(C=1),0.516129,0.516,0.4375,1
1,Logistic regression c: 100,LogisticRegression(C=100),0.516129,0.516,0.4375,100
2,Logistic regression c: 1000,LogisticRegression(C=1000),0.516129,0.516,0.4375,1000
3,Logistic regression c: 10000,LogisticRegression(C=10000),0.516129,0.516,0.4375,10000


In [23]:
def get_test_data(session='1'):
  subject = 2
  session = session
  run = '0'

  test_raw = data[subject][session][run]
  # 1. re-reference: to almost-mastoids
  test_raw.set_eeg_reference(ref_channels=['T7', 'T8'])

  # 2. band-pass filter
  test_raw_filtered = test_raw.copy().filter(
      picks=['eeg'],
      l_freq=.1,
      h_freq=30.0,
      n_jobs=10,
      method='iir',
      iir_params=None
      )

  # 3. Notch filter
  power_freq = 50
  nyquist_freq = test_raw_filtered.info['sfreq'] / 2

  test_raw_filtered = test_raw_filtered.notch_filter(
      picks=['eeg', 'eog'],
      freqs=np.arange(power_freq, nyquist_freq, power_freq),
      n_jobs=10,
  )

  # fing events on the STIM channel
  events = mne.find_events(test_raw_filtered)

  # create events dict
  event_ids = {'Target': 2, 'Non-Target': 1}

  # create segments
  tmin = -0.2
  tmax = 0.6
  baseline = (-0.2,0)
  test_epochs = mne.Epochs(
      test_raw_filtered,
      events,
      event_id=event_ids,
      tmin=tmin,
      tmax=tmax,
      baseline=baseline,
  )

  return test_epochs


def test_checker(X_test, y_test, model, n_samples=10):
  for i in range(len(X_test[:n_samples])):
      print(f"Checking test trial {i + 1}...\n")
      time.sleep(1.4)

      # Get the prediction for the current sample
      y_pred = model.predict(X_test[i].reshape(1, -1))

      # Check if the prediction is correct
      if y_pred[0] == y_test[i]:
          print("Correct! ❤️\n\n")
      else:
          print("Incorrect! 😢\n\n")

      time.sleep(0.5)

In [24]:
test_checker(X_test=X_test, y_test=y_test, model=lm)

Checking test trial 1...

Correct! ❤️


Checking test trial 2...

Incorrect! 😢


Checking test trial 3...

Incorrect! 😢


Checking test trial 4...

Incorrect! 😢


Checking test trial 5...

Incorrect! 😢


Checking test trial 6...

Correct! ❤️


Checking test trial 7...

Correct! ❤️


Checking test trial 8...

Correct! ❤️


Checking test trial 9...

Correct! ❤️


Checking test trial 10...

Incorrect! 😢


