In [1]:
# For some reasons, file descriptors (FDs) do not get released.
# This is a work around which increases the allowed limit.
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))

In [2]:
import math
import os
import re
import random
import multiprocessing as mp
import numpy as np
from scipy.io import loadmat
import torch
import pytorch_lightning as pl
from copy import deepcopy
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.optim.swa_utils import AveragedModel, SWALR

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Batch size
train_B = 16
val_B = 1

p_dropout = 0.

SCALE = True
data_folder = 'BH average 100-sps (0.75-38 Hz)'
files_names = os.listdir(data_folder)

g1 = ['PAT31', 'MC', 'FEsg', 'DHut', 'NKra', 'PAT2']
g2 = ['RC', 'MBra', 'ESow', 'PAT1', 'EG', 'FigSa', 'PAT28']
g3 = ['PAT47', 'RA', 'PAT51', 'MPi', 'GNA', 'LM', 'DG']
g4 = ['PAT33', 'PB', 'PAT22', 'MAXJ', 'PAT19', 'LP', 'HeMod', 'PAT50', 'PJuly']
g5 = ['PAT48', 'KS', 'PAT3', 'PAT8', 'LRio', 'HB', 'ALo', 'JAlv', 'PAT25', 'AR']

In [None]:
train_patients = g1 + g2 + g3 + g4
val_patients = g5
del g1, g2, g3, g4, g5

In [4]:
def prep_patient_data(patient_name):
    # no_electrodes: E
    patient_segments = []
    patient_labels = []
    
    regex = re.compile(f'{patient_name}\D[0-9.a-z_A-Z]*')
    
    for file_name in files_names:
        if regex.match(file_name):
            # set label
            if 'sp' in file_name:
                label = torch.tensor(1.0).float()
            elif 'ns' in file_name:
                label = torch.tensor(0.0).float()
            
            # load segment and extract the band power in each channel
            segment = loadmat(os.path.join(data_folder, file_name))['segment']
            # segment: np.array (no_electrodes, no_time_samples)
            segment = torch.tensor(segment).float()
            
            # append to patient's data
            patient_segments.append(segment)
            patient_labels.append(label)
    
    return patient_segments, patient_labels

In [5]:
def prepare_dataset(patients_names):
    group_segments = []
    group_labels = []
    
    #with mp.Pool(processes=mp.cpu_count() - 1) as p:
    #    for patient_segments, patient_labels in p.imap_unordered(prep_patient_data, patients_names):
    #        group_segments += patient_segments
    #        group_labels += patient_labels
    
    for patient_name in patients_names:
        patient_segments, patient_labels = prep_patient_data(patient_name)
        group_segments += patient_segments
        group_labels += patient_labels
    
    return group_segments, group_labels

In [None]:
train_segments, train_labels = prepare_dataset(train_patients)
val_segments, val_labels = prepare_dataset(val_patients)

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, segments, labels, flip_prob, montage):
        self.segments = segments
        self.labels = labels
        self.flip_prob = flip_prob
        
        if flip_prob > 0:
            self.idx_swaps = self.montage_swap_func(montage)
    
    
    @staticmethod
    def montage_swap_func(montage):
        if montage == 'BH':
            electrodes = ['Fp1','F7','T3','T5','O1','F3','C3','P3','Fz','Cz','Pz','Fp2','F8','T4','T6','O2','F4','C4','P4']
        elif montage == 'MCH':
            electrodes = ['C3','C4','O1','O2','Cz','F3','F4','F7','F8','Fz','Fp1','Fp2','P3','P4','Pz','T3','T4','T5','T6']
        else:
            raise NameError('Unavailable Montage')
        
        dic = {electrode: idx for idx, electrode in enumerate(electrodes)}
        
        label_swaps = [('Fp1', 'Fp2'), ('F7', 'F8'), ('F3', 'F4'), ('C3', 'C4'), ('P3', 'P4'), ('T3', 'T4'), ('T5', 'T6'), ('O1', 'O2')]
        return [(dic[e_left], dic[e_right]) for e_left, e_right in label_swaps]
    
    
    def flip_lr(self, segment):
        # flip based on self.montage
        for e_left, e_right in self.idx_swaps:
            cloned_channel = torch.clone(segment[e_left])
            segment[e_left] = segment[e_right]
            segment[e_right] = cloned_channel
        
        return segment
    
    
    def __len__(self):
        return len(self.labels)
    
    
    def __getitem__(self, index):
        if random.random() < self.flip_prob:
            return self.flip_lr( torch.clone(self.segments[index]) ), self.labels[index]
        else:
            return self.segments[index], self.labels[index]

