In [1]:
# For some reasons, file descriptors (FDs) do not get released.
# This is a work around which increases the allowed limit.
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (65535, rlimit[1]))

In [2]:
import math
import os
import re
import random
import multiprocessing as mp
import numpy as np
from scipy.io import loadmat
from scipy.integrate import simpson
import torch
import torch_scatter
import pytorch_lightning as pl
from sklearn.metrics import roc_auc_score, average_precision_score
from copy import deepcopy
from torch.optim.swa_utils import AveragedModel, SWALR

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Batch size
train_B = 16
val_B = 1

p_dropout = 0.

SCALE = True
data_folder = 'BH average 100-sps (0.75-38 Hz)'
files_names = os.listdir(data_folder)

USE_SURROGATES = False
no_surrogates = 500
surrogate_thr = 0.9
fs = 100
nperseg = 140
noverlap = 135
freq_bands = ['delta', 'theta', 'alpha', 'beta']

attention = True
dedicated_vgg = False
eps = 0.
train_eps = False
dedicated_GIN = True

montage = 'BH'
if montage == 'BH':
    montage = ['Fp1','F7','T3','T5','O1','F3','C3','P3','Fz','Cz','Pz','Fp2','F8','T4','T6','O2','F4','C4','P4']
elif montage == 'MCH':
    montage = ['C3','C4','O1','O2','Cz','F3','F4','F7','F8','Fz','Fp1','Fp2','P3','P4','Pz','T3','T4','T5','T6']
else:
    raise NameError('Unavailable Montage')


g1 = ['PAT31', 'MC', 'FEsg', 'DHut', 'NKra', 'PAT2']
g2 = ['RC', 'MBra', 'ESow', 'PAT1', 'EG', 'FigSa', 'PAT28']
g3 = ['PAT47', 'RA', 'PAT51', 'MPi', 'GNA', 'LM', 'DG']
g4 = ['PAT33', 'PB', 'PAT22', 'MAXJ', 'PAT19', 'LP', 'HeMod', 'PAT50', 'PJuly']
g5 = ['PAT48', 'KS', 'PAT3', 'PAT8', 'LRio', 'HB', 'ALo', 'JAlv', 'PAT25', 'AR']

In [None]:
train_patients = g1 + g3 + g4 + g5
val_patients = g2
del g1, g2, g3, g4, g5

