In [1]:
# For some reasons, file descriptors (FDs) do not get released.
# This is a work around which increases the allowed limit.
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))

In [2]:
import math
import os
import re
import random
import multiprocessing as mp
import numpy as np
from scipy.io import loadmat
import torch
import pytorch_lightning as pl
from copy import deepcopy
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.optim.swa_utils import AveragedModel, SWALR

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Batch size
train_B = 16
eval_B = 1

p_dropout = 0.2
no_temporal_kernels = 15
no_spatial_kernels = 2

SCALE = True

order = [
    'FP1', 'F7', 'T3', 'T5', 'O1',
    'F3', 'C3', 'P3',
    'FZ', 'CZ', 'PZ',
    'FP2', 'F8', 'T4', 'T6', 'O2',
    'F4', 'C4', 'P4'
]

g1 = ['PAT31', 'MC', 'FEsg', 'DHut', 'NKra', 'PAT2']
g2 = ['RC', 'MBra', 'ESow', 'PAT1', 'EG', 'FigSa', 'PAT28']
g3 = ['PAT47', 'RA', 'PAT51', 'MPi', 'GNA', 'LM', 'DG']
g4 = ['PAT33', 'PB', 'PAT22', 'MAXJ', 'PAT19', 'LP', 'HeMod', 'PAT50', 'PJuly']
g5 = ['PAT48', 'KS', 'PAT3', 'PAT8', 'LRio', 'HB', 'ALo', 'JAlv', 'PAT25', 'AR']

data_folder = 'BH average 100-sps (0.75-38 Hz)'
files_names = os.listdir(data_folder)

In [4]:
eval_group = 'g2'

if eval_group == 'g1':
    train_patients = g2 + g3 + g4 + g5
    eval_patients = g1
elif eval_group == 'g2':
    train_patients = g1 + g3 + g4 + g5
    eval_patients = g2
elif eval_group == 'g3':
    train_patients = g1 + g2 + g4 + g5
    eval_patients = g3
elif eval_group == 'g4':
    train_patients = g1 + g2 + g3 + g5
    eval_patients = g4
elif eval_group == 'g5':
    train_patients = g1 + g2 + g3 + g4
    eval_patients = g5
else:
    raise NameError('Unknown Eval Group !')

In [5]:
def prep_patient_data(patient_name):
    
    patient_segments = []
    patient_labels = []
    
    regex = re.compile(f'{patient_name}\D[0-9.a-z_A-Z]*')
    
    for file_name in files_names:
        if regex.match(file_name):
            segment = torch.tensor( loadmat(os.path.join(data_folder, file_name))['segment'] ).float()
            if 'sp' in file_name:
                label = torch.tensor(1.0, dtype=torch.float)
            elif 'ns' in file_name:
                label = torch.tensor(0.0, dtype=torch.float)
            else:
                raise NameError('Unknown label!')
            
            patient_segments.append(segment)
            patient_labels.append(label)
    
    return patient_segments, patient_labels

In [None]:
train_segments = []
train_labels = []

for p in train_patients:
    segments, labels = prep_patient_data(p)
    train_segments += segments
    train_labels += labels



eval_segments = []
eval_labels = []

for p in eval_patients:
    segments, labels = prep_patient_data(p)
    eval_segments += segments
    eval_labels += labels

In [6]:
def median_IQR_func(group_segments):
    numbers = []
    for segment in group_segments:
        for row in range(segment.shape[0]):
            for col in range(segment.shape[1]):
                numbers.append(segment[row, col].item())
    
    quantiles = torch.tensor(numbers).quantile( torch.tensor([0.25, 0.5, 0.75]) )
    median = quantiles[1].item()
    IQR = quantiles[2].item() - quantiles[0].item()
    
    return median, IQR


In [7]:
def scaler_func(group_segments, median, IQR):
    # Scales the group_segments list inplace
    for idx in range(len(group_segments)):
        group_segments[idx] = ((group_segments[idx] - median) / IQR)
    
    return group_segments


