In [None]:
from petastorm import make_reader
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
import sys
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from sklearn import metrics
import pandas as pd
import seaborn as sn
import os
import xarray as xr
import warnings
from tqdm import tqdm
sys.path.append(os.path.join(os.getcwd(), ".."))
from Metrics import Our_SQA_method

path_formatted_glasgow = "/workspaces/maitrise/data/20220902_data_physio_formatted_merged/merged/dataParquet"
path_petastorm = f"file:///{path_formatted_glasgow}"

path_formated_cinc2011= "/workspaces/maitrise/data/20221006_physio_quality/set-a/dataParquet"
path_petastorm_cinc2011 = f"file:///{path_formated_cinc2011}"

save_path = "/workspaces/maitrise/results"




In [None]:
path_labels_patho = "/workspaces/maitrise/data/Dx_map.csv"
labels = pd.read_csv(path_labels_patho)
labels = labels.to_numpy()
labels = labels[:,1]

In [None]:
if (save_path is not None) and (not os.path.exists(save_path)):
        os.makedirs(save_path)
##What we need : 
##1) Patient ID
##2) Lead names(save as indexes)
##3) Pathology (repeated on eahc lead)
##4) ECG signal
i_stop = 150
counter = 0
ECG_signals = torch.zeros((i_stop,5000,12))
Leads_index  = torch.zeros((12))
SQA_score = torch.zeros((i_stop,12))
Pathologies = torch.zeros((i_stop,12))
with make_reader(path_petastorm) as reader:
    for idx, sample in enumerate(reader):
        if idx == 0:
            lead_names = sample.signal_names.astype(str)
            Leads_index[:] = torch.tensor(list(range(12)),dtype = torch.int8)
        if len(sample.signal[:,0])!=5000:
            continue
        
        if counter == i_stop :
            break
        else : 

            
            ECG_signals[counter,:,:] = torch.tensor(sample.signal[:,:])
            SQA_score[counter,:] = torch.tensor(Our_SQA_method.SQA_method_lead_score(sample.signal[:,:].T,500))
            if 0 in sample.score_classes:
                Pathologies[counter] = torch.tensor(np.where(sample.diagnostics[0]==labels)[0],dtype = torch.int64)
            else : 
                Pathologies[counter,:] = torch.tensor(np.where(sample.diagnostics[np.where(sample.diagnostics == sample.score_classes[0])[0][0]]==labels)[0],dtype = torch.int64)
            counter +=1
        


### Convention :

For the lead names, we will use the following convention : 

|I|II|III|aVR|aVL|aVF|V1|V2|V3|V4|V5|V6|
|---|---|---|---|---|---|---|---|---|---|---|---|
|0|1|2|3|4|5|6|7|8|9|10|11|

### GAN Generator and Discriminator

Class for Discrimnator and Generator. *THIS IS TEMPORARY*. It will be displaced in another git later. It willbe only used here to check if the Conditional GAN do what it must do

In [None]:
class Discriminator(nn.Module):
    def __init__(self,in_channels,features = 32,n_classes = 111,len_time_serie = 5000):
        ##in_channels = 5000 (our signal)
        super(Discriminator,self).__init__()
        ##Input : Batch_size*1*len_seq
        self.len_seq = len_time_serie
        self.pathologies = n_classes
        self.Embedding_path = nn.Embedding(n_classes,n_classes)
        self.model = nn.Sequential(
            self._Block(in_channels+n_classes,features,1),
            self._Block(features,features,2),
            self._Block(features,features*2,1),
            self._Block(features*2,features*2,2),
            self._Block(features*2,features*4,1),
            self._Block(features*4,features*4,2),
            self._Block(features*4,features*8,1),
            self._Block(features*8,features*8,2),
            nn.Flatten()
        )
        self.out1 = nn.Sequential(nn.Linear(features*8,1),nn.Sigmoid())
        self.out2 = nn.Sequential(nn.Linear(features*8,n_classes),nn.Softmax())

    def _Block(self,in_channels,out_channels,stride,kernel_size=3,padding=1):
        return nn.Sequential(
            nn.Conv1d(in_channels,out_channels,kernel_size,stride,padding,padding_mode = "zeros",bias = False),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25)
        )
    def forward(self,x,lab_path):

        embedding_path = torch.reshape(self.Embedding_path(lab_path),(x.size(0),self.pathologies,x.size(2)))
        d_in = torch.cat([x,embedding_path],1)
        o = self.model(d_in)
        validity = self.out1(o)
        path_class = self.out2(o)
        return validity,path_class