In [6]:
def median_IQR_func(group_segments):
    numbers = []
    for segment in group_segments:
        for row in range(segment.shape[0]):
            for col in range(segment.shape[1]):
                numbers.append(segment[row, col].item())
    
    quantiles = torch.tensor(numbers).quantile( torch.tensor([0.25, 0.5, 0.75]) )
    median = quantiles[1].item()
    IQR = quantiles[2].item() - quantiles[0].item()
    
    return median, IQR

In [None]:
def scaler_func(group_segments, median, IQR):
    # Scales the group_segments list inplace
    for idx in range(len(group_segments)):
        group_segments[idx] = ((group_segments[idx] - median) / IQR)
    
    return group_segments

In [None]:
if SCALE:
    median, IQR = median_IQR_func(train_segments)
    train_segments = scaler_func(train_segments, median, IQR)
    val_segments = scaler_func(val_segments, median, IQR)

train_dataset = CustomDataset(train_segments, train_labels, 0.5, 'BH')
val_dataset = CustomDataset(val_segments, val_labels, 0., 'BH')

In [None]:
del files_names, train_segments, train_labels, train_patients, val_patients, val_segments, val_labels


In [7]:
class PositionalEncoding(torch.nn.Module):
    def __init__(self, base, dim=None):
        super().__init__()
        self.base = base
        # dim: dimension across which to attend
        if (dim is not None) and (dim not in {-1, -2}):
            raise NameError('Error: Set the attention dimension to either -1 or -2 !')
        
        self.dim = dim
    
    
    def forward(self, x, dim=None):
        if dim is None:
            if self.dim is None:
                raise NameError('Error: Expecting attention dimension to be either -1 or -2, but None was given!')
            dim = self.dim
        elif dim not in {-1, -2}:
            raise NameError('Error: Set the attention dimension to either -1 or -2 !')
        
        
        if dim == -2:
            # x: (B, T or E: dimension across which to attend, d: no_input_features)
            x.transpose_(-1, -2)
            # x: (B, d: no_input_features, T or E: dimension across which to attend)
        
        # x: (B, d, T or E)
        _, d, T = x.shape
        
        pe = torch.zeros((d, T), device=x.device)
        t = torch.arange(T, device=x.device).reshape(1, T)
        
        col_len = math.ceil(d / 2)
        col = torch.arange(start=1, end=col_len+1, device=x.device).reshape(col_len, 1)
        
        angle = t / (self.base ** ((2 * col) / d))
        sin = torch.sin(angle)
        
        if d % 2 == 0:
            cos = torch.cos(angle)
        else:
            cos = torch.cos(angle[0:-1])
        
        #cos = torch.where(torch.tensor(d % 2 == 0, device=x.device), torch.cos(angle), torch.cos(angle[0:-1]))
        
        pe[0::2, :] = sin
        pe[1::2, :] = cos
        
        x += pe
        
        if dim == -2:
            # x: (B, d: no_input_features, T or E: dimension across which to attend)
            x.transpose_(-1, -2)
            # x: (B, T or E: dimension across which to attend, d: no_input_features)
        
        # overall, x preserves its inputted shape
        return x

