In [1]:
import math

import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf

## Carga de los datos

In [2]:
mne.set_log_level(verbose=False)

In [14]:
print(1) if 1  else print(2)

1


In [17]:
class EDFData_TF(tf.keras.utils.Sequence):
    def __init__(self, path, batch_size, channels=None):
        super(EDFData_TF, self).__init__()
        self.path = path
        self.batch_size = batch_size
        self.channels = channels if channels else 'all'
        self.epochs, self.sampling_rate = self.get_epochs(path)
        self.id_to_class_dict = {value-1:key for key, value in self.epochs.event_id.items()}

    def __getitem__(self, idx):
        # In TF, should return a full batch

        X = self.epochs[idx * self.batch_size:(idx + 1)*self.batch_size].load_data()._data
        Y = self.epochs[idx * self.batch_size:(idx + 1)*self.batch_size].events[:,-1]-1

        # return tf.squeeze(tf.Tensor(self.epochs[idx].load_data()._data)), tf.Tensor([self.epochs[idx].events[0][-1]])-1
        return X, Y

    def __len__(self):
        # In TF, len should return the number of batches
        return math.ceil(len(self.epochs)/self.batch_size)

    def get_epochs(self, path):
        data = mne.io.read_raw_edf(path)
        sampling_rate = data.info['sfreq']
        events, events_id = mne.events_from_annotations(data, regexp='Sleep stage')

        tmax = 30. - 1. / sampling_rate  # tmax is included
        epochs = mne.Epochs(raw=data, 
                            events=events,
                            event_id=events_id,
                            tmin=0., 
                            tmax=tmax, 
                            baseline=None, 
                            event_repeated='merge',
                            picks=self.channels)

        epochs.drop_bad()
        return epochs, sampling_rate

In [20]:
dataset = EDFData_TF("../Data/PSG1.edf", batch_size=4)

  data = mne.io.read_raw_edf(path)
  data = mne.io.read_raw_edf(path)


In [21]:
dataset.epochs.ch_names

['C3']

In [12]:
len(dataset)

221

In [13]:
for a, b in dataset:
    break

In [14]:
a.shape, b.shape

((4, 50, 15360), (4,))

## Model

In [15]:
sr = int(dataset.sampling_rate)
sr

512

In [18]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv1D(128, kernel_size=sr//2, padding='same', strides=sr//4, activation="relu", input_shape=(50,15360), data_format='channels_first'),
    tf.keras.layers.MaxPooling1D(8),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Conv1D(128, kernel_size=8, padding='same', strides=1, activation="relu", data_format='channels_first'),
    tf.keras.layers.Conv1D(128, kernel_size=8, padding='same', strides=1, activation="relu", data_format='channels_first'),
    tf.keras.layers.Conv1D(128, kernel_size=8, padding='same', strides=1, activation="relu", data_format='channels_first'),
    tf.keras.layers.MaxPooling1D(4),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.LSTM(128),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(5, activation="softmax")
])

In [19]:
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), optimizer=tf.optimizers.Adam(learning_rate=0.001), metrics=["accuracy"])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1d_1 (Conv1D)            (None, 128, 120)          1638528   
_________________________________________________________________
max_pooling1d (MaxPooling1D) (None, 16, 120)           0         
_________________________________________________________________
dropout (Dropout)            (None, 16, 120)           0         
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 128, 120)          16512     
_________________________________________________________________
conv1d_3 (Conv1D)            (None, 128, 120)          131200    
_________________________________________________________________
conv1d_4 (Conv1D)            (None, 128, 120)          131200    
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 32, 120)           0

In [20]:
model.fit(dataset, epochs=10)

Epoch 1/10

KeyboardInterrupt: 