In [2]:
import torch
import torch.nn as nn  
from torch.utils.data import Dataset, DataLoader
import time
import os
import random
import numpy as np
from torch.profiler import profile, record_function, ProfilerActivity
import shutil
from torch.utils.tensorboard import SummaryWriter
from matplotlib import pyplot as plt 
from core.datahelper import DataSplitter, DataAE
from torch.optim.lr_scheduler import StepLR

## Models

In [3]:
class FcNet(nn.Module):
    def __init__(self):
        super(FcNet, self).__init__()

        self.name = 'fc'

        self.linear1 = nn.Linear(256 * 5 * 23, 300)
        self.linear2 = nn.Linear(300, 100)
        self.linear3 = nn.Linear(100, 50)
        self.linear4 = nn.Linear(50, 20)
        self.linear5 = nn.Linear(20, 1)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        x = torch.relu(self.linear3(x))
        x = torch.relu(self.linear4(x))
        x = torch.sigmoid(self.linear5(x))

        return x

In [4]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.name = 'conv'

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 2))
        self.norm1 = nn.BatchNorm2d(32)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm2 = nn.BatchNorm2d(32)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm3 = nn.BatchNorm2d(32)
        self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm4 = nn.BatchNorm2d(32)

        self.linear5 = nn.Linear(4992, 100)
        self.linear6 = nn.Linear(100, 50)
        self.linear7 = nn.Linear(50, 20)
        self.linear8 = nn.Linear(20, 1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.norm1(x)
        x = self.maxpool1(x)
        
        x = torch.relu(self.conv2(x))
        x = self.norm2(x)
        x = self.maxpool2(x)

        x = torch.relu(self.conv3(x))
        x = self.norm3(x)
        x = self.maxpool3(x)

        x = torch.relu(self.conv4(x))
        x = self.norm4(x)
        x = torch.flatten(x, 1)

        x = torch.relu(self.linear5(x))
        x = torch.relu(self.linear6(x))
        x = torch.relu(self.linear7(x))
        x = torch.sigmoid(self.linear8(x))
        return x

In [5]:
class LstmNet(nn.Module):
    def __init__(self):
        super(LstmNet, self).__init__()
        self.name = 'lstm'

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 2))
        self.norm1 = nn.BatchNorm2d(32)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm2 = nn.BatchNorm2d(32)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm3 = nn.BatchNorm2d(32)
        self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm4 = nn.BatchNorm2d(32)

        self.dropout = nn.Dropout(p=0.1)
        self.lstm5 = torch.nn.LSTM(input_size=32, hidden_size=20, bidirectional=True, batch_first=True)

        self.linear6 = nn.Linear(40, 1)



    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.norm1(x)
        x = self.maxpool1(x)
        
        x = torch.relu(self.conv2(x))
        x = self.norm2(x)
        x = self.maxpool2(x)

        x = torch.relu(self.conv3(x))
        x = self.norm3(x)
        x = self.maxpool3(x)

        x = torch.relu(self.conv4(x))
        x = self.norm4(x)

        
        x = torch.reshape(x, (-1, 32, 156))
        x = torch.transpose(x, 1, 2)
        x = x.contiguous()
        
        x = self.dropout(x)
        x = self.lstm5(x)[0]
        x = x[:, -1, :]

        x = torch.sigmoid(self.linear6(x))
        return x


In [6]:
class EncoderNet(nn.Module):
    def __init__(self):
        super(EncoderNet, self).__init__()
        self.name = 'encoder'

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 2))
        self.norm1 = nn.BatchNorm2d(32)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm2 = nn.BatchNorm2d(32)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm3 = nn.BatchNorm2d(32)
        self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm4 = nn.BatchNorm2d(32)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.norm1(x)
        x = self.maxpool1(x)
        
        x = torch.relu(self.conv2(x))
        x = self.norm2(x)
        x = self.maxpool2(x)

        x = torch.relu(self.conv3(x))
        x = self.norm3(x)
        x = self.maxpool3(x)

        x = torch.relu(self.conv4(x))
        x = self.norm4(x)

        return x

