In [1]:
import h5py
import json
import requests 
import os
import shutil

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (15,5)
plt.rcParams['axes.grid'] = 'on'

from pathlib import Path
from scipy import signal
from scipy.signal import butter, lfilter

from helpers import *
from dataset import DreemDataset

from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import StratifiedKFold

import warnings

# Dreem Case study



#### Create dataset object

Using all processings defined in the previous notebook, we can create a dataset object that will load raw records, process and cut them into 30sec data points. The init is a bit long since all the processing is done here.

In [50]:
# Paths to data
PATH_TO_DATA = Path("data/h5/")
PATH_TO_PROCESSED_DATA = Path("data/h5_processed/")
PATH_TO_HYPNOGRAM = Path("data/hypnograms/")

# List of expected records
records_list =  [
"8e0bf011-1db6-46fa-a3cd-496e60c0de6f",
"d8a9babd-8454-42e9-9286-eb66c996d3e6",
"c5080eac-a388-4b1f-818f-a7f902fe4c06",
"62492470-d4d5-4dee-8030-80cca44fb002",
"87748119-6fff-45d2-9219-888532fb7efd",
"9bd9224a-bbdf-46c2-a494-3bbfcfd7e776",
"8f3dc41c-df99-4a5f-82cf-6b9f6e265b92"
]

# Instantiating the data set will load and process records into samples stored in PATH_TO_PROCESSED_DATA
dataset = DreemDataset(
    records_list, PATH_TO_DATA, PATH_TO_HYPNOGRAM, Path(PATH_TO_PROCESSED_DATA))
print(f"Number of samples in dataset = {len(dataset)}")
print(f"Sample shape is {dataset[0][0].shape}")

Number of samples in dataset = 5467
Sample shape is (7, 1500)


For now a sample is a (7, 1500) matrix.

#### Define learning framework

We can keep one record as validation and the remaining in the training process. The latter will be splitted into a train and test set to monitor networks learning process.


We are in a context of multiclass classification (C=5). Our choice of metric must include the class imbalance information. 
- We can analyze the performance of a model precisely by looking at traditional metrics (confusion_matrix  -> precision, recall, roc_curve, balanced accuracy, etc.) for each class value.
- We can provide a global metric with a weighted average of per-class metrics (balanced accuracy or f1 score with average = micro to take into account class imbalances)

#### Define learning model