In [None]:
if SCALE:
    median, IQR = median_IQR_func(train_segments)
    train_segments = scaler_func(train_segments, median, IQR)
    eval_segments = scaler_func(eval_segments, median, IQR)

In [8]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, segments, labels, flip_prob, order):
        self.segments = segments
        self.labels = labels
        self.flip_prob = flip_prob
        
        if flip_prob > 0:
            self.idx_swaps = self.idx_swap_func(order)
    
    
    @staticmethod
    def idx_swap_func(order):
        dic = {electrode: idx for idx, electrode in enumerate(order)}
        
        label_swaps = [('FP1', 'FP2'), ('F7', 'F8'), ('F3', 'F4'), ('C3', 'C4'), ('P3', 'P4'), ('T3', 'T4'), ('T5', 'T6'), ('O1', 'O2')]
        return [(dic[e_left], dic[e_right]) for e_left, e_right in label_swaps]
    
    
    def flip_lr(self, segment):
        # flip left and right
        for e_left, e_right in self.idx_swaps:
            cloned_channel = torch.clone(segment[e_left])
            segment[e_left] = segment[e_right]
            segment[e_right] = cloned_channel
        
        return segment
    
    
    def __len__(self):
        return len(self.labels)
    
    
    def __getitem__(self, index):
        if random.random() < self.flip_prob:
            seg = self.flip_lr( torch.clone(self.segments[index]) )
        else:
            seg = self.segments[index]
        
        return seg, self.labels[index]

In [None]:
train_dataset = CustomDataset(train_segments, train_labels, 0.5, order)
eval_dataset = CustomDataset(eval_segments, eval_labels, 0, order)


In [None]:
del files_names
del train_segments, train_labels, train_patients
del eval_segments, eval_labels, eval_patients

In [9]:
class SpatioTemporalCNN(torch.nn.Module):
    def __init__(self, no_temporal_kernels=15, no_spatial_kernels=2):
        super().__init__()
        
        # (B, 1, 300, 19) ====> (B, 16, 251, 19)
        self.temporal_cnn = torch.nn.Conv2d(
            in_channels=1, out_channels=no_temporal_kernels, kernel_size=(50, 1)
        )
        
        self.module_list = torch.nn.ModuleList([
            torch.nn.Conv2d(in_channels=1, out_channels=no_spatial_kernels, kernel_size=(1,19)) \
            for c in range(no_temporal_kernels)
        ])
    
    
    
    def forward(self, x):
        # x: (B, 1, 300, 19)
        
        x = self.temporal_cnn(x)
        
        B, no_temporal_kernels, N, E = x.shape         # N: 251, E: no_electrodes (19)
        
        no_spatial_kernels = self.module_list[0].out_channels
        
        y = torch.empty(
            (B, no_spatial_kernels, N, no_temporal_kernels),
            device=x.device
        )
        for c in range(no_temporal_kernels):
            input_tensor = x[:, c:c+1, :, :]                      # (B, 1, N, E)
            output_tensor = self.module_list[c](input_tensor)     # (B, no_spatial_kernels, N, 1)
            y[:, :, :, c:c+1] = output_tensor.clone()
        
        y.transpose_(-2, -3)   # (B, N, no_spatial_kernels, no_temporal_kernels)
        y.transpose_(-1, -2)   # (B, N, no_temporal_kernels, no_spatial_kernels)
        
        return y.flatten(-2, -1)        # (B, N, no_temporal_kernels * no_spatial_kernels)
    