In [4]:
def wpli_raw(x, y, fs, nperseg, noverlap, norm='ortho'):
    # x.shape = y.shape: (no_time_samples,) OR (no_signals, no_time_samples)
    # If x and y are 2-D, the wpli is computed for each pair (x_i, y_i)
    # Returns:
    # 1) f: (no_positive_freqs,)
    # 2) wpli: (no_positive_freqs,) OR (no_signals, no_positive_freqs) depending on x and y (1-D OR 2-D)
    
    no_original_dimensions = x.ndim
    
    if (x.shape != y.shape) or (no_original_dimensions not in {1, 2}):
        raise NameError(
            f'x & y MUST be 1/2-D arrays of the same shape. Given x.shape: ({x.shape}), y.shape: ({y.shape})!'
        )
    
    if no_original_dimensions == 1:
        x = np.expand_dims(x, axis=0)
        y = np.expand_dims(y, axis=0)
    
    # Now, x.shape = y.shape: (no_signals, no_time_samples)
    no_signals, no_time_samples = x.shape
    step = nperseg - noverlap
    start_time_idx_last_segment = no_time_samples - nperseg
    
    
    if ((start_time_idx_last_segment % step) != 0):
        raise NameError('nperseg & noverlap should be chosen carefully!')
    
    no_segments = 1 + (start_time_idx_last_segment // step)
    
    
    X = np.zeros((no_signals, no_segments, nperseg))
    Y = np.zeros((no_signals, no_segments, nperseg))
    
    
    start_time_idx = 0
    for segment_idx in range(no_segments):
        end_time_idx = start_time_idx + nperseg
        X[:, segment_idx, :] = x[:, start_time_idx : end_time_idx]
        Y[:, segment_idx, :] = y[:, start_time_idx : end_time_idx]
        start_time_idx += step
    
    
    f = np.arange(start=0, stop=1+(nperseg//2), step=1) * (fs / nperseg)     # (no_positive_freqs,)
    
    X = np.fft.rfft(X, norm=norm)     # (no_signals, no_segments, no_positive_freqs)
    Y = np.fft.rfft(Y, norm=norm)     # (no_signals, no_segments, no_positive_freqs)
    
    I = X * Y.conjugate()
    I = I.imag
    
    numerator = np.abs(np.mean(I, axis=-2, keepdims=False))     # (no_signals, no_positive_freqs)
    denominator = np.mean(np.abs(I), axis=-2, keepdims=False)     # (no_signals, no_positive_freqs)
    
    mask = (denominator == 0)
    denominator[mask] = np.inf
    
    wpli = (numerator / denominator)     # (no_signals, no_positive_freqs)
    
    if no_original_dimensions == 1:
        x = x.squeeze(axis=0)           # (no_time_samples,)
        y = y.squeeze(axis=0)           # (no_time_samples,)
        wpli = wpli.squeeze(axis=0)     # (no_positive_freqs,)
    
    return f, wpli


In [5]:
def freq_FC_integral(f, freq_FC, freq_bands=['delta', 'theta', 'alpha', 'beta']):
    # f: (no_positive_freqs,)
    # freq_FC: (no_positive_freqs,) OR (no_signals, no_positive_freqs)
    # Returns:
    # The integrated connectivity across the specified frequency bands. The i/p, o/p shapes are as follows:
    # freq_FC: (no_positive_freqs,) ====Returns====> (num_freq_bands,)
    # OR
    # freq_FC: (no_signals, no_positive_freqs) ====Returns====> (no_signals, num_freq_bands)
    # Note: num_freq_bands = len(freq_bands)
    
    num_freq_bands = len(freq_bands)
    no_original_dimensions_freq_FC = freq_FC.ndim
    
    if f.ndim != 1:
        raise NameError('f should be a 1-D array!')
    elif len(f) != freq_FC.shape[-1]:
        raise NameError(f'len(f) ({len(f)}) != freq_FC.shape[-1] ({freq_FC.shape[-1]}) ... They have to match !!')
    elif no_original_dimensions_freq_FC == 1:
        freq_FC = np.expand_dims(freq_FC, axis=0)
    elif no_original_dimensions_freq_FC != 2:
        raise NameError('freq_FC should be 1/2-D array!')
    
    
    # Now, freq_FC: (no_signals, no_positive_freqs)
    no_signals, _ = freq_FC.shape
    
    ans = np.zeros((no_signals, num_freq_bands))
    
    df = f[1] - f[0]
    for freq_band_idx, freq_band in enumerate(freq_bands):
        if freq_band == 'delta':
            low, high = 0.5, 4.
        elif freq_band == 'theta':
            low, high = 4., 8.
        elif freq_band == 'alpha':
            low, high = 8., 13.
        elif freq_band == 'beta':
            low, high = 13., 30.
        elif freq_band == 'wide':
            low, high = 0.5, 30.
        else:
            raise NameError('Unavailable Band')
        
        
        mask = np.logical_and(low <= f, f <= high)
        masked_f = f[mask]
        
        mask = np.broadcast_to(mask, (no_signals, len(mask)))
        masked_freq_FC = freq_FC[mask].reshape((no_signals, len(masked_f)))
        
        
        f_L = masked_f[0]
        f_U = masked_f[-1]
        
        ans[:, freq_band_idx] = simpson(masked_freq_FC, dx=df) / (f_U - f_L)
    
    
    if no_original_dimensions_freq_FC == 1:
        freq_FC = freq_FC.squeeze(axis=0)     # (no_positive_freqs,)
        ans = ans.squeeze(axis=0)       # (num_freq_bands,)
    
    return ans
    

In [6]:
def wpli_combined(x, y, fs, nperseg, noverlap, freq_bands=['delta', 'theta', 'alpha', 'beta']):
    # x.shape = y.shape: (no_time_samples,) OR (no_signals, no_time_samples)
    # Returns: 
    # (num_freq_bands,) OR (no_signals, num_freq_bands) depending on x and y (1-D OR 2-D)
    
    f, wpli = wpli_raw(x, y, fs, nperseg, noverlap)
    ans = freq_FC_integral(f, wpli, freq_bands=freq_bands)
    
    return ans

In [7]:
def wpli_matrix_generator(segment, fs, nperseg, noverlap, freq_bands=['delta', 'theta', 'alpha', 'beta']):
    # segment: (no_electrodes/no_signals, no_time_samples)
    # fs: sampling rate
    # Returns:
    # wpli_matrix: (no_connectivities, num_freq_bands) ==> Each row represents FC_{x,y}^B for all B in freq_bands
    
    no_electrodes, no_time_samples = segment.shape
    
    wpli_list = []
    for x_idx in range(no_electrodes - 1):
        x_signal = segment[x_idx]
        x = np.broadcast_to(x_signal, (no_electrodes - x_idx - 1, no_time_samples))
        y = segment[x_idx + 1 : , :]
        x_connectivity = wpli_combined(x, y, fs, nperseg, noverlap, freq_bands=freq_bands)
        wpli_list.append(x_connectivity)
    
    
    wpli_matrix = np.concatenate(wpli_list)     # (no_connectivities, num_freq_bands)
    
    return wpli_matrix


In [8]:
def surrogate_func(x, no_surrogates, seed=None, norm='ortho'):
    # x: (no_time_samples,) ====Returns====> (no_surrogates, no_time_samples)
    # OR
    # x: (no_signals, no_time_samples) ====Returns====> (no_signals, no_surrogates, no_time_samples)
    no_original_dimensions = x.ndim
    
    if no_original_dimensions == 1:
        x = np.expand_dims(x, axis=0)
    elif no_original_dimensions != 2:
        raise NameError(f'x MUST be 1/2-D array. Given x.shape: ({x.shape}) !')
    
    # Now, x.shape: (no_signals, no_time_samples)
    no_signals, no_time_samples = x.shape
    surrogates = np.zeros((no_signals, no_surrogates, no_time_samples))
    
    X = np.fft.rfft(x, norm=norm)          # (no_signals, no_positive_freqs)
    X_abs = np.abs(X)                          # (no_signals, no_positive_freqs)
    X_angle = np.angle(X)                      # (no_signals, no_positive_freqs)
    
    if seed:
        rng = np.random.default_rng(seed)
    else:
        rng = np.random.default_rng()
    
    for surrogate_idx in range(no_surrogates):
        rng.permuted(X_angle, axis=-1, out=X_angle)
        S = X_abs * np.exp(1.0j * X_angle)     # (no_signals, no_positive_freqs)
        surrogates[:, surrogate_idx, :] = np.fft.irfft(S, n=no_time_samples, norm=norm)
    
    
    if no_original_dimensions == 1:
        x = x.squeeze(axis=0)                       # (no_time_samples,)
        surrogates = surrogates.squeeze(axis=0)     # (no_surrogates, no_time_samples)
    
    return surrogates


In [9]:
def threshold_func(freq_FC_matrix, USE_SURROGATES, segment, no_surrogates=500, freq_FC_method='wpli', surrogate_thr=0.90, freq_bands=['delta','theta','alpha','beta'], min_thr=0.75, quantile=0.75):
    # Note: E is the number of electrodes (no_electrodes)
    
    # freq_FC_matrix: (no_connectivities, num_freq_bands)
    # 1st row represents FC between electrode 0 and electrode 1, i.e., FC(0, 1)
    # 2nd row represents FC(0, 2)
    # 3rd row represents FC(0, 3)
    # ....
    # (E - 1)th row represents FC(0, E - 1)
    # (E)th row represents FC(1, 2)
    # (E + 1)th row represents FC(1, 3)
    # ....
    # last row row represents FC(E - 2, E - 1)
    
    # USE_SURROGATES: Boolean to determine using surrogates in thresholding or not
    # segment: (no_electrodes/no_signals, no_time_samples)
    # freq_FC_method: The frequency functional connectivity method to use among the surrogates
    # Note: {segment,no_surrogates,freq_FC_method,surrogate_thr,freq_bands} are only used if USE_SURROGATES=True
    
    # Returns:
    # A list of 1-D arrays (each array corresponds to one of the frequency bands). Each 1-D array
    # includes the numerical indices of the True connection as explained above in "freq_FC_matrix"
    
    
    no_connectivities, num_freq_bands = freq_FC_matrix.shape
    
    thr_1 = np.quantile(freq_FC_matrix, quantile, axis=0)               # (num_freq_bands,)
    thr_2 = np.array([min_thr for _ in range(num_freq_bands)])     # (num_freq_bands,)
    thr = np.maximum(thr_1, thr_2)                                 # (num_freq_bands,)
    thr = np.broadcast_to(thr, (no_connectivities, num_freq_bands))
    masked_freq_FC_matrix = (freq_FC_matrix >= thr)                # (no_connectivities, num_freq_bands)
    
    
    if USE_SURROGATES:
        # Some True elements in masked_freq_FC_matrix may be converted to False
        # if they fail to pass the given surrogate_thr
        no_electrodes = segment.shape[0]
        
        masked_any = np.any(masked_freq_FC_matrix, axis=-1)
        index = np.argwhere(masked_any).squeeze(axis=1)
        min_electrode, max_electrode = index_to_min_max_electrodes(index, no_electrodes)
        
        selected_electrodes = np.concatenate((min_electrode, max_electrode))
        selected_electrodes = np.unique(selected_electrodes).tolist()
        electrode_num_to_surrogate_idx = {e_num: s_idx for s_idx, e_num in enumerate(selected_electrodes)}
        
        surrogates = surrogate_func(segment[selected_electrodes], no_surrogates)
        
        for idx, min_elec, max_elec in zip(index, min_electrode, max_electrode):
            idx, min_elec, max_elec = idx.item(), min_elec.item(), max_elec.item()
            
            x = surrogates[electrode_num_to_surrogate_idx[min_elec]]
            y = surrogates[electrode_num_to_surrogate_idx[max_elec]]
            
            bands_idxs = np.argwhere(masked_freq_FC_matrix[idx]).squeeze(axis=1)
            bands_names = [freq_bands[band_idx.item()] for band_idx in bands_idxs]
            
            if freq_FC_method == 'wpli':
                surrogates_FC = wpli_combined(x, y, fs, nperseg, noverlap, freq_bands=bands_names)
            else:
                raise NameError('The only defined frequency based functional connectivity is wpli!')
            
            TP = (surrogates_FC < freq_FC_matrix[idx, bands_idxs])
            TP = TP.mean(axis=0)
            
            set_to_false_bands_idxs = bands_idxs[TP < surrogate_thr]
            masked_freq_FC_matrix[idx, set_to_false_bands_idxs] = False
    
    
    lis = []
    for band_idx in range(num_freq_bands):
        masked_freq_FC_vector_band = masked_freq_FC_matrix[:, band_idx]
        index = np.argwhere(masked_freq_FC_vector_band).squeeze(axis=1)
        lis.append(index)
    
    return lis

In [10]:
def index_to_min_max_electrodes(index, E):
    # IMPORTANT ASSUMPTION: wpli(x,y) = wpli(y,x) which is valid for wpli method. This code will NOT work
    # for unsummetric connectivity methods
    
    # numbers in index represent the following:
    # 0 ===> (electrode 0, electrode 1)
    # 1 ===> (0, 2)
    # 2 ===> (0, 3)
    # ....
    # E - 2 ===> (0, E - 1)
    # E - 1 ===> (1, 2)
    # E ===> (1, 3)
    # ....
    # greatest number ===> (E - 2, E - 1)
    
    # Example: E = 19
    #             a = [18, 18+17=35, 18+17+16=51, 18+17+16+15=66, ......, 18+17+16+15+14+...+1]
    # min_electrode =   0,        1,           2,              3, ......,                E - 2
    
    
    B = 3 - (2 * E)
    C = 2 * (index + 1 - E)
    sqrt_term = (B ** 2) - (4 * C)
    min_electrode = (-B - np.sqrt(sqrt_term)) / 2
    min_electrode = np.floor(1 + min_electrode).astype('int')
    i = min_electrode - 1
    a_i = (i + 1) * (E - 1 - (i / 2))
    max_electrode = min_electrode + 1 + (index - a_i.astype('int'))
    
    return min_electrode, max_electrode


In [11]:
def adjacencies_converter(index, montage):
    # Converts index which is a 1-D numpy array (representing an adjacency by
    # indices) into TWO output torch tensors (each 1-D), namely, src and dst.
    # src[i] ---> dst[i]: is an existing link in the graph.
    # src and dst include the indexes of the electrodes as specified by the montage list.
    
    # Additionally, we add (j, i) for each existing (i, j)
    
    # index: 1-D numpy array.
    # montage: order of the electrodes
    
    # Returns:
    # list of 2 tensors (each 1-D and possibly empty), src and dst
    
    # Skip the checks for valid input!
    
    src = []
    dst = []
    E = len(montage)
    min_electrode, max_electrode = index_to_min_max_electrodes(index, E)
    
    for min_elec, max_elec in zip(min_electrode, max_electrode):
        min_elec, max_elec = min_elec.item(), max_elec.item()
        src += [min_elec, max_elec]
        dst += [max_elec, min_elec]

    
    src = torch.tensor(src, dtype=torch.long)
    dst = torch.tensor(dst, dtype=torch.long)
    
    return [src, dst]

In [12]:
def prep_patient_data(patient_name):
    # no_electrodes: E
    patient_segments = []
    patient_adjs = []
    patient_labels = []
    
    regex = re.compile(f'{patient_name}\D[0-9.a-z_A-Z]*')
    
    
    for file_name in files_names:
        if regex.match(file_name):
            # set label
            if 'sp' in file_name:
                label = torch.tensor(1.0, dtype=torch.float)
            elif 'ns' in file_name:
                label = torch.tensor(0.0, dtype=torch.float)
            
            segment = loadmat(os.path.join(data_folder, file_name))['segment']
            
            # 1) wpli_matrix: (no_connectivities, num_freq_bands)
            wpli_matrix = wpli_matrix_generator(segment, fs, nperseg, noverlap, freq_bands=freq_bands)
            
            
            # 2) Thresholding, we can either use surrogates or not
            adjacencies = threshold_func(wpli_matrix, USE_SURROGATES, segment, no_surrogates=no_surrogates, freq_FC_method='wpli', surrogate_thr=surrogate_thr, freq_bands=freq_bands)
            del wpli_matrix
            
            adjacencies = [adjacencies_converter(adj, montage) for adj in adjacencies]
            # adjacencies is a list of lists:
            # outer list: different frequency bands
            # inner list: src & dst 1-D torch.long tensors
            
            
            patient_segments.append(torch.from_numpy(segment).float())
            patient_adjs.append(adjacencies)
            patient_labels.append(label)
    
    return patient_segments, patient_adjs, patient_labels


In [13]:
def prepare_dataset(patients_names):
    group_segments = []
    group_adjs = []
    group_labels = []
    
    #with mp.Pool(processes=mp.cpu_count() - 1) as p:
    #    for patient_segments, patient_adjs, patient_labels in p.imap_unordered(prep_patient_data, patients_names):
    #        group_segments += patient_segments
    #        group_adjs += patient_adjs
    #        group_labels += patient_labels
    
    for patient_name in patients_names:
        patient_segments, patient_adjs, patient_labels = prep_patient_data(patient_name)
        group_segments += patient_segments
        group_adjs += patient_adjs
        group_labels += patient_labels
    
    return group_segments, group_adjs, group_labels

In [None]:
train_segments, train_adjs, train_labels = prepare_dataset(train_patients)
val_segments, val_adjs, val_labels = prepare_dataset(val_patients)

In [14]:
def median_IQR_func(group_segments):
    numbers = []
    for segment in group_segments:
        for row in range(segment.shape[0]):
            for col in range(segment.shape[1]):
                numbers.append(segment[row, col].item())
    
    quantiles = torch.tensor(numbers).quantile( torch.tensor([0.25, 0.5, 0.75]) )
    median = quantiles[1].item()
    IQR = quantiles[2].item() - quantiles[0].item()
    
    return median, IQR


In [None]:
def scaler_func(group_segments, median, IQR):
    # Scales the group_segments list inplace
    for idx in range(len(group_segments)):
        group_segments[idx] = ((group_segments[idx] - median) / IQR)
    
    return group_segments


In [None]:
if SCALE:
    median, IQR = median_IQR_func(train_segments)
    train_segments = scaler_func(train_segments, median, IQR)
    val_segments = scaler_func(val_segments, median, IQR)

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, segments, adjs, labels, flip_prob, montage):
        self.segments = segments
        self.adjs = adjs
        self.labels = labels
        self.flip_prob = flip_prob
        
        if flip_prob > 0:
            self.idx_swaps_list, self.idx_swaps_dict = self.montage_swap_func(montage)
    
    
    @staticmethod
    def montage_swap_func(montage):
        name_to_idx_dict = {electrode_name: idx for idx, electrode_name in enumerate(montage)}
        
        label_swaps = [('Fp1', 'Fp2'), ('F7', 'F8'), ('F3', 'F4'), ('C3', 'C4'), ('P3', 'P4'), ('T3', 'T4'), ('T5', 'T6'), ('O1', 'O2')]
        
        idx_swaps_list = []
        for e_left, e_right in label_swaps:
            if (e_left in name_to_idx_dict) and (e_right in name_to_idx_dict):
                idx_swaps_list.append((name_to_idx_dict[e_left], name_to_idx_dict[e_right]))
        
        idx_swaps_dict = {}
        for e_left, e_right in idx_swaps_list:
            idx_swaps_dict[e_left] = e_right
            idx_swaps_dict[e_right] = e_left
        
        return idx_swaps_list, idx_swaps_dict
    
    
    def flip_lr(self, segment, adj):
        # flip segment ====> use self.idx_swaps_list
        flipped_segment = torch.clone(segment)
        for e_left, e_right in self.idx_swaps_list:
            flipped_segment[e_left], flipped_segment[e_right] = segment[e_right], segment[e_left]
        
        # flip adj. ====> use self.idx_swaps_dict
        # Note that adj is a list of lists:
        # outer list: different frequency bands
        # inner list: [src, dst] which are 1-D torch.long tensors of the same length
        flipped_adj = []
        for inner_list in adj:
            flipped_src, flipped_dst = torch.clone(inner_list[0]), torch.clone(inner_list[1])
            for i, (flipped_src_node, flipped_dst_node) in enumerate(zip(flipped_src, flipped_dst)):
                flipped_src_node, flipped_dst_node = flipped_src_node.item(), flipped_dst_node.item()
                if flipped_src_node in self.idx_swaps_dict:
                    flipped_src[i] = self.idx_swaps_dict[flipped_src_node]
                if flipped_dst_node in self.idx_swaps_dict:
                    flipped_dst[i] = self.idx_swaps_dict[flipped_dst_node]
            
            flipped_adj.append([flipped_src, flipped_dst])
        
        return flipped_segment, flipped_adj
    
    
    def __len__(self):
        return len(self.labels)
    
    
    def __getitem__(self, index):
        segment = self.segments[index]
        adj = self.adjs[index]
        num_nodes_in_example = torch.tensor(segment.shape[0])
        if random.random() < self.flip_prob:
            flipped_segment, flipped_adj = self.flip_lr(segment, adj)
            return num_nodes_in_example, flipped_segment, self.labels[index], flipped_adj
        else:
            return num_nodes_in_example, segment, self.labels[index], adj
        

In [15]:
def collate_fn(batch):
    global freq_bands
    
    # The structure of batch is as follows:
    # 1) outer most list ===> batch examples
    # 2) tuple ===> 4 elements: (num_nodes_in_examples in the example, segment, label, connectivities lists)
    # 3) middle list ===> freq_band_idx
    # 4) inner most list ===> src, dst
    
    num_nodes_in_examples, segments, labels, adjs = zip(*batch)
    del batch
    
    num_nodes_in_examples = torch.tensor(num_nodes_in_examples)  # 1-D: [#nodes_ex_1, #nodes_ex_2, ...., #nodes_ex_B]
    segments = torch.cat(segments, dim=0)                        # (#nodes_ex_1 + #nodes_ex_2 + .... + #nodes_ex_B, 300)
    labels = torch.tensor(labels)                                # 1-D: [label_ex_1, label_ex_2, ...., label_ex_B]
    
    # adjs now is a tuple of lists of lists of 1-D tensors
    # outer most tuple ===> batch examples
    # outer list ===> freq_band_idx
    # inner list ===> src, dst
    
    num_freq_bands = len(freq_bands)
    src_lists = [[] for _ in range(num_freq_bands)]
    dst_lists = [[] for _ in range(num_freq_bands)]
    increment = 0
    for example_idx, step in enumerate(num_nodes_in_examples):
        for freq_band_idx in range(num_freq_bands):
            src_lists[freq_band_idx].append(increment + adjs[example_idx][freq_band_idx][0])
            dst_lists[freq_band_idx].append(increment + adjs[example_idx][freq_band_idx][1])
            
        increment += step.item()
    
    
    indices = []
    for src_list, dst_list in zip(src_lists, dst_lists):
        src = torch.cat(src_list, dim=-1)
        dst = torch.cat(dst_list, dim=-1)
        
        i = torch.stack([dst, src], dim=0)
        indices.append(i)
    
    return num_nodes_in_examples, segments, labels, *indices
    
    '''
    sparse_matrices = []     # 1 sparse matrix per frequency band
    total_num_nodes = num_nodes_in_examples.sum().item()
    for src_list, dst_list in zip(src_lists, dst_lists):
        src = torch.cat(src_list, dim=-1)
        dst = torch.cat(dst_list, dim=-1)
        
        i = torch.stack([dst, src], dim=0)
        v = torch.ones((len(src),), dtype=torch.float)
        sparse_mat = torch.sparse_coo_tensor(i, v, [total_num_nodes, total_num_nodes])
        sparse_matrices.append(sparse_mat)
    
    
    return num_nodes_in_examples, segments, labels, *sparse_matrices
    '''


In [16]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, no_input_features, H, d_qk, d_v, dim, bias_att=True):
        super().__init__()
        self.H = H
        self.d_qk = d_qk
        self.d_v = d_v
        
        # dim: dimension across which to attend
        if dim not in {-1, -2}:
            raise NameError('Error: Enforce your attention dimension to either -1 or -2 !')
        self.dim = dim
        
        self.q = torch.nn.Parameter(torch.empty(H, d_qk))
        torch.nn.init.xavier_normal_(self.q)
        
        self.att_lin = torch.nn.Linear(in_features=no_input_features, out_features=(H * d_qk) + (H * d_v), bias=bias_att)
    
    
    
    @staticmethod
    def qkv_func(q, k, v):
        
        # num_freq_bands: F or E or T
        # q: (B, H, 1, d_qk)
        # k: (B, H, F or E or T, d_qk)
        # v: (B, H, F or E or T, d_v)
        
        B, H, F, d_qk = k.shape
        d_v = v.shape[-1]
        
        logits = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_qk)     # logits: (B, H, 1, F)
        s_m = logits.softmax(dim=-1)                                        # s_m: (B, H, 1, F)
        
        # (B, H, 1, F) * (B, H, F, d_v) = (B, H, 1, d_v) ===permute===> (B, 1, H, d_v) ===reshape===> (B, H * d_v)
        y = torch.matmul(s_m, v).transpose(-2,-3).reshape((B, H * d_v))
        
        return y, s_m    
    
    
    
    def forward(self, x):
        if self.dim == -1:
            # x: (B, C=no_input_features, T or E)
            x.transpose_(-1, -2)
        
        # x: (B, T or E or F, C=no_input_features)
        B, F, _ = x.shape
        
        # (B, T or E, C=no_input_features) ====self.att_lin====> ( B, F, (H * d_qk) + (H * d_v) )
        k, v = self.att_lin(x).split(split_size=[self.H * self.d_qk, self.H * self.d_v], dim=-1)
        
        q = self.q.expand(B, 1, self.H, self.d_qk).transpose(-2, -3)   # q: (B, self.H, 1, self.d_qk)
        k = k.reshape((B, F, self.H, self.d_qk)).transpose(-2, -3)     # k: (B, self.H, F, self.d_qk)
        v = v.reshape((B, F, self.H, self.d_v)).transpose(-2, -3)      # v: (B, self.H, F, self.d_v)
        
        y, s_m = self.qkv_func(q, k, v)          # y: (B, self.H * self.d_v), s_m: (B, self.H, 1, F)
        
        if self.dim == -1:
            x.transpose_(-1, -2)                                       # (B, C=no_input_features, F)
        
        # overall, x preserves its inputted shape & y shape (B, self.H * self.d_v)
        return y, s_m
    

