In [1]:
from tools.data import DreemDatasets
from preprocessing.features import ExtractFeatures
import numpy as np
import matplotlib.pyplot as plt
from models.cnn import CNN
import torch
import torch.utils.data
import torch.optim as optim
from tools.trainer import CNNTrainer as Trainer

In [2]:
use_datasets = ['eeg_1', 'eeg_2', 'eeg_3', 'eeg_4', 'eeg_5', 'eeg_6', 'eeg_7']
batch_size = 32

In [3]:
train_set, val_set = DreemDatasets('dataset/train.h5', 'dataset/train_y.csv', 
                                   keep_datasets=use_datasets, split_train_val=0.8, seed=0,
                                   size=5000).get()

train_set.load_data("dataset/eegs_bands/train")  # Charge en mémoire. Peut-être un peu long
val_set.load_data("dataset/eegs_bands/val")  # Charge en mémoire. Peut-être un peu long

train_set.close()  # On ferme les fichiers h5
val_set.close()

Loading data in memory...
5412 in 7 datasets to load
Loading dataset eeg_1 ...
Loading dataset eeg_2 ...
Loading dataset eeg_3 ...
Loading dataset eeg_4 ...
Loading dataset eeg_5 ...
Loading dataset eeg_6 ...
Loading dataset eeg_7 ...
Done.
Loading data in memory...
1353 in 7 datasets to load
Loading dataset eeg_1 ...
Loading dataset eeg_2 ...
Loading dataset eeg_3 ...
Loading dataset eeg_4 ...
Loading dataset eeg_5 ...
Loading dataset eeg_6 ...
Loading dataset eeg_7 ...
Done.


In [4]:
X, _, y = train_set[:]
X_val, _, y_val = train_set[:]

X = X.transpose(1, 0, 2, 3)
X_val = X_val.transpose(1, 0, 2, 3)

X = X.reshape(-1, 7*4, 1500)
X_val = X_val.reshape(-1, 7*4, 1500)

y = y.reshape(-1, 1)
y_val = y_val.reshape(-1, 1)

print(y.shape)

(5412, 1)


In [5]:
train_loader = torch.utils.data.DataLoader(train_set.torch_dataset(), 
                                           batch_size=batch_size, 
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set.torch_dataset(), batch_size=batch_size)

In [6]:
number_groups = 3
channels = [128, 64, 32]
kernel_sizes = [100, 50, 50]
kernel_pooling = [5, 2, 2]
use_cuda = True

def transform(data, _):
    data = data.view(-1, 7*4, 1500)
    return data, None


In [7]:
model = CNN(in_features=1500, out_features=5, in_channels=7*4,
            number_groups=number_groups, size_groups=1, hidden_channels=channels, kernel_sizes=kernel_sizes,
            kernel_pooling=kernel_pooling)
if use_cuda:
    model.cuda()

In [8]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

trainer = Trainer(train_loader, val_loader, optimizer, model_50hz=model,
                  log_every=10, save_folder='builds', transform=transform)

trainer.train(n_epochs=50)

  input = module(input)
Train - Epoch 1: : 170it [00:21,  7.76it/s, Loss: 1.6202894449234009]                           
Val - Epoch 1: : 43it [00:03, 14.10it/s, Loss: 1.6980857849121094]                            



Validation set: Average loss: 1.6669, Accuracy: 343/1353 (25%)



