# Cerebral Cortex Synthetic ECG Generation from the WESAD data set

> Run the following 2 cells ONLY if in Colab else skip them. They will install miniconda on Colab. Before running, first activate GPU by: Edit > Notebook settings > Hardware accelerator > GPU > Save

In [None]:
!pwd
!which python
!python --version
#Check if GPU is detected
!nvidia-smi

In [None]:
import importlib, sys, os
sys.path.insert(0, os.path.abspath('..'))
IN_COLAB = 'google.colab' in sys.modules
# If in Colab, install conda/mamba using condacolab python package and 
# wait until kernel restarts after the installation
if IN_COLAB:
    !pip install -q condacolab
    import condacolab
    condacolab.install_miniconda()

> Start running from following cell after kernel restarts OR when running locally on linux without dependencies installed. Don't run the cells above this.

## Check notebook dependencies

In [None]:
import importlib, sys, os
sys.path.insert(0, os.path.abspath('..'))
IN_COLAB = 'google.colab' in sys.modules
HAVE_CARDIOGEN = importlib.util.find_spec("CardioGen") is not None

if(not HAVE_CARDIOGEN):
    if IN_COLAB: 
        print("\nGetting CardioGen")
        !git clone https://github.com/SENSE-Lab-OSU/cardio_gen_model.git
        !conda env update -n base -f ./cardio_gen_model/conda_requirements_linux.yml
        !pip install ./cardio_gen_model
    else:
        raise SystemExit("Please install CardioGen from https://github.com/SENSE-Lab-OSU/cardio_gen_model.git")
else:
    print("CardioGen found")

## Import library functions

> Start running from following cell when running locally on linux with all dependencies installed. Don't run the cells above this.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import neurokit2 as nk
import glob
import pickle
import numpy as np
import sys
import os
import matplotlib.pyplot as plt

In [2]:
#Download the WESAD data if needed
import shutil
import getpass
from zipfile import ZipFile
import requests

def download_file(url,local_path):
    with requests.get(url, stream=True) as r:
        #r.raw.decode_content = True
        with open(local_path, 'wb') as f:
            shutil.copyfileobj(r.raw, f)
    return local_path

def unzip_file(zip_path,dir_path):
    try:
        with ZipFile(zip_path,'r') as zf:
            zf.extractall(path=dir_path)
    except RuntimeError:
        print('AES Encrypted zip file. Need pyzipper to continue.\n')
        shutil.rmtree(dir_path, ignore_errors=True) #delete old dir
        import pyzipper
        with pyzipper.AESZipFile(zip_path,'r',encryption=pyzipper.WZ_AES) as zf:
            pwd=getpass.getpass(prompt='Archive is password protected. Please enter the password to continue \n')
            zf.extractall(path=dir_path,pwd=pwd.encode())  

def get_file(zip_file_url,zip_path,dir_path):
    print('\nDownloading {} File...\n'.format(zip_path))
    download_file(zip_file_url,zip_path)
    print('Extracting Files from {}...\n'.format(zip_path))
    unzip_file(zip_path,dir_path)
    print('Deleting temp Files...\n')
    os.remove(zip_path)
    print('All Done!\n')
    return


if IN_COLAB:
    proj_path='./cardio_gen_model'
else:
    proj_path='.' 

data_dir=f'{proj_path}/data/pre-training'

wesad_data_dir=data_dir+'/WESAD/'
wesad_data_url='https://uni-siegen.sciebo.de/s/HGdUkoNlW1Ub0Gx/download'

if (not os.path.exists(wesad_data_dir)):
    get_file(wesad_data_url,data_dir+'/WESAD.zip',data_dir)

In [5]:
#import useful library functions
import os
import sys
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pickle
from math import gcd
import neurokit2 as nk
import seaborn as sns
import copy
import time
sns.set()
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
tf.get_logger().setLevel('ERROR')
tf.keras.backend.set_floatx('float32')

import sys
from CardioGen.lib.simulator_for_CC import Simulator
from CardioGen.HR2Rpeaks import HR2Rpeaks_Simulator
from CardioGen.Rpeaks2Sig import Rpeaks2Sig_Simulator
from CardioGen.lib.utils import filtr_HR, get_continous_wins
from CardioGen.lib.data import load_data_wesad as load_data

#Define global constants
model_path=proj_path+'/data/post-training/'
data_path=proj_path+'/data/pre-training/WESAD/'

n_classes=load_data.n_classes
n_stresses=load_data.n_stresses
all_class_ids=copy.deepcopy(load_data.class_ids)
win_len_s=load_data.win_len_s
step_s=load_data.step_s
bsize=load_data.test_bsize
Fs_ppg=load_data.Fs_ppg_new
Fs_ecg=load_data.Fs_ecg_new