In [8]:
class MaskedMultiHeadSelfAttention(torch.nn.Module):
    def __init__(self, no_input_features, H, d_qk, d_v, dim, mask=None, bias_att=True, bias_out=True):
        super().__init__()
        self.H = H
        self.d_qk = d_qk
        self.d_v = d_v
        
        # dim: dimension across which to attend
        if dim not in {-1, -2}:
            raise NameError('Error: Enforce your attention dimension to either -1 or -2 !')
        self.dim = dim
        
        # Most probably mask will be a string ("all_but_self", "no_future", or "no_present_no_future").
        # But it can be an actual mask though of shape (E, E) or (T, T)
        self.mask = mask
        
        self.att_lin = torch.nn.Linear(in_features=no_input_features, out_features=(2 * H * d_qk) + (H * d_v), bias=bias_att)
        self.out_lin = torch.nn.Linear(in_features=H*d_v, out_features=no_input_features, bias=bias_out)
    
    
    @staticmethod
    def qkv_func(q, k, v, mask=None):
        # no_electrodes: E or T
        # q: (B, H, E or T, d_qk)
        # k: (B, H, E or T, d_qk)
        # v: (B, H, E or T, d_v)
        
        B, H, E, d_qk = q.shape
        d_v = v.shape[-1]
        
        # logits: (B, H, E or T, E or T)
        logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_qk)
        
        if mask is not None:
            if isinstance(mask, str):
                if mask == 'all_but_self':
                    mask = ( torch.eye(E, device=logits.device) == 0 )
                elif mask == 'no_future':
                    mask = ( torch.ones((E, E), device=logits.device).tril() == 1 )
                elif mask == 'no_present_no_future':
                    mask = ( torch.ones((E, E), device=logits.device).tril(-1) == 1 )
                else:
                    raise NameError(f'Error: Available mask strings are ("all_but_self", "no_future", "no_present_no_future") but "{mask}" was given!')
            
            # mask: (E, E) or (T, T)
            # BUT, mask shape MUST be (B, H, E, E) or (B, H, T, T)
            mask = mask.expand(B, H, E, E)
            logits[mask == False] = -9e15
        
        s_m = logits.softmax(dim=-1)     # s_m: (B, H, E, E) or (B, H, T, T)
        
        # (B, H, E, E) * (B, H, E, d_v) = (B, H, E, d_v) ===permute===> (B, E, H, d_v) ===reshape===> (B, E, H * d_v)
        # or
        # (B, H, T, T) * (B, H, T, d_v) = (B, H, T, d_v) ===transpose===> (B, T, H, d_v) ===reshape===> (B, T, H * d_v)
        return torch.matmul(s_m, v).transpose(-2,-3).reshape((B, E, H * d_v))
    
    
    def forward(self, x, mask=None):
        if self.dim == -1:
            # x: (B, C=no_input_features, T or E)
            x.transpose_(-1, -2)
        
        # x: (B, T or E, C=no_input_features)
        B, T, _ = x.shape
        
        # (B, T or E, C=no_input_features) ====self.att_lin====> ( B, T or E, (H * d_qk) + (H * d_qk) + (H * d_v) )
        q, k, v = self.att_lin(x).split(split_size=[self.H * self.d_qk, self.H * self.d_qk, self.H * self.d_v], dim=-1)
        
        q = q.reshape((B, T, self.H, self.d_qk)).transpose(-2, -3)     # q: (B, self.H, T or E, self.d_qk)
        k = k.reshape((B, T, self.H, self.d_qk)).transpose(-2, -3)     # k: (B, self.H, T or E, self.d_qk)
        v = v.reshape((B, T, self.H, self.d_v)).transpose(-2, -3)      # v: (B, self.H, T or E, self.d_v)
        
        # priority in the mask is the one used with the forward function
        if mask is None:
            y = self.out_lin( self.qkv_func(q, k, v, mask=self.mask) )      # (B, T or E, C=no_input_features)
        else:
            y = self.out_lin( self.qkv_func(q, k, v, mask=mask) )           # (B, T or E, C=no_input_features)
        
        if self.dim == -1:
            x.transpose_(-1, -2)                                       # (B, C=no_input_features, T or E)
            y.transpose_(-1, -2)                                       # (B, C=no_input_features, T or E)
        
        # overall, x preserves its inputted shape & y has the same shape as the inputted x
        return y