class DecoderNet(nn.Module):
    def __init__(self):
        super(DecoderNet, self).__init__()
        self.name = 'decoder'

        self.upconv1 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm1 = nn.BatchNorm2d(32)
        self.upsample1 = nn.Upsample(scale_factor=2)
        
        self.upconv2 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm2 = nn.BatchNorm2d(32)
        self.upsample2 = nn.Upsample(scale_factor=2)
        
        self.upconv3 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=(4, 2))
        self.norm3 = nn.BatchNorm2d(32)
        self.upsample3 = nn.Upsample(scale_factor=2)

        self.upconv4 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(3, 2))
        self.norm4 = nn.BatchNorm2d(1)


    def forward(self, x):
        x = torch.relu(self.upconv1(x))


        x = self.norm1(x)
        x = self.upsample1(x)

        
        x = torch.relu(self.upconv2(x))

        x = self.norm2(x)
        x = self.upsample2(x)


        x = torch.relu(self.upconv3(x))

        x = self.norm3(x)
        x = self.upsample3(x)


        x = torch.relu(self.upconv4(x))

        x = self.norm4(x)

        return x

In [45]:
class LockedDropout(nn.Module):
    # ...
    def forward(self, x, dropout=0.5):
        if not self.training or not dropout:
            return x
        m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
        mask = torch.autograd.Variable(m, requires_grad=False) / (1 - dropout)
        mask = mask.expand_as(x)
        return mask * x

class WeightDrop(torch.nn.Module):
    def __init__(self, module, weights, dropout=0, variational=False):
        super(WeightDrop, self).__init__()
        self.module = module
        self.weights = weights
        self.dropout = dropout
        self.variational = variational
        self._setup()

    def widget_demagnetizer_y2k_edition(*args, **kwargs):
        # We need to replace flatten_parameters with a nothing function
        # It must be a function rather than a lambda as otherwise pickling explodes
        # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
        # (╯°□°）╯︵ ┻━┻
        return

    def _setup(self):
        # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
        if issubclass(type(self.module), torch.nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw', torch.nn.Parameter(w.data))

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = None
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
                w = mask.expand_as(raw_w) * raw_w
            else:
                w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            if not self.training:
                w = w.data
            setattr(self.module, name_w, w)

    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)

def embedded_dropout(embed, words, dropout=0.1, scale=None):
    if dropout:
        mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
        masked_embed_weight = mask * embed.weight
    else:
        masked_embed_weight = embed.weight
    if scale:
        masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight

    padding_idx = embed.padding_idx
    if padding_idx is None:
        padding_idx = -1

    X = torch.nn.functional.embedding(words, masked_embed_weight,
        padding_idx, embed.max_norm, embed.norm_type,
        embed.scale_grad_by_freq, embed.sparse
    )
    return X

## Reccurent dropout

In [38]:
class LstmDropNet(nn.Module):
    def __init__(self):
        super(LstmDropNet, self).__init__()
        self.name = 'drop-lstm'

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 2))
        self.norm1 = nn.BatchNorm2d(32)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm2 = nn.BatchNorm2d(32)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm3 = nn.BatchNorm2d(32)
        self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 2))
        self.norm4 = nn.BatchNorm2d(32)

        self.dropout = nn.Dropout(p=0.1)
        self.lstm5 = torch.nn.LSTM(input_size=32, hidden_size=20, bidirectional=True, batch_first=True)
        self.lstm5 = WeightDrop(self.lstm5, ['weight_hh_l0', 'bias_hh_l0', 'weight_hh_l0_reverse', 'bias_hh_l0_reverse'], dropout=0.5) 

        self.linear6 = nn.Linear(40, 1)



    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.norm1(x)
        x = self.maxpool1(x)
        
        x = torch.relu(self.conv2(x))
        x = self.norm2(x)
        x = self.maxpool2(x)

        x = torch.relu(self.conv3(x))
        x = self.norm3(x)
        x = self.maxpool3(x)

        x = torch.relu(self.conv4(x))
        x = self.norm4(x)

        
        x = torch.reshape(x, (-1, 32, 156))
        x = torch.transpose(x, 1, 2)
        x = x.contiguous()
        
        x = self.dropout(x)
        x = self.lstm5(x)[0]
        x = x[:, -1, :]

        x = torch.sigmoid(self.linear6(x))
        return x