ver=12 #version of the model_weights to use. Refer to README for details.
Dsplit_filename = (f'{proj_path}/data/pre-training/'
                   f'WESAD_musig_Dsplit_w{win_len_s}s{step_s}b{bsize}.pickle')
if os.path.isfile(Dsplit_filename):
    with open (Dsplit_filename, 'rb') as fp:
        musig_dict,Dsplit_mask_dict = pickle.load(fp)
else:
    assert False, ('Could not find existing Dsplit_mask_dict. '
                   'Run get_train_data in R2S mode first.')

In [6]:
class avgHRV2PPG_Augmentor(Simulator):
    '''
    find peak_train. from there on, more or less ppg ecg
    '''
    def __init__(self,P_ID_out,path='../data/post-training/',
                 latent_size_HRV=5,latent_size_Morph=2,Fs_tacho=5,
                 Fs_out=None,win_len_s=8,step_s=2,bsize=13,
                 dict_musig={}):
        
        super().__init__()
        self.Fs_tacho=Fs_tacho
        self.Fs_out=Fs_out
        self.P_ID_out=P_ID_out
        self.path=path
        self.latent_size_HRV=latent_size_HRV
        self.latent_size_Morph=latent_size_Morph
        self.win_len_s=win_len_s
        self.step_s=step_s
        self.bsize=bsize
        self.dict_musig = dict_musig
        up_factor=self.Fs_out/self.Fs_tacho
        assert up_factor%1==0, 'Fs_out must be a multiple of Fs_tacho'
        self.up_factor=int(up_factor)
        #self.dict_musig=self.dict_musig[self.P_ID_out]
        
        P_ID_HRV='WESAD'
        P_ID_Morph=P_ID_out

        #Create a universal HR2Rpeaks_Simulator object 
        self.sim_HR2pks=HR2Rpeaks_Simulator(
                    RNN_win_len_s=win_len_s+(bsize-1)*step_s,
                    step_size_s=step_s,P_ID=P_ID_HRV,path=path,Fs_HR=Fs_tacho,
                    Fs_tacho=Fs_tacho,latent_size=latent_size_HRV)
        
        #Create a subject-wise Rpeaks2Sig_Simulator object 
        self.sim_pks2ppg=Rpeaks2Sig_Simulator(Fs_in=Fs_out,Fs_out=Fs_out,
                    P_ID=P_ID_Morph,path=path,sig_id='ppg',
                    latent_size=latent_size_Morph,logging=False,batch_size=32,
                    RNN_win_len=win_len_s,win_step_size=step_s)
        
    def __call__(self,cond_HRV):
        
        #Get synthetic R-peaks and tachogram from HR2Rpeaks_Simulator object
        arr_pk_synth,cond_HRV,arr_tacho_synth=self.sim_HR2pks(cond_HRV,
                    Fs_out=self.Fs_out,step_size_s=win_len_s+(bsize-1)*step_s)
        
        if len(arr_pk_synth)==0:
            return None,None,None,None
        cond_Morph=np.kron(cond_HRV[:,1:], np.ones((self.up_factor,1), 
                            dtype=cond_HRV[:,1:].dtype))
        cond_ppg=np.concatenate([arr_pk_synth.reshape(-1,1),cond_Morph],axis=-1)
        
        #Get synthetic PPG from Rpeaks2Sig_Simulator object
        ppg_synth,cond_ppg=self.sim_pks2ppg(cond_ppg,step_size_s=win_len_s)
        
        return cond_HRV,arr_tacho_synth,arr_pk_synth,ppg_synth