In [10]:
class MaskedMultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, no_input_features, H, d_qk, d_v, dim, mask=None, bias_att=False, bias_out=False):
        super().__init__()
        self.H = H
        self.d_qk = d_qk
        self.d_v = d_v
        
        # dim: dimension across which to attend
        if dim not in {-1, -2}:
            raise NameError('Error: Enforce your attention dimension to either -1 or -2 !')
        self.dim = dim
        
        # Most probably mask will be a string ("all_but_self", "no_future", or "no_present_no_future").
        # But it can be an actual mask though of shape (E, E) or (T, T)
        self.mask = mask
        
        self.att_lin = torch.nn.Linear(in_features=no_input_features, out_features=(2 * H * d_qk) + (H * d_v), bias=bias_att)
        self.out_lin = torch.nn.Linear(in_features=H*d_v, out_features=no_input_features, bias=bias_out)
        
    
    
    @staticmethod
    def qkv_func(q, k, v, mask=None):
        # no_electrodes: E or T
        # q: (B, H, E or T, d_qk)
        # k: (B, H, E or T, d_qk)
        # v: (B, H, E or T, d_v)
        
        B, H, E, d_qk = q.shape
        d_v = v.shape[-1]
        
        # logits: (B, H, E or T, E or T)
        logits = torch.matmul(q, k.transpose(-2, -1))
        
        if mask is not None:
            if isinstance(mask, str):
                if mask == 'all_but_self':
                    mask = ( torch.eye(E, device=logits.device) == 0 )
                elif mask == 'no_future':
                    mask = ( torch.ones((E, E), device=logits.device).tril() == 1 )
                elif mask == 'no_present_no_future':
                    mask = ( torch.ones((E, E), device=logits.device).tril(-1) == 1 )
                else:
                    raise NameError(f'Error: Available mask strings are ("all_but_self", "no_future", "no_present_no_future") but "{mask}" was given!')
            
            # mask: (E, E) or (T, T)
            # BUT, mask shape MUST be (B, H, E, E) or (B, H, T, T)
            mask = mask.expand(B, H, E, E)
            logits[mask == False] = -9e15
        
        s_m = logits.softmax(dim=-1)     # s_m: (B, H, E, E) or (B, H, T, T)
        
        # (B, H, E, E) * (B, H, E, d_v) = (B, H, E, d_v) ===permute===> (B, E, H, d_v) ===reshape===> (B, E, H * d_v)
        # or
        # (B, H, T, T) * (B, H, T, d_v) = (B, H, T, d_v) ===transpose===> (B, T, H, d_v) ===reshape===> (B, T, H * d_v)
        return torch.matmul(s_m, v).transpose(-2,-3).reshape((B, E, H * d_v))
    
    
    def forward(self, x, mask=None):
        if self.dim == -1:
            # x: (B, C=no_input_features, T or E)
            x.transpose_(-1, -2)
        
        # x: (B, T or E, C=no_input_features)
        B, T, no_input_features = x.shape
        
        # (B, T or E, C=no_input_features) ====self.att_lin====> ( B, T or E, (H * d_qk) + (H * d_qk) + (H * d_v) )
        q, k, v = self.att_lin(x).split(split_size=[self.H * self.d_qk, self.H * self.d_qk, self.H * self.d_v], dim=-1)
        
        q = q.reshape((B, T, self.H, self.d_qk)).transpose(-2, -3)     # q: (B, self.H, T or E, self.d_qk)
        k = k.reshape((B, T, self.H, self.d_qk)).transpose(-2, -3)     # k: (B, self.H, T or E, self.d_qk)
        v = v.reshape((B, T, self.H, self.d_v)).transpose(-2, -3)      # v: (B, self.H, T or E, self.d_v)
        
        # priority in the mask is the one used with the forward function
        if mask is None:
            mask = self.mask
        
        y = self.out_lin(self.qkv_func(q, k, v, mask=mask))           # (B, T or E, C=no_input_features)
        
        if self.dim == -1:
            x.transpose_(-1, -2)                                       # (B, C=no_input_features, T or E)
            y.transpose_(-1, -2)                                       # (B, C=no_input_features, T or E)
        
        # overall, x preserves its inputted shape & y has the same shape as the inputted x
        return y

