In [1]:
# For some reasons, file descriptors (FDs) do not get released.
# This is a work around which increases the allowed limit.
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))

In [2]:
import math
import os
import re
import random
import multiprocessing as mp
import numpy as np
from scipy.io import loadmat
import torch
import pytorch_lightning as pl
from sklearn.metrics import roc_auc_score, average_precision_score
from copy import deepcopy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Batch size
train_B = 16
val_B = 1

SCALE = True
data_folder = 'BH average 100-sps (0.75-38 Hz)'
files_names = os.listdir(data_folder)

g1 = ['PAT31', 'MC', 'FEsg', 'DHut', 'NKra', 'PAT2']
g2 = ['RC', 'MBra', 'ESow', 'PAT1', 'EG', 'FigSa', 'PAT28']
g3 = ['PAT47', 'RA', 'PAT51', 'MPi', 'GNA', 'LM', 'DG']
g4 = ['PAT33', 'PB', 'PAT22', 'MAXJ', 'PAT19', 'LP', 'HeMod', 'PAT50', 'PJuly']
g5 = ['PAT48', 'KS', 'PAT3', 'PAT8', 'LRio', 'HB', 'ALo', 'JAlv', 'PAT25', 'AR']

In [None]:
train_patients = g2 + g3 + g4 + g5
val_patients = g1

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, segments, labels, flip_prob, montage):
        self.segments = segments
        self.labels = labels
        self.flip_prob = flip_prob
        
        if flip_prob > 0:
            self.idx_swaps = self.montage_swap_func(montage)
    
    
    @staticmethod
    def montage_swap_func(montage):
        if montage == 'BH':
            electrodes = ['Fp1','F7','T3','T5','O1','F3','C3','P3','Fz','Cz','Pz','Fp2','F8','T4','T6','O2','F4','C4','P4']
        elif montage == 'MCH':
            electrodes = ['C3','C4','O1','O2','Cz','F3','F4','F7','F8','Fz','Fp1','Fp2','P3','P4','Pz','T3','T4','T5','T6']
        else:
            raise NameError('Unavailable Montage')
        
        dic = {electrode: idx for idx, electrode in enumerate(electrodes)}
        
        label_swaps = [('Fp1', 'Fp2'), ('F7', 'F8'), ('F3', 'F4'), ('C3', 'C4'), ('P3', 'P4'), ('T3', 'T4'), ('T5', 'T6'), ('O1', 'O2')]
        return [(dic[e_left], dic[e_right]) for e_left, e_right in label_swaps]
    
    
    def flip_lr(self, segment):
        # flip based on self.montage
        for e_left, e_right in self.idx_swaps:
            cloned_channel = torch.clone(segment[e_left])
            segment[e_left] = segment[e_right]
            segment[e_right] = cloned_channel
        
        return segment
    
    
    def __len__(self):
        return len(self.labels)
    
    
    def __getitem__(self, index):
        if random.random() < self.flip_prob:
            return self.flip_lr(torch.clone(self.segments[index])), self.labels[index]
        else:
            return self.segments[index], self.labels[index]

In [4]:
def prep_patient_data(patient_name):
    
    patient_segments = []
    patient_labels = []
    #patient_filesnames = []
    
    regex = re.compile(f'{patient_name}\D[0-9.a-z_A-Z]*')
    
    for file_name in files_names:
        if regex.match(file_name):
            segment = torch.tensor( loadmat(os.path.join(data_folder, file_name))['segment'] ).float()
            if 'sp' in file_name:
                label = torch.tensor(1.0).float()
            elif 'ns' in file_name:
                label = torch.tensor(0.0).float()
            
            patient_segments.append(segment)
            patient_labels.append(label)
            #patient_filesnames.append(file_name)
    
    #return patient_segments, patient_labels, patient_filesnames
    return patient_segments, patient_labels

In [5]:
def prepare_dataset(patients_names):
    group_segments = []
    group_labels = []
    #group_filesnames = []
    
    #with mp.Pool(processes=mp.cpu_count() - 1) as p:
    #    #for patient_segments, patient_labels, patient_filesnames in p.imap_unordered(prep_patient_data, patients_names):
    #    for patient_segments, patient_labels in p.imap_unordered(prep_patient_data, patients_names):
    #        group_segments += patient_segments
    #        group_labels += patient_labels
    #        #group_filesnames += patient_filesnames
    
    for patient_name in patients_names:
        patient_segments, patient_labels = prep_patient_data(patient_name)
        group_segments += patient_segments
        group_labels += patient_labels
    
    #return group_segments, group_labels, group_filesnames
    return group_segments, group_labels