In [None]:
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, warmup, max_num_iters):
        self.warmup = warmup
        self.max_num_iters = max_num_iters
        super().__init__(optimizer)
    
    
    def get_lr(self):
        # returns a list of the learning rates
        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
        return [base_lr * lr_factor for base_lr in self.base_lrs]
    
    
    def get_lr_factor(self, epoch):
        # Optional method that computes lr_factor, NOT learning rate itself.
        if epoch <= self.warmup:
            slope = (1 - 0) / (self.warmup - 0)
            lr_factor = slope * epoch
        else:
            f = 1 / (2 * (self.max_num_iters - self.warmup))
            lr_factor = 0.5 * ( 1 + math.cos( 2 * math.pi * f * (epoch - self.warmup) ) )
        
        return lr_factor


In [None]:
min_validation_loss = math.inf
best_epoch = 0
best_model_state = None
AP_at_min_val_loss = 0.
EARLY_STOPPING = False


In [17]:
class CustomLightningModule(pl.LightningModule):
    def __init__(self, num_freq_bands, dedicated_vgg=False, eps=0., train_eps=False, dedicated_GIN=True, attention=True, heads=1, d_qk=15, d_v=15):
        super().__init__()
        
        self.register_buffer('num_freq_bands', torch.tensor(num_freq_bands, dtype=torch.int))
        
        if dedicated_vgg:
            self.vggs = torch.nn.ModuleList([self.construct_vgg() for _ in range(num_freq_bands)])
        else:
            self.vggs = torch.nn.ModuleList([self.construct_vgg()])
        
        
        if dedicated_GIN:
            self.h1_lin = torch.nn.ModuleList([torch.nn.Linear(in_features=13, out_features=13) for _ in range(num_freq_bands)])
            self.h2_lin = torch.nn.ModuleList([torch.nn.Linear(in_features=13, out_features=13) for _ in range(num_freq_bands)])
            self.h3_lin = torch.nn.ModuleList([torch.nn.Linear(in_features=13, out_features=13) for _ in range(num_freq_bands)])
        else:
            self.h1_lin = torch.nn.ModuleList([torch.nn.Linear(in_features=13, out_features=13)])
            self.h2_lin = torch.nn.ModuleList([torch.nn.Linear(in_features=13, out_features=13)])
            self.h3_lin = torch.nn.ModuleList([torch.nn.Linear(in_features=13, out_features=13)])
        
        
        l = len(self.h1_lin)
        if train_eps:
            self.eps1 = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(eps)) for _ in range(l)])
            self.eps2 = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(eps)) for _ in range(l)])
            self.eps3 = torch.nn.ParameterList([torch.nn.Parameter(torch.tensor(eps)) for _ in range(l)])
        else:
            self.register_buffer('eps1', torch.tensor([eps for _ in range(l)]))
            self.register_buffer('eps2', torch.tensor([eps for _ in range(l)]))
            self.register_buffer('eps3', torch.tensor([eps for _ in range(l)]))
        
        self.attention = attention
        if attention:
            self.m_h_a = MultiHeadAttention(4 * 13, heads, d_qk, d_v, -2)
            self.class_mlp = torch.nn.Sequential(
                torch.nn.Linear(in_features=heads*d_v, out_features=10),
                torch.nn.LeakyReLU(),
                torch.nn.Dropout(p=p_dropout),
                torch.nn.Linear(in_features=10, out_features=5),
                torch.nn.LeakyReLU(),
                torch.nn.Dropout(p=p_dropout),
                torch.nn.Linear(in_features=5, out_features=1)
            )
        else:
            self.class_mlp = torch.nn.Sequential(
                torch.nn.Linear(in_features=num_freq_bands*4*13, out_features=30),
                torch.nn.LeakyReLU(),
                torch.nn.Dropout(p=p_dropout),
                torch.nn.Linear(in_features=30, out_features=10),
                torch.nn.LeakyReLU(),
                torch.nn.Dropout(p=p_dropout),
                torch.nn.Linear(in_features=10, out_features=1)
            )
        
        self.criterion = torch.nn.BCEWithLogitsLoss()
        self.automatic_optimization = False
    
    
    
    @staticmethod
    def construct_vgg():
        return torch.nn.Sequential(
            # (1, #nodes_ex_1 + .... + #nodes_ex_B, 300) ====> (2, #nodes_ex_1 + .... + #nodes_ex_B, 298)
            torch.nn.Conv2d(in_channels=1, out_channels=2, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (2, #nodes_ex_1 + .... + #nodes_ex_B, 298) ====> (3, #nodes_ex_1 + .... + #nodes_ex_B, 296)
            torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (3, #nodes_ex_1 + .... + #nodes_ex_B, 296) ==/2==> (3, #nodes_ex_1 + .... + #nodes_ex_B, 148)
            #torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (3, #nodes_ex_1 + .... + #nodes_ex_B, 148) ====> (4, #nodes_ex_1 + .... + #nodes_ex_B, 146)
            torch.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (4, #nodes_ex_1 + .... + #nodes_ex_B, 146) ====> (5, #nodes_ex_1 + .... + #nodes_ex_B, 144)
            torch.nn.Conv2d(in_channels=4, out_channels=5, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (5, #nodes_ex_1 + .... + #nodes_ex_B, 144) ==/2==> (5, #nodes_ex_1 + .... + #nodes_ex_B, 72)
            #torch.nn.Conv2d(in_channels=5, out_channels=5, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (5, #nodes_ex_1 + .... + #nodes_ex_B, 72) ====> (6, #nodes_ex_1 + .... + #nodes_ex_B, 70)
            torch.nn.Conv2d(in_channels=5, out_channels=6, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (6, #nodes_ex_1 + .... + #nodes_ex_B, 70) ====> (7, #nodes_ex_1 + .... + #nodes_ex_B, 68)
            torch.nn.Conv2d(in_channels=6, out_channels=7, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (7, #nodes_ex_1 + .... + #nodes_ex_B, 68) ====> (8, #nodes_ex_1 + .... + #nodes_ex_B, 66)
            torch.nn.Conv2d(in_channels=7, out_channels=8, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (8, #nodes_ex_1 + .... + #nodes_ex_B, 66) ====> (9, #nodes_ex_1 + .... + #nodes_ex_B, 64)
            torch.nn.Conv2d(in_channels=8, out_channels=9, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (9, #nodes_ex_1 + .... + #nodes_ex_B, 64) ==/2==> (9, #nodes_ex_1 + .... + #nodes_ex_B, 32)
            #torch.nn.Conv2d(in_channels=9, out_channels=9, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (9, #nodes_ex_1 + .... + #nodes_ex_B, 32) ====> (10, #nodes_ex_1 + .... + #nodes_ex_B, 30)
            torch.nn.Conv2d(in_channels=9, out_channels=10, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (10, #nodes_ex_1 + .... + #nodes_ex_B, 30) ====> (11, #nodes_ex_1 + .... + #nodes_ex_B, 28)
            torch.nn.Conv2d(in_channels=10, out_channels=11, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (11, #nodes_ex_1 + .... + #nodes_ex_B, 28) ====> (12, #nodes_ex_1 + .... + #nodes_ex_B, 26)
            torch.nn.Conv2d(in_channels=11, out_channels=12, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (12, #nodes_ex_1 + .... + #nodes_ex_B, 26) ====> (13, #nodes_ex_1 + .... + #nodes_ex_B, 24)
            torch.nn.Conv2d(in_channels=12, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (13, #nodes_ex_1 + .... + #nodes_ex_B, 24) ==/2==> (13, #nodes_ex_1 + .... + #nodes_ex_B, 12)
            #torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1,2), stride=(1,2)),
            torch.nn.MaxPool2d(kernel_size=(1,2)),
            
            # (13, #nodes_ex_1 + .... + #nodes_ex_B, 12) ====> (13, #nodes_ex_1 + .... + #nodes_ex_B, 10)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (13, #nodes_ex_1 + .... + #nodes_ex_B, 10) ====> (13, #nodes_ex_1 + .... + #nodes_ex_B, 8)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (13, #nodes_ex_1 + .... + #nodes_ex_B, 8) ====> (13, #nodes_ex_1 + .... + #nodes_ex_B, 6)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (13, #nodes_ex_1 + .... + #nodes_ex_B, 6) ====> (13, #nodes_ex_1 + .... + #nodes_ex_B, 4)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (13, #nodes_ex_1 + .... + #nodes_ex_B, 4) ====> (13, #nodes_ex_1 + .... + #nodes_ex_B, 2)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 3)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout),
            # (13, #nodes_ex_1 + .... + #nodes_ex_B, 2) ====> (13, #nodes_ex_1 + .... + #nodes_ex_B, 1)
            torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=(1, 2)),
            torch.nn.LeakyReLU(), torch.nn.Dropout(p=p_dropout)
        )
    
    
    
    @staticmethod
    def positional_embedding(xs, num_nodes_in_examples):
        # xs is a list of a single tensor OR num_freq_bands tensors. Each tensor in this list
        # has the shape (#nodes_ex_1 + #nodes_ex_2 + .... + #nodes_ex_B, d)
        
        # num_nodes_in_examples: 1-D tensor [#nodes_ex_1, #nodes_ex_2, ...., #nodes_ex_B]
        # xs is a list of 2-D tensors. Each tensor has the following structure:
        # 1st row: embedding of 1st electrode in 1st example
        # 2nd row: embedding of 2nd electrode in 1st example
        # 3rd row: embedding of 3rd electrode in 1st example
        # ....
        # (#nodes_ex_1)th row: embedding of (#nodes_ex_1)th electrode in 1st example
        # (#nodes_ex_1 + 1)th row: embedding of 1st electrode in 2nd example
        
        # Add positional embedding to make the GNN differentiate between electrodes
        unique_num_electrodes_in_examples = num_nodes_in_examples.unique()
        
        device = xs[0].device
        d = xs[0].shape[1]
        
        pe_dic = {}
        for E in unique_num_electrodes_in_examples:
            E = E.item()
            
            pe = torch.zeros((d, E), device=device)
            e = torch.arange(E, device=device).reshape(1, E)
            
            col_len = math.ceil(d / 2)
            col = torch.arange(start=1, end=col_len+1, device=device).reshape(col_len, 1)
            
            angle = e / (100 ** ((2 * col) / d))
            sin = torch.sin(angle)
            
            if d % 2 == 0:
                cos = torch.cos(angle)
            else:
                cos = torch.cos(angle[0:-1])
            
            #cos = torch.where(torch.tensor(d % 2 == 0, device=x.device), torch.cos(angle), torch.cos(angle[0:-1]))
            
            pe[0::2, :] = sin
            pe[1::2, :] = cos
            
            pe.transpose_(-1, -2)     # pe: (E, d)
            pe_dic[E] = pe
        
        
        # pe_dic: dictionary mapping the number of electrodes (E) to tensor of shape (E, d)
        start_idx_xs = 0
        for num_nodes_in_example in num_nodes_in_examples:
            num_nodes_in_example = num_nodes_in_example.item()
            end_idx_xs = start_idx_xs + num_nodes_in_example
            
            for i in range(len(xs)):
                xs[i][start_idx_xs : end_idx_xs] = xs[i][start_idx_xs : end_idx_xs] + pe_dic[num_nodes_in_example]
            
            start_idx_xs = end_idx_xs
        
        return xs
    
    
    
    @staticmethod
    def gin_forward(H_before, sparse_matrices, eps, h):
        H_after = []
        for freq_band_idx, sparse_mat in enumerate(sparse_matrices):
            if len(h) == 1:
                h_iter = h[0]
                eps_iter = eps[0]
            else:
                h_iter = h[freq_band_idx]
                eps_iter = eps[freq_band_idx]
            
            y = ((1 + eps_iter) * H_before[freq_band_idx]) + torch.sparse.mm(sparse_mat, H_before[freq_band_idx])
            y = h_iter(y)
            y = torch.nn.functional.leaky_relu(y)
            y = torch.nn.functional.dropout(y, p=p_dropout)
            H_after.append(y)
        
        return H_after
    
    
    
    @staticmethod
    def construct_H(num_nodes_in_examples, *Hs):
        num_freq_bands = len(Hs[0])
        batch_size = num_nodes_in_examples.numel()
        
        H = []
        for freq_band_idx in range(num_freq_bands):
            lis = [H_i[freq_band_idx] for H_i in Hs]
            H.append(torch.cat(lis, dim=-1))
        
        index = torch.arange(start=0, end=batch_size, step=1, device=num_nodes_in_examples.device)
        index = index.repeat_interleave(num_nodes_in_examples)
        
        H = [torch_scatter.scatter(h, index, dim=0) for h in H]
        H = [h.reshape(batch_size, 1, -1) for h in H]
        H = torch.cat(H, dim=1)                    # H: (B, num_freq_bands, -1)
        
        return H
    
    
    
    def forward(self, num_nodes_in_examples, segments, sparse_matrices):
        # num_nodes_in_examples: 1-D tensor [#nodes_ex_1, #nodes_ex_2, ...., #nodes_ex_B]
        # segments: (#nodes_ex_1 + #nodes_ex_2 + .... + #nodes_ex_B, 300)
        # sparse_matrices is a list. Each item of that list is a sparse matrix representing the
        # adjacency of a specific frequency band
        
        B = num_nodes_in_examples.numel()
        num_freq_bands = len(sparse_matrices)
        
        xs = []
        for vgg in self.vggs:
            x = vgg(segments.unsqueeze(dim=0))     # x: (L, #nodes_ex_1 + #nodes_ex_2 + .... + #nodes_ex_B, 1)
            x.squeeze_(dim=-1)                     # x: (L, #nodes_ex_1 + #nodes_ex_2 + .... + #nodes_ex_B)
            x.transpose_(-1, -2)                   # x: (#nodes_ex_1 + #nodes_ex_2 + .... + #nodes_ex_B, L=13)
            xs.append(x)
        
        xs = self.positional_embedding(xs, num_nodes_in_examples)
        
        if (len(xs) == 1) and (num_freq_bands > 1):
            H0 = xs * num_freq_bands
        else:
            H0 = xs
        
        H1 = self.gin_forward(H0, sparse_matrices, self.eps1, self.h1_lin)
        
        H2 = self.gin_forward(H1, sparse_matrices, self.eps2, self.h2_lin)
        
        H3 = self.gin_forward(H2, sparse_matrices, self.eps3, self.h3_lin)
        
        # Each of H0, H1, H2, and H3 are lists of length num_freq_bands.
        # Each item in that list is a 2-D tensor with
        # #rows = #nodes_ex_1 + #nodes_ex_2 + .... + #nodes_ex_B
        
        # collapse_H
        H = self.construct_H(num_nodes_in_examples, H0, H1, H2, H3)     # H: (B, num_freq_bands, 4 * 13)
        
        if self.attention:
            logits, s_m = self.m_h_a(H)     # logits: (B, self.H * self.d_v), s_m: (B, self.H, 1, F)
            logits = self.class_mlp(logits)              # (B, 1)
        else:
            logits = H.reshape(B, -1)       # (B, num_freq_bands*4*13)
            logits = self.class_mlp(logits)     # (B, 1)
        
        logits.squeeze_(dim=-1)             # (B,)
        
        if self.attention:
            return logits, s_m
        else:
            return logits
    
    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=0)
        cosine_scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=100, max_num_iters=self.trainer.max_epochs)
        swa_scheduler = SWALR(optimizer, swa_lr=0.0001, anneal_epochs=100, anneal_strategy='cos')
        return [optimizer], [cosine_scheduler, swa_scheduler]
    
    
    
    def train_dataloader(self):
        dataset = CustomDataset(train_segments, train_adjs, train_labels, 0.5, montage)
        # mp.cpu_count()-1
        return torch.utils.data.DataLoader(dataset, batch_size=train_B, shuffle=True, num_workers=0,
                                           collate_fn=collate_fn, pin_memory=True, drop_last=True)
    
    
    
    def val_dataloader(self):
        dataset = CustomDataset(val_segments, val_adjs, val_labels, 0, montage)
        return torch.utils.data.DataLoader(dataset, batch_size=val_B, shuffle=False, num_workers=0,
                                           collate_fn=collate_fn, pin_memory=True, drop_last=False)
    
    
    
    def training_step(self, batch, batch_idx):
        #num_nodes_in_examples, segments, y, *sparse_matrices = batch
        num_nodes_in_examples, segments, y, *indices = batch
        
        total_num_nodes = num_nodes_in_examples.sum().item()
        sparse_matrices = []
        for i in indices:
            v = torch.ones((i.shape[1],), dtype=torch.float, device=i.device)
            sparse_mat = torch.sparse_coo_tensor(i, v, [total_num_nodes, total_num_nodes])
            sparse_matrices.append(sparse_mat)
        
        
        self.eval()
        with torch.no_grad():
            if self.attention:
                logits_before_optimizer_step, s_m = self(num_nodes_in_examples, segments, sparse_matrices)
            else:
                logits_before_optimizer_step = self(num_nodes_in_examples, segments, sparse_matrices)
            loss_before_optimizer_step = self.criterion(logits_before_optimizer_step, y)
        self.train()
        
        
        opt = self.optimizers()
        def closure():
            if self.attention:
                logits, s_m = self(num_nodes_in_examples, segments, sparse_matrices)
            else:
                logits = self(num_nodes_in_examples, segments, sparse_matrices)
            loss = self.criterion(logits, y)
            opt.zero_grad()
            self.manual_backward(loss)
            return loss
        
        opt.step(closure=closure)
        
        return {'loss_before_optimizer_step': loss_before_optimizer_step.detach()}
    
    
    
    def training_epoch_end(self, training_step_outputs):
        if self.global_rank == 0:
            train_loss = torch.stack([dic['loss_before_optimizer_step'] for dic in training_step_outputs]).mean()
            print(f'Epoch: {self.current_epoch}')
            print(f'Training loss BEFORE optimizer step: {round(train_loss.item(), 5)}')
            print('--------------------------------------')
    
    
    
    def validation_step(self, batch, batch_idx):
        # num_nodes_in_examples, segments, y, *sparse_matrices = batch
        num_nodes_in_examples, segments, y, *indices = batch
        
        total_num_nodes = num_nodes_in_examples.sum().item()
        sparse_matrices = []
        for i in indices:
            v = torch.ones((i.shape[1],), dtype=torch.float, device=i.device)
            sparse_mat = torch.sparse_coo_tensor(i, v, [total_num_nodes, total_num_nodes])
            sparse_matrices.append(sparse_mat)
        
        
        if self.attention:
            logits, s_m = self(num_nodes_in_examples, segments, sparse_matrices)
        else:
            logits = self(num_nodes_in_examples, segments, sparse_matrices)
        loss = self.criterion(logits, y)
        return {'loss': loss, 'y': y, 'logits': logits}
    
    
    
    def validation_epoch_end(self, validation_step_outputs):
        global min_validation_loss, best_epoch, best_model_state, AP_at_min_val_loss
        
        swa_start_epoch = 1000
        
        cosine_scheduler, swa_scheduler = self.lr_schedulers()
        
        if self.current_epoch > 0:
            if self.current_epoch < swa_start_epoch:
                cosine_scheduler.step()
            else:
                swa_scheduler.step()
            
        
        
        if self.global_rank == 0:
            
            if self.current_epoch > (swa_start_epoch + swa_scheduler.anneal_epochs):
                swa_model.update_parameters(self)
                
            
            val_loss = torch.stack([dic['loss'] for dic in validation_step_outputs]).mean()
            val_loss = val_loss.item()
            print(f'val_loss: {val_loss}')
            
            y = torch.cat([dic['y'] for dic in validation_step_outputs]).cpu()
            logits = torch.cat([dic['logits'] for dic in validation_step_outputs]).cpu()
            pred = torch.sigmoid(logits)
            
            y = np.array(y)
            pred = np.array(pred)
            
            try:
                validation_AP = average_precision_score(y, pred)
                
                if (val_loss < min_validation_loss) and (self.current_epoch > 1):
                    min_validation_loss = val_loss
                    best_epoch = self.current_epoch
                    best_model_state = deepcopy(self.state_dict())
                    AP_at_min_val_loss = validation_AP
                
                
                print(f'validation_AP = {round(validation_AP, 5)} ... AP_at_min_val_loss = {round(AP_at_min_val_loss, 5)} @ epoch {best_epoch}')
                
            except ValueError:
                print(f'ValueError @ epoch: {self.current_epoch} ====> y = {y}')
            
            print('=======================================================================================')



In [None]:
num_freq_bands = len(freq_bands)
my_lightning_module = CustomLightningModule(num_freq_bands, dedicated_vgg=dedicated_vgg, eps=eps, train_eps=train_eps, dedicated_GIN=dedicated_GIN, attention=attention, heads=1, d_qk=25, d_v=25)

swa_model = AveragedModel(my_lightning_module)

trainer = pl.Trainer(gpus=0, enable_checkpointing=False, enable_progress_bar=False, logger=False, max_epochs=400)

trainer.fit(my_lightning_module)

torch.optim.swa_utils.update_bn(my_lightning_module.train_dataloader(), swa_model)


In [None]:
PATH = f'models/g2_FC_GIN_10k_AP_{round(100 * AP_at_min_val_loss, 2)}.pt'
torch.save(best_model_state, PATH)

In [18]:
import pandas as pd


num_freq_bands = len(freq_bands)
model = CustomLightningModule(num_freq_bands, dedicated_vgg=dedicated_vgg,
                              eps=eps, train_eps=train_eps,
                              dedicated_GIN=dedicated_GIN, attention=attention,
                              heads=1, d_qk=25, d_v=25)

#IEDs_names = [item for item in os.listdir(data_folder) if 'sp' in item]
NIEDs_names = [item for item in os.listdir(data_folder) if 'ns' in item]


In [19]:
lis = [item for item in os.listdir('models/') if 'g1_FC_GIN' in item]
g1_model_name = max(lis)
train_segments, _, _ = prepare_dataset(g2 + g3 + g4 + g5)
g1_median, g1_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g2_FC_GIN' in item]
g2_model_name = max(lis)
train_segments, _, _ = prepare_dataset(g1 + g3 + g4 + g5)
g2_median, g2_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g3_FC_GIN' in item]
g3_model_name = max(lis)
train_segments, _, _ = prepare_dataset(g1 + g2 + g4 + g5)
g3_median, g3_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g4_FC_GIN' in item]
g4_model_name = max(lis)
train_segments, _, _ = prepare_dataset(g1 + g2 + g3 + g5)
g4_median, g4_IQR = median_IQR_func(train_segments)


lis = [item for item in os.listdir('models/') if 'g5_FC_GIN' in item]
g5_model_name = max(lis)
train_segments, _, _ = prepare_dataset(g1 + g2 + g3 + g4)
g5_median, g5_IQR = median_IQR_func(train_segments)


#s = pd.Series(index=IEDs_names, dtype='float64')
s = pd.Series(index=NIEDs_names, dtype='float64')


#for IED_name in IEDs_names:
for NIED_name in NIEDs_names:
    #segment = loadmat(os.path.join(data_folder, IED_name))['segment']
    segment = loadmat(os.path.join(data_folder, NIED_name))['segment']
            
    # 1) wpli_matrix: (no_connectivities, num_freq_bands)
    wpli_matrix = wpli_matrix_generator(segment, fs, nperseg, noverlap, freq_bands=freq_bands)
    
    
    # 2) Thresholding, we can either use surrogates or not
    adjacencies = threshold_func(wpli_matrix, USE_SURROGATES, segment, no_surrogates=no_surrogates, freq_FC_method='wpli', surrogate_thr=surrogate_thr, freq_bands=freq_bands)
    del wpli_matrix
    
    adjacencies = [adjacencies_converter(adj, montage) for adj in adjacencies]
    # adjacencies is a list of lists:
    # outer list: different frequency bands
    # inner list: src & dst 1-D torch.long tensors
    
    sparse_matrices = []
    for adj in adjacencies:
        src, dst = adj
        i = torch.stack([dst, src], dim=0)
        v = torch.ones((i.shape[1],), dtype=torch.float)
        sparse_matrices.append( torch.sparse_coo_tensor(i, v, [19, 19]) )
    
    
    segment = torch.tensor(segment).float()
    
    #patient_name = IED_name[: IED_name.find('_')]
    patient_name = NIED_name[: NIED_name.find('_')]
    
    if patient_name in g1:
        segment = (segment - g1_median) / g1_IQR
        model.load_state_dict(torch.load(f'models/{g1_model_name}'))
    elif patient_name in g2:
        segment = (segment - g2_median) / g2_IQR
        model.load_state_dict(torch.load(f'models/{g2_model_name}'))
    elif patient_name in g3:
        segment = (segment - g3_median) / g3_IQR
        model.load_state_dict(torch.load(f'models/{g3_model_name}'))
    elif patient_name in g4:
        segment = (segment - g4_median) / g4_IQR
        model.load_state_dict(torch.load(f'models/{g4_model_name}'))
    elif patient_name in g5:
        segment = (segment - g5_median) / g5_IQR
        model.load_state_dict(torch.load(f'models/{g5_model_name}'))
    else:
        raise NameError('Unknown Patient')
    
    
    model.eval()
    
    if attention:
        logit, _ = model(torch.tensor([19]), segment, sparse_matrices)
    else:
        logit = model(torch.tensor([19]), segment, sparse_matrices)
    
    
    #s.at[IED_name] = torch.sigmoid(logit).item()
    s.at[NIED_name] = torch.sigmoid(logit).item()

In [20]:
s

RC_24_ns.mat      8.600377e-03
FEsg_38_ns.mat    1.120687e-04
HB_2_ns.mat       4.159114e-08
FEsg_26_ns.mat    2.963187e-01
FigSa_9_ns.mat    4.961883e-02
                      ...     
LM_3_ns.mat       6.453300e-06
RC_6_ns.mat       8.747430e-06
FEsg_19_ns.mat    1.311908e-01
NKra_10_ns.mat    6.252312e-03
PAT48_ns_6.mat    1.284350e-04
Length: 593, dtype: float64

In [21]:
#with pd.ExcelWriter('IEDs_results.xlsx', mode='a') as writer:
with pd.ExcelWriter('NIEDs_results.xlsx', mode='a') as writer:
    s.to_excel(writer, sheet_name='FC_GIN', header=False)

In [None]:
# SWA

loss_list = []
y_list = []
logits_list = []
s_m_list = []

dataset = CustomDataset(val_segments, val_adjs, val_labels, 0, montage)
dl = torch.utils.data.DataLoader(dataset, batch_size=val_B, shuffle=False, num_workers=mp.cpu_count()-1,
                                 collate_fn=collate_fn, pin_memory=False, drop_last=False)

swa_model.eval()
with torch.no_grad():
    for batch in dl:
        # num_nodes_in_examples, segments, y, *sparse_matrices = batch
        num_nodes_in_examples, segments, y, *indices = batch
        
        total_num_nodes = num_nodes_in_examples.sum().item()
        sparse_matrices = []
        for i in indices:
            v = torch.ones((i.shape[1],), dtype=torch.float, device=i.device)
            sparse_mat = torch.sparse_coo_tensor(i, v, [total_num_nodes, total_num_nodes])
            sparse_matrices.append(sparse_mat)
        
        
        if attention:
            logits, s_m = swa_model(num_nodes_in_examples, segments, sparse_matrices)
        else:
            logits = swa_model(num_nodes_in_examples, segments, sparse_matrices)
        loss = torch.nn.BCEWithLogitsLoss()(logits, y)
        
        loss_list.append(loss)
        y_list.append(y)
        logits_list.append(logits)
        s_m_list.append(s_m)
    
    loss = torch.stack(loss_list)
    y = torch.cat(y_list, dim=-1)
    logits = torch.cat(logits_list, dim=-1)
    pred = torch.sigmoid(logits)
    if attention:
        s_m = torch.cat(s_m_list, dim=0)


y = np.array(y)
pred = np.array(pred)
if attention:
    s_m = np.array(s_m)

AP = average_precision_score(y, pred)

print(f'AP = {AP}')
print(f'loss.mean() = {loss.mean()}')

In [None]:
# Attention

num_freq_bands = len(freq_bands)
my_mod = CustomLightningModule(num_freq_bands, dedicated_vgg=dedicated_vgg, eps=eps, train_eps=train_eps, dedicated_GIN=dedicated_GIN, attention=attention, heads=1, d_qk=15, d_v=15)
my_mod.load_state_dict(best_model_state)

loss_list = []
y_list = []
logits_list = []
s_m_list = []

dataset = CustomDataset(val_segments, val_adjs, val_labels, 0, montage)
dl = torch.utils.data.DataLoader(dataset, batch_size=val_B, shuffle=False, num_workers=mp.cpu_count()-1,
                                 collate_fn=collate_fn, pin_memory=False, drop_last=False)

my_mod.eval()
with torch.no_grad():
    for batch in dl:
        # num_nodes_in_examples, segments, y, *sparse_matrices = batch
        num_nodes_in_examples, segments, y, *indices = batch
        
        total_num_nodes = num_nodes_in_examples.sum().item()
        sparse_matrices = []
        for i in indices:
            v = torch.ones((i.shape[1],), dtype=torch.float, device=i.device)
            sparse_mat = torch.sparse_coo_tensor(i, v, [total_num_nodes, total_num_nodes])
            sparse_matrices.append(sparse_mat)
        
        
        logits, s_m = my_mod(num_nodes_in_examples, segments, sparse_matrices)
        loss = torch.nn.BCEWithLogitsLoss()(logits, y)
        
        loss_list.append(loss)
        y_list.append(y)
        logits_list.append(logits)
        s_m_list.append(s_m)
    
    loss = torch.stack(loss_list)
    y = torch.cat(y_list, dim=-1)
    logits = torch.cat(logits_list, dim=-1)
    pred = torch.sigmoid(logits)
    s_m = torch.cat(s_m_list, dim=0)


y = np.array(y)
pred = np.array(pred)
s_m = np.array(s_m).squeeze()

AP = average_precision_score(y, pred)

print(f'AP = {AP}')
print(f'loss.mean() = {loss.mean()}')

In [None]:
print(s_m.shape)
print(s_m.mean(axis=0))

In [None]:
s_m_argmax = s_m.argmax(axis=1)
s_m_argmax

In [None]:
unique, counts = np.unique(s_m_argmax, return_counts=True)

unique = [freq_bands[i] for i in unique]

print(f'unique = {unique}')
print(f'couts = {counts}')

In [None]:
freq_band_idx = 3

mask = (s_m_argmax == freq_band_idx)
u, c = np.unique(y[mask], return_counts=True)
print(f'u, c = ({u}, {c})')
diff = np.abs(y[mask] - pred[mask])
print(f'mean difference: {diff.mean()}')