In [19]:
import os

import numpy as np

from src.utils import read_eeg_signal_from_file

labels = []
data = []

data_path = 'data/data_preprocessed_python'

for filename in os.listdir('data/data_preprocessed_python'):
    file_path = os.path.join(data_path, filename)
    trial = read_eeg_signal_from_file(file_path)
    labels.append(trial['labels'])
    data.append(trial['data'][:, :32, :])  # leave only eeg channels

labels = np.array(labels)
data = np.array(data)

In [20]:
data.shape

(32, 40, 32, 8064)

In [21]:
import pickle

electrode_placement = pickle.load(open('data/electrode_placement.dat', 'rb'))

In [23]:
from tqdm import tqdm
from src.fourier import images_from_eeg

SAMPLE_RATE = 128
RESOLUTION = 28
freq_bands = {'Theta': [4, 8],
              'Alpha': [8, 16],
              'Beta+Gamma': [16, 45]}
psd_config = dict(
    selected_channels=range(32),
    freq_bands=freq_bands,
    window_size=SAMPLE_RATE,
    step_size=SAMPLE_RATE,
    sample_rate=SAMPLE_RATE,
)

X = []
for participant_data in tqdm(data):
    t = []
    for readings in participant_data:
        subset = readings.T[SAMPLE_RATE * 18: SAMPLE_RATE * 24]  # crop baseline
        images = images_from_eeg(subset, **psd_config, loc_dict=electrode_placement, resolution=RESOLUTION)
        t.append(images)
    X.append(t)
X = np.array(X, dtype=np.float32)

100%|██████████| 32/32 [00:42<00:00,  1.34s/it]


In [24]:
X.shape

(32, 40, 6, 3, 28, 28)

In [25]:
Y = []
for i in range(X.shape[0]):
    _y = []
    for j in range(X.shape[1]):
        _y.append(labels[i][j][:2])
    Y.append(_y)
Y = (np.array(Y, dtype=np.float32) - 5) / 4

In [26]:
(Y.min(), Y.max())

(-1.0, 1.0)

In [27]:
pickle.dump(X, open('data/X.dat', 'wb'))
pickle.dump(Y, open('data/Y.dat', 'wb'))

In [28]:
import plotly.express as px
import plotly.io as pio

pio.templates.default = "plotly_white"

img = X[0][7][3][0]
fig = px.imshow(img, zmin=img.min(), zmax=img.max())
fig.update_layout(
    width=400,
    height=400,
    margin=dict(
        l=25,
        r=50,
        b=25,
        t=25,
        pad=4
    ),
)
fig.show()