In [11]:
class Block(torch.nn.Module):
    def __init__(self, kernel_size, no_input_features, no_output_features, H, d_qk, d_v, dim, mask=None, bias_att=False, bias_out=False, p_dropout=0.2):
        super().__init__()
        
        # Most probably mask will be a string ("all_but_self", "no_future", or "no_present_no_future").
        # But it can be an actual mask though of shape (E, E) or (T, T)
        self.mask = mask
        
        # Regardless of the provided dim (dimension across which to attend),
        # permute the i/p tensor in the forward method to make that dim (T or E) -2
        self.dim = dim
        
        self.att = MaskedMultiHeadSelfAttention(no_input_features, H, d_qk, d_v, -2, mask=mask, bias_att=bias_att, bias_out=bias_out)
        self.batch_norm_att = torch.nn.BatchNorm1d(no_input_features)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p_dropout)
        
        self.lin = torch.nn.Linear(no_input_features, no_output_features)
        self.batch_norm_lin = torch.nn.BatchNorm1d(no_input_features)
        
        self.maxpool = torch.nn.MaxPool1d(kernel_size)
    
    
    
    def forward(self, x, mask=None):
        if self.dim == -1:
            # x: (B, C=no_input_features, T or E)
            x.transpose_(-1,-2)
        
        # x: (B, T or E, no_input_features)
        
        if mask is None:
            mask = self.mask
        
        y = self.att(x, mask)          # (B, T or E, no_input_features)
        y.transpose_(-1, -2)           # (B, no_input_features, T or E)
        y = self.batch_norm_att(y)     # (B, no_input_features, T or E)
        y.transpose_(-1, -2)           # (B, T or E, no_input_features)
        y = self.dropout(y)            # (B, T or E, no_input_features)
        
        x = x + y
        
        # x: (B, T or E, no_input_features)
        
        y = self.lin(x)                # (B, T or E, no_output_features)
        y.transpose_(-1, -2)           # (B, no_output_features, T or E)
        y = self.batch_norm_lin(y)     # (B, no_output_features, T or E)
        y.transpose_(-1, -2)           # (B, T or E, no_output_features)
        y = self.relu(y)
        y = self.dropout(y)            # (B, T or E, no_output_features)
        
        # Pad the smaller tensor from x and y with zeros to sum match their sizes.
        
        B, T, no_input_features = x.shape
        no_output_features = y.shape[2]
        
        z = torch.zeros((B, T, abs(no_output_features - no_input_features)), device=x.device)
        
        if no_output_features > no_input_features:
            x = torch.cat((x, z), dim=-1)
        else:
            y = torch.cat((y, z), dim=-1)
        
        # max_no_features = max(no_output_features, no_input_features)
        
        x = x + y                   # (B, T, max_no_features)
        
        x.transpose_(-1, -2)     # (B, max_no_features, T)
        # The last dimension gets reduced by the maxpool layer (T ===> T')
        x = self.maxpool(x)      # (B, max_no_features, T')
        x.transpose_(-1, -2)     # (B, T', max_no_features)
        
        if self.dim == -1:
            x.transpose_(-1, -2)     # (B, max_no_features, T')
        
        return x

In [12]:
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_num_iters):
        self.warmup = warmup
        self.max_num_iters = max_num_iters
        super().__init__(optimizer)
    
    
    def get_lr(self):
        # returns a list of the learning rates
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]
    
    
    def get_lr_factor(self, epoch):
        # Optional method that computes lr_factor, NOT learning rate itself.
        if epoch <= self.warmup:
            slope = (1 - 0) / (self.warmup - 0)
            lr_factor = slope * epoch
        else:
            f = 1 / (2 * (self.max_num_iters - self.warmup))
            lr_factor = 0.5 * ( 1 + math.cos( 2 * math.pi * f * (epoch - self.warmup) ) )
        
        return lr_factor

In [13]:
min_validation_loss = math.inf
best_epoch = 0
best_model_state = None
AP_at_min_val_loss = 0.