In [7]:
# Get all [avgHRV;Y_st;Y_Wid] windows and select train+val ones of SOI (ppg/ecg)
def get_synth_data(avghrv2ppg_aug_dict,class_name,seq_format_function,
                   Dsplit_mask_dict):
    #class_name,Fs_out='S5',Fs_ppg
    sample_key=list(avghrv2ppg_aug_dict.keys())[0]
    Fs_out=avghrv2ppg_aug_dict[sample_key].Fs_out
    Fs_tacho=avghrv2ppg_aug_dict[sample_key].Fs_tacho
    
    #for class_name in list(all_class_ids.keys()):
    load_data.class_ids={class_name:all_class_ids[class_name]}
    
    #Load all data
    list_cond_HRV,list_HRV,Dsplit_mask_dict=(load_data.get_train_data(
                data_path,mode='HR2R',win_len_s=win_len_s,step_s=step_s,
                Fs_tacho=Fs_tacho,Dsplit_mask_dict=Dsplit_mask_dict))
    
    #ppg_in_data,ppg_out_data=input_dict['ppg'][0],output_dict['ppg'][0]
    
    Dsplit_mask=Dsplit_mask_dict['hrv'][class_name] #must use hrv mask
    train_mask=Dsplit_mask[0].astype(int)
    #val_train_mask= np.sum(Dsplit_mask[0:2],axis=0).astype(int)
    sel_mask=train_mask*1
    
    start_idxs,end_idxs=get_continous_wins(sel_mask)
    #start_idxs,n_stresses,n_classes=start_idxs[:2],3,2#TODO:For Debugging only
    print(f'Generating synthetic data using subject {class_name} with '
          f'{len(start_idxs)} sequences')
    
    in_data_synth=[[] for j in range((n_stresses-1)*n_classes)]
    out_data_synth=[[] for j in range((n_stresses-1)*n_classes)]

    # Iterate over each set and generate data
    for i in range(len(start_idxs)):
    #for i in range(2):
        
        # Defragment windows into continous signal segments.
        in_seq_wins=list_cond_HRV[0][start_idxs[i]:end_idxs[i]]
        in_seq=(load_data.sliding_window_defragmentation([in_seq_wins],
                    ((bsize-1)*step_s+win_len_s)*Fs_tacho,
                    step_s*Fs_tacho))
        
        
        # Pick a signal and divide [avgHRV],[Y_st],[Y_Wid]. Keeping avgHRV fixed and 
        # cycle through 4 (out of 5) stress conditions and all 15 classes. So for 
        # every signal, we get 15*4=60 signals + 1 original signal. Hence, net 
        # augmentation factor of 61.
        cond_HRV_init=np.zeros(in_seq.shape)
        cond_HRV_init[:,0]=in_seq[:,0]*1
        start_time=time.time()
        #avghrv2ppg_aug=avghrv2ppg_aug_dict[sample_key]
        
        for s in [0,3]:
        #for s in range(n_stresses-1):
            for c in range(n_classes)[:]:
                print(f'Stress={s+1}, Class={c}')
                j=s*n_classes+c #counter
                cond_HRV=cond_HRV_init*1
                cond_HRV[:,1+s+1]=1 #extra +1 for skipping stress=0 channel
                cond_HRV[:,1+n_stresses+c]=1
                
                #Pick subject-specific model
                avghrv2ppg_aug=avghrv2ppg_aug_dict[list(all_class_ids.keys())
                                                    [c]]
                #Generate Synthetic data
                cond_HRV,arr_tacho_synth,arr_pk_synth,ppg_synth=avghrv2ppg_aug(
                                                                cond_HRV)
                ppg_synth*=musig_dict[list(all_class_ids.keys())[c]]['ppg']['sigma'] #rescale
                ppg_synth+=musig_dict[list(all_class_ids.keys())[c]]['ppg']['mu'] #add back mean
                
                if cond_HRV is None:
                    #in_wins,out_wins=None,None
                    continue
                else:
                    in_wins,out_wins=seq_format_function(avghrv2ppg_aug,cond_HRV,
                                    arr_tacho_synth,arr_pk_synth,ppg_synth)
                    in_data_synth[j].append(in_wins)
                    out_data_synth[j].append(out_wins)
                    
        print(f'Time taken for sequence {i}= {time.time()-start_time}')
    in_data_list=[np.concatenate(arr_list,axis=0) 
                  for arr_list in in_data_synth if len(arr_list)>0]
    out_data_list=[np.concatenate(arr_list,axis=0) 
                   for arr_list in out_data_synth if len(arr_list)>0]
    print('\n=======================\n',len(out_data_list))
    in_data=np.concatenate(in_data_list,axis=0).astype(np.float32)
    out_data=np.concatenate(out_data_list,axis=0).astype(np.float32)
    
    # samp_idx=10
    # plt.figure();plt.plot(in_data[samp_idx,:,:])
    # plt.plot(out_data[samp_idx,:,:])
    return in_data,out_data