In [47]:
class EncoderLstmNet(nn.Module):
    def __init__(self, encoder):
        super(EncoderLstmNet, self).__init__()
        self.name = 'encoder-lstm'
        
        self.encoder = encoder
        self.encoder.eval()
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.dropout = nn.Dropout(p=0.1)
        self.lstm = torch.nn.LSTM(input_size=32, hidden_size=10, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(20, 1)

    def forward(self, x):
        x = self.encoder(x)

        x = torch.reshape(x, (-1, 32, 156))
        x = torch.transpose(x, 1, 2)
        x = x.contiguous()
        
        x = self.dropout(x)
        x = self.lstm(x)[0]
        x = x[:, -1, :]

        x = torch.sigmoid(self.linear(x))

        return x
    
    def train(self, mode=True):
        super().train()
        self.encoder.eval()
        return self
        

In [60]:
class EncoderDropLstmNet(nn.Module):
    def __init__(self, encoder):
        super(EncoderDropLstmNet, self).__init__()
        self.name = 'encoder-drop-lstm'
        
        self.encoder = encoder
        self.encoder.eval()
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.dropout = nn.Dropout(p=0.1)
        self.lstm = torch.nn.LSTM(input_size=32, hidden_size=10, bidirectional=True, batch_first=True)
        self.lstm = WeightDrop(self.lstm, ['weight_hh_l0', 'bias_hh_l0', 'weight_hh_l0_reverse', 'bias_hh_l0_reverse'], dropout=0.5)
        self.linear = nn.Linear(20, 1)

    def forward(self, x):
        x = self.encoder(x)

        x = torch.reshape(x, (-1, 32, 156))
        x = torch.transpose(x, 1, 2)
        x = x.contiguous()
        
        x = self.dropout(x)
        x = self.lstm(x)[0]
        x = x[:, -1, :]

        x = torch.sigmoid(self.linear(x))

        return x
    
    def train(self, mode=True):
        super().train()
        self.encoder.eval()
        return self

## Metrics

In [8]:
class AccuracyMetric:
    
    def __init__(self):
        self.name = 'accuracy'
        self.sum = 0
        self.count = 0
    
    def clean(self):
        self.sum = 0
        self.count = 0

    def accumulate(self, y_pred, y):
        self.sum += (y_pred == y).sum().item() / len(y)
        self.count += 1

    def calculate(self):
        return self.sum / self.count


class RecallMetric:
    def __init__(self):
        self.name = 'recall'
        self.tp = 0
        self.fn = 0
    
    def clean(self):
        self.tp = 0
        self.fn = 0

    def accumulate(self, y_pred, y):
        self.tp += ((y_pred == y) & (y == 1)).sum().item()
        self.fn += (y > y_pred).sum().item()

    def calculate(self):
        return self.tp / (self.tp + self.fn)


class SpecifityMetric:
    def __init__(self):
        self.name = 'specifity'
        self.tn = 0
        self.n = 0
    
    def clean(self):
        self.tn = 0
        self.n = 0

    def accumulate(self, y_pred, y):
        self.tn += ((y_pred == y) & (y == 0)).sum().item()
        self.n += (y == 0).sum().item()

    def calculate(self):
        return self.tn / self.n


In [9]:
class ELDataset(Dataset):

    def __init__(self, files):
        self.files = files

    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        with record_function("get_data"):
            filepath, label = self.files[idx]
            with record_function("np_read"):
                raw = np.load(filepath)
            data = np.ascontiguousarray(raw.transpose(), dtype=np.float32).reshape(1, 1280, 23)
            return data, np.array([label], dtype=np.float32)


class NormilizeTransform(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        return (sample - self.mean) / self.std

class EEGRamDataset(Dataset):

    def __init__(self, files, gen_normilize=False, transform=None):
        self._data = np.zeros((len(files), 1, 1280, 23), dtype=np.float32)
        self._labels = np.zeros((len(files), 1), dtype=np.float32)
        for indx, (filepath, label) in enumerate(files):
            raw = np.load(filepath)
            data = np.ascontiguousarray(raw.transpose(), dtype=np.float32).reshape(1, 1280, 23)
            self._data[indx] = data
            self._labels[indx] = np.array([label], dtype=np.float32)

        self.mean = None
        self.std = None

        if gen_normilize:
            self.mean = self._data.mean()
            self.std = self._data.std()

            transform = NormilizeTransform(self.mean, self.std)

        self.transform = transform

    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, idx):
        data = self._data[idx]
        if self.transform:
            return self.transform(data), self._labels[idx]
        return data, self._labels[idx]

    def get_mean_std(self):
        return self.mean, self.std
    

## Dataloaders

In [10]:
def get_data_loaders_p_s(path_to_data, patient, seizure, mean=None, std=None, batch_size=64):
    train, test = DataSplitter.load_files(path_to_data, path_to_data, patient, seizure)

    if mean is not None:
        train_data = EEGRamDataset(train, transform=NormilizeTransform(mean, std))
    else:
        train_data = EEGRamDataset(train, gen_normilize=True)
        mean, std = train_data.get_mean_std()

    # train_mean, train_std = train_data.get_mean_std()
    test_data = EEGRamDataset(test, transform=NormilizeTransform(mean, std))
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

    return train_dataloader, test_dataloader

def get_ae_dataloaders(path_to_data, bacth_size):
    train, test = DataAE.load_files(path_to_data, path_to_data)
    train_dataset = EEGRamDataset(train, gen_normilize=True)
    train_mean, train_std = train_dataset.get_mean_std()

    train_data = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_data = DataLoader(EEGRamDataset(test, transform=NormilizeTransform(train_mean, train_std)), batch_size=bacth_size, shuffle=True)

    return train_data, test_data

## Training each patient

In [65]:
class PatientSpecTrainig:
    mean = 0.21090616
    std = 65.68308

    def train_one_epoch(model, train_dataloader, loss_fn, optimizer, device='cpu'):
        train_size = len(train_dataloader.dataset)
        for i, (x, y) in enumerate(train_dataloader):
            x, y = x.to(device), y.to(device)
            
            pred = model(x)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 50 == 0:
                loss, current = loss.item(), i * len(x)
                print(f"Batch Loss: {loss:>7f}  [{current:>5d}/{train_size:>5d}]")

    def train_patient_seizure(p_ind, s_indx, gen_model, train_dataloader, test_dataloader, path_to_models, loss_fn, optimizer_gen, metrics=[], steps=15, device='cpu', batch_size=32, rewrite=False):
        model = gen_model().to(device)
        optimizer = optimizer_gen(model)
        # train_dataloader, test_dataloader = get_data_loaders_p_s(path_to_data, p_ind, s_indx, batch_size)
        mean, std = train_dataloader.dataset.get_mean_std()
        path_to_models = os.path.join(path_to_models, model.name, str(p_ind), str(s_indx))
        epoch = 1

        scheduler = StepLR(optimizer, step_size=300, gamma=0.5)

        if rewrite:
            if os.path.exists(path_to_models):
                shutil.rmtree(path_to_models)
            logs_path = os.path.join('logs', model.name, str(p_ind), str(s_indx))
            if os.path.exists(logs_path):
                shutil.rmtree(logs_path)
        else:
            if os.path.exists(path_to_models):
                epochs = map(lambda x: int(x[:-3]), os.listdir(path_to_models))
                if epochs == []:
                    pass
                else:
                    load_epoch = max(epochs)
                    checkpoint = torch.load(os.path.join(path_to_models, str(load_epoch) + '.pt'))
                    model.load_state_dict(checkpoint['model_state_dict'])
                    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                    # scheduler.load_state_dict(checkpoint['sheduler_state_dict'])
                    epoch = checkpoint['epoch'] + 1

        
        writer = SummaryWriter(os.path.join('logs', model.name, str(p_ind), str(s_indx)))

        while epoch <= steps:
            print('-'*30)
            print('Epoch {}:'.format(epoch))
            model.train(True)
            PatientSpecTrainig.train_one_epoch(model, train_dataloader, loss_fn, optimizer, device)

            with torch.no_grad():
                model.train(False)
                train_metrics = PatientSpecTrainig.get_metrics(model, train_dataloader, loss_fn, metrics, device)
                val_metrics = PatientSpecTrainig.get_metrics(model, test_dataloader, loss_fn, metrics, device)

                print('Train')
                for key in train_metrics:
                    print('{0}: {1}'.format(key, train_metrics[key]))
                    writer.add_scalar("{0}/train".format(key), train_metrics[key], epoch)


                print('Test')
                for key in val_metrics:
                    print('{0}: {1}'.format(key, val_metrics[key]))
                    writer.add_scalar("{0}/test".format(key), val_metrics[key], epoch)

                writer.flush()

            scheduler.step()
            PatientSpecTrainig.save_model(epoch, model, os.path.join(path_to_models, str(epoch) + '.pt'), optimizer, scheduler, mean, std, val_metrics)
            epoch += 1

        writer.close()

    def train_patient_model(p_ind, model_class, data_path, model_path, learning_parametrs, device='cpu', rewrite=False):
        loss_fn = torch.nn.BCELoss().to(device)
        gen_model = PatientSpecTrainig.get_model(model_class, path_to_encoder=os.path.join(path_to_models, 'autoencoder', '55.tar'))
        for s_indx in range(DataSplitter.get_seizures_number(data_path, p_ind)):
            print("Patient {0} - seizure {1}".format(p_ind, s_indx))
            mean = None
            std = None
            if model_class == EncoderLstmNet or model_class == EncoderDropLstmNet:
                mean = PatientSpecTrainig.mean
                std = PatientSpecTrainig.std
            train_dataloader, test_dataloader = get_data_loaders_p_s(path_to_data, p_ind, s_indx, mean, std, 64)
            PatientSpecTrainig.train_patient_seizure(p_ind, s_indx, gen_model, train_dataloader, test_dataloader, model_path,
                                                     loss_fn, PatientSpecTrainig.get_optimizer, steps=learning_parametrs[model_class.__name__]['steps'],
                                                     device=device, rewrite=rewrite, metrics=[AccuracyMetric(), RecallMetric(), SpecifityMetric()])
    
    def get_optimizer(model):
        return torch.optim.RMSprop(model.parameters(), lr=0.001, alpha=0.9)

    def get_metrics(model, dataloader, loss_fn, metrics=[], device='cpu'):
        validation_loss = 0.0
        for m in metrics:
            m.clean()
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred_v = model(x)
            pred_round_v = pred_v.round()
            loss_bacth = loss_fn(pred_v, y)
            validation_loss += loss_bacth

            for m in metrics:
                m.accumulate(pred_round_v, y)

        res = {'loss': validation_loss.item() / len(dataloader)}
        for m in metrics:
            res[m.name] = m.calculate()
        return res
    
    def get_model(model_class, path_to_encoder=None):
        if model_class == EncoderLstmNet or model_class == EncoderDropLstmNet:
            def gen_model():
                encoder = EncoderNet()
                checkpoint_autoencoder = torch.load(path_to_encoder)
                encoder.load_state_dict(checkpoint_autoencoder['encoder_state_dict'])
                return model_class(encoder)
            return gen_model
        def gen_model():
            return model_class()
        return gen_model

    def save_model(epoch, model, path_to_save, optimizer, sheduler, mean, std, metrics):
        os.makedirs(os.path.dirname(path_to_save), exist_ok=True)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': sheduler.state_dict(),
            'mean': mean,
            'std': std,
            'metrics': metrics,
            }, path_to_save)  