In [None]:
#train_segments, train_labels, train_filesnames = prepare_dataset(train_patients)
#val_segments, val_labels, val_filesnames = prepare_dataset(val_patients)
train_segments, train_labels = prepare_dataset(train_patients)
val_segments, val_labels = prepare_dataset(val_patients)

In [6]:
def median_IQR_func(group_segments):
    numbers = []
    for segment in group_segments:
        for row in range(segment.shape[0]):
            for col in range(segment.shape[1]):
                numbers.append(segment[row, col].item())
    
    quantiles = torch.tensor(numbers).quantile( torch.tensor([0.25, 0.5, 0.75]) )
    median = quantiles[1].item()
    IQR = quantiles[2].item() - quantiles[0].item()
    
    return median, IQR

In [None]:
def scaler_func(group_segments, median, IQR):
    # Scales the group_segments list inplace
    for idx in range(len(group_segments)):
        group_segments[idx] = ((group_segments[idx] - median) / IQR)
    
    return group_segments

In [None]:
if SCALE:
    median, IQR = median_IQR_func(train_segments)
    train_segments = scaler_func(train_segments, median, IQR)
    val_segments = scaler_func(val_segments, median, IQR)

#train_dataset = CustomDataset(train_segments, train_labels, train_filesnames)
#val_dataset = CustomDataset(val_segments, val_labels, val_filesnames)
train_dataset = CustomDataset(train_segments, train_labels, 0.5, 'BH')
val_dataset = CustomDataset(val_segments, val_labels, 0., 'BH')

In [None]:
del files_names, train_segments, train_labels, val_segments, val_labels

In [None]:
min_validation_loss = math.inf
best_epoch = 0
best_model_state = None
AP_at_min_val_loss = 0.
EARLY_STOPPING = False


