In [1]:
import pyedflib
import mne
import numpy as np

import torchvision
import torchaudio
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau


import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from IPython.display import clear_output
from tqdm import tqdm
import matplotlib.ticker as ticker
from os import listdir
import os

from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay

%matplotlib inline

In [2]:
def ReadSignal(file_name): 

    f = pyedflib.EdfReader(file_name)
    n = f.signals_in_file
    signal_labels = f.getSignalLabels()
    sigbufs = np.zeros((20, f.getNSamples()[0])) #or n
    
    if n == 22:
        for i in np.arange(19):
            sigbufs[i, :] = f.readSignal(i)
        sigbufs[19, :] = f.readSignal(21)
    elif n == 23:
        for i in np.arange(19):
            sigbufs[i, :] = f.readSignal(i)
        sigbufs[19, :] = f.readSignal(20)
    else:
        for i in np.arange(n):
            sigbufs[i, :] = f.readSignal(i)

    time = [1/f.samplefrequency(0) * i for i in range(len(sigbufs[0]))]

    annotations = f.readAnnotations()  


    new_annot = [(annotations[0][i], annotations[1][i], annotations[2][i])  
                 for i in range(len(annotations[0])) 
                                if (annotations[1][i] > 0) and (annotations[2][i] in ["Ð´Ð°Ð±Ð» Ñ\x81Ð¿Ð°Ð¹Ðº", "*", "?", "F7", "F7(blockUserBlock)", 'Ñ\x8dÐ°'])]
    f.close()
    return sigbufs, new_annot, time, f.samplefrequency(0)

In [3]:
record_names = ["DataToLabel/P7.edf", 
               "DataToLabel/P13.edf", 
               "DataToLabel/P14.edf", 
               "DataToLabel/P15.edf", 
               "DataToLabel/P16.edf", 
               "DataToLabel/P17.edf", 
               "DataToLabel/P18.edf", 
               "DataToLabel/P19.edf"]

records = []
annots = []
times = []
freqs = []
for file_name in record_names:

    sigbufs, new_annot, time, freq = ReadSignal(file_name)
    records.append(sigbufs)
    annots.append(new_annot)
    times.append(time)
    freqs.append(freq)
    

In [4]:
def NormalizeAndClip(data):
    for i in tqdm(range(len(data))):
        signal = data[i]
        means = signal.mean(axis=1)[..., None]
        stds = signal.std(axis=1)[..., None]
        signal = np.clip((signal - means) / stds, a_min=-10, a_max=10)
        data[i] = signal

In [5]:
NormalizeAndClip(records)

100%|████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 26.70it/s]


In [7]:
test_data = []
for record in records:
    test_data.append(torch.FloatTensor(record))

In [8]:
class conbr_block(nn.Module):
    def __init__(self, in_layer, out_layer, kernel_size, stride, dilation):
        super(conbr_block, self).__init__()

        self.conv1 = nn.Conv1d(in_layer, out_layer, kernel_size=kernel_size, stride=stride, dilation = dilation, padding = 3, bias=True)
        self.bn = nn.BatchNorm1d(out_layer)
        self.relu = nn.ReLU()
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn(x)
        out = self.relu(x)
        
        return out       