In [66]:
path_to_data = "C:\\data\\CHBData\\preprocess"
path_to_models = 'D:\CHBModel'
patients_ids = [1, 3, 7, 9, 10, 20, 21, 22]
learning_parametrs = {
    'FcNet': {'steps': 50},
    'ConvNet': {'steps': 50},
    'LstmNet': {'steps': 50},
    'EncoderLstmNet': {'steps': 300},
    'LstmDropNet': {'steps': 100},
    'EncoderDropLstmNet': {'steps': 300},
}

In [None]:
for p in patients_ids:
    PatientSpecTrainig.train_patient_model(p, EncoderDropLstmNet, path_to_data, path_to_models, learning_parametrs, device='cuda', rewrite=False)

## Training autoencoder

In [176]:
class AETrainig:
    def train_one_epoch(encoder, decoder, dataloader, loss_fn, encoder_optimizer, decoder_optimizer, device='cpu'):
        train_size = len(dataloader.dataset)
        for i, (x, _) in enumerate(dataloader):
            x = x.to(device)
            
            encoded = encoder(x)
            decoded = decoder(encoded)

            loss = loss_fn(decoded, x)

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()

            if i % 200 == 0:
                loss, current = loss.item(), i * len(x)
                print(f"Batch Loss: {loss:>7f}  [{current:>5d}/{train_size:>5d}]")



    def train(train_dataloader, test_dataloader, path_to_models, device='cpu', steps=15, rewrite=False):
        encoder = EncoderNet().to(device)
        decoder = DecoderNet().to(device)
        loss_fn = torch.nn.MSELoss().to(device)

        encoder_optimizer = torch.optim.RMSprop(encoder.parameters(), lr=0.0002, alpha=0.98)
        decoder_optimizer = torch.optim.RMSprop(decoder.parameters(), lr=0.0002, alpha=0.98)
        
        train_mean, train_std = train_dataloader.dataset.get_mean_std()

        epoch = 1
        logs_path = os.path.join('logs', 'autoencoder')
        path_to_autoencoder = os.path.join(path_to_models, 'autoencoder')

        if rewrite:
            if os.path.exists(path_to_autoencoder):
                shutil.rmtree(path_to_autoencoder)

            if os.path.exists(logs_path):
                shutil.rmtree(logs_path)
        else:
            epochs = map(lambda x: int(x[:-4]), os.listdir(path_to_autoencoder))
            if epochs != []:
                load_epoch = max(epochs)
                checkpoint_autoencoder = torch.load(os.path.join(path_to_autoencoder, str(load_epoch) + '.tar'))
                encoder.load_state_dict(checkpoint_autoencoder['encoder_state_dict'])
                decoder.load_state_dict(checkpoint_autoencoder['decoder_state_dict'])
                # encoder_optimizer.load_state_dict(checkpoint_autoencoder['encoder_optimizer_state_dict'])
                # decoder_optimizer.load_state_dict(checkpoint_autoencoder['decoder_optimizer_state_dict'])
                epoch = checkpoint_autoencoder['epoch'] + 1

        writer = SummaryWriter(logs_path)

        while epoch <= steps:
            print('-'*30)
            print('Epoch {}:'.format(epoch))

            encoder.train(True)
            decoder.train(True)

            AETrainig.train_one_epoch_ae(encoder, decoder, train_dataloader, loss_fn, encoder_optimizer, decoder_optimizer, device)

            with torch.no_grad():
                encoder.train(False)
                decoder.train(False)

                train_loss = 0
                for x, _ in train_dataloader:
                    x = x.to(device)
                    train_loss += loss_fn(decoder(encoder(x)), x)
                train_loss /= len(train_dataloader)

                test_loss = 0
                for x, _ in test_dataloader:
                    x = x.to(device)
                    test_loss += loss_fn(decoder(encoder(x)), x)
                test_loss /= len(test_dataloader)

                print('Train')
                print('{0}: {1}'.format('Loss ', train_loss))
                writer.add_scalar("{0}/train".format('Loss'), train_loss, epoch)


                print('Test')
                print('{0}: {1}'.format('Loss ', test_loss))
                writer.add_scalar("{0}/test".format('Loss'), test_loss, epoch)

                writer.flush()

                AETrainig.save_autoencoder(os.path.join(path_to_autoencoder, str(epoch) + '.tar'), epoch, encoder, decoder, encoder_optimizer, decoder_optimizer, train_mean, train_std, None)
            epoch += 1

        writer.close()

    def save(path_to_save, epoch, encoder, decoder, encoder_optimizer, decoder_optimizer, mean, std, metrics):
        os.makedirs(os.path.dirname(path_to_save), exist_ok=True)
        torch.save({
            'epoch': epoch,
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
            'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
            'mean': mean,
            'std': std,
            'metrics': metrics,
            }, path_to_save)

In [44]:
ae_train_dataloader, ae_test_dataloader = get_ae_dataloaders(path_to_data)

In [None]:
AETrainig.train(ae_train_dataloader, ae_test_dataloader, path_to_models, device='cuda', steps=100, rewrite=False)