In [17]:
import sys

assert sys.version_info >= (3, 10)

In [18]:
from packaging.version import Version
import mne
import sklearn
import torch

assert Version(mne.__version__) >= Version("1.10.1")
assert Version(sklearn.__version__) >= Version("1.4.0")
assert Version(torch.__version__) >= Version("2.1.0")

In [19]:
from pathlib import Path
import importlib

DATASET_DIR = Path("../../0-raw-data/motor-imaginary")
EXTRACT_DIR = Path(DATASET_DIR / "data")

def download_and_extract_motor_imaginery_data():
    target_dir = DATASET_DIR.resolve()
    if str(target_dir) not in sys.path:
        sys.path.append(str(target_dir))

    import data_fetcher
    importlib.reload(data_fetcher)

    data_fetcher.download_and_extract_data(delete_zip=False)

download_and_extract_motor_imaginery_data()

[Fetcher] Starting data preparation...
[Download] Skip: BCICIV_2a_gdf.zip already exists
[Extract] Skip: already extracted at /home/kanathipp/Stuffs/Works/final-project-federated-learning/0-raw-data/motor-imaginary/data
[Fetcher] Completed.


In [20]:
import mne
def read_data(path):
    raw = mne.io.read_raw_gdf(path, 
                              preload=True,
                              eog=['EOG-left', 'EOG-central', 'EOG-right']
                             )
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    raw.set_eeg_reference()
    events = mne.events_from_annotations(raw)
    epochs = mne.Epochs(raw, events[0], event_id=[5,6,7,8],on_missing ='warn')
    features = epochs.get_data()
    labels = epochs.events[:,-1]
    return features,labels

In [21]:
%%capture
features,labels,groups=[],[],[]
for i in range(1,10):
  feature,label=read_data(Path(EXTRACT_DIR/ f'A0{i}T.gdf'))
  features.append(feature)
  labels.append(label)
  groups.append([i]*len(label))

In [22]:
import numpy as np

features = np.concatenate(features)
labels = np.concatenate(labels)
groups = np.concatenate(groups)
features = np.moveaxis(features, 1, 2)


features.shape,labels.shape,groups.shape

((3808, 176, 22), (3808,), (3808,))

In [23]:
np.isnan(features).sum()

np.int64(0)

In [24]:
unique, counts = np.unique(labels, return_counts=True)
unique, counts

(array([5, 6, 7, 8]), array([ 136, 2376,  648,  648]))

In [25]:
unique, counts = np.unique(groups, return_counts=True)
unique, counts

(array([1, 2, 3, 4, 5, 6, 7, 8, 9]),
 array([440, 440, 440, 288, 440, 440, 440, 440, 440]))

In [26]:
from sklearn.preprocessing import StandardScaler
from sklearn.base import TransformerMixin, BaseEstimator

#https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix
class StandardScaler3D(BaseEstimator,TransformerMixin):
    #batch, sequence, channels
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self,X,y=None):
        self.scaler.fit(X.reshape(-1, X.shape[2]))
        return self

    def transform(self,X):
        return self.scaler.transform(X.reshape( -1,X.shape[2])).reshape(X.shape)

In [None]:
from sklearn.model_selection import GroupShuffleSplit

gss_test = GroupShuffleSplit(n_splits=1,test_size=0.2, random_state=42)
train_val_index, test_index = next(gss_test.split(features, labels, groups))

gss_val = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
relative_train_index, relative_val_index = next(gss_val.split(features[train_val_index], labels[train_val_index], groups[train_val_index]))
train_index = train_val_index[relative_train_index]
val_index = train_val_index[relative_val_index]

train_features, train_labels = features[train_index], labels[train_index]
val_features,  val_labels  = features[val_index],  labels[val_index]
test_features,  test_labels  = features[test_index],  labels[test_index]

scaler = StandardScaler3D()
train_features = scaler.fit_transform(train_features)
val_features  = scaler.transform(val_features)
test_features  = scaler.transform(test_features)

train_features = np.moveaxis(train_features, 1, 2)
val_features = np.moveaxis(val_features, 1, 2)
test_features  = np.moveaxis(test_features, 1, 2)

จำนวน group:
  train_groups: 5
  val_groups  : 1
  test_groups : 3
id group:
  train_groups: [8 2 6 1 9]
  val_groups  : [3]
  test_groups : [5 4 7]
จำนวน sample:
  train: 2200
  val  : 440
  test : 1168
overlap(train,val): 0
overlap(train,test): 0
overlap(val,test): 0
final shapes:
  train_features: (2200, 22, 176)
  val_features  : (440, 22, 176)
  test_features : (1168, 22, 176)


In [34]:
import torch

train_features = torch.Tensor(train_features)
val_features = torch.Tensor(val_features)
test_features = torch.Tensor(test_features)

train_labels = torch.Tensor(train_labels)
val_labels = torch.Tensor(val_labels)
test_labels = torch.Tensor(test_labels)

len(val_features), len(val_labels), len(test_features), len(test_labels)

(440, 440, 1168, 1168)

In [29]:
train_features.shape

torch.Size([2048, 22, 176])

In [30]:
val_features.shape

torch.Size([880, 22, 176])

In [31]:
test_features.shape

torch.Size([880, 22, 176])

In [32]:
import numpy as np

def remap_np(y: torch.Tensor) -> torch.Tensor:
    y_np = y.view(-1).cpu().numpy()
    uniq = np.unique(y_np)                       
    lut = {u: i for i, u in enumerate(uniq)}    
    y_new = np.vectorize(lut.get)(y_np)
    return torch.as_tensor(y_new, dtype=torch.long, device=y.device)

train_labels = remap_np(train_labels)
val_labels = remap_np(val_labels)
test_labels   = remap_np(test_labels)