In [14]:
class CustomLightningModule(pl.LightningModule):
    def __init__(self, p_dropout, no_temporal_kernels, no_spatial_kernels):
        super().__init__()
        
        self.spatiotemporal_cnn = SpatioTemporalCNN(
            no_temporal_kernels=no_temporal_kernels,
            no_spatial_kernels=no_spatial_kernels
        )
        
        self.batch_norm = torch.nn.BatchNorm1d(
            no_temporal_kernels * no_spatial_kernels
        )
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p_dropout)
        self.maxpool = torch.nn.MaxPool1d(4)
        
        
        # kernel_size, no_input_features, no_output_features, H, d_qk, d_v, dim, mask=None, bias_att=False, bias_out=False, p_dropout=0.2
        self.block_1 = Block(
            4, 30, 30, 1, 30, 30, -2, mask=None,
            bias_att=False, bias_out=False, p_dropout=p_dropout
        )
        self.block_2 = Block(
            5, 30, 30, 1, 30, 30, -2, mask=None,
            bias_att=False, bias_out=False, p_dropout=p_dropout
        )
        
        self.class_lin = torch.nn.Linear(in_features=3*30, out_features=1)
        
        self.criterion = torch.nn.BCEWithLogitsLoss()
    
    
    
    def forward(self, segment):
        # segment: (B, 19, 300)
        B = segment.shape[0]
        #torch._assert(segment.shape==(B,19,300), f'segment.shape = {segment.shape}')
        
        segment.transpose_(-1, -2)                    # (B, 300, 19)
        segment.unsqueeze_(1)                         # (B, 1, 300, 19)
        x = self.spatiotemporal_cnn(segment)          # (B, 251, 32)
        
        x.transpose_(-1, -2)                          # (B, 32, 251)
        x = self.batch_norm(x)                        # (B, 32, 251)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.maxpool(x)                           # (B, 32, 251/4 = 62)
        x.transpose_(-1, -2)                          # (B, 62, 32)
        
        x = self.block_1(x)                           # (B, 62/4 = 15, 32)
        x = self.block_2(x)                           # (B, 15/5 = 3, 32)
        
        x = x.flatten(-2, -1)                         # (B, 90)
        
        return self.class_lin(x).reshape((B,))
    
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=0)
        cosine_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=100, max_num_iters=self.trainer.max_epochs)
        swa_scheduler = SWALR(optimizer, swa_lr=0.0001, anneal_epochs=100, anneal_strategy='cos')
        return [optimizer], [cosine_scheduler, swa_scheduler]
    
    
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            train_dataset, batch_size=train_B, shuffle=True, num_workers=0, drop_last=True
        )
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        logits = self(x)
        loss = self.criterion(logits, y)
        
        return {'loss': loss}
    
    
    
    def training_epoch_end(self, training_step_outputs):
        if self.global_rank == 0:
            train_loss = torch.stack([dic['loss'] for dic in training_step_outputs]).mean()
            print(f'Epoch: {self.current_epoch}')
            print(f'Training loss BEFORE optimizer step: {round(train_loss.item(), 5)}')
            print('----------------------------------------------------------------------------')
    
    
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            eval_dataset, batch_size=eval_B, shuffle=False, num_workers=0, drop_last=False
        )
    
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        logits = self(x)
        loss = self.criterion(logits, y)
        return {'loss': loss, 'y': y, 'logits': logits}
    
    
    
    def validation_epoch_end(self, validation_step_outputs):
        global min_validation_loss, best_epoch, best_model_state, AP_at_min_val_loss
        
        swa_start_epoch = 1000
        
        cosine_scheduler, swa_scheduler = self.lr_schedulers()
        
        if self.current_epoch > 0:
            if self.current_epoch < swa_start_epoch:
                cosine_scheduler.step()
            else:
                swa_scheduler.step()
        
        
        
        if self.global_rank == 0:
            
            if self.current_epoch > (swa_start_epoch + swa_scheduler.anneal_epochs):
                swa_model.update_parameters(self)
                
            
            val_loss = torch.stack([dic['loss'] for dic in validation_step_outputs]).mean()
            val_loss = val_loss.item()
            print(f'val_loss: {val_loss}')
            
            y = torch.cat([dic['y'] for dic in validation_step_outputs]).cpu()
            logits = torch.cat([dic['logits'] for dic in validation_step_outputs]).cpu()
            pred = torch.sigmoid(logits)
            
            y = np.array(y)
            pred = np.array(pred)
            
            try:
                validation_AP = average_precision_score(y, pred)
                
                if (val_loss < min_validation_loss) and (self.current_epoch > 1):
                    min_validation_loss = val_loss
                    best_epoch = self.current_epoch
                    best_model_state = deepcopy(self.state_dict())
                    AP_at_min_val_loss = validation_AP
                
                
                print(f'validation_AP = {round(validation_AP, 5)} ... AP_at_min_val_loss = {round(AP_at_min_val_loss, 5)} @ epoch {best_epoch}')
                
            except ValueError:
                print(f'ValueError @ epoch: {self.current_epoch} ====> y = {y}')
            
            print('=======================================================================================')
            