class se_block(nn.Module):
    def __init__(self,in_layer, out_layer):
        super(se_block, self).__init__()
        
        self.conv1 = nn.Conv1d(in_layer, out_layer//8, kernel_size=1, padding=0)
        self.conv2 = nn.Conv1d(out_layer//8, in_layer, kernel_size=1, padding=0)
        self.fc = nn.Linear(1,out_layer//8)
        self.fc2 = nn.Linear(out_layer//8,out_layer)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):

        x_se = nn.functional.adaptive_avg_pool1d(x,1)
        x_se = self.conv1(x_se)
        x_se = self.relu(x_se)
        x_se = self.conv2(x_se)
        x_se = self.sigmoid(x_se)
        
        x_out = torch.add(x, x_se)
        return x_out

class re_block(nn.Module):
    def __init__(self, in_layer, out_layer, kernel_size, dilation):
        super(re_block, self).__init__()
        
        self.cbr1 = conbr_block(in_layer,out_layer, kernel_size, 1, dilation)
        self.cbr2 = conbr_block(out_layer,out_layer, kernel_size, 1, dilation)
        self.seblock = se_block(out_layer, out_layer)
    
    def forward(self,x):

        x_re = self.cbr1(x)
        x_re = self.cbr2(x_re)
        x_re = self.seblock(x_re)
        x_out = torch.add(x, x_re)
        return x_out          

class UNET_1D(nn.Module):
    def __init__(self ,input_dim,layer_n,kernel_size,depth):
        super(UNET_1D, self).__init__()
        self.input_dim = input_dim
        self.layer_n = layer_n
        self.kernel_size = kernel_size
        self.depth = depth
        
        self.AvgPool1D1 = nn.AvgPool1d(input_dim, stride=5, padding=8)
        self.AvgPool1D2 = nn.AvgPool1d(input_dim, stride=25, padding=8)
        self.AvgPool1D3 = nn.AvgPool1d(input_dim, stride=125, padding=8)
        
        self.layer1 = self.down_layer(self.input_dim, self.layer_n, self.kernel_size,1, 2)
        self.layer2 = self.down_layer(self.layer_n, int(self.layer_n*2), self.kernel_size,5, 2)
        self.layer3 = self.down_layer(int(self.layer_n*2)+int(self.input_dim), int(self.layer_n*3), self.kernel_size,5, 2)
        self.layer4 = self.down_layer(int(self.layer_n*3)+int(self.input_dim), int(self.layer_n*4), self.kernel_size,5, 2)
        self.layer5 = self.down_layer(int(self.layer_n*4)+int(self.input_dim), int(self.layer_n*5), self.kernel_size,4, 2)

        self.cbr_up1 = conbr_block(int(self.layer_n*7), int(self.layer_n*3), self.kernel_size, 1, 1)
        self.cbr_up2 = conbr_block(int(self.layer_n*5), int(self.layer_n*2), self.kernel_size, 1, 1)
        self.cbr_up3 = conbr_block(int(self.layer_n*3), self.layer_n, self.kernel_size, 1, 1)
        self.upsample = nn.Upsample(scale_factor=5, mode='nearest')
        self.upsample1 = nn.Upsample(scale_factor=5, mode='nearest') #for 4000 it is 5 and for 100 is 4
        
        self.outcov = nn.Conv1d(self.layer_n, 2, kernel_size=self.kernel_size, stride=1,padding = 3)
    
        
    def down_layer(self, input_layer, out_layer, kernel, stride, depth):
        block = []
        block.append(conbr_block(input_layer, out_layer, kernel, stride, 1))
        for i in range(depth):
            block.append(re_block(out_layer,out_layer,kernel,1))
        return nn.Sequential(*block)
            
    def forward(self, x):
        pool_x1 = self.AvgPool1D1(x)

        
        pool_x2 = self.AvgPool1D2(x)

        
        pool_x3 = self.AvgPool1D3(x)

        
        
        #############Encoder#####################
        
        out_0 = self.layer1(x)

        out_1 = self.layer2(out_0)        

        
        
        x = torch.cat([out_1,pool_x1],1)


        out_2 = self.layer3(x)

        
        x = torch.cat([out_2,pool_x2],1)

        x = self.layer4(x)

        
        
        #############Decoder####################
        
        up = self.upsample1(x)
        
        up = torch.cat([up,out_2],1)

        
        up = self.cbr_up1(up)

        
        up = self.upsample(up)
  
        
        up = torch.cat([up,out_1],1)

        
        up = self.cbr_up2(up)

        
        
        up = self.upsample(up)

        
        up = torch.cat([up,out_0],1)

        
        up = self.cbr_up3(up)
        
        out = self.outcov(up)
        
        return out

In [9]:
def CollectingPreds(model, test_data):

    model.eval()
    model.cpu()
    all_preds = []
    for i in range(len(test_data)):
        record_preds = []
        for idx in tqdm(range(OVERLAP, test_data[i].size()[1]- RECEPTIVE_FIELD - OVERLAP, RECEPTIVE_FIELD)):

            train_seq = test_data[i][:, idx-OVERLAP:idx+RECEPTIVE_FIELD+OVERLAP][None, ...]
                  
            out = model(train_seq)
            m = nn.Softmax(dim=1)
            out = m(out)
            
            preds = np.argmax(out.detach().cpu().numpy(), axis=1)
            record_preds.append(preds)
        shapes = np.array(record_preds).shape
        record_preds = np.array(record_preds).reshape(shapes[0] * shapes[1] * shapes[2])
        all_preds.append(record_preds)
    return all_preds

In [10]:
def MergeClose(predictions, threshold):
    i = 0
    in_event = False
    while i < len(predictions):
        while i < len(predictions) and predictions[i] == 1:
            in_event = True
            i += 1
        if  i < len(predictions) and in_event:
            if np.any(predictions[i:i+threshold]):
                while  i < len(predictions) and predictions[i] == 0:
                    predictions[i] = 1
                    i += 1
            else:
                in_event = False
        i += 1

def DeleteShortEvents(predictions, threshold):
    i = 0
    while i < len(predictions):
        event_len = 0
        event_idx_start = i
        while i < len(predictions) and predictions[i] == 1:
            i += 1
            event_len += 1
        if event_len < threshold:
            predictions[event_idx_start:i] = 0
        i += 1
def PostProcessing(predictions, threshold1, threshold2=None):
    MergeClose(predictions, threshold1)
    if threshold2 is None:
        DeleteShortEvents(predictions, threshold1)
    else:
        DeleteShortEvents(predictions, threshold2)

In [20]:
OVERLAP = 0
RECEPTIVE_FIELD = 4000
model = UNET_1D(20,128,7,3)
model.load_state_dict(torch.load("./CrossValidationResults/13RecordsNormalized/Split2/Unet1d"))  
all_preds = CollectingPreds(model, test_data)
threshold = 20 #low because of lower sample rate, for 500 need to up to 30
    
for j in range(len(all_preds)):
    PostProcessing(all_preds[j], threshold)

100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.11it/s]
100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.46it/s]
100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.25it/s]
100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.38it/s]
100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.64it/s]
100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.75it/s]
100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.87it/s]
100%|██████████████████████████████████████████████████████| 30/30 [00:01<00:00, 16.76it/s]