Train - Epoch 2:   1%|          | 1/169.125 [00:00<00:27,  6.08it/s, Loss: 1.6395312547683716]


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 2: : 170it [00:19,  8.60it/s, Loss: 1.6123155355453491]                           
Val - Epoch 2: : 43it [00:02, 14.98it/s, Loss: 1.663747787475586]                             
Train - Epoch 3:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6746, Accuracy: 320/1353 (23%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 3: : 170it [00:19, 13.58it/s, Loss: 1.5938670635223389]                           
Val - Epoch 3: : 43it [00:04,  9.82it/s, Loss: 1.6804420948028564]                            
Train - Epoch 4:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6566, Accuracy: 358/1353 (26%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 4: : 170it [00:23,  3.92it/s, Loss: 1.536651611328125]                            
Val - Epoch 4: : 43it [00:02, 14.62it/s, Loss: 1.686047911643982]                            
Train - Epoch 5:   0%|          | 0/169.125 [00:00<?, ?it/s, Loss: 1.5940333604812622]


Validation set: Average loss: 1.6471, Accuracy: 366/1353 (27%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 5: : 170it [00:19,  8.62it/s, Loss: 1.4561983346939087]                           
Val - Epoch 5: : 43it [00:05,  8.32it/s, Loss: 1.7354542016983032]                            
Train - Epoch 6:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6450, Accuracy: 374/1353 (27%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 6: : 170it [00:22,  6.02it/s, Loss: 1.2503067255020142]                           
Val - Epoch 6: : 43it [00:02, 15.08it/s, Loss: 1.6666842699050903]                            
  0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6401, Accuracy: 381/1353 (28%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 7: : 170it [00:19,  8.59it/s, Loss: 1.450714111328125]                            
Val - Epoch 7: : 43it [00:04,  8.99it/s, Loss: 1.7180752754211426]                            
Train - Epoch 8:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6343, Accuracy: 387/1353 (28%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 8: : 170it [00:23,  7.26it/s, Loss: 1.5210355520248413]                           
Val - Epoch 8: : 43it [00:02, 14.72it/s, Loss: 1.6998323202133179]                            
Train - Epoch 9:   0%|          | 0/169.125 [00:00<?, ?it/s, Loss: 1.4456874132156372]


Validation set: Average loss: 1.6305, Accuracy: 390/1353 (28%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 9: : 170it [00:19,  8.62it/s, Loss: 1.3735445737838745]                           
Val - Epoch 9: : 43it [00:04,  9.50it/s, Loss: 1.706855297088623]                             
Train - Epoch 10:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6293, Accuracy: 398/1353 (29%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 10: : 170it [00:23,  4.10it/s, Loss: 1.583541989326477]                            
Val - Epoch 10: : 43it [00:03, 12.49it/s, Loss: 1.674306035041809]                             
  0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6363, Accuracy: 384/1353 (28%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 11: : 170it [00:19,  8.60it/s, Loss: 1.5213453769683838]                           
Val - Epoch 11: : 43it [00:04,  8.81it/s, Loss: 1.6414953470230103]                            
Train - Epoch 12:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6349, Accuracy: 388/1353 (28%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 12: : 170it [00:23,  5.01it/s, Loss: 1.4963903427124023]                           
Val - Epoch 12: : 43it [00:02, 14.74it/s, Loss: 1.6485189199447632]                            
Train - Epoch 13:   1%|          | 1/169.125 [00:00<00:19,  8.42it/s, Loss: 1.3399925231933594]


Validation set: Average loss: 1.6289, Accuracy: 389/1353 (28%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 13: : 170it [00:21,  8.02it/s, Loss: 1.441310167312622]                            
Val - Epoch 13: : 43it [00:04, 10.22it/s, Loss: 1.6380342245101929]                            
Train - Epoch 14:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6268, Accuracy: 397/1353 (29%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 14: : 170it [00:23,  4.96it/s, Loss: 1.380204200744629]                            
Val - Epoch 14: : 43it [00:03, 11.50it/s, Loss: 1.7616692781448364]                            
Train - Epoch 15:   1%|          | 1/169.125 [00:00<00:19,  8.74it/s, Loss: 1.4672635793685913]


Validation set: Average loss: 1.6336, Accuracy: 397/1353 (29%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 15: : 170it [00:19, 13.54it/s, Loss: 1.4747726917266846]                           
Val - Epoch 15: : 43it [00:04,  9.80it/s, Loss: 1.7310110330581665]                            
Train - Epoch 16:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6308, Accuracy: 397/1353 (29%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 16: : 170it [00:23,  7.19it/s, Loss: 1.4200325012207031]                           
Val - Epoch 16: : 43it [00:04, 10.63it/s, Loss: 1.6704049110412598]                            
Train - Epoch 17:   0%|          | 0/169.125 [00:00<?, ?it/s, Loss: 1.3007540702819824]


Validation set: Average loss: 1.6257, Accuracy: 401/1353 (29%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 17: : 170it [00:20,  8.31it/s, Loss: 1.3247218132019043]                           
Val - Epoch 17: : 43it [00:04,  9.14it/s, Loss: 1.6989020109176636]                            
Train - Epoch 18:   0%|          | 0/169.125 [00:00<?, ?it/s]


Validation set: Average loss: 1.6337, Accuracy: 389/1353 (28%)


Saved models in builds/2018-12-03 13:52:17.


Train - Epoch 18:  28%|██▊       | 48/169.125 [00:07<00:23,  5.14it/s, Loss: 1.620147466659546] 


KeyboardInterrupt: 