In [None]:
from petastorm import make_reader
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
import sys
from sklearn import metrics
import pandas as pd
import seaborn as sn
import os
import xarray as xr
import warnings
from tqdm import tqdm
sys.path.append(os.path.join(os.getcwd(), ".."))
from shared_utils.utils_data import format_data_to_xarray_2020,format_data_to_xarray

path_formatted_glasgow = "/workspaces/maitrise/data/20220902_data_physio_formatted_merged/merged/dataParquet"
path_petastorm = f"file:///{path_formatted_glasgow}"

path_formated_cinc2011= "/workspaces/maitrise/data/20221006_physio_quality/set-a/dataParquet"
path_petastorm_cinc2011 = f"file:///{path_formated_cinc2011}"

save_path = "/workspaces/maitrise/results"



In [None]:
if (save_path is not None) and (not os.path.exists(save_path)):
        os.makedirs(save_path)



##What we need : 
##1) Patient ID
##2) Lead names(save as indexes)
##3) Pathology (repeated on eahc lead)
##4) ECG signal
i_stop = 50
ECG_signals = torch.zeros((50,5000,12))
Leads_index  = torch.zeros((50,12))
SQA_score = torch.zeros((50,12))
Pathologys = torch.zeros((50))
with make_reader(path_petastorm) as reader:
    for idx, sample in enumerate(reader):
        if idx == 0:
            lead_names = sample.signal_names.astype(str)
        if len(sample.signal[:,0])!=5000:
            continue
        
        if idx == i_stop :
            break
        else : 

            Leads_index[idx,:] = torch.tensor(list(range(12)))
            ECG_signals[idx,:,:] = torch.tensor(sample.signal[:,:])
            if 0 in sample.score_classes:
                Pathologys[idx] = torch.tensor(int(sample.diagnostics[0]))
            else : 
                Pathologys[idx] = torch.tensor(int(sample.diagnostics[np.where(sample.diagnostics == sample.score_classes[0])]))

        


### Convention :

For the lead names, we will use the following convention : 

|I|II|III|aVR|aVL|aVF|V1|V2|V3|V4|V5|V6|
|---|---|---|---|---|---|---|---|---|---|---|---|
|0|1|2|3|4|5|6|7|8|9|10|11|

### GAN Generator and Discriminator

Class for Discrimnator and Generator. *THIS IS TEMPORARY*. It will be displaced in another git later. It willbe only used here to check if the Conditional GAN do what it must do

In [None]:
class Discriminator(nn.Module):
    def __init__(self,in_channels,features = 32,n_lead = 12,n_classes = 13):
        ##in_channels = 1
        super(Discriminator,self).__init__()
        ##Input : Batch_size*1*len_seq
        #self.Embedding_path = nn.Embedding(n_classes,len_seq)
        #self.Embedding_lead = nn.Embedding(n_lead,len_seq)
        self.model = nn.Sequential(
            self._Block(in_channels,features,1),
            self._Block(features,features,2),
            self._Block(features,features*2,1),
            self._Block(features*2,features*2,2),
            self._Block(features*2,features*4,1),
            self._Block(features*4,features*4,2),
            self._Block(features*4,features*8,1),
            self._Block(features*8,features*8,2),
            nn.Flatten()
        )
        self.out1 = nn.Sequential(nn.Linear(features*8,1),nn.Sigmoid())
        #self.out2 = nn.Sequential(nn.Linear(features,n_classes),nn.Softmax())

    def _Block(self,in_channels,out_channels,stride,kernel_size=3,padding=1):
        return nn.Sequential(
            nn.Conv1d(in_channels,out_channels,kernel_size,stride,padding,padding_mode = "zeros",bias = False),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25)
        )
    def forward(self,x,lab_signal=0,lab_path=0):
        #d_in = torch.cat([x,self.Embedding_lead(lab_signal),self.Embedding_path(lab_path)],0)
        o = self.model(x)
        validity = self.out1(o)
        #path_class = self.out2(o)
        return validity


