In [1]:
import mne
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

## Carga de los datos

In [104]:
class EDFData(torch.utils.data.Dataset):
    def __init__(self, path):
        super(EDFData, self).__init__()
        self.path = path
        # self.data = mne.io.read_raw_edf(path)
        # self.events, self.events_id = mne.events_from_annotations(self.data, regexp='Sleep stage')
        self.epochs = self.get_epochs(path)
        self.id_to_class_dict = {value:key for key, value in self.epochs.event_id.items()}

    def __getitem__(self, idx):
        return torch.squeeze(torch.Tensor(self.epochs[idx].load_data()._data)), torch.Tensor([self.epochs[idx].events[0][-1]])-1
    # def __getitem__(self, idx):
    #     return self.epochs[idx]._data, self.epochs[idx].events[0][-1]

    def __len__(self):
        return len(self.epochs)

    def get_epochs(self, path):
        data = mne.io.read_raw_edf(path)
        events, events_id = mne.events_from_annotations(data, regexp='Sleep stage')

        tmax = 30. - 1. / data.info['sfreq']  # tmax is included
        epochs = mne.Epochs(raw=data, 
                            events=events,
                            event_id=events_id,
                            tmin=0., 
                            tmax=tmax, 
                            baseline=None, 
                            event_repeated='merge')
        epochs.drop_bad()
        # epochs.load_data()
        return epochs

In [112]:
BATCH_SIZE = 4
dataset = EDFData("../Data/PSG1_0001.edf")
dataloader = torch.utils.data.DataLoader(dataset, batch_size = BATCH_SIZE, drop_last=True)

Extracting EDF parameters from e:\Python\TFM\SSD_IA3\Data\PSG1_0001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Used Annotations descriptions: ['Sleep stage N1', 'Sleep stage N2', 'Sleep stage N3', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
Not setting metadata
862 matching events found
No baseline correction applied
0 projection items activated
Loading data for 862 events and 15360 original time points ...
  data = mne.io.read_raw_edf(path)
  data = mne.io.read_raw_edf(path)
1 bad epochs dropped


In [106]:
# %%time
# for a, b in dataset:
#     pass
#     # print(a.shape)

## Model

In [113]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1d = nn.Conv1d(in_channels=50, out_channels=1, kernel_size=3, padding=1)
        self.linear = nn.Linear(in_features=15360, out_features=5)

    def forward(self, X):
        pred = F.relu(self.conv1d(X))
        pred = pred.view(X.shape[0],-1)
        pred = self.linear(pred)
        return pred

In [120]:
device="cpu"

In [121]:
model = Model()
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss()

In [122]:
%%time
EPOCHS = 1

for epoch in range(EPOCHS):
    loss_epoch = 0
    for X, Y in dataloader:
        X, Y = X.to(device), Y.to(device)
        Y = torch.squeeze(Y)
        Y=Y.long()
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, Y)
        loss.backward()
        loss_epoch += loss
    print(loss_epoch)

s and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original time points ...
Loading data for 1 events and 15360 original tim

In [123]:
Y, pred

(tensor([4, 4, 4, 4], device='cuda:0'),
 tensor([[  6.6568,  -3.6175, -12.1228, -13.2038,  -3.6589],
         [  6.3801,  -3.6081, -11.9478, -12.8430,  -3.3025],
         [  6.5559,  -3.6108, -12.1148, -13.1863,  -3.6667],
         [  6.3805,  -3.4911, -12.1439, -13.3419,  -3.8977]], device='cuda:0',
        grad_fn=<AddmmBackward>))

In [124]:
pred.shape, Y.shape

(torch.Size([4, 5]), torch.Size([4]))