# SleepEEGNet implemented in TensorFlow2
* This notebook is for implement SleepEEGNet (Original paper is found on https://arxiv.org/abs/1903.02108)
* In this implementation, some hyperparameters and some structures are little bit different from the author's original implementation.

In [1]:
path_eeg_fpz_cz_cs = '../data/eeg_fpz_cz.cs'
path_eeg_pz_oz_cs = '../data/eeg_pz_oz.cs'
path_eeg_fpz_cz_tm = '../data/eeg_fpz_cz.tm'
path_eeg_pz_oz_tm = '../data/eeg_pz_oz.tm'

In [2]:
SEED = 5

In [3]:
DEVICES = '0'

In [4]:
import sys
from env import *

In [5]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = DEVICES
import time
import tensorflow as tf
from tensorflow.keras import models, layers, losses, optimizers, metrics
from data_loader import DataLoader
import numpy as np
import IPython.display as ipd

In [6]:
EPOCHS = 0
ETA = 1e-4
BATCH_SIZE = 20

print('EPOCHS',EPOCHS)
print('ETA',ETA)
print('BATCH_SIZE',BATCH_SIZE)

EPOCHS 0
ETA 0.0001
BATCH_SIZE 20


In [7]:
MAX_TIME_STEP = 10
SEQUENCE_LENGTH = 10

In [8]:
classes = ['W', 'N1', "N2", "N3", "REM"]
NUM_CLASSES = len(classes)

char2numY = dict(zip(classes, range(len(classes))))
print('char2numY',char2numY)

char2numY['<SOD>'] = len(char2numY)
char2numY['<EOD>'] = len(char2numY)
print('char2numY', char2numY, len(char2numY))

num2charY = dict(zip(char2numY.values(), char2numY.keys()))
print('num2charY', num2charY)

char2numY {'W': 0, 'N1': 1, 'N2': 2, 'N3': 3, 'REM': 4}
char2numY {'W': 0, 'N1': 1, 'N2': 2, 'N3': 3, 'REM': 4, '<SOD>': 5, '<EOD>': 6} 7
num2charY {0: 'W', 1: 'N1', 2: 'N2', 3: 'N3', 4: 'REM', 5: '<SOD>', 6: '<EOD>'}


## Data

In [9]:
dloader = DataLoader()

In [10]:
# seq data loading
_, _, _, _, _, _, x_seq_train, y_seq_train, x_seq_valid, y_seq_valid, x_seq_test, y_seq_test = dloader(path_eeg_fpz_cz_cs, seed=SEED, len_seq=SEQUENCE_LENGTH, return_sequences=True)

print(x_seq_train.shape, y_seq_train.shape)
print(x_seq_valid.shape, y_seq_valid.shape)
print(x_seq_test.shape, y_seq_test.shape)

x_seq_train /= np.max(x_seq_train)
x_seq_valid /= np.max(x_seq_valid)
x_seq_test /= np.max(x_seq_test)

train_seq_dataset = tf.data.\
            Dataset.from_tensor_slices((x_seq_train, y_seq_train)).\
            batch(BATCH_SIZE).shuffle(len(x_seq_train))

valid_seq_dataset = tf.data.\
            Dataset.from_tensor_slices((x_seq_valid, y_seq_valid)).\
            batch(BATCH_SIZE)

test_seq_dataset = tf.data.\
            Dataset.from_tensor_slices((x_seq_test,y_seq_test)).\
            batch(BATCH_SIZE)

(2042, 10, 3000, 1) (2042, 10)
(1082, 10, 3000, 1) (1082, 10)
(1092, 10, 3000, 1) (1092, 10)


## Model

In [11]:
class CNN(models.Model):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.c1 = models.Sequential([
            layers.Conv1D(64, 50, strides=6, padding='same', activation=tf.nn.relu),
            layers.MaxPool1D(8, 8, padding='same'),
            
            layers.Dropout(.5),
            
            layers.Conv1D(128, 8, strides=1, padding='same', activation=tf.nn.relu),
            layers.Conv1D(128, 8, strides=1, padding='same', activation=tf.nn.relu),
            layers.Conv1D(128, 8, strides=1, padding='same', activation=tf.nn.relu),
            
            layers.MaxPool1D(4, 4, padding='same'),
        ])
        
        self.c2 = models.Sequential([
            layers.Conv1D(64, 400, strides=50, padding='same', activation=tf.nn.relu),
            layers.MaxPool1D(4, 4, padding='same'),
            
            layers.Dropout(.5),
            
            layers.Conv1D(128, 6, strides=1, padding='same', activation=tf.nn.relu),
            layers.Conv1D(128, 6, strides=1, padding='same', activation=tf.nn.relu),
            layers.Conv1D(128, 6, strides=1, padding='same', activation=tf.nn.relu),
            
            layers.MaxPool1D(2, 2, padding='same')
        ])
        
        self.drop = layers.Dropout(.5)
        
    def call(self, inputs, training=False):
        c1 = self.c1(inputs, training=training)
        c1 = layers.Flatten()(c1)
        c2 = self.c2(inputs, training=training)
        c2 = layers.Flatten()(c2)
        x = tf.concat([c1, c2], axis=-1)
        x = self.drop(x, training=training)
        return x
        
cnn = CNN()
cnn(np.ones((BATCH_SIZE,3000,1),dtype=np.float32)).shape

TensorShape([20, 3072])

In [12]:
class Encoder(models.Model):
    def __init__(self, cnn):
        super(Encoder, self).__init__()
        self.cnn = cnn
        self.lstm_f = layers.LSTM(128, return_sequences=True, return_state=True)
        self.lstm_b = layers.LSTM(128, return_sequences=True, return_state=True, go_backwards=True)
        
    def call(self, inputs, training=False):
        x = []
        for i in range(inputs.shape[1]):
            _x = cnn(inputs[:,i])
            _x = tf.expand_dims(_x, axis=1)
            x.append(_x)
        x = tf.concat(x, axis=1)
        f, fh, _ = self.lstm_f(x)
        b, bh, _ = self.lstm_b(x)
        x = tf.concat([f,b], axis=-1)
        h = tf.concat([fh, bh], axis=-1)
        return x, h
        
encoder = Encoder(cnn)
encoder(np.ones((BATCH_SIZE, SEQUENCE_LENGTH, 3000, 1),dtype=np.float32))[0].shape,\
encoder(np.ones((BATCH_SIZE, SEQUENCE_LENGTH, 3000, 1),dtype=np.float32))[1].shape 

(TensorShape([20, 10, 256]), TensorShape([20, 256]))

In [13]:
class Attention(models.Model):
    def __init__(self, latent):
        super(Attention, self).__init__()
        self.We = layers.Dense(latent)
        self.Wh = layers.Dense(latent)
        self.tanh = layers.Activation(tf.nn.tanh)
        self.softmax = layers.Activation(tf.nn.softmax)
        
    def call(self, encoder_hidden, decoder_hidden, training=False):
        WE = self.We(encoder_hidden)
        decoder_hidden = tf.expand_dims(decoder_hidden, axis=1)
        WH = self.Wh(decoder_hidden)
        x = WE + WH
        f = self.tanh(x)
        alpha = self.softmax(f)
        c = alpha * encoder_hidden
        c = tf.reduce_sum(c, axis=1)
        return c
        
attention = Attention(256)
attention(
    np.ones((BATCH_SIZE,SEQUENCE_LENGTH,256),dtype=np.float32), 
    np.ones((BATCH_SIZE, 256),dtype=np.float32)
).shape

TensorShape([20, 256])

In [14]:
class Decoder(models.Model):
    def __init__(self):
        super(Decoder, self).__init__()
        self.lstm_f = layers.LSTM(128, return_sequences=True, return_state=True)
        self.lstm_b = layers.LSTM(128, return_sequences=True, return_state=True, go_backwards=True)
        self.attention = Attention(256)
        self.classes = layers.Dense(7, activation=tf.nn.softmax)
        
    def call(self, decoder_input, prev_decoder_hidden, encoder_hidden, training=False):
        c = self.attention(encoder_hidden, prev_decoder_hidden)
        x = tf.concat([c, decoder_input], axis=-1)
        x = tf.expand_dims(x, axis=1)
        f, fh, _ = self.lstm_f(x, training=training)
        b, bh, _ = self.lstm_b(x, training=training)
        h = tf.concat([fh, bh], axis=-1)
        x = tf.concat([f, b], axis=-1)
        prediction = self.classes(x)
        prediction = tf.squeeze(prediction, axis=1)
        return prediction, h
        
decoder = Decoder()
decoder(
    np.ones((BATCH_SIZE,NUM_CLASSES+2),dtype=np.float32), # decoder 에 입력
    np.ones((BATCH_SIZE,256),dtype=np.float32), # 이전 단계 decoder hidden (첨에는 encoder마지막 놈의 hidden)
    np.ones((BATCH_SIZE,SEQUENCE_LENGTH,256),dtype=np.float32) # 인코더 전체 시퀀스 output
)[0].shape,\
decoder(
    np.ones((BATCH_SIZE,NUM_CLASSES+2),dtype=np.float32), # decoder 에 입력
    np.ones((BATCH_SIZE,256),dtype=np.float32), # 이전 단계 decoder hidden (첨에는 encoder마지막 놈의 hidden)
    np.ones((BATCH_SIZE,SEQUENCE_LENGTH,256),dtype=np.float32) # 인코더 전체 시퀀스 output
)[1].shape

(TensorShape([20, 7]), TensorShape([20, 256]))

## Utils

In [15]:
loss_object = losses.CategoricalCrossentropy()
acc_object = metrics.CategoricalAccuracy()
loss = metrics.Mean()
acc = metrics.Mean()

varlist = encoder.trainable_variables+decoder.trainable_variables
opt = optimizers.Adam(learning_rate=ETA)

In [16]:
def train_step(inputs):
    _X, _y = inputs
    # because SEQUENCE_LENGTH is fixed, I omit adding '<SOD>' and '<EOD>' for all y.
#     _y = np.insert(_y, 0, char2numY.get('<SOD>'), axis=1)
#     _y = np.insert(_y, _y.shape[-1], char2numY.get('<EOD>'), axis=1)
    _y = tf.one_hot(_y, depth=NUM_CLASSES+2)
    
    _loss = 0.
    
    with tf.GradientTape() as tape:
        encoder_output, hidden_state = encoder(_X, training=True)
        
        decoder_input = tf.one_hot(char2numY.get('<SOD>'), depth=NUM_CLASSES+2)
        decoder_input = tf.multiply(
            tf.ones((_X.shape[0],NUM_CLASSES+2),dtype=tf.float32),
            decoder_input) # begging from <SOD>

        for t in range(SEQUENCE_LENGTH):
            # During training step, valid step as well as test step, there is no '<EOD>' related implementation because of the reason line number 3.
            pred, hidden_state = decoder(decoder_input, hidden_state, encoder_output, training=True)
            _loss += loss_object(_y[:,t], pred)
            acc.update_state(acc_object(_y[:,t], pred))
            loss.update_state(_loss)
            decoder_input = _y[:,t]

    grads = tape.gradient(_loss, varlist)
    opt.apply_gradients(list(zip(grads, varlist)))

In [17]:
valid_loss = metrics.Mean()
valid_acc = metrics.Mean()

In [18]:
def valid_step(inputs):
    _X, _y = inputs
    _y = tf.one_hot(_y, depth=NUM_CLASSES+2)
    
    _loss = 0.
    
    encoder_output, hidden_state = encoder(_X, training=False)

    decoder_input = tf.one_hot(char2numY.get('<SOD>'), depth=NUM_CLASSES+2)
    decoder_input = tf.multiply(
        tf.ones((_X.shape[0],NUM_CLASSES+2),dtype=tf.float32),
        decoder_input)

    for t in range(SEQUENCE_LENGTH):
        pred, hidden_state = decoder(decoder_input, hidden_state, encoder_output, training=False)
        _loss += loss_object(_y[:,t], pred)
        valid_acc.update_state(acc_object(_y[:,t], pred))
        valid_loss.update_state(_loss)
        decoder_input = _y[:,t]

In [19]:
test_loss = metrics.Mean()
test_acc = metrics.Mean()

In [20]:
def test_step(inputs):
    _X, _y = inputs
    _y = tf.one_hot(_y, depth=NUM_CLASSES+2)
    
    _loss = 0.
    
    encoder_output, hidden_state = encoder(_X, training=False)

    decoder_input = tf.one_hot(char2numY.get('<SOD>'), depth=NUM_CLASSES+2)
    decoder_input = tf.multiply(
        tf.ones((_X.shape[0],NUM_CLASSES+2),dtype=tf.float32),
        decoder_input)

    for t in range(SEQUENCE_LENGTH):
        pred, hidden_state = decoder(decoder_input, hidden_state, encoder_output, training=False)
        _loss += loss_object(_y[:,t], pred)
        test_acc.update_state(acc_object(_y[:,t], pred))
        test_loss.update_state(_loss)
        decoder_input = _y[:,t]

## Training

In [21]:
total_time = time.time()

min_loss = 1e10
min_epoch = 0

for e in range(EPOCHS):
    start_time = time.time()
    
    for i,x in enumerate(train_seq_dataset):
        train_step(x)
        
    for i,x in enumerate(valid_seq_dataset):
        valid_step(x)
    
    ipd.clear_output(wait=True)
    print(f"{e+1}/{EPOCHS}, loss={loss.result():.8f}, train acc={acc.result()*100:.2f}%,")
    print(f"validation: loss={valid_loss.result():.8f}, acc={valid_acc.result()*100:.2f}%, {time.time()-start_time:.2f} sec/epoch, totally {time.time()-total_time:.2f} seconds")
    print(f"\tbest valid loss = {min_loss:.8f} at epoch-{min_epoch}")
    
    if min_loss > valid_loss.result():
        min_loss = valid_loss.result()
        min_epoch = e
        # save best model
        encoder.save_weights(f"../weights/sleepeegnet/encoder-{SEED}")
        decoder.save_weights(f"../weights/sleepeegnet/decoder-{SEED}")
    
    loss.reset_states()
    acc.reset_states()
    valid_loss.reset_states()
    valid_acc.reset_states()

In [22]:
encoder.load_weights(f"../weights/sleepeegnet/encoder-{SEED}")
decoder.load_weights(f"../weights/sleepeegnet/decoder-{SEED}")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f4ebfd61910>

## Test

In [23]:
start_time = time.time()
for i, x in enumerate(test_seq_dataset):
    test_step(x)
    ipd.clear_output(wait=True)
    print(i)
ipd.clear_output(wait=True)
print(f"test: loss={test_loss.result():.8f}, acc={test_acc.result()*100:.2f}%, {time.time()-start_time:.2f} seconds")
test_loss.reset_states()
test_acc.reset_states()

test: loss=2.12339211, acc=89.33%, 11.52 seconds
