In [1]:
import matplotlib.pyplot as plt
import numpy as np
from pyedflib import highlevel

In [2]:
# В пояснении не нуждается
def generate_out_for_signal(signals, signal_headers, header):
    out = np.zeros((3, signals.shape[1]))
    
    ds1 = 0
    is1 = 0
    swd1 = 0
    
    for annotation in header['annotations']:
        if annotation[2] not in ['ds1', 'is1', 'swd1', 'ds2', 'is2', 'swd2', 'dds2']:
            raise Exception('Unknown annotation: ' + annotation[2])

        if annotation[2] == 'ds1':
            time = int(annotation[0] * signal_headers[0]['sample_rate'])
            ds1 = time
        elif annotation[2] == 'is1':
            time = int(annotation[0] * signal_headers[0]['sample_rate'])
            is1 = time
        elif annotation[2] == 'swd1':
            time = int(annotation[0] * signal_headers[0]['sample_rate'])
            swd1 = time
        elif annotation[2] == 'ds2' or annotation[2] == 'dds2':
            time = int(annotation[0] * signal_headers[0]['sample_rate'])
            out[0, ds1:time] = 1
        elif annotation[2] == 'is2':
            time = int(annotation[0] * signal_headers[0]['sample_rate'])
            out[1, is1:time] = 1
        elif annotation[2] == 'swd2':
            time = int(annotation[0] * signal_headers[0]['sample_rate'])
            out[2, swd1:time] = 1
    return out

In [3]:
signals, signal_headers, header = highlevel.read_edf('../data/Ati4x3_12m_BL_6h_fully_marked.edf')
out = generate_out_for_signal(signals, signal_headers, header)
X_test = signals.T
y_test = out.T
signals.shape, out.shape

((3, 8640400), (3, 8640400))

In [4]:
signals, signal_headers, header = highlevel.read_edf('../data/Ati4x6_14m_BL_6h_fully_marked.edf')
out = generate_out_for_signal(signals, signal_headers, header)
X_train = signals.T
y_train = out.T
signals.shape, out.shape

((3, 8640400), (3, 8640400))

In [5]:
y_test[:, 0].sum(), y_test[:, 1].sum(), y_test[:, 2].sum(), y_train[:, 0].sum(), y_train[:, 1].sum(), y_train[:, 2].sum()

(1746800.0, 88400.0, 127600.0, 2045048.0, 67013.0, 251131.0)

In [6]:
import torch
from torch.nn import LSTM
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler

class LSTMModel(torch.nn.Module):
    def __init__(self, input_size, hidden_layer_size, output_size):
        super().__init__()
        self.hidden_layer_size = hidden_layer_size
        self.lstm = LSTM(input_size, hidden_layer_size, 1)
        self.linear = torch.nn.Linear(hidden_layer_size, output_size)

    def forward(self, input_seq, hidden_cell=None):
        lstm_out, hidden_cell_out = self.lstm(input_seq.view(len(input_seq) ,1, -1), hidden_cell)
        predictions = self.linear(lstm_out.view(len(input_seq), -1))
        return predictions, hidden_cell_out

model = LSTMModel(3, 3, 3)
model.train()
x_train_torch = torch.tensor(X_train, dtype=torch.float32)
y_train_torch = torch.tensor(y_train, dtype=torch.float32)
dataset = TensorDataset(x_train_torch, y_train_torch)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler)

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

model = model.to(device)
loss_function = torch.nn.MSELoss()
loss_function = loss_function.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [13]:
torch.cuda.empty_cache()

In [None]:
from time import time
from tqdm import tqdm

for epoch in range(20):
    hidden_cell = None
    sum_loss = 0
    t = time()
    loss = 0
    optimizer.zero_grad()
    for x_train_torch, y_train_torch in tqdm(dataloader):
        x_train_torch, y_train_torch = x_train_torch.to(device), y_train_torch.to(device)
        outputs, hidden_cell = model(x_train_torch, hidden_cell)
        loss += loss_function(outputs, y_train_torch)
    loss.backward()
    optimizer.step()
    sum_loss += loss.detach().item()
    print('epoch {}, loss {}'.format(epoch, sum_loss / 128))


100%|██████████| 67504/67504 [02:08<00:00, 524.45it/s]


epoch 0, loss 169.32437133789062


100%|██████████| 67504/67504 [02:05<00:00, 536.64it/s]


epoch 1, loss 152.80780029296875


100%|██████████| 67504/67504 [02:17<00:00, 492.40it/s]


epoch 2, loss 138.02293395996094


  1%|          | 476/67504 [00:01<02:35, 431.16it/s]

In [13]:
x_test_torch = torch.tensor(X_test, dtype=torch.float32)
y_test_torch = torch.tensor(y_test, dtype=torch.float32)
validation_dataset = TensorDataset(x_test_torch, y_test_torch)
validation_sampler = SequentialSampler(validation_dataset)
validation_loader = DataLoader(validation_dataset, batch_size=128, sampler=validation_sampler)

model.eval()
running_vloss = 0.0
# y_pred = np.zeros((0, 3))
# y_true = np.zeros((0, 3))
i = 0
with torch.no_grad():
    hidden_cell = None
    for input, labels in tqdm(validation_loader):
        input, labels = input.to(device), labels.to(device)
        output, hidden_cell = model(input, hidden_cell)
        # y_pred = np.vstack((y_pred, output.detach().cpu().numpy()))
        # y_true = np.vstack((y_true, labels.detach().cpu().numpy()))
        vloss = loss_function(output, labels)
        running_vloss += vloss

y_pred.shape, y_true.shape

 15%|█▍        | 10000/67504 [02:07<12:13, 78.39it/s]


((1280128, 3), (1280128, 3))