In [21]:
def CreateNewAnnotation(time_start, labels, freq): 
    freq = 1/freq
    i = 0
    label_starts = [time_start]
    label_lens = [-1]
    desc = ["StartPredictionTime"]
    while i < len(labels):
        if labels[i] == 1:
            desc.append("ModelPrediction")
            label_starts.append(time_start + i*freq)
            cur_start = i
            while labels[i] == 1:
                i += 1
            label_lens.append((i - cur_start) * freq)
        i += 1
    label_starts += [time_start + i*freq]
    label_lens += [-1]
    desc += ["EndPredictionTime"]

    return np.array(label_starts), np.array(label_lens), np.array(desc)

In [22]:
def WriteEDF(predictions, freq):
    
    for i in range(len(record_names)):
        time_start = 0
        
        preds_annotations = CreateNewAnnotation(time_start, predictions[i], freq)
        data = mne.io.read_raw_edf(record_names[i])

        preds_annotations = list(preds_annotations)
        preds_annotations[1] = np.clip(preds_annotations[1], a_min=0, a_max = None)

        old_annot = np.array([[data.annotations[i]["onset"], data.annotations[i]["duration"], data.annotations[i]["description"]] 
                      for i in range(len(data.annotations))])
        
        full_annot = np.concatenate([np.array(preds_annotations), old_annot.T], axis=1)
        annotations = mne.Annotations(np.array(full_annot)[0], np.array(full_annot)[1], np.array(full_annot)[2])
        
        data.set_annotations(annotations)
        
        data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
        data.close()
        

In [23]:
WriteEDF(all_preds, freqs[0])

Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P7.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Overwriting existing file.
Reading 0 ... 121588  =      0.000 ...   609.455 secs...
Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P13.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Overwriting existing file.
Reading 0 ... 121389  =      0.000 ...   608.458 secs...


  data = mne.io.read_raw_edf(record_names[i])
  data = mne.io.read_raw_edf(record_names[i])
  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
  data = mne.io.read_raw_edf(record_names[i])
  data = mne.io.read_raw_edf(record_names[i])


Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P14.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Overwriting existing file.
Reading 0 ... 120593  =      0.000 ...   604.468 secs...
Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P15.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Overwriting existing file.


  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
  data = mne.io.read_raw_edf(record_names[i])
  data = mne.io.read_raw_edf(record_names[i])
  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)


Reading 0 ... 120991  =      0.000 ...   606.463 secs...
Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P16.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Overwriting existing file.
Reading 0 ... 120991  =      0.000 ...   606.463 secs...
Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P17.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Overwriting existing file.
Reading 0 ... 121986  =      0.000 ...   611.450 secs...


  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
  data = mne.io.read_raw_edf(record_names[i])
  data = mne.io.read_raw_edf(record_names[i])
  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
  data = mne.io.read_raw_edf(record_names[i])
  data = mne.io.read_raw_edf(record_names[i])


Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P18.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Overwriting existing file.
Reading 0 ... 120593  =      0.000 ...   604.468 secs...
Extracting EDF parameters from /mnt/hdd-home/gromov_n/extreme-events/EEG seizure/EpiActivity/DataToLabel/P19.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
  data = mne.io.read_raw_edf(record_names[i])


Overwriting existing file.
Reading 0 ... 120593  =      0.000 ...   604.468 secs...


  data.export("DataToLabel/Preds_" + record_names[i].split("/")[1], overwrite=True)
