In [1]:
import pyedflib
import mne
import numpy as np

import torchvision
import torchaudio
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau


import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from IPython.display import clear_output
from tqdm.notebook import tqdm
import matplotlib.ticker as ticker
from os import listdir
import os

from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay

%matplotlib inline

In [13]:
class conbr_block(nn.Module):
    def __init__(self, in_layer, out_layer, kernel_size, stride, dilation):
        super(conbr_block, self).__init__()
        self.stride = stride
        self.conv1 = nn.Conv1d(in_layer, out_layer, kernel_size=kernel_size, stride=stride, dilation = dilation, 
                               padding = int(np.ceil(dilation * (kernel_size-1) / 2)), bias=True) # for stride=1, else need to calculate and change
        self.bn = nn.BatchNorm1d(out_layer)
        self.relu = nn.ReLU()
    
    def forward(self,x):
        inp_shape = int(np.ceil(x.shape[2] / self.stride))
        x = self.conv1(x)
        x = self.bn(x)
        out = self.relu(x)[:, :, :inp_shape] 
        print("conbr_out", out.shape)
        return out      

class se_block(nn.Module):
    def __init__(self,in_layer, out_layer):
        super(se_block, self).__init__()
        
        self.conv1 = nn.Conv1d(in_layer, out_layer//8, kernel_size=1, padding=0)
        self.conv2 = nn.Conv1d(out_layer//8, in_layer, kernel_size=1, padding=0)
        self.fc = nn.Linear(1,out_layer//8)
        self.fc2 = nn.Linear(out_layer//8,out_layer)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):

        x_se = nn.functional.adaptive_avg_pool1d(x,1)
        x_se = self.conv1(x_se)
        x_se = self.relu(x_se)
        x_se = self.conv2(x_se)
        x_se = self.sigmoid(x_se)
        
        x_out = torch.add(x, x_se)
        return x_out

class re_block(nn.Module):
    def __init__(self, in_layer, out_layer, kernel_size, dilation):
        super(re_block, self).__init__()
        
        self.cbr1 = conbr_block(in_layer,out_layer, kernel_size, 1, dilation)
        self.cbr2 = conbr_block(out_layer,out_layer, kernel_size, 1, dilation)
        self.seblock = se_block(out_layer, out_layer)
    
    def forward(self,x):
        x_re = self.cbr1(x)
        x_re = self.cbr2(x_re)
        x_re = self.seblock(x_re)        
        x_out = torch.add(x, x_re)
        return x_out          

class UNET_1D(nn.Module):
    def __init__(self ,input_dim,layer_n,kernel_size, n_down_layers, depth, n_features=1): # n_features for additional features in some other exps
        super(UNET_1D, self).__init__()
        self.input_dim = input_dim
        self.layer_n = layer_n
        self.kernel_size = kernel_size
        self.n_down_layers = n_down_layers
        self.depth = depth
        
        self.AvgPool1D = nn.ModuleList([nn.AvgPool1d(input_dim, stride=5**i, padding=8) for i in range(1, self.n_down_layers)])
        
        
        self.layer1 = self.down_layer(self.input_dim, self.layer_n, self.kernel_size,1, depth)
        self.layer1_sneo = self.down_layer(self.input_dim, self.layer_n, self.kernel_size,1, self.depth)
        self.layer1_mc = self.down_layer(self.input_dim, self.layer_n, self.kernel_size,1, self.depth)
        
        self.layer2 = self.down_layer(self.layer_n, int(self.layer_n*2), self.kernel_size,5, self.depth)
        
        self.down_layers = nn.ModuleList([self.down_layer(self.layer_n*(1+i)+n_features*self.input_dim, self.layer_n*(2+i), 
                                            self.kernel_size,5, self.depth) for i in range(1, self.n_down_layers)])


        self.cbr_up = nn.ModuleList([self.down_layer(int(self.layer_n*(2*i+1)), int(self.layer_n*i), self.kernel_size, 1, 1) 
                       for i in range(self.n_down_layers, 0, -1)]) #input size is a sizes sum of outs of 2 down layers for current down depth
        self.upsample = nn.Upsample(scale_factor=5, mode='nearest') 
        
        self.outcov = nn.Conv1d(self.layer_n, 2, kernel_size=self.kernel_size, stride=1,
                                padding = int(np.ceil(1 * (self.kernel_size-1) / 2)))
    
        
    def down_layer(self, input_layer, out_layer, kernel, stride, depth): # may be should change name on "depth_layer"
        block = []
        block.append(conbr_block(input_layer, out_layer, kernel, stride, 1))
        for i in range(depth):
            block.append(re_block(out_layer,out_layer,kernel,1))
        return nn.Sequential(*block)
        
        
            
    def forward(self, x):
        inp_shape = x.shape[2]
        
        
        
        #############Encoder#####################

        out_0 = self.layer1(x)
        out_1 = self.layer2(out_0)
        outs = [out_0, out_1]
        for i in range(self.n_down_layers-1):
            
            pool = self.AvgPool1D[i](x)
            print("pool ", pool.shape)
            print("outs[-1]", outs[-1].shape)
            x_down = torch.cat([outs[-1][:, :, :pool.shape[-1]],pool],1)
            outs.append(self.down_layers[i](x_down))

        for i in range(len(outs)):
            print("outs " + str(i), outs[i].shape)


        #############Decoder####################
        up = self.upsample(outs[-1])[:, :, :outs[-2].shape[2]]
        for i in range(self.n_down_layers):
            print("up outs", outs[-2-i].shape)           
            up = torch.cat([up,outs[-2-i]],1)
            up = self.cbr_up[i](up)
            if i + 1 < self.n_down_layers:
                up = self.upsample(up)[:, :, :outs[-3-i].shape[2]]

        print("up ", up.shape)
        out = self.outcov(up)


        return out[:, :, :inp_shape] 

In [16]:
model_check = UNET_1D(20,128,7, 5,depth=3, n_features=1)#.to("cuda")
out = model_check(torch.zeros((8, 20, 4000)))#.to("cuda"))
out.shape

conbr_out torch.Size([8, 128, 4000])
conbr_out torch.Size([8, 128, 4000])
conbr_out torch.Size([8, 128, 4000])
conbr_out torch.Size([8, 128, 4000])
conbr_out torch.Size([8, 128, 4000])
conbr_out torch.Size([8, 128, 4000])
conbr_out torch.Size([8, 128, 4000])
conbr_out torch.Size([8, 256, 800])
conbr_out torch.Size([8, 256, 800])
conbr_out torch.Size([8, 256, 800])
conbr_out torch.Size([8, 256, 800])
conbr_out torch.Size([8, 256, 800])
conbr_out torch.Size([8, 256, 800])
conbr_out torch.Size([8, 256, 800])
pool  torch.Size([8, 20, 800])
outs[-1] torch.Size([8, 256, 800])
conbr_out torch.Size([8, 384, 160])
conbr_out torch.Size([8, 384, 160])
conbr_out torch.Size([8, 384, 160])
conbr_out torch.Size([8, 384, 160])
conbr_out torch.Size([8, 384, 160])
conbr_out torch.Size([8, 384, 160])
conbr_out torch.Size([8, 384, 160])
pool  torch.Size([8, 20, 160])
outs[-1] torch.Size([8, 384, 160])
conbr_out torch.Size([8, 512, 32])
conbr_out torch.Size([8, 512, 32])
conbr_out torch.Size([8, 512, 32])


torch.Size([8, 2, 4000])

In [13]:
kernels = np.arange(3, 15)
windows = np.arange(3000, 10000, 100)
depths = np.arange(2, 6)
down_layers = np.arange(2, 6)

In [14]:
for kernel in tqdm(kernels):
    for window in windows:
        for depth in depths:
            for down_layer in down_layers:
                model_check = UNET_1D(20,128,kernel, down_layer,depth,1)
                out = model_check(torch.zeros((8, 20, window)))
                assert out.shape == torch.Size([8, 2, window])

  0%|          | 0/12 [00:00<?, ?it/s]

outs 0 torch.Size([8, 128, 3000])
outs 1 torch.Size([8, 256, 600])
outs 2 torch.Size([8, 384, 120])
up outs torch.Size([8, 256, 600])
up outs torch.Size([8, 128, 3000])
up  torch.Size([8, 128, 3000])
outs 0 torch.Size([8, 128, 3000])
outs 1 torch.Size([8, 256, 600])
outs 2 torch.Size([8, 384, 120])
outs 3 torch.Size([8, 512, 24])
up outs torch.Size([8, 384, 120])
up outs torch.Size([8, 256, 600])
up outs torch.Size([8, 128, 3000])
up  torch.Size([8, 128, 3000])
outs 0 torch.Size([8, 128, 3000])
outs 1 torch.Size([8, 256, 600])
outs 2 torch.Size([8, 384, 120])
outs 3 torch.Size([8, 512, 24])
outs 4 torch.Size([8, 640, 5])
up outs torch.Size([8, 512, 24])
up outs torch.Size([8, 384, 120])
up outs torch.Size([8, 256, 600])
up outs torch.Size([8, 128, 3000])
up  torch.Size([8, 128, 3000])
outs 0 torch.Size([8, 128, 3000])
outs 1 torch.Size([8, 256, 600])
outs 2 torch.Size([8, 384, 120])
outs 3 torch.Size([8, 512, 24])
outs 4 torch.Size([8, 640, 5])
outs 5 torch.Size([8, 768, 1])
up outs to

In [4]:
model_check = UNET_1D(20,128,5, 2,5,1)
out = model_check(torch.zeros((8, 20, 4000)))
out.shape

torch.Size([8, 2, 4000])

In [24]:
model_check = UNET_1D(20,128,7, 3, 5,1)
out = model_check(torch.zeros((8, 20, 4000)))
out.shape

up torch.Size([8, 512, 160])
out_2 torch.Size([8, 384, 160])


torch.Size([8, 2, 4000])