class Generator(nn.Module):
    def __init__(self,z_dim,in_channels,features = 32,len_seq = 5000):
        super(Generator,self).__init__()
        self.features = features
        self.z_dim = z_dim
        self.len_seq = len_seq
        self.down_1 = self._downsampling_block(self.features,self.features,2)
        self.skip_attention_1 = self._downsampling_block(self.features,self.features,1,kernel_size=2,dillation_rate=2)
        self.down_2 = self._downsampling_block(self.features*2,self.features*2,2)
        self.skip_attention_2 = self._downsampling_block(self.features*2,self.features*2,1,kernel_size=2,dillation_rate=2)
        self.down_3 = self._downsampling_block(self.features*4,self.features*4,2)
        self.skip_attention_3 = self._downsampling_block(self.features*4,self.features*4,1,kernel_size=2,dillation_rate=2)
        self.up_3 = self._upsampling_block(self.features*4,self.features*4,2)
        self.up_2 = self._upsampling_block(self.features*4,self.features*2,2)
        self.up_1 = self._upsampling_block(self.features*2,self.features,2)
        self.final = nn.Sequential(nn.Conv1d(self.features,in_channels,kernel_size = 3,stride = 1,padding = 1,padding_mode = "zeros"),nn.Sigmoid())

    def _downsampling_block(self,in_channels,out_channels,stride,kernel_size=3,padding=1,dillation_rate = 1):
        return nn.Sequential(
            nn.Conv1d(in_channels,out_channels,kernel_size,stride,padding,padding_mode = "zeros",bias = False,dilation = dillation_rate),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(0.2)
            )

    def _upsampling_block(self,in_channels,out_channels,stride,kernel_size=3,padding=1):
        return nn.Sequential(
            nn.ConvTranspose1d(in_channels,out_channels,kernel_size,stride,padding,padding_mode = "zeros",bias = False),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(0.2)
            )

    def _novel_residual_block(self,x,out_channels,stride,kernel_size=3,padding=1):
        fe_add = nn.Conv1d(x.size(1),out_channels,kernel_size,stride,padding,padding_mode="zeros")(x)
        fe = nn.BatchNorm1d(out_channels)(fe_add)
        fe = nn.LeakyReLU(0.25)(fe)
        fe = torch.add(fe,fe_add)
        return fe

    def forward(self,x):
        gen = self._novel_residual_block(x,self.features,1)
        skip_1 = self.skip_attention_1(gen)
        gen = self.down_1(gen)
        gen = self._novel_residual_block(gen,self.features*2,1)
        skip_2 = self.skip_attention_2(gen)
        gen = self.down_2(gen)
        gen = self._novel_residual_block(gen,self.features*4,1)
        skip_3 = self.skip_attention_3(gen)
        gen = self.down_3(gen)
        gen = self.up_3(gen)
        gen = torch.add(gen,skip_3)
        gen = self.up_2(gen)
        gen = torch.add(gen,skip_2)
        gen = self.up_1(gen)
        gen = torch.add(gen,skip_1)
        ECG_reconstruct = self.final(gen)
        return ECG_reconstruct


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv1d,nn.ConvTranspose1d,nn.BatchNorm1d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

def test():
    N,in_channels,HW= 8,5000,1
    z_dim = 1000
    x = torch.rand((N,in_channels,HW))
    disc = Discriminator(in_channels,features= 8)
    assert disc(x).shape == (N,1)
    print("Discriminator OK")

    ##Generator :
    gen = Generator(z_dim,in_channels,8)
    z = torch.rand((N,z_dim,1))
    assert gen(z).shape== (N,in_channels,1)
    print("Generator OK")


In [None]:
##The dataset 

class DatasetECGReconstruction(Dataset):
    def __init__(self,signals,leads,pathology):#quality_scores):
        self.ECGs = signals
        self.leads = leads
        self.path = pathology
        #self.SQA = quality_scores

    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, index):
        Signal = self.ECGs[index]
        path_patient = self.path[index]
        #score_leads = self.SQA[index]
        leads = self.leads
        return {"ECG recording":Signal,"pathology":path_patient,"Leads index":leads}

In [None]:
d_ECG = DatasetECGReconstruction(ECG_signals,Leads_index,Pathologys)

In [None]:
DL_ECG = DataLoader(d_ECG,batch_size=8,shuffle = True)
print(DL_ECG)

In [None]:
for (index,batch) in enumerate(DL_ECG):
    print(f"For index {index},we have the follwoing batch : ")
    print(batch["pathology"])