In [None]:
my_lightning_module = CustomLightningModule(p_dropout, no_temporal_kernels, no_spatial_kernels)

swa_model = AveragedModel(my_lightning_module)

trainer = pl.Trainer(gpus=0, enable_checkpointing=False, enable_progress_bar=False, logger=False, max_epochs=500)

trainer.fit(my_lightning_module)

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(my_lightning_module.train_dataloader(), swa_model)

In [None]:
PATH = f'models/{eval_group}_satelight_{round(100 * AP_at_min_val_loss, 2)}.pt'
torch.save(best_model_state, PATH)

In [None]:
g1
==
maximum validation_AP = 0.86006 @ val_loss: 0.3857518136501312 @ train_loss: 0.02103

Epoch: 125
Training loss BEFORE optimizer step: 0.08427
----------------------------------------------------------------------------
val_loss: 0.29432666301727295
validation_AP = 0.80855 ... AP_at_min_val_loss = 0.80855 @ epoch 126
====================================================================

Epoch: 223
Training loss BEFORE optimizer step: 0.10962
----------------------------------------------------------------------------
val_loss: 0.27389994263648987
validation_AP = 0.84424 ... AP_at_min_val_loss = 0.84424 @ epoch 224
====================================================================



g2
==
maximum validation_AP = 0.89437 @ val_loss: 0.3216995894908905 @ train_loss: 0.03656

Epoch: 158
Training loss BEFORE optimizer step: 0.04207
--------------------------------------------------------------------
val_loss: 0.261417955160141
validation_AP = 0.88421 ... AP_at_min_val_loss = 0.88421 @ epoch 159
====================================================================



g3
==
maximum validation_AP = 0.91116 @ val_loss: 0.28639698028564453 @ train_loss: 0.08264
validation_AP = 0.9136 @ val_loss: 0.3856848478317261 @ train_loss: 0.0163
maximum validation_AP = 0.91742 @ val_loss: 0.33682939410209656 @ train_loss: 0.0261

Epoch: 81
Training loss BEFORE optimizer step: 0.10821
-------------------------------------------------------------------
val_loss: 0.22748209536075592
validation_AP = 0.89869 ... AP_at_min_val_loss = 0.89869 @ epoch 82
===================================================================

Epoch: 108
Training loss BEFORE optimizer step: 0.08981
----------------------------------------------------------------------------
val_loss: 0.23838919401168823
validation_AP = 0.90162 ... AP_at_min_val_loss = 0.90162 @ epoch 109
====================================================================



g4
==
validation_AP = 0.94751 @ val_loss: 0.2212197333574295 @ train_loss: 0.01003

Epoch: 235
Training loss BEFORE optimizer step: 0.03781
----------------------------------------------------------------------------
val_loss: 0.15423248708248138
validation_AP = 0.95322 ... AP_at_min_val_loss = 0.95322 @ epoch 236
====================================================================



