In [1]:
import mne
import torch

%matplotlib widget

In [2]:

from bci_aic3.data import BCIDataset, load_data
from bci_aic3.paths import CONFIG_DIR, LABEL_MAPPING_PATH, RAW_DATA_DIR, TRAINING_STATS_PATH
from bci_aic3.util import read_json_to_dict, normalize, save_training_stats, load_training_stats
from bci_aic3.models.eegnet import EEGNet

In [3]:
import numpy as np
from torch.utils.data import DataLoader


In [4]:
label_mapping = read_json_to_dict(LABEL_MAPPING_PATH)

In [6]:
# 1. Load your dataset
train_mi = BCIDataset(
    "train.csv",
    base_path=RAW_DATA_DIR,
    task_type="MI",
    split="train",
    label_mapping=label_mapping,
    training_stats=None,
)

2400it [02:26, 16.36it/s]


In [7]:
training_stats = train_mi.get_training_stats()
training_stats

{'mean': tensor([298498.0625, 301786.0625, 334266.9375, 326867.8750, 347858.1875,
         288628.1250, 285722.0938, 286179.0625]),
 'std': tensor([ 25025.9785,  27414.4160,  71625.7344,  81569.4297, 121221.1562,
          28605.6387,  30016.5605,  28722.1230])}

In [8]:
save_training_stats(training_stats, TRAINING_STATS_PATH/"mi_train.pt")

In [9]:
stats = load_training_stats(TRAINING_STATS_PATH/"mi_train.pt")
stats

{'mean': tensor([298498.0625, 301786.0625, 334266.9375, 326867.8750, 347858.1875,
         288628.1250, 285722.0938, 286179.0625]),
 'std': tensor([ 25025.9785,  27414.4160,  71625.7344,  81569.4297, 121221.1562,
          28605.6387,  30016.5605,  28722.1230])}

In [10]:

# 1. Load your dataset
test_mi = BCIDataset(
    "test.csv",
    base_path=RAW_DATA_DIR,
    split="test",
    task_type="MI",
    #  label_mapping=label_mapping,
    training_stats=training_stats,
)

50it [00:02, 18.11it/s]


In [11]:
test_mi.tensor_data.shape

torch.Size([50, 8, 2250])

In [12]:
train_mi_loader = DataLoader(train_mi, batch_size=len(train_mi), shuffle=False)

In [13]:
# 2. Pull everything into numpy
data_batch, labels = next(iter(train_mi_loader))
# shape -> [n_epochs, n_channels, n_times]
data = data_batch.numpy()
labels = labels.numpy()

In [14]:
data.shape, labels.shape

((2400, 8, 2250), (2400,))

In [15]:
data.min(), data.max(), data.mean(), data.std()

(np.float32(-4.385366),
 np.float32(6.321079),
 np.float32(-2.476852e-06),
 np.float32(0.9999999))

In [16]:
# 3. Define channel info
SCALING_FACTOR = 1e-6  # MNE constant to convert
sfreq = 250.0  # your sampling rate
ch_names = ["FZ", "C3", "CZ", "C4", "PZ", "PO7", "OZ", "PO8"]
ch_types = ["eeg"] * data.shape[1]
info = mne.create_info(ch_names, sfreq, ch_types)

In [19]:
# 4. Create an MNE EpochsArray
#    (we'll pretend each epoch starts at t=0; adjust tmin if you use pre-stimulus)
epochs = mne.EpochsArray(
    data * SCALING_FACTOR, info, events=None, event_id=None, tmin=0.0
)
epochs.event_id = dict((str(lbl), int(lbl)) for lbl in np.unique(labels))
epochs.events = np.column_stack(
    (np.arange(len(labels)), np.zeros(len(labels), int), labels)
)

Not setting metadata
2400 matching events found
No baseline correction applied
0 projection items activated


In [None]:
epochs.get_data().max(), epochs.get_data().max(), epochs.get_data().mean(), epochs.get_data().std()

np.float64(6.321078672044678e-06)

In [21]:
epochs.get_data(range(8), range(10, 20), units="V").shape

(10, 8, 2250)

In [22]:
epochs.get_data(range(8), 0, units="uV")

array([[[-1.57248587, -1.57540796, -1.56028921, ..., -1.58432556,
         -1.57036754, -1.57353986],
        [-1.68972622, -1.68376516, -1.68479846, ..., -1.68882627,
         -1.69588407, -1.70191083],
        [-0.75545694, -0.71620428, -0.71392901, ..., -0.70742902,
         -0.73556225, -0.770356  ],
        ...,
        [-1.17179582, -1.17173954, -1.17275226, ..., -1.17680577,
         -1.17818456, -1.1783751 ],
        [-1.77858954, -1.77287404, -1.778968  , ..., -1.78216669,
         -1.7917281 , -1.79514757],
        [-1.61137655, -1.60652678, -1.61100343, ..., -1.62324295,
         -1.63081609, -1.63395055]]], shape=(1, 8, 2250))

In [23]:
epochs.ch_names

['FZ', 'C3', 'CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']

In [24]:
train_mi[0]

(tensor([[-1.5725, -1.5754, -1.5603,  ..., -1.5843, -1.5704, -1.5735],
         [-1.6897, -1.6838, -1.6848,  ..., -1.6888, -1.6959, -1.7019],
         [-0.7555, -0.7162, -0.7139,  ..., -0.7074, -0.7356, -0.7704],
         ...,
         [-1.1718, -1.1717, -1.1728,  ..., -1.1768, -1.1782, -1.1784],
         [-1.7786, -1.7729, -1.7790,  ..., -1.7822, -1.7917, -1.7951],
         [-1.6114, -1.6065, -1.6110,  ..., -1.6232, -1.6308, -1.6340]]),
 tensor(0))

### Preprocessing

In [None]:
# Set a standard montage
montage = mne.channels.make_standard_montage("standard_1020")
raw_trial.set_montage(montage, on_missing="ignore")

# Filter
raw_trial.filter(l_freq=5, h_freq=30, fir_design="firwin", verbose=False)
raw_trial.notch_filter(freqs=50, fir_design="firwin", verbose=False)

# Re-reference
raw_trial.set_eeg_reference(ref_channels="average", projection=True, verbose=False)
raw_trial.apply_proj(verbose=False)

In [None]:
raw_trial.plot()

In [None]:
# Plot the first 7 seconds of data for 8 channels
raw_trial.plot(duration=1, n_channels=8)
print()

In [None]:
raw_trial.compute_psd(fmax=50, exclude=["PZ", "CZ", "FZ"]).plot()
print()

In [None]:
trial_data = raw_trial.get_data()
trial_data

In [None]:
# Plot the sensor locations in 2D and 3D
raw_trial.plot_sensors(ch_type="eeg", show_names=True, block=True)
print()