In [9]:
class EncoderBlock(torch.nn.Module):
    def __init__(self, no_input_features, no_mid_features, H, d_qk, d_v, dim, mask=None, bias_att=True, bias_out=True, p_dropout=0.1):
        super().__init__()
        
        # Most probably mask will be a string ("all_but_self", "no_future", or "no_present_no_future").
        # But it can be an actual mask though of shape (E, E) or (T, T)
        self.mask = mask
        
        # Regardless of the provided dim (dimension across which to attend),
        # permute the i/p tensor in the forward method to make that dim (T or E) -2
        self.dim = dim
        
        self.att = MaskedMultiHeadSelfAttention(no_input_features, H, d_qk, d_v, -2, mask=mask, bias_att=bias_att, bias_out=bias_out)
        self.norm_att = torch.nn.LayerNorm(no_input_features)
        
        self.seq_feed_forward = torch.nn.Sequential(
            torch.nn.Linear(no_input_features, no_mid_features),
            torch.nn.Dropout(p_dropout),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(no_mid_features, no_input_features)
        )
        self.norm_feed_forward = torch.nn.LayerNorm(no_input_features)
        
        self.dropout = torch.nn.Dropout(p_dropout)
    
    
    def forward(self, x, mask=None):
        if self.dim == -1:
            # x: (B, C=no_input_features, T or E)
            x.transpose_(-1,-2)
        
        # x: (B, T or E, C=no_input_features)
        
        # priority in the mask is the one used with the forward function
        if mask is None:
            y = self.norm_att( x + self.dropout( self.att(x, self.mask) ) )
        else:
            y = self.norm_att( x + self.dropout( self.att(x, mask) ) )
        
        
        y = self.norm_feed_forward( y + self.dropout( self.seq_feed_forward(y) ) )
        
        
        if self.dim == -1:
            x.transpose_(-1, -2)     # (B, C=no_input_features, T or E)
            y.transpose_(-1, -2)     # (B, C=no_input_features, T or E)
            
        
        # overall, x preserves its inputted shape & y has the same shape as the inputted x
        return y
    

In [None]:
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_num_iters):
        self.warmup = warmup
        self.max_num_iters = max_num_iters
        super().__init__(optimizer)
    
    
    def get_lr(self):
        # returns a list of the learning rates
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]
    
    
    def get_lr_factor(self, epoch):
        # Optional method that computes lr_factor, NOT learning rate itself.
        if epoch <= self.warmup:
            slope = (1 - 0) / (self.warmup - 0)
            lr_factor = slope * epoch
        else:
            f = 1 / (2 * (self.max_num_iters - self.warmup))
            lr_factor = 0.5 * ( 1 + math.cos( 2 * math.pi * f * (epoch - self.warmup) ) )
        
        return lr_factor
    

In [None]:
min_validation_loss = math.inf
best_epoch = 0
best_model_state = None
AP_at_min_val_loss = 0.