In [8]:
def seq_format_function_TAu(avghrv2ppg_aug,cond_HRV,arr_tacho_synth,
                            arr_pk_synth,ppg_synth):
    bsize,bstep=5,2#1 #TODO: we can increase bstep here
    Fs_out=avghrv2ppg_aug.Fs_out
    seq_in=ppg_synth
    #seq_out=cond_HRV[:,0].reshape(-1,1)
    seq_out=np.stack([cond_HRV[:,0],arr_tacho_synth],axis=1)
    
    #upsample seq_out to Fs_out
    #up_factor=avghrv2ppg_aug.Fs_out/avghrv2ppg_aug.Fs_tacho
    #assert up_factor%1==0, f'up_factor should have been an integer. but is {up_factor}'
    
    seq_out=load_data.resample(seq_out,avghrv2ppg_aug.Fs_tacho,
                avghrv2ppg_aug.up_factor,1,show_plots=False)
    #print(seq_out.shape)
    # Apt block creation as per TAu but using Dsplit_mask['hrv']
    # Fragment all signals back to desired fragmentation. Could reuse train & val
    # masks for Dsplit (although potential issue with GRU memory propagation is 
    # val data may now have seen more train data in a sense)
    in_wins,out_wins=load_data.sliding_window_fragmentation([seq_in,seq_out],
                    ((bsize-1)*step_s+win_len_s)*Fs_out,
                    bstep*step_s*Fs_out)
    return in_wins,out_wins

In [9]:
seq_format_function=seq_format_function_TAu
save_name='WESAD_synth_TAu/ppg'
show_plots=False
suffix='s14_c'

#P_ID='WESAD'
latent_size_HRV=4
latent_size_Morph=2
Fs_tacho=5


    
#Load all augmentor models in a single dict
avghrv2ppg_aug_dict={}

def get_aug_dict():
    for clas_name in list(all_class_ids.keys())[:]:
        #Create Simulator Model
        avghrv2ppg_aug=avgHRV2PPG_Augmentor(P_ID_out=clas_name,path=model_path,
                        latent_size_HRV=latent_size_HRV,
                        latent_size_Morph=latent_size_Morph,Fs_tacho=Fs_tacho,
                        Fs_out=Fs_ppg,win_len_s=win_len_s,
                        step_s=step_s,bsize=bsize)
        
        # Put specialized models in a dict
        #avghrv2ppg_aug_list.append(avghrv2ppg_aug)
        avghrv2ppg_aug_dict[clas_name]=avghrv2ppg_aug
        del avghrv2ppg_aug
    print(len(avghrv2ppg_aug_dict))
    return

In [10]:
save_dir=f'{model_path}/{save_name}'
os.makedirs(save_dir,exist_ok=True)
in_data_list,out_data_list=[],[]
#class_name='S7'
#for class_name in ['S15']:#,'S11','S17']:

for class_name in list(all_class_ids.keys())[:]:
    filename = (save_dir+f'/{class_name}_{suffix}.pickle')
    
    if os.path.isfile(filename):
        with open (filename, 'rb') as fp:
            in_data,out_data = pickle.load(fp)
    else:
        if len(avghrv2ppg_aug_dict)==0: get_aug_dict()
        in_data,out_data=get_synth_data(avghrv2ppg_aug_dict,class_name,
                                    seq_format_function,Dsplit_mask_dict)
        # Save data
        with open(filename, 'wb') as handle:
            pickle.dump([in_data,out_data], handle)
            
    print(in_data.shape,out_data.shape)
    in_data_list.append(in_data)
    out_data_list.append(out_data)    
    
in_data=np.concatenate(in_data_list,axis=0).astype(np.float32)
out_data=np.concatenate(out_data_list,axis=0).astype(np.float32)
print(in_data.shape,out_data.shape)

HRV Gen Model Exists. Loading ...
Restored from ./data/post-training/model_weights_v12/WESAD_HRV_model\ckpt-2000
Done!
ppg Morph gen_model exists. Loading ...
Restored from ./data/post-training/model_weights_v12/S2_ppg_Morph_model\ckpt-2400
Done!
HRV Gen Model Exists. Loading ...
Restored from ./data/post-training/model_weights_v12/WESAD_HRV_model\ckpt-2000
Done!
ppg Morph gen_model exists. Loading ...
Restored from ./data/post-training/model_weights_v12/S3_ppg_Morph_model\ckpt-2400
Done!
HRV Gen Model Exists. Loading ...
Restored from ./data/post-training/model_weights_v12/WESAD_HRV_model\ckpt-2000
Done!
ppg Morph gen_model exists. Loading ...
Restored from ./data/post-training/model_weights_v12/S4_ppg_Morph_model\ckpt-2400
Done!
HRV Gen Model Exists. Loading ...
Restored from ./data/post-training/model_weights_v12/WESAD_HRV_model\ckpt-2000
Done!
ppg Morph gen_model exists. Loading ...
Restored from ./data/post-training/model_weights_v12/S5_ppg_Morph_model\ckpt-2400
Done!
HRV Gen Mode

KeyboardInterrupt: 