In [7]:
# With no padding: L_out = ( L_in - (L_kernel - 1) ) / stride
class CustomLightningModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # input: (B, 19, 300)
        self.seq = torch.nn.Sequential(
            # input: (B, 19, 300) ===> ( B, 32, (300 - (3 - 1)) / 1 ) = (B, 32, 298)
            torch.nn.Conv1d(in_channels=19, out_channels=32, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),
            # torch.nn.Dropout(p=0.3),

            # (B, 32, 298) ===> ( B, 32, (298 - (3 - 1)) / 1 ) = (B, 32, 296)
            torch.nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 32, 296) ===> (B, 32, 148)
            #torch.nn.MaxPool1d(kernel_size=2, stride=2),
            torch.nn.Conv1d(in_channels=32, out_channels=32, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),

            # (B, 32, 148) ===> ( B, 16, (148 - (3 - 1)) / 1 ) = (B, 16, 146)
            torch.nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),
            # torch.nn.Dropout(p=0.3),

            # (B, 16, 146) ===> ( B, 16, (146 - (3 - 1)) / 1 ) = (B, 16, 144)
            torch.nn.Conv1d(in_channels=16, out_channels=16, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 16, 144) ===> (B, 16, 72)
            #torch.nn.MaxPool1d(kernel_size=2, stride=2),
            torch.nn.Conv1d(in_channels=16, out_channels=16, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),

            # (B, 16, 72) ===> ( B, 8, (72 - (3 - 1)) / 1 ) = (B, 8, 70)
            torch.nn.Conv1d(in_channels=16, out_channels=8, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),
            # torch.nn.Dropout(p=0.3),

            # (B, 8, 70) ===> ( B, 8, (70 - (3 - 1)) / 1 ) = (B, 8, 68)
            torch.nn.Conv1d(in_channels=8, out_channels=8, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 8, 68) ===> (B, 8, 34)
            #torch.nn.MaxPool1d(kernel_size=2, stride=2),
            torch.nn.Conv1d(in_channels=8, out_channels=8, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),

            # (B, 8, 34) ===> ( B, 4, (34 - (3 - 1)) / 1 ) = (B, 4, 32)
            torch.nn.Conv1d(in_channels=8, out_channels=4, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),
            # torch.nn.Dropout(p=0.3),

            # (B, 4, 32) ===> ( B, 4, (32 - (3 - 1)) / 1 ) = (B, 4, 30)
            torch.nn.Conv1d(in_channels=4, out_channels=4, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 4, 30) ===> (B, 4, 15)
            #torch.nn.MaxPool1d(kernel_size=2, stride=2),
            torch.nn.Conv1d(in_channels=4, out_channels=4, kernel_size=2, stride=2),
            torch.nn.LeakyReLU(),

            # (B, 4, 15) ===> ( B, 2, (15 - (3 - 1)) / 1 ) = (B, 2, 13)
            torch.nn.Conv1d(in_channels=4, out_channels=2, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),
            # torch.nn.Dropout(p=0.3),

            # (B, 2, 13) ===> ( B, 2, (13 - (3 - 1)) / 1 ) = (B, 2, 11)
            torch.nn.Conv1d(in_channels=2, out_channels=2, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # CANNOT DO MAXPOOL

            # (B, 2, 11) ===> ( B, 1, (11 - (3 - 1)) / 1 ) = (B, 1, 9)
            torch.nn.Conv1d(in_channels=2, out_channels=1, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 1, 9) ===> ( B, 1, (9 - (3 - 1)) / 1 ) = (B, 1, 7)
            torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 1, 7) ===> ( B, 1, (7 - (3 - 1)) / 1 ) = (B, 1, 5)
            torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 1, 5) ===> ( B, 1, (5 - (3 - 1)) / 1 ) = (B, 1, 3)
            torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, padding='valid'),
            torch.nn.LeakyReLU(),

            # (B, 1, 3) ===> ( B, 1, (3 - (3 - 1)) / 1 ) = (B, 1, 1)
            torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, padding='valid')
        )
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.automatic_optimization = False

    
    def forward(self, x):
        # x: (B, 19, 300)
        B = x.shape[0]
        return self.seq(x).reshape((B,))
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=0)
        return [optimizer], [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.99, min_lr=0.00001)]
    
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(train_dataset, batch_size=train_B, shuffle=True,
                                           num_workers=0, drop_last=True)
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch

        #torch._assert(x.shape == (train_B, 19, 300) , f'x.shape = {x.shape} BUT it should be ({train_B}, 19, 300)')
        
        self.eval()
        with torch.no_grad():
            logits_before_optimizer_step = self(x)
            loss_before_optimizer_step = self.criterion(logits_before_optimizer_step, y)
        
        self.train()
        
        # Note that this is NOT a pytorch optimizer. I think it's a wrapper
        opt = self.optimizers()

        def closure():
            logits = self(x)

            #torch._assert(logits.shape == y.shape , f'@ Training ====> logits.shape = {logits.shape} != y.shape = {y.shape}')

            loss = self.criterion(logits, y)
            opt.zero_grad()
            self.manual_backward(loss)
            return loss

        opt.step(closure=closure)

        # I think since I am doing manual optimization, there is no need to return anything
        return {'loss_before_optimizer_step': loss_before_optimizer_step.detach()}
    
    
    def training_epoch_end(self, training_step_outputs):
        if self.global_rank == 0:
            train_loss = torch.stack([dic['loss_before_optimizer_step'] for dic in training_step_outputs]).mean()
            print(f'Epoch: {self.current_epoch}')
            print(f'Training loss BEFORE optimizer step: {round(train_loss.item(), 5)}')
            print('--------------------------------------')
    
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(val_dataset, batch_size=val_B, shuffle=False,
                                           num_workers=0, drop_last=False)
    
    
    def validation_step(self, batch, batch_idx):
        x, y = batch

        #torch._assert(x.shape == (val_B, 19, 300) , f'x.shape = {x.shape} BUT it should be ({val_B}, 19, 300)')

        logits = self(x)

        #torch._assert(logits.shape == y.shape , f'@ Val ====> logits.shape = {logits.shape} != y.shape = {y.shape}')

        loss = self.criterion(logits, y)

        return {'loss': loss, 'y': y, 'logits': logits}
    
    
    def validation_epoch_end(self, validation_step_outputs):
        global min_validation_loss, best_epoch, best_model_state, AP_at_min_val_loss
        
        val_loss = torch.stack([dic['loss'] for dic in validation_step_outputs]).mean()
        
        
        if self.global_rank == 0:
            
            val_loss = torch.stack([dic['loss'] for dic in validation_step_outputs]).mean()
            val_loss = val_loss.item()
            print(f'val_loss: {val_loss}')
            
            y = torch.cat([dic['y'] for dic in validation_step_outputs]).cpu()
            logits = torch.cat([dic['logits'] for dic in validation_step_outputs]).cpu()
            pred = torch.sigmoid(logits)
            
            y = np.array(y)
            pred = np.array(pred)
            
            try:
                validation_AP = average_precision_score(y, pred)
                
                if (val_loss < min_validation_loss) and (self.current_epoch > 1):
                    min_validation_loss = val_loss
                    best_epoch = self.current_epoch
                    best_model_state = deepcopy(self.state_dict())
                    AP_at_min_val_loss = validation_AP
                
                
                print(f'validation_AP = {round(validation_AP, 5)} ... AP_at_min_val_loss = {round(AP_at_min_val_loss, 5)} @ epoch {best_epoch}')
                
            except ValueError:
                print(f'ValueError @ epoch: {self.current_epoch} ====> y = {y}')
            
            print('=======================================================================================')
            