In [10]:
class CustomLightningModule(pl.LightningModule):
    def __init__(self, p_dropout):
        super().__init__()
        
        #self.vgg = self.construct_vgg(51, 6, p_dropout)
        self.vgg = torch.nn.Sequential(
            # (B, 1, 19, 300) ====> (B, 2, 19, 298)
            torch.nn.Conv2d(in_channels=1, out_channels=2, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 2, 19, 298) ====> (B, 3, 19, 296)
            torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 3, 19, 296) ==/2==> (B, 3, 19, 148)
            #torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (B, 3, 19, 148) ====> (B, 4, 19, 146)
            torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 4, 19, 146) ====> (B, 5, 19, 144)
            torch.nn.Conv2d(in_channels=4, out_channels=5, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 5, 19, 144) ==/2==> (B, 5, 19, 72)
            #torch.nn.Conv2d(in_channels=5, out_channels=5, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (B, 5, 19, 72) ====> (B, 6, 19, 70)
            torch.nn.Conv2d(in_channels=5, out_channels=6, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 6, 19, 70) ====> (B, 7, 19, 68)
            torch.nn.Conv2d(in_channels=6, out_channels=7, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 7, 19, 68) ====> (B, 8, 19, 66)
            torch.nn.Conv2d(in_channels=7, out_channels=8, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 8, 19, 66) ====> (B, 9, 19, 64)
            torch.nn.Conv2d(in_channels=8, out_channels=9, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 9, 19, 64) ==/2==> (B, 9, 19, 32)
            #torch.nn.Conv2d(in_channels=9, out_channels=9, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (B, 9, 19, 32) ====> (B, 10, 19, 30)
            torch.nn.Conv2d(in_channels=9, out_channels=10, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 10, 19, 30) ====> (B, 11, 19, 28)
            torch.nn.Conv2d(in_channels=10, out_channels=11, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 11, 19, 28) ====> (B, 12, 19, 26)
            torch.nn.Conv2d(in_channels=11, out_channels=12, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 12, 19, 26) ====> (B, 13, 19, 24)
            torch.nn.Conv2d(in_channels=12, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 13, 19, 24) ==/2==> (B, 13, 19, 12)
            #torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (B, 13, 19, 12) ====> (B, 13, 19, 10)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 13, 19, 10) ====> (B, 13, 19, 8)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 13, 19, 8) ====> (B, 13, 19, 6)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 13, 19, 6) ====> (B, 13, 19, 4)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 13, 19, 4) ====> (B, 13, 19, 2)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (B, 13, 19, 2) ====> (B, 13, 19, 1)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 2)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout)
        )
        
        self.positional_encoding = PositionalEncoding(100)
        
        # features:13, no_mid_features:8, H:3, d_qk:5, d_v:7, dim:-2
        self.encoder_block_1 = EncoderBlock(13, 8, 3, 5, 7, -2, mask='all_but_self', p_dropout=p_dropout)
        
        # features: 13, no_mid_features: 8, H: 3, d_qk: 5, d_v: 7, dim: -2
        self.encoder_block_2 = EncoderBlock(13, 8, 3, 5, 7, -2, mask='all_but_self', p_dropout=p_dropout)
        
        # features: 13, no_mid_features: 8, H: 3, d_qk: 5, d_v: 7, dim: -2
        self.encoder_block_3 = EncoderBlock(13, 8, 3, 5, 7, -2, mask='all_but_self', p_dropout=p_dropout)
        
        
        self.seq = torch.nn.Sequential(
            torch.nn.Linear(in_features=13*4, out_features=25),
            torch.nn.Dropout(0.0),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=25, out_features=10),
            torch.nn.Dropout(0.0),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=10, out_features=5),
            torch.nn.Dropout(0.0),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(in_features=5, out_features=1)
        )
        
        
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.automatic_optimization = False
    
    
    
    def forward(self, segment):
        # segment: (B, 19, 300)
        
        H_0 = self.vgg(segment.unsqueeze(dim=1))     # H_0: (B, 13, 19, 1)
        H_0.squeeze_(dim=-1)                         # H_0: (B, 13, 19)
        H_0.transpose_(-1, -2)                       # H_0: (B, 19, 13)
        H_0 = self.positional_encoding(H_0, -2)      # H_0: (B, 19, 13)
        
        H_1 = self.encoder_block_1(H_0)             # H_1: (B, 19, 13)
        H_1 = self.positional_encoding(H_1, -2)     # H_1: (B, 19, 13)
        
        H_2 = self.encoder_block_2(H_1)             # H_2: (B, 19, 13)
        H_2 = self.positional_encoding(H_2, -2)     # H_2: (B, 19, 13)
        
        H_3 = self.encoder_block_3(H_2)             # H_3: (B, 19, 13)
        H_3 = self.positional_encoding(H_3, -2)     # H_3: (B, 19, 13)
        
        #H = torch.cat((H_0, H_1, H_2), dim=-1)          # H: (B, 19, 13*3)
        H = torch.cat((H_0, H_1, H_2, H_3), dim=-1)      # H: (B, 19, 13*4)
        H = H.sum(dim=-2)                                # H: (B, 13*4)
        H = self.seq(H)                                  # H: (B, 1)
        
        # returned tensor: (B,)
        return H.reshape((H_0.shape[0],))
        
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=0)
        cosine_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=100, max_num_iters=self.trainer.max_epochs)
        swa_scheduler = SWALR(optimizer, swa_lr=0.0001, anneal_epochs=100, anneal_strategy='cos')
        return [optimizer], [cosine_scheduler, swa_scheduler]
    
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            train_dataset, batch_size=train_B, shuffle=True, num_workers=0, drop_last=True
        )
    
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        self.eval()
        with torch.no_grad():
            logits_before_optimizer_step = self(x)
            loss_before_optimizer_step = self.criterion(logits_before_optimizer_step, y)
        self.train()
        
        # Note that this is NOT a pytorch optimizer. I think it's a wrapper
        opt = self.optimizers()

        def closure():
            logits = self(x)
            loss = self.criterion(logits, y)
            opt.zero_grad()
            self.manual_backward(loss)
            return loss

        opt.step(closure=closure)

        # I think since I am doing manual optimization, there is no need to return anything
        return {'loss_before_optimizer_step': loss_before_optimizer_step.detach()}
    
    
    def training_epoch_end(self, training_step_outputs):
        if self.global_rank == 0:
            train_loss = torch.stack([dic['loss_before_optimizer_step'] for dic in training_step_outputs]).mean()
            print(f'Epoch: {self.current_epoch}')
            print(f'Training loss BEFORE optimizer step: {round(train_loss.item(), 5)}')
            print('----------------------------------------------------------------------------')
    
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            val_dataset, batch_size=val_B, shuffle=False, num_workers=0, drop_last=False
        )
    
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        return {'loss': loss, 'y': y, 'logits': logits}
    
    
    def validation_epoch_end(self, validation_step_outputs):
        global min_validation_loss, best_epoch, best_model_state, AP_at_min_val_loss
        
        swa_start_epoch = 1000
        
        cosine_scheduler, swa_scheduler = self.lr_schedulers()
        
        if self.current_epoch > 0:
            if self.current_epoch < swa_start_epoch:
                cosine_scheduler.step()
            else:
                swa_scheduler.step()
            
        
        
        if self.global_rank == 0:
            
            if self.current_epoch > (swa_start_epoch + swa_scheduler.anneal_epochs):
                swa_model.update_parameters(self)
                
            
            val_loss = torch.stack([dic['loss'] for dic in validation_step_outputs]).mean()
            val_loss = val_loss.item()
            print(f'val_loss: {val_loss}')
            
            y = torch.cat([dic['y'] for dic in validation_step_outputs]).cpu()
            logits = torch.cat([dic['logits'] for dic in validation_step_outputs]).cpu()
            pred = torch.sigmoid(logits)
            
            y = np.array(y)
            pred = np.array(pred)
            
            try:
                validation_AP = average_precision_score(y, pred)
                
                if (val_loss < min_validation_loss) and (self.current_epoch > 1):
                    min_validation_loss = val_loss
                    best_epoch = self.current_epoch
                    best_model_state = deepcopy(self.state_dict())
                    AP_at_min_val_loss = validation_AP
                
                
                print(f'validation_AP = {round(validation_AP, 5)} ... AP_at_min_val_loss = {round(AP_at_min_val_loss, 5)} @ epoch {best_epoch}')
                
            except ValueError:
                print(f'ValueError @ epoch: {self.current_epoch} ====> y = {y}')
            
            print('=======================================================================================')
    