g5
==
Epoch: 184
Training loss BEFORE optimizer step: 0.01027
---------------------------------------------------------------------
val_loss: 0.14156535267829895
validation_AP = 0.96198 ... AP_at_min_val_loss = 0.96198 @ epoch 185
=====================================================================

In [29]:
import pandas as pd


model = CustomLightningModule(p_dropout, no_temporal_kernels, no_spatial_kernels)

#IEDs_names = [item for item in os.listdir(data_folder) if 'sp' in item]
NIEDs_names = [item for item in os.listdir(data_folder) if 'ns' in item]

In [30]:
def prepare_dataset(patients_names):
    group_segments = []
    group_labels = []
    
    for patient_name in patients_names:
        patient_segments, patient_labels = prep_patient_data(patient_name)
        group_segments += patient_segments
        group_labels += patient_labels
    
    return group_segments, group_labels

In [31]:
lis = [item for item in os.listdir('models/') if 'g1_satelight' in item]
g1_model_name = max(lis)
train_segments, _ = prepare_dataset(g2 + g3 + g4 + g5)
g1_median, g1_IQR = median_IQR_func(train_segments)

lis = [item for item in os.listdir('models/') if 'g2_satelight' in item]
g2_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g3 + g4 + g5)
g2_median, g2_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g3_satelight' in item]
g3_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g4 + g5)
g3_median, g3_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g4_satelight' in item]
g4_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g3 + g5)
g4_median, g4_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g5_satelight' in item]
g5_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g3 + g4)
g5_median, g5_IQR = median_IQR_func(train_segments)

In [32]:
#s = pd.Series(index=IEDs_names, dtype='float64')
s = pd.Series(index=NIEDs_names, dtype='float64')

In [33]:
#for IED_name in IEDs_names:
for NIED_name in NIEDs_names:
    #segment = loadmat(os.path.join(data_folder, IED_name))['segment']
    segment = loadmat(os.path.join(data_folder, NIED_name))['segment']
    segment = torch.tensor(segment).float()
    
    #patient_name = IED_name[: IED_name.find('_')]
    patient_name = NIED_name[: NIED_name.find('_')]
    
    if patient_name in g1:
        segment = (segment - g1_median) / g1_IQR
        model.load_state_dict(torch.load(f'models/{g1_model_name}'))
    elif patient_name in g2:
        segment = (segment - g2_median) / g2_IQR
        model.load_state_dict(torch.load(f'models/{g2_model_name}'))
    elif patient_name in g3:
        segment = (segment - g3_median) / g3_IQR
        model.load_state_dict(torch.load(f'models/{g3_model_name}'))
    elif patient_name in g4:
        segment = (segment - g4_median) / g4_IQR
        model.load_state_dict(torch.load(f'models/{g4_model_name}'))
    elif patient_name in g5:
        segment = (segment - g5_median) / g5_IQR
        model.load_state_dict(torch.load(f'models/{g5_model_name}'))
    else:
        raise NameError('Unknown Patient')
    
    
    segment.unsqueeze_(0)
    model.eval()
    
    #s.at[IED_name] = torch.sigmoid(model(segment)).item()
    s.at[NIED_name] = torch.sigmoid(model(segment)).item()

In [34]:
s

RC_24_ns.mat      0.000196
FEsg_38_ns.mat    0.032728
HB_2_ns.mat       0.000244
FEsg_26_ns.mat    0.054695
FigSa_9_ns.mat    0.013867
                    ...   
LM_3_ns.mat       0.035856
RC_6_ns.mat       0.002633
FEsg_19_ns.mat    0.096821
NKra_10_ns.mat    0.082357
PAT48_ns_6.mat    0.001178
Length: 593, dtype: float64

In [35]:
#with pd.ExcelWriter('IEDs_results.xlsx', mode='a') as writer:
with pd.ExcelWriter('NIEDs_results.xlsx', mode='a') as writer:
    s.to_excel(writer, sheet_name='satelight', header=False)