In [None]:
my_lightning_module = CustomLightningModule()
trainer = pl.Trainer(gpus=0, enable_checkpointing=False, enable_progress_bar=False, logger=False, max_epochs=500)
trainer.fit(my_lightning_module)

In [None]:
PATH = f'models/g1_vanilla_vgg_10k_AP_{round(100 * AP_at_min_val_loss, 2)}.pt'
torch.save(best_model_state, PATH)

In [8]:
import pandas as pd


model = CustomLightningModule()

#IEDs_names = [item for item in os.listdir(data_folder) if 'sp' in item]
NIEDs_names = [item for item in os.listdir(data_folder) if 'ns' in item]


In [9]:
lis = [item for item in os.listdir('models/') if 'g1_vanilla_vgg' in item]
g1_model_name = max(lis)
train_segments, _ = prepare_dataset(g2 + g3 + g4 + g5)
g1_median, g1_IQR = median_IQR_func(train_segments)

lis = [item for item in os.listdir('models/') if 'g2_vanilla_vgg' in item]
g2_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g3 + g4 + g5)
g2_median, g2_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g3_vanilla_vgg' in item]
g3_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g4 + g5)
g3_median, g3_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g4_vanilla_vgg' in item]
g4_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g3 + g5)
g4_median, g4_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g5_vanilla_vgg' in item]
g5_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g3 + g4)
g5_median, g5_IQR = median_IQR_func(train_segments)


#s = pd.Series(index=IEDs_names, dtype='float64')
s = pd.Series(index=NIEDs_names, dtype='float64')


#for IED_name in IEDs_names:
for NIED_name in NIEDs_names:
    #segment = loadmat(os.path.join(data_folder, IED_name))['segment']
    segment = loadmat(os.path.join(data_folder, NIED_name))['segment']
    segment = torch.tensor(segment).float()
    
    #patient_name = IED_name[: IED_name.find('_')]
    patient_name = NIED_name[: NIED_name.find('_')]
    
    if patient_name in g1:
        segment = (segment - g1_median) / g1_IQR
        model.load_state_dict(torch.load(f'models/{g1_model_name}'))
    elif patient_name in g2:
        segment = (segment - g2_median) / g2_IQR
        model.load_state_dict(torch.load(f'models/{g2_model_name}'))
    elif patient_name in g3:
        segment = (segment - g3_median) / g3_IQR
        model.load_state_dict(torch.load(f'models/{g3_model_name}'))
    elif patient_name in g4:
        segment = (segment - g4_median) / g4_IQR
        model.load_state_dict(torch.load(f'models/{g4_model_name}'))
    elif patient_name in g5:
        segment = (segment - g5_median) / g5_IQR
        model.load_state_dict(torch.load(f'models/{g5_model_name}'))
    else:
        raise NameError('Unknown Patient')
    
    
    segment.unsqueeze_(0)
    model.eval()
    
    #s.at[IED_name] = torch.sigmoid(model(segment)).item()
    s.at[NIED_name] = torch.sigmoid(model(segment)).item()

In [10]:
s

RC_24_ns.mat      0.005738
FEsg_38_ns.mat    0.000087
HB_2_ns.mat       0.050953
FEsg_26_ns.mat    0.103463
FigSa_9_ns.mat    0.004070
                    ...   
LM_3_ns.mat       0.058508
RC_6_ns.mat       0.004121
FEsg_19_ns.mat    0.003594
NKra_10_ns.mat    0.000157
PAT48_ns_6.mat    0.029589
Length: 593, dtype: float64

In [11]:
#with pd.ExcelWriter('IEDs_results.xlsx', mode='a') as writer:
with pd.ExcelWriter('NIEDs_results.xlsx', mode='a') as writer:
    s.to_excel(writer, sheet_name='vanilla_vgg', header=False)