class Generator(nn.Module):
    def __init__(self,z_dim,in_channels,features = 32,len_seq = 5000,n_lead = 12,n_classes = 111):
        super(Generator,self).__init__()
        self.Embedding_path = nn.Embedding(n_classes,n_classes)
        self.Embedding_lead = nn.Embedding(n_lead,n_lead)
        self.features = features
        self.z_dim = z_dim
        self.len_seq = len_seq
        self.lead = n_lead
        self.pathologies = n_classes
        self.down_1 = self._downsampling_block(self.features,self.features,2)
        self.skip_attention_1 = self._downsampling_block(self.features,self.features,1,kernel_size=2,dillation_rate=2)
        self.down_2 = self._downsampling_block(self.features*2,self.features*2,2)
        self.skip_attention_2 = self._downsampling_block(self.features*2,self.features*2,1,kernel_size=2,dillation_rate=2)
        self.down_3 = self._downsampling_block(self.features*4,self.features*4,2)
        self.skip_attention_3 = self._downsampling_block(self.features*4,self.features*4,1,kernel_size=2,dillation_rate=2)
        self.up_3 = self._upsampling_block(self.features*4,self.features*4,2)
        self.up_2 = self._upsampling_block(self.features*4,self.features*2,2)
        self.up_1 = self._upsampling_block(self.features*2,self.features,2)
        self.final = nn.Sequential(nn.Conv1d(self.features,in_channels,kernel_size = 3,stride = 1,padding = 1,padding_mode = "zeros"),nn.Sigmoid())

    def _downsampling_block(self,in_channels,out_channels,stride,kernel_size=3,padding=1,dillation_rate = 1):
        return nn.Sequential(
            nn.Conv1d(in_channels,out_channels,kernel_size,stride,padding,padding_mode = "zeros",bias = False,dilation = dillation_rate),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(0.2)
            )

    def _upsampling_block(self,in_channels,out_channels,stride,kernel_size=3,padding=1):
        return nn.Sequential(
            nn.ConvTranspose1d(in_channels,out_channels,kernel_size,stride,padding,padding_mode = "zeros",bias = False),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(0.2)
            )

    def _novel_residual_block(self,x,out_channels,stride,kernel_size=3,padding=1):
        fe_add = nn.Conv1d(x.size(1),out_channels,kernel_size,stride,padding,padding_mode="zeros")(x)
        fe = nn.BatchNorm1d(out_channels)(fe_add)
        fe = nn.LeakyReLU(0.25)(fe)
        fe = torch.add(fe,fe_add)
        return fe

    def forward(self,x,lead_label,path_label):
        embedding_lead = self.Embedding_lead(lead_label).unsqueeze(2)
        embedding_path = self.Embedding_path(path_label).unsqueeze(2)
        d_in = torch.cat([x,embedding_lead,embedding_path],dim=1)
        gen = self._novel_residual_block(d_in,self.features,1)
        skip_1 = self.skip_attention_1(gen)
        gen = self.down_1(gen)
        gen = self._novel_residual_block(gen,self.features*2,1)
        skip_2 = self.skip_attention_2(gen)
        gen = self.down_2(gen)
        gen = self._novel_residual_block(gen,self.features*4,1)
        skip_3 = self.skip_attention_3(gen)
        gen = self.down_3(gen)
        gen = self.up_3(gen)
        gen = torch.add(gen,skip_3)
        gen = self.up_2(gen)
        gen = torch.add(gen,skip_2)
        gen = self.up_1(gen)
        gen = torch.add(gen,skip_1)
        ECG_reconstruct = self.final(gen)
        return ECG_reconstruct


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv1d,nn.ConvTranspose1d,nn.BatchNorm1d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

def test():
    N,in_channels,HW= 8,5000,1
    z_dim = 1000
    x = torch.rand((N,in_channels,HW))
    disc = Discriminator(in_channels,features= 8)
    assert disc(x).shape == (N,1)
    print("Discriminator OK")

    ##Generator :
    gen = Generator(z_dim,in_channels,8)
    z = torch.rand((N,z_dim,1))
    assert gen(z).shape== (N,in_channels,1)
    print("Generator OK")


In [None]:
##The dataset 

class DatasetECGReconstruction(Dataset):
    def __init__(self,signals,leads,pathology,quality_scores):
        self.ECGs = signals
        self.leads = leads
        self.path = pathology
        self.SQA = quality_scores

    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, index):
        Signal = self.ECGs[index]
        path_patient = self.path[index]
        score_leads = self.SQA[index,:]
        leads = self.leads
        return (Signal,path_patient,score_leads,leads)

In [None]:
d_ECG = DatasetECGReconstruction(ECG_signals,Leads_index,Pathologies,SQA_score)

In [None]:
##Training 


##Hyperparameter
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") : Useless. No GPU available for the moment
Learning_rate = 2e-4
batch_size = 32
epochs = 10
input_channels = 5000 #Signal length
HW = 1
number_leads = 12
number_pathologies = 39
z_dim = 5000
feat_d = 32
feat_g = 32

##Generator and dsicriminator initialized : 
DL_ECG = DataLoader(d_ECG,batch_size=batch_size,shuffle = True)
gen = Generator(z_dim,input_channels,feat_g)
disc = Discriminator(input_channels,feat_d)


##Loss function used by the article : 

def LSGAN_disc_real(d):
    return torch.mean((d-1)**2)

def LSGAN_disc_fake(d):
    return torch.mean((d+1)**2)