In [None]:
my_lightning_module = CustomLightningModule(p_dropout)

swa_model = AveragedModel(my_lightning_module)

trainer = pl.Trainer(gpus=0, enable_checkpointing=False, enable_progress_bar=False, logger=False, max_epochs=400)

trainer.fit(my_lightning_module)

# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(my_lightning_module.train_dataloader(), swa_model)

In [None]:
PATH = f'models/g5_attention_GIN_10k_AP_{round(100 * AP_at_min_val_loss, 2)}.pt'
torch.save(best_model_state, PATH)


In [11]:
import pandas as pd


model = CustomLightningModule(p_dropout)

#IEDs_names = [item for item in os.listdir(data_folder) if 'sp' in item]
NIEDs_names = [item for item in os.listdir(data_folder) if 'ns' in item]


In [12]:
lis = [item for item in os.listdir('models/') if 'g1_attention_GIN' in item]
g1_model_name = max(lis)
train_segments, _ = prepare_dataset(g2 + g3 + g4 + g5)
g1_median, g1_IQR = median_IQR_func(train_segments)

lis = [item for item in os.listdir('models/') if 'g2_attention_GIN' in item]
g2_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g3 + g4 + g5)
g2_median, g2_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g3_attention_GIN' in item]
g3_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g4 + g5)
g3_median, g3_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g4_attention_GIN' in item]
g4_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g3 + g5)
g4_median, g4_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g5_attention_GIN' in item]
g5_model_name = max(lis)
train_segments, _ = prepare_dataset(g1 + g2 + g3 + g4)
g5_median, g5_IQR = median_IQR_func(train_segments)