In many possible ways of classifying EEG data, we will go with the implementation from [here](https://arxiv.org/pdf/1707.03321.pdf) (at least a part of it). The paper proposes a feature extractor network that is used to make predictions without temporal context. The features obtained with that trained network are then used to perform temporal sleep stage classification (by integrating features from adjacent time segments in the input) (this part is not done).

In [56]:
T = 1500
# C = 7

class BasicModel(nn.Module):

    def __init__(self, C=7):
        super().__init__()   
        
        # First conv layer
        self.conv1 = nn.Conv2d(
            in_channels = C, out_channels = C, 
            kernel_size = (C, 1), stride = (1, 1), 
            padding = 'same')  # VALID in paper, mistake ?
        # Activation ??
        
        self.conv2 = nn.Conv2d(
            in_channels = 1, 
            out_channels = 8, 
            kernel_size = (1, 25), 
            stride = (1, 1), 
            padding = 'same')
        self.relu2 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size = (1, 6), stride = (1, 6))
        
        self.conv3 = nn.Conv2d(
            in_channels = 8, 
            out_channels = 8, 
            kernel_size = (1, 25), 
            stride = (1, 1), 
            padding = 'same')
        self.relu3 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size = (1, 6), stride = (1, 6))
        
        self.dropout = nn.Dropout(p=0.5)
        self.layer_out = nn.Linear(C*(T//36)*8, 5) 
        # self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        # Start with (Batch_size, C, T)
        x = x[:, :, :, None] # (Batch_size, C, T, 1)
        # print("Expand", x.requires_grad)
        x = self.conv1(x) # (Batch_size, C, T, 1)
        # print("Conv1", x.requires_grad)
        x = torch.permute(x, (0, 3, 1, 2)) # (Batch_size, 1, C, T)
        # print("Permute", x.requires_grad)
        x = self.conv2(x) # (Batch_size, 8, C, T)
        x = self.relu2(x)
        x = self.pool3(x) # (Batch_size, 8, C, T//6)
        x = self.conv3(x) # (Batch_size, 8, C, T//6)
        x = self.relu3(x)
        x = self.pool4(x) # (Batch_size, 8, C, T//36)
        x = torch.flatten(x, start_dim=1) # (Batch_size, 8*C*(T//36))
        # print(x.requires_grad)
        x = self.dropout(x)
        x = self.layer_out(x) # (B, 5)
        # No softmax since we are using CrossEntropyLoss which 
        # expects logits as the model output not probabilities coming from softmax.
        # print(x.requires_grad)
        return x
    
# initialize weights as indicated in the paper
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.weight.data.normal_(0.0, 0.1)

#### Define learning process

We will follow the process defined in the paper. Since our classes are imbalanced, we will perform a custom batch sampling that uses a stratified splitting to maintain class ratios in the various batches.

The metric that we will use, as proposed in the paper, will be balanced accuracy (every class will have similar impact on the final score, this will account for the imbalance).

In [59]:
warnings.filterwarnings(action="ignore", message="y_pred contains classes not in y_true")

n_epochs = 2
model = BasicModel().double()
model.apply(weights_init)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

train_epoch_losses = []
test_epoch_losses = []

writer = SummaryWriter()
n_iter = 0
targets = np.array([dataset[i][1] for i in range(len(dataset))])

for epoch in range(n_epochs):
    
    train_epoch_loss = 0
    train_epoch_acc = 0
    
    # TODO stratified batches ?
    train_idx, test_idx = train_test_split(np.arange(len(targets)), stratify=targets, test_size=0.2, shuffle=True)  
    test_sampler = torch.utils.data.RandomSampler(test_idx)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=test_sampler)

    train_idx = train_idx[:len(train_idx)-len(train_idx)%128]
    skf = StratifiedKFold(n_splits=34)
    splits = skf.split(train_idx, targets[train_idx])
    
    for _, batch_idx in splits:
        
        train_sampler = torch.utils.data.RandomSampler(batch_idx)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=len(batch_idx), sampler=train_sampler)
        X_train_batch, y_train_batch = next(iter(train_loader))
        
        with torch.set_grad_enabled(True):
            
            # Zero the gradients
            optimizer.zero_grad()

            X_train_batch = Variable(X_train_batch, requires_grad=True)

            # Get prediction
            y_train_pred = model(X_train_batch)

            # Calculate loss
            train_loss = criterion(y_train_pred, torch.tensor(get_one_hot_encoding(y_train_batch)))

            # Backpropagate
            train_loss.backward()

            # Update weights
            optimizer.step()

            # Record batch loss
            train_epoch_loss += train_loss.item()
            
            train_acc = balanced_accuracy_score(y_train_batch, torch.argmax(y_train_pred, axis=1))
            train_epoch_acc+=train_acc
            
        writer.add_scalar('Loss/train', train_loss.item(), n_iter)
        writer.add_scalar('Acc/train', train_acc, n_iter)
        
        with torch.no_grad():
            test_epoch_loss = 0
            test_epoch_acc = 0
            for X_test_batch, y_test_batch in test_loader:     
                
                y_test_pred = model(X_test_batch)
                
                test_loss = criterion(y_test_pred, torch.tensor(get_one_hot_encoding(y_test_batch)))
                test_epoch_loss += test_loss.item()
                
                test_acc = balanced_accuracy_score(y_test_batch, torch.argmax(y_test_pred, axis=1))
                
                test_epoch_acc+=test_acc
                
            writer.add_scalar('Loss/test_epoch', test_epoch_loss/len(test_loader), n_iter)
            writer.add_scalar('Acc/test_epoch', test_epoch_acc/len(test_loader), n_iter)
                
        n_iter +=1
    
    train_loss_str = f'Train Loss: {train_epoch_loss/34:.5f}'
    train_acc_str = f'acc: {train_epoch_acc/34:.5f}'
      
    test_loss_str = f'Test Loss: {test_epoch_loss/len(test_loader):.5f}'
    test_acc_str = f'acc: {test_epoch_acc/len(test_loader):.5f}'
    
    print(f'Epoch {epoch+1:03}: | {train_loss_str} ({train_acc_str}) | {test_loss_str} ({test_acc_str})')

Epoch 001: | Train Loss: 2.05559 (acc: 0.34417) | Test Loss: 2.18590 (acc: 0.31079)
Epoch 002: | Train Loss: 0.65304 (acc: 0.57516) | Test Loss: 2.66313 (acc: 0.35466)


Training monitored on tensorboard, test loss goes up after the first two epochs so we stop the training there. We can use this network as it is but the idea is to push forward with the implementation proposed in [this paper](https://arxiv.org/pdf/1707.03321.pdf). Overall performance is bad, we can see that the training performance increases so the model learns, however it overfits very quickly. We will not go further on this subject, by lack of time. Some ideas : 
- Finish implementing paper with the second model
- Analyze results per class, identify where the selected model fails

#### Save trained model

In [60]:
torch.save(model.state_dict(), 'models/saved_model_state_dict.pth')

#### Test API call

Run *app.py* with flask before executing this cell.

In [64]:
domain = 'http://127.0.0.1:5000/hypnogram'
record_id = '87748119-6fff-45d2-9219-888532fb7efd'

r = requests.get(url = f'{domain}/{record_id}')

In [73]:
print(r.json()[:100]) # Returns a lot of DEEPs because the model is bad

['DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'DEEP', 'DEEP', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'DEEP', 'N2', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'N2', 'DEEP', 'DEEP', 'N2', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'N2', 'N2', 'DEEP', 'DEEP', 'N2', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP', 'DEEP']