def cross_cat_entrop(yi,yipred):
    return -torch.sum(yi*torch.log(yipred),dim = 0)

opt_gen = optim.Adam(gen.parameters(),lr = Learning_rate,betas = (0.5,0.999))
opt_disc = optim.Adam(disc.parameters(),lr = Learning_rate,betas = (0.5,0.999))

fixed_noise = torch.rand(8,z_dim,1)
fixed_label = torch.randint(11,(8,))
fixed_patho = torch.randint(110,(8,))

folder_real = "/workspaces/maitrise/results/real_sig"
folder_fake = "/workspaces/maitrise/results/fake_sig"
step = 0
if (not os.path.exists(folder_real)):
    os.makedirs(folder_real)
if (not os.path.exists(folder_fake)):
    os.makedirs(folder_fake)


write_real = SummaryWriter(os.path.join(folder_real,os.path.join("logs","real")))
write_fake = SummaryWriter(os.path.join(folder_fake,os.path.join("logs","fake")))
L2Loss = nn.MSELoss()
gen.train()
disc.train()

In [None]:
##Training code

def batch_adaptator(dataload_batch):
    signals = dataload_batch[0]
    signals_quality = dataload_batch[2]
    pathologies = dataload_batch[1].type(torch.int32)
    path_lead = torch.empty((signals.size(0)))
    real_utilized = torch.empty((signals.size(0),signals.size(1),1))
    reference = torch.empty((signals.size(0),signals.size(1),1))
    values,index = torch.min(signals_quality,dim = 1)
    for i in range(values.size(0)):
        if values[i] <0.5:
            real_utilized[i] = signals[i,:,index[i].item()].view(signals.size(1),1).detach().clone()
        else : 
            real_utilized[i,:] = torch.rand(signals.size(1),1)
        reference[i] = signals[i,:,index[i].item()].view(signals.size(1),1).detach().clone()
        path_lead[i] = torch.tensor(pathologies[i,0].item())
    return real_utilized,reference,index,pathologies,path_lead.type(torch.int32)

def Recreate_original_batch(dataload_batch,fake_data,indexes):
    signals = dataload_batch[0]
    fake_signals = torch.empty((signals.size(0),signals.size(1),signals.size(2)))
    for i in range(signals.size(0)):
        fake_signals[i] = signals[i]
        fake_signals[i,:,indexes[i].item()] = fake_data[i].view(-1)
    return fake_signals


gen_perf_sig = []
for epoch in range(epochs+1):
    for batch_idx,(real) in enumerate(DL_ECG):
        real_used,reference_sig,lead_index,pathologies,gen_path = batch_adaptator(real)
        fake_used = gen(real_used,lead_index,gen_path)
        fake = Recreate_original_batch(real,fake_used,lead_index)
        ##Train discriminator

        disc_real,disc_patho_pred = disc(real[0],real[1].type(torch.int32))
        disc_real,disc_pred_label = disc_real.reshape(-1),torch.argmax(disc_patho_pred,dim=1).type(torch.int32)
        loss_disc_real = LSGAN_disc_real(disc_real)
        disc_fake,disc_fake_pred = disc(fake,real[1].type(torch.int32))
        disc_fake,disc_fake_label = disc_fake.reshape(-1),torch.argmax(disc_fake_pred,dim=1).type(torch.int32)
        loss_disc_fake = LSGAN_disc_fake(disc_fake)
        #categorical_loss = cross_cat_entrop(pathologies[:,0],disc_pred_label)
        loss_disc = (loss_disc_real+loss_disc_fake)

        #Train Generator : 
        loss_rec = L2Loss(fake_used,reference_sig)
        tot_loss = (loss_disc+10*loss_rec)
        disc.zero_grad()
        gen.zero_grad()
        tot_loss.backward(retain_graph=True)
        opt_gen.step()
        opt_disc.step()


        ###Categorical loss entropy Discriminator : 

        if batch_idx%1 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{batch_size} \ Total_loss : {tot_loss}, Loss D : {loss_disc:.4f}, loss Reconstruct: {loss_rec:.4f}")
    gen_perf_sig.append(gen(fixed_noise,fixed_label,fixed_patho))

print(gen_perf_sig[0].size())
for index,i in enumerate(gen_perf_sig):
    plt.figure()
    fig,ax = plt.subplots(
        nrows = 2,ncols = int(i.size(0)/2),figsize = (10,15)
    )
    coordinates = [(x,y) for x in range(2) for y in range(4)]
    counter = 0
    for j in coordinates:
        ax[j[0],j[1]].plot(np.linspace(0,i.size(1),i.size(1)),i[counter,:].detach().numpy())
        ax[j[0],j[1]].set_xlabel("Time step")
        ax[j[0],j[1]].set_ylabel("amplitude")
        ax[j[0],j[1]].grid()
        ax[j[0],j[1]].set_title(f"Generated lead {lead_names[fixed_label[index]]}")
        fig.suptitle(f"All generated signals at epoch {index+1} with pathology {labels[fixed_patho[index]]}")
        counter +=1
    plt.show()


        