#s = pd.Series(index=IEDs_names, dtype='float64')
s = pd.Series(index=NIEDs_names, dtype='float64')


#for IED_name in IEDs_names:
for NIED_name in NIEDs_names:
    #segment = loadmat(os.path.join(data_folder, IED_name))['segment']
    segment = loadmat(os.path.join(data_folder, NIED_name))['segment']
    segment = torch.tensor(segment).float()
    
    #patient_name = IED_name[: IED_name.find('_')]
    patient_name = NIED_name[: NIED_name.find('_')]
    
    if patient_name in g1:
        segment = (segment - g1_median) / g1_IQR
        model.load_state_dict(torch.load(f'models/{g1_model_name}'))
    elif patient_name in g2:
        segment = (segment - g2_median) / g2_IQR
        model.load_state_dict(torch.load(f'models/{g2_model_name}'))
    elif patient_name in g3:
        segment = (segment - g3_median) / g3_IQR
        model.load_state_dict(torch.load(f'models/{g3_model_name}'))
    elif patient_name in g4:
        segment = (segment - g4_median) / g4_IQR
        model.load_state_dict(torch.load(f'models/{g4_model_name}'))
    elif patient_name in g5:
        segment = (segment - g5_median) / g5_IQR
        model.load_state_dict(torch.load(f'models/{g5_model_name}'))
    else:
        raise NameError('Unknown Patient')
    
    
    segment.unsqueeze_(0)
    model.eval()
    
    #s.at[IED_name] = torch.sigmoid(model(segment)).item()
    s.at[NIED_name] = torch.sigmoid(model(segment)).item()

In [13]:
s

RC_24_ns.mat      0.001003
FEsg_38_ns.mat    0.004454
HB_2_ns.mat       0.005053
FEsg_26_ns.mat    0.003548
FigSa_9_ns.mat    0.000866
                    ...   
LM_3_ns.mat       0.006997
RC_6_ns.mat       0.000534
FEsg_19_ns.mat    0.002681
NKra_10_ns.mat    0.009048
PAT48_ns_6.mat    0.003915
Length: 593, dtype: float64

In [14]:
#with pd.ExcelWriter('IEDs_results.xlsx', mode='a') as writer:
with pd.ExcelWriter('NIEDs_results.xlsx', mode='a') as writer:
    s.to_excel(writer, sheet_name='attention_GIN', header=False)