# Visualisation des features

In [1]:
from tools.data import DreemDatasets
from preprocessing.features import ExtractFeatures
import numpy as np
import matplotlib.pyplot as plt

In [2]:
use_datasets = ['eeg_1', 'eeg_2', 'eeg_3', 'eeg_4', 'eeg_5', 'eeg_6']
features = ['min', 'max', 'energy', 'mmd', 'esis']

## Créer features

Sauter à la prochaine section si déjà créées.

On définit les transformations pour les différents eeg. On choisit d'extraire toutes les bandes de fréquence (d'où la wildcard `'*'`).

In [3]:
dataset_transforms = {
    "eeg_1": ExtractFeatures(features, bands='*'),
    "eeg_2": ExtractFeatures(features, bands='*'),
    "eeg_3": ExtractFeatures(features, bands='*'),
    "eeg_4": ExtractFeatures(features, bands='*'),
    "eeg_5": ExtractFeatures(features, bands='*'),
    "eeg_6": ExtractFeatures(features, bands='*')
}

Récupère tout les datasets et load en mémoire.

In [4]:
train_set, val_set = DreemDatasets('dataset/train.h5', 'dataset/train_y.csv', 
                                   keep_datasets=use_datasets, split_train_val=0.8, seed=0,
                                   size=5000, transforms=dataset_transforms).get()

train_set.load_data()  # Charge en mémoire. Peut-être un peu long

train_set.close()  # On ferme les fichiers h5
val_set.close()

In [16]:
data, _, targets = train_set[:]

print(data.shape)

torch.Size([6, 4, 5412, 5])


In [17]:
np.save("dataset/train_features.npy", data)
np.save("dataset/train_targets.npy", targets)

## Visualiser les features

In [18]:
data = np.load("dataset/train_features.npy")
targets = np.load("dataset/train_targets.npy")

print(data.shape)  # Shape: nombre eeg x nombre bande x nombre de data x nombre features

(6, 4, 5412, 5)


Calcul les moyennes et écarts type pour chaque features

In [39]:
target_labels = ['awake', 'rem', 'stage 1', 'stage 2', 'stage 3']

keys = {name: targets == k for k, name in enumerate(target_labels)}
data_labeled = {name: data[:, :, keys[name]] for name in target_labels}
print(data_labeled['awake'][:, :, 0, 4])
plt.plot()

[[3.91906439e+11 1.77342339e+09 3.85493598e+08 1.67295816e+08]
 [2.23530653e+09 1.85387195e+08 7.56094186e+07 5.33588293e+07]
 [4.00329381e+11 1.63203297e+09 3.28742870e+08 1.38065804e+08]
 [3.95744810e+11 1.59899798e+09 3.36073319e+08 1.38781569e+08]
 [3.95971831e+11 1.59601065e+09 3.39927189e+08 1.41394434e+08]
 [3.05201776e+09 2.46106883e+08 8.05863961e+07 5.63361816e+07]]


In [22]:
bands = ['delta', 'theta', 'alpha', 'beta']

In [33]:
means = {name: data_labeled[name].mean(axis=2) for name in target_labels}
std = {name: data_labeled[name].std(axis=2) for name in target_labels}

print('std')
print(std['awake'])
print('mean')
print(means['awake'])

std
[[[2.34175350e+05 1.70818732e+05 2.28951389e+11 6.28677043e+05
   4.57902779e+13]
  [5.40892034e+04 5.58943029e+04 7.94389757e+09 1.90833122e+05
   4.76633854e+12]
  [3.93917261e+04 3.96964625e+04 2.91554612e+09 1.19773538e+05
   3.06132342e+12]
  [3.29303131e+04 3.32432339e+04 1.66589573e+09 8.13442291e+04
   2.91531752e+12]]

 [[2.21008476e+05 2.40038345e+05 2.54737784e+11 8.49229654e+05
   5.09475568e+13]
  [6.67890100e+04 6.88002126e+04 1.01442818e+10 2.58502041e+05
   6.08656910e+12]
  [3.64326321e+04 3.81870268e+04 2.37871556e+09 1.24855230e+05
   2.49765134e+12]
  [2.40996498e+04 2.53814095e+04 7.65672976e+08 7.45609544e+04
   1.33992771e+12]]

 [[2.58726984e+05 1.86547300e+05 3.18823252e+11 8.07818100e+05
   6.37646504e+13]
  [5.89851174e+04 5.76132127e+04 6.98181420e+09 2.24768809e+05
   4.18908852e+12]
  [3.66911974e+04 3.73444792e+04 2.41359633e+09 1.28520656e+05
   2.53427614e+12]
  [3.08992698e+04 2.78920237e+04 1.11750859e+09 9.07502944e+04
   1.95564004e+12]]

 [[2.0