In [None]:
# importing all the needed modules
import csv
import os
import numpy as np
from collections import Counter
import pandas as pd
from io import StringIO
import sklearn
import re
from sklearn.preprocessing import MinMaxScaler
import xml.etree.ElementTree as ET
import ecg_plot
import time

import random
import pywt
from numpy import savez_compressed
from sklearn.model_selection import train_test_split
from sklearn.cluster import MiniBatchKMeans
from imblearn.under_sampling import ClusterCentroids
from sklearn.utils import shuffle
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import tensorflow as tf
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics import roc_auc_score, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import signal
from matplotlib.patches import Shadow
from numpy import sqrt
import keras.backend as K
import yaml

batch_size = 64
random.seed(10)

In [None]:
def get_config(file_path):
    """
    Reads the config file

    Parameters
    ----------
    file_path : str
        The file in which to search for configuration parameters.

    Returns
    -------
    dict
        Returns the configuration parameters.
    or None if the file is not found.
    """
    with open(file_path, 'r') as stream:
        config_from_file = yaml.safe_load(stream)
    return config_from_file

config = get_config("config.yaml")
dirs = config['data_directories']

# the path to the file with the patients' ids in the ECG dataset and patients' ids in the file with diagnoses
ids_path = config['path_to_ids'] 

# the path to the file with patients' diagnoses
diagnosisICD10_path = config['path_to_diagnoses']

# directories with the xml files, each xml file contains an ECG waveform of a person whose id corresponds to the first part of the file name
ecg_dirs = config['ecg_dirs']

In [None]:
def read_ids(file_path):
    """
    Reads two different types of patients' ids - patient's id in the ECG dataset and patient's id in the file with diagnoses

    Parameters
    ----------
    file_path : str
        The file in which to search for both types of ids.

    Returns
    -------
    dict
        Returns a dictionary where the key is the patient's id in the ECG datasource, 
        the value is the id of a patient in the file with the diagnoses.
    or None if the file is not found.
    """
    ids = {}
    with open (file_path, 'r') as f:
        for row in csv.reader(f,delimiter='\t'):
            ids[row[1]] = row[0]
        return ids

def read_diagnosis(file_path):
    """
    Reads the diagnoses from the file

    Parameters
    ----------
    file_path : str
        The file in which to search for the diagnoses

    Returns
    -------
    dict
        Returns a dictionary where the key is the patient's id in the file with the diagnoses, 
        the value is a set of diagnoses made by a doctor according to ICD-10
    or None if the file is not found.
    """
    diagnoses = {}
    with open (file_path, 'r') as f:
        for row in csv.reader(f, delimiter='\t'):
            diagnoses[row[0]] = row[1:]
        return diagnoses

def get_actual_ids(dirs):
    """
    Reads the ids of the patients who got their ecg recorded

    Parameters
    ----------
    dirs : str
        The list of directories with the xml containing ECG waveforms (one xml file per person). 
        The prefix of the file name up to the '_' symbol represents the patient's id.

    Returns
    -------
    list
        Returns a list of patient's ids for whom the ecgs were recorded
    """
    selected_ids = []
    for dir in dirs:
        for file in os.listdir(dir):
            position = file.find('_')
            if position != -1:
                patient_id = file[:position]
            else:
                print('file name does not contain a patient id:' , file)
                continue
            selected_ids.append(patient_id)
    return selected_ids

def print_dictionary(dictionary, num):
    """
    Prints certain number of pairs in dictionary

    Parameters
    ----------
    dictionary : dict
        Any dictionary
    num : int
        An integer representing the number of pairs that the used wants to print

    Returns
    -------
    None
    """
    i = 0
    for k, v in dictionary.items():
        print(k, v)
        i += 1
        if i == num:
            break

def clean_empty_records(dictionary):
    """
    Removes the keys with the zero-sized lists as values from the dictionary

    Parameters
    ----------
    dictionary : dict
        Any dictionary

    Returns
    -------
    dict
        Returns a dictionary with no zero-sized lists as values
    """
    for k in dictionary.copy().keys():
        if len(dictionary[k]) == 0:
            del dictionary[k]
    return dictionary

def includes_a_disease(diagnoses, startswith_symbol):
    """
    Looks for the matching prefix in any of the diagnoses in the list. 

    Parameters
    ----------
    diagnoses : list
        A list of diagnoses according to the ICD-10
    startswith_symbol: string
        A string (normally the ICD-10 disease code) to match with the all the dianoses

    Returns
    -------
    bool
        Returns True, if at least one of the diseases in the list starts with the startswith_symbol. Returns False otherwise
    """
    for d in diagnoses:
        if d.startswith(startswith_symbol):
            return True
    return False

In [None]:
# reading the ids of patients
all_ids = read_ids(ids_path)

# reading the dianoses of patients
all_diagnoses = read_diagnosis(diagnosisICD10_path)

# checking how the directory of two types of ids looks like
print_dictionary(all_ids, 5)

# checking the list of all the diagnoses for one patient with the id 1
print(all_diagnoses['1'])

In [None]:
# reading the ids of patients whose ECGs were recorded
actual_ids = get_actual_ids(ecg_dirs)

In [None]:
# should skip these files while reading ecg waveforms from the ecg_dirs
skipped_ids = []

# key = file_id, value = diagnoses
diagnoses_ids = {}

# The loop goes over the list of patients' ids whose ecgs were recorded and checks if there are diagnoses matching the ids
for actual_id in actual_ids:
    try:
        # if the ecg id was found in the list of dianoses' ids, then the ecg id becomes a key and the list of diagnoses becomes a value
        id = all_ids[actual_id]
        diagnoses_ids[actual_id] = all_diagnoses[id]
    except KeyError as key_error:
        # if the ecg id was not found in the list of dianoses' ids it means, that there are no dianoses for this patient. This patient should be excluded from the final dataset.
        skipped_ids.append(actual_id)

In [None]:
# a key is a patient's id, value - the list of diagnoses of this patient. In this loop the empty and not starting with 
# the letter I diagnoses are eliminated. All the codes of diseases of the circulatory system (according to ICD-10) start with
# the letter I. Since this analysis is focused on detecting AMI which is a heart condition, all the other codes are removed.
for id, diagnoses in diagnoses_ids.items():
    cleared_diagnoses = []
    for d in diagnoses:
        if len(d) != 0 and d[0] == 'I':
            cleared_diagnoses.append(d)
    diagnoses_ids[id] = cleared_diagnoses

# checking how the cleared dictionary only with the ICD-10 codes belonging to the diseases of the circulatory system looks like
print_dictionary(diagnoses_ids, 10)

In [None]:
# remove the entries with empty lists of diagnoses
diagnoses_ids = clean_empty_records(diagnoses_ids)

# checking the resulting directory
print_dictionary(diagnoses_ids, 30)

In [None]:
def get_frequency_map_of_unique_diagnoses(dictionary):
    """
    Prints the first 30 pairs (ICD-10 code, N times found) based on how often this ICD-10 code appears in the all diagnoses

    Parameters
    ----------
    dictionary : dict
        A dictionary where the values are the lists of diagnoses according to the ICD-10

    Returns
    -------
    None
    """
    diagnosis_codes = []
    for v in dictionary.values():
        for diagnosis in v:
            # typical diagnosis looks like 'I10 Essential (primary) hypertension' where the first whitespace denotes 
            # the end of the ICD-10 code
            position = diagnosis.find(' ')
            # Some diagnoses may contain only the ICD-10 code, for example 'I48'
            if position == -1:
                position = len(diagnosis)
            diagnosis = diagnosis[:position]
            diagnosis_codes.append(diagnosis)
    for pair in Counter(diagnosis_codes).most_common(200)[:30]:
        print(pair)

# print the frequncies of diagnoses in the dataset
get_frequency_map_of_unique_diagnoses(diagnoses_ids)

In [None]:
# According to the ICD-10 all the diseases can be devided into 10 groups:
# I00-I02 Acute rheumatic fever
# I05-I09 Chronic rheumatic heart diseases
# I10-I15 Hypertensive diseases
# I20-I25 Ischaemic heart diseases
# I26-I28 Pulmonary heart disease and diseases of pulmonary circulation
# I30-I52 Other forms of heart disease
# I60-I69 Cerebrovascular diseases
# I70-I79 Diseases of arteries, arterioles and capillaries
# I80-I89 Diseases of veins, lymphatic vessels and lymph nodes, not elsewhere classified
# I95-I99 Other and unspecified disorders of the circulatory system

# The code in this block aims to check how many diagnoses belong to each of these groups

diag_codes_groups = {}
diag_codes_groups['I00'] = 0
diag_codes_groups['I05'] = 0
diag_codes_groups['I10'] = 0
diag_codes_groups['I20'] = 0
diag_codes_groups['I26'] = 0
diag_codes_groups['I30'] = 0
diag_codes_groups['I60'] = 0
diag_codes_groups['I70'] = 0
diag_codes_groups['I80'] = 0
diag_codes_groups['I95'] = 0

def starts_with(diagnosis, codes):
    """
    Checks if the diagnosis starts with any of the ICD-10 codes from the provided list of ICD-10 codes

    Parameters
    ----------
    diagnosis : string
        A diagnosis according to the ICD-10
    codes: list
        A list of the ICD-10 codes

    Returns
    -------
    bool
        Returns True, if diagnosis starts with one of the ICD-10 codes. Returns False otherwise
    """
    for code in codes:
        if diagnosis.startswith(code):
            return True
    return False

def is_within_I_range(diagnosis, int1, int2):
    """
    Checks if the diagnosis belongs to a certain disease group according to ICD-10. 
    For example, If int1=0 and int2=2, the functoin checks if the diagnosis belongs to the Acute rheumatic fever (I00-I02) 
    group of diseases.

    Parameters
    ----------
    diagnosis : string
        A diagnosis according to the ICD-10
    int1: integer
        A number after the letter I, together they make up an ICD-10 code. This is the start of the range for a certain disease group.
    int2: integer
        A number after the letter I, together they make up an ICD-10 code. This is the end of the range for a certain disease group.

    Returns
    -------
    bool
        Returns True, if diagnosis belongs to a certain disease group according to ICD-10. Returns False otherwise
    """
    codes = []
    for i in range(int1, int2+1):
        num_to_str = str(i)
        if len(num_to_str) == 1:
            code = 'I' + '0' + num_to_str
        else:
            code = 'I' + num_to_str
        codes.append(code)
    return starts_with(diagnosis, codes)
    

def devide_into_diagnosis_groups(dictionary):
    """
    Puts the diagnoses into the disease groups according to the ICD-10 and prints how many diseases in the dataset belong to each group.

    Parameters
    ----------
    dictionary : dict
        A dictionary where the values are the lists of diagnoses

    Returns
    -------
    None
    """
    for v in dictionary.values():
        for diagnosis in v:
            if is_within_I_range(diagnosis, 0, 2):
                diag_codes_groups['I00'] += 1
            elif is_within_I_range(diagnosis, 5, 9):
                diag_codes_groups['I05'] += 1
            elif is_within_I_range(diagnosis, 10, 15):
                diag_codes_groups['I10'] += 1
            elif is_within_I_range(diagnosis, 20, 25):
                diag_codes_groups['I20'] += 1
            elif is_within_I_range(diagnosis, 26, 28):
                diag_codes_groups['I26'] += 1
            elif is_within_I_range(diagnosis, 30, 52):
                diag_codes_groups['I30'] += 1
            elif is_within_I_range(diagnosis, 60, 69):
                diag_codes_groups['I60'] += 1
            elif is_within_I_range(diagnosis, 70, 79):
                diag_codes_groups['I70'] += 1
            elif is_within_I_range(diagnosis, 80, 89):
                diag_codes_groups['I80'] += 1
            elif is_within_I_range(diagnosis, 95, 99):
                diag_codes_groups['I95'] += 1
            else:
                print('wrong code:', diagnosis)
    print(diag_codes_groups)

# prints how many diseases in the dataset belong to each group
devide_into_diagnosis_groups(diagnoses_ids)

In [None]:
# I20-I25 - Ischaemic heart diseases

for id, diagnoses in diagnoses_ids.items():
    cleared_diagnoses = []
    
    # should cut off patients with other additional heart diseases
    if len(diagnoses) == 1:
        if is_within_I_range(diagnoses[0], 20, 25):
            cleared_diagnoses.append(diagnoses[0])
    diagnoses_ids[id] = cleared_diagnoses
    
diagnoses_ids_new = clean_empty_records(diagnoses_ids)
print_dictionary(diagnoses_ids_new, 50)
print(len(diagnoses_ids_new))

In [None]:
# I21 - Acute myocardial infarction

# if any of the diagnoses belongs to the Acute myocardial infarction (I21) group, the entry gets deleted
for id, diagnoses in diagnoses_ids.copy().items():
    if not includes_a_disease(diagnoses, 'I21'):
        del diagnoses_ids[id]

# checks 50 pairs from the resulting dictionary
print_dictionary(diagnoses_ids, 50)

# checks the size of the resulting dictionary
print(len(diagnoses_ids))

In [None]:
def label_the_data_for_AMI():
    """
    Labels the dataset based on the patients' diagnoses. Patients having no cardiac diseases are labeled with 0, patients 
    diagnosed with the AMI are labeled with 1. Others are excluded from the dataset.
    
    Parameters
    ----------
    None
    
    Returns
    -------
    dict:
        A dictionary where the key is a patient's id, and the value is a label (0 or 1)
    list:
        A list of all diagnoses to check how many diagnoses belong to each subgroup of I21 (I21.0, I21.1, I21.2, I21.3, I21.4, I21.9) 
    """
    st = time.time()

    # key = patient's id, value = label
    id_label_dictionary = {}
    
    # reading the ids of patients (two types of ids, ecgs' and diagnoses' ids)
    all_ids = read_ids(ids_path)
    
    # reading the dianoses of patients
    all_diagnoses = read_diagnosis(diagnosisICD10_path)

    # reading the ids of patients whose ECGs were recorded
    actual_ids = get_actual_ids(ecg_dirs)
    
    # key = file_id, value = diagnoses
    diagnoses_ids = {}
    
    skipped_ids = 0
    
    # The loop goes over the list of patients' ids whose ecgs were recorded and checks if there are diagnoses matching the ids
    for actual_id in actual_ids:
        try:
            id = all_ids[actual_id]
            diagnoses_ids[actual_id] = all_diagnoses[id]
        except KeyError as key_error:
            skipped_ids += 1
    print('skipped', skipped_ids, 'patient ids: no ECG record found for them')

    # remove all diagnoses except for the cardiac diseases' diagnoses
    for id, diagnoses in diagnoses_ids.items():
        cleared_diagnoses = []
        for d in diagnoses:
            if len(d) != 0 and d[0] == 'I':
                cleared_diagnoses.append(d)
        diagnoses_ids[id] = cleared_diagnoses

    # collecting statistics
    other_diseases = 0
    no_cardiac_diseases = 0
    I48_patients = 0

    # a list with all diagnoses to check how many diagnoses belong to each subgroup of I21 (I21.0, I21.1, I21.2, I21.3, I21.4, I21.9)
    diag_test = []

    for id, diagnoses in diagnoses_ids.items():
        if includes_a_disease(diagnoses, 'I21'):
            # patient has been diagnosed with AMI, is labeled with 1
            id_label_dictionary[id] = 1
            I48_patients += 1
            diag_test.append(diagnoses)
            
        elif len(diagnoses) == 0:
            # patient has no cardiac diseases at all, is labeled with 0
            id_label_dictionary[id] = 0
            no_cardiac_diseases += 1
        else:
            # other cardiac diseases not including AMI - remove this patient from the dataset
            other_diseases += 1
            
    print(other_diseases, 'patients have cardiac diseases not including AMI - excluded from the dataset')
    print(no_cardiac_diseases, 'patients have no cardiac diseases - included in the dataset with the label 0')
    print(I48_patients, 'patients have been diagnosed with AMI - included in the dataset with the label 1')
    print(len(id_label_dictionary), ' - dataset size')

    elapsed_time = time.time() - st
    print('Reading and cleaning diagnoses time:', elapsed_time/60, 'minutes')

    return id_label_dictionary, diag_test

In [None]:
# getting the labels of the data and the list with all the diagnoses
id_label_dictionary, diag_test = label_the_data_for_AMI()
elems = []

# checks how many diagnoses belong to each subgroup of I21 (I21.0, I21.1, I21.2, I21.3, I21.4, I21.9) to plot the pie chart
for elem in diag_test:
    for diag in elem:
        if diag.startswith('I21'):
            elems.append(diag)
            
# print the pairs (Diagnosis, frequency) to check how many dianoses belong to each I21 subgroup
for pair in Counter(elems).most_common(200)[:6]:
        print(pair)

In [None]:
# Drawing a pie chart to get a visual representation of the AMI subtypes' frequencies 

# AMI subtypes
labels_for_chart = ['I21.0', 'I21.1', 'I21.2', 'I21.3', 'I21.4', 'I21.9']

diagnoses_strings = ['I21.0 Acute transmural myocardial infarction of anterior wall', 'I21.1 Acute transmural myocardial infarction of inferior wall', 'I21.2 Acute transmural myocardial infarction of other sites', 
		'I21.3 Acute transmural myocardial infarction of unspecified site', 'I21.4 Acute subendocardial myocardial infarction', 'I21.9 Acute myocardial infarction, unspecified']

# frequencies
data = [190, 194, 35, 29, 382, 230]
explode = (0.0, 0.1, 0.0, 0.1, 0.1, 0.1)
colors = ( "orange", "cyan", "lightcyan",
		"lightgrey", "peachpuff", "beige")
wp = { 'linewidth' : 1, 'edgecolor' : "lightgreen" }

def func(pct, allvalues):
    '''
    Format of printing the frequencies on the pieces of pie
    '''
	absolute = int(pct / 100.*np.sum(allvalues))
	return "{:.1f}%\n({:d} patients)".format(pct, absolute)

fig, ax = plt.subplots(figsize =(10, 7))
wedges, texts, autotexts = ax.pie(data, 
								autopct = lambda pct: func(pct, data),
								explode = explode, 
								labels = labels_for_chart,
								shadow = False,
								colors = colors,
								startangle = 90,
								wedgeprops = wp,
                                pctdistance = 0.6,
								textprops = dict(color ="black", fontsize = 16))

for w in wedges:
    # create shadow patch
    s = Shadow(w, -0.01, -0.01)
    s.set_gid(w.get_label() + "_shadow")
    s.set_zorder(w.get_zorder() - 0.1)
    ax.add_patch(s)
    
ax.legend(wedges, diagnoses_strings,
		loc ="center left",
		bbox_to_anchor =(1, 0, 0.5, 1),
        fontsize="16")

plt.setp(autotexts, size = 12, weight ="bold")
ax.set_title("Acute myocardial infarction: subtypes", size=18.0)
plt.axis('equal')

# saving the plot to use it in the thesis
plt.savefig('../../images/STEMI_proportions.svg', format="svg", dpi = 800, bbox_inches = 'tight')

# show plot
plt.show()


In [None]:
def get_waves_from_file(filepath):
    """
    Return the 12-lead ECG waveform as a list of size 12 where each element corresponds to the one of the ECG leads
    
    Parameters
    ----------
    filepath: string
        A file name where to read the ECG from
    
    Returns
    -------
    list:
        A list of lists where each element corresponds to the one of the ECG leads or empty list if the file or the XML tag were not found
    """
    # contains the waveform data of the ecg strip
    xpath = '//RestingECGMeasurements/MedianSamples/WaveformData' 
    try:
        df = pd.read_xml(filepath, xpath=xpath)
    except ValueError as value_error:
        return []
    except Exception as e:
        return []
    return get_waveform(df)

def get_waveform(df, column='WaveformData'):
    """
    Return the 12-lead ECG waveform as a list of size 12 where each element corresponds to the one of the ECG leads
    
    Parameters
    ----------
    df: dataframe
        The dataframe with one field containing the 12-lead ECG
    column: string
        Then name of the field in the dataframe containing the 12-lead ECG
    
    Returns
    -------
    list:
        A list of lists where each element corresponds to the one of the ECG leads
    """
    waves = df[column]
    waves_processed = []
    
    for wave in waves:
        wave = re.sub(r"\s+", "", wave)
    
        # converts the string into the list of numbers to represent certain ECG lead
        res = [int(num) for num in wave.split(',')]
        waves_processed.append(list(res))
    return waves_processed

def normalize_all_waveforms(waves):
    """
    Normalizes the waveforms between -1 and 1
    
    Parameters
    ----------
    waves: list
        A list of lists containing the 12-lead ECGs
    
    Returns
    -------
    list:
        A list of lists containing normalized between -1 and 1 12-lead ECGs
    """
    scaler = MinMaxScaler((-1, 1)).fit(waves)
    return [scaler.transform(np.array(wave).reshape(1, -1)) for wave in waves]

def get_gender(filepath):
    """
    Returns the gender from the XML file
    
    Parameters
    ----------
    filepath: string
        An XML file name containing the ECG and the patient's information
    
    Returns
    -------
    string:
        Returns 'MALE', 'FEMALE' or an empty string in case the file or the XML tag were not found
    """
    xpath = '//CardiologyXML/PatientInfo'
    try:
        df = pd.read_xml(filepath, xpath=xpath)
        return df['Gender'][0]
    except ValueError as value_error:
        print('value error', value_error)
    except Exception as e:
        print('exception', e)
    return ''

def is_normal_based_on_machine_prediction(filepath):
    """
    Checks if the ECG is normal based on the Acute cardiac ischemia time-insensitive predictive instrument (ACI-TIPI)
    
    Parameters
    ----------
    filepath: string
        An XML file name containing the ECG and the predictions made by the ACI-TIPI
    
    Returns
    -------
    bool:
       Returns True, if the ECG is normal according to the prediction made by the ACI-TIPI. Otherwise returns False.
    """
    tree = ET.parse(filepath)
    root = tree.getroot()
    for inter in root.findall('Interpretation'):
        obj = inter.find('Diagnosis')
        if obj != None:
            for diag in obj:
                if diag.text == 'Normal ECG':
                    return True
    return False

def get_waveforms(labels_dict):
    """
    Devides the dataset based on the gender into two datasets, extracts the ECG waveforms, 
    removes not normal ECGs based on the ACI_TIPI predictions
    
    Parameters
    ----------
    labels_dict: dict
        A dictionary where the key is a patient's id, and the value is a label (0 or 1)
    
    Returns
    -------
    dict:
       A dictionary containing only the males where the key is the patient's id and the value is a list of 12-lead ECG
    dict:
       A dictionary containing only the females where the key is the patient's id and the value is a list of 12-lead ECG
    dict:
        A dictionary containing all the labels where the key is the patient's is and the value is the label
    """
    X_males, X_females = {}, {}
    st = time.time()
    for dir in ecg_dirs:
        for file in sorted(os.listdir(dir)):
            position = file.find('_')
            if position != -1:
                patient_id = file[:position]
            else:
                print('file name does not contain a patient id:' , file)
                continue

            try:
                label = labels_dict[patient_id]

                # if this patient had a secondary ECG recording, it's better to remove this patient from the dataset
                # otherwise it's unclear based on which ECG the diagnoses were made
                try:
                    del X_males[patient_id]
                    del labels_dict[patient_id]
                    continue
                except KeyError as key_error:
                    pass

                try:
                    del X_females[patient_id]
                    del labels_dict[patient_id]
                    continue
                except KeyError as key_error:
                    pass

                # patient had only one ECG recording
                file_name = os.path.join(dir, file)
                waves = np.array(get_waves_from_file(file_name)).ravel()
                if len(waves) == 0:
                    del labels_dict[patient_id]
                else:
                    # healthy patient
                    if label == 0:
                        if not is_normal_based_on_machine_prediction(file_name):
                            del labels_dict[patient_id]
                            continue
                    gender = get_gender(file_name)
                    if gender == '':
                        del labels_dict[patient_id]
                    else:
                        if gender == 'MALE':
                            X_males[patient_id] = waves
                        elif gender == 'FEMALE':
                            X_females[patient_id] = waves
            except KeyError as key_error:
                pass
                
    elapsed_time = time.time() - st
    print('Reading waveforms time:', elapsed_time/60, 'minutes')
    return X_males, X_females, labels_dict

def standardize_and_filter(ecgs, labels_dict):
    """
    Standardizes the ECGs and applies median filter to them
    
    Parameters
    ----------
    ecgs: dict
        A dictionary where the key is the patient's id and the value is a list of 12-lead ECG
    labels_dict: dict
        A dictionary where the key is a patient's id, and the value is a label (0 or 1)
    
    Returns
    -------
    list:
        A list of standardizes waveforms
    list:
        A list of standardizes and filtered waveforms
    list:
        A list of corresponding labels
    """
    st = time.time()
    waves, labels = [], []
    for id in ecgs.keys():
        labels.append(labels_dict[id])
        waves.append(ecgs[id])

    normalized_waves = np.array(normalize_all_waveforms(waves))
    normalized_waves = normalized_waves.reshape(normalized_waves.shape[0], -1)
    waves = [signal.medfilt(wave, kernel_size=13) for wave in normalized_waves]
    
    elapsed_time = time.time() - st
    print('Standardizing (+ filtering) waveforms time:', elapsed_time/60, 'minutes')
    return normalized_waves, np.array(waves), np.array(labels)

In [None]:
# getting the ECG waveforms for males and females and the updated dictionary of the labels for both genders
X_males, X_females, labels_dict = get_waveforms(id_label_dictionary.copy())
print(len(X_males), len(X_females))

In [None]:
# normalize and apply the median filter to the waves for both genders
normalized_waves_males, waves_males, labels_males = standardize_and_filter(X_males, labels_dict)
normalized_waves_females, waves_females, labels_females = standardize_and_filter(X_females, labels_dict)
print(waves_males[0].reshape(12, 600).shape)

In [None]:
def save_waves(waves, title, filepath='../../images/ecg6', columns = 6):
    """
    Saves and plots the plot of the 12 ECG leads
    
    Parameters
    ----------
    waves: list
        A list of 12 ECG leads
    title: string
        A title of the plot
    filepath: string
        A file name to save the plot in
    columns: integer
        A number of columns on the plot
    
    Returns
    -------
    None
    """
    ecg_plot.plot(waves, sample_rate = 500, title = "", columns=columns)
    plt.title(title)
    plt.xlabel("Time (seconds)")
    plt.ylabel("Signal amplitude (mV)")
    ecg_plot.save_as_png(filepath, dpi = 800)
    ecg_plot.show()

def save_one_wave(wave, filepath='../images/ecg1.png'):
    """
    Saves and plots the plot of one ECG lead
    
    Parameters
    ----------
    wave: list
        One ECG lead
    filepath: string
        A file name to save the plot in
    
    Returns
    -------
    None
    """
    ecg_plot.plot_1(wave, title = 'ECG 1', sample_rate = 500)
    ecg_plot.save_as_png(filepath, dpi = 800)
    ecg_plot.show()

# save the plot of the 12 ECG leads to use in the thesis
save_waves(normalized_waves_males[1].reshape(12, 600)[:6], '6 first ECG leads', filepath='../../images/ecg6_2')

# save the plot of one ECG lead to use in the thesis
save_one_wave(waves[1].reshape(12, 600)[0])

In [None]:
# print and save the filtered ECG leads to use in the thesis
wave_filtered = waves_males[1]
save_waves(wave_filtered.reshape(12, 600)[:6], '6 first ECG leads with the median filter applied', filepath='../../images/ecg6_filtered_2', columns=6)

In [None]:
def save_together_with_different_diagnoses(wave1, wave2, wave3, wave1_no_filter, wave2_no_filter, wave3_no_filter, title, filepath):
    """
    Saves and plots the plot of leads from 3 different ECGs. The first lead belongs to the STEMI-diagnosed patient. Others-
    tp healthy individuals.
    
    Parameters
    ----------
    wave1: list
        A list of number representing one filtered ECG lead of a STEMI-diagnosed patient
    wave2: list
        A list of number representing one filtered ECG lead of a healthy individual
    wave3: list
        A list of number representing one filtered ECG lead of a healthy individual
    wave1_no_filter: list
        A list of number representing one non-filtered ECG lead of a STEMI-diagnosed patient
    wave2_no_filter: list
        A list of number representing one non-filtered ECG lead of a healthy individual
    wave3_no_filter: list
        A list of number representing one non-filtered ECG lead of a healthy individual
    title: string
        A title of the plot
    filepath: string
        A file name to save the plot in
    
    Returns
    -------
    None
    """
    xpoints = np.array(np.linspace(0.0, 1.2, num=600))
    ypoints1 = np.array(wave1)
    ypoints2 = np.array(wave2)
    ypoints3 = np.array(wave3)
    
    ypoints1_no_filter = np.array(wave1_no_filter)
    ypoints2_no_filter = np.array(wave2_no_filter)
    ypoints3_no_filter = np.array(wave3_no_filter)
    
    plt.title(title, size=13.0)
    plt.xlabel("Time (seconds)", size=13.0)
    plt.ylabel("Signal amplitude (mV)", size=13.0)
    plt.grid(True)
    plt.plot(xpoints, ypoints1, 'r-', label='Filtered 1st lead of the STEMI-patient', linewidth=1)
    plt.plot(xpoints, ypoints2, 'g-', label='Filtered 1st lead of a healthy patient1', linewidth=1)
    plt.plot(xpoints, ypoints3, 'b-', label='Filtered 1st lead of a healthy patient2', linewidth=1)
    
    plt.plot(xpoints, ypoints1_no_filter, 'r-', label='1st lead of the STEMI-patient', linewidth=1, alpha=0.5)
    plt.plot(xpoints, ypoints2_no_filter, 'g-', label='1st lead of a healthy patient1', linewidth=1, alpha=0.5)
    plt.plot(xpoints, ypoints3_no_filter, 'b-', label='1st lead of a healthy patient2', linewidth=1, alpha=0.5)
    
    plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left", fontsize="13")
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.savefig(filepath, dpi = 800, bbox_inches = 'tight')
    plt.show()
    
# Plots and saves the plot of the leads from 3 different ECGs. Waveform 0 was diagnosed with STEMI
save_together_with_different_diagnoses(waves_males[0][:600], waves_males[1][:600], waves_males[2][:600], normalized_waves_males[0][:600], normalized_waves_males[1][:600], normalized_waves_males[2][:600],'Comparison of the 1st ECG leads of the STEMI diagnosed and healthy patients', '../../images/ecg_Comparison_diff_diagnoses2')

In [None]:
def reverse_one_hot(predictions):
    """
    Turnes the probabilities into zeroes and ones indicating which class the objects belongs to with the default threshold (0.5)
    
    Parameters
    ----------
    predictions: list
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and class 1 (AMI diagnosed patients)
    
    Returns
    -------
    list:
        A list of predicted labels
    """
    reversed_x = []
    for x in predictions:
        reversed_x.append(np.argmax(np.array(x)))
    return reversed_x

def moved_threshold(predictions, threshold):
    """
    Turnes the probabilities into zeroes and ones indicating which class the objects belongs to with the defined thereshold
    
    Parameters
    ----------
    predictions: list
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and class 1 (AMI diagnosed patients)
    threshold: number
        The threshold for belonging to the class of AMI diagnosed patients
    Returns
    -------
    list:
        A list of predicted labels
    """
    labels = []
    for x in predictions:
        if x[0] < threshold:
            labels.append(1)
        else:
            labels.append(0)
    return labels

def get_model(image_shape, num_classes):
    """
    Defines a simple CNN model 
    
    Parameters
    ----------
    image_shape: tuple
        The shape of the input into the model
    num_classes: number
        The number of classes for classification
    Returns
    -------
    object:
        A compiled model that is ready to be trained
    """
    model = Sequential()
    model.add(Conv2D(16, (3, 3), activation='relu', input_shape=(image_shape[0], image_shape[1], image_shape[2])))
    model.add(Conv2D(32, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=tf.keras.optimizers.Adadelta(), metrics=['accuracy', 'mae', 'mse', tf.keras.metrics.AUC()])
    return model

def get_complex_model(image_shape, num_classes):
    """
    Defines a more complex CNN model 
    
    Parameters
    ----------
    image_shape: tuple
        The shape of the input into the model
    num_classes: number
        The number of classes for classification
    Returns
    -------
    object:
        A compiled model that is ready to be trained
    """
    model = Sequential()
    model.add(Conv2D(8, (3, 3), activation='relu', input_shape=(image_shape[0], image_shape[1], image_shape[2])))    
    model.add(Conv2D(16, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    
    model.add(Conv2D(16, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.4))
    
    model.add(Dense(num_classes, activation='softmax'))
    model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adadelta(), metrics=['accuracy', 'mae', 'mse', tf.keras.metrics.AUC()])
    return model
    
def learn_the_model_experiment(training_set_X, training_set_y, val_X, val_y, testing_set_X, testing_set_y, get_model_func, epochs, image_shape = (24, 300, 1), num_classes = 2):
    """
    Traines the model on the training dataset, validates the model on the validation dataset and predicts 
    the labels for the testing dataset 
    
    Parameters
    ----------
    training_set_X: list
        A list of the ECG waveforms from the training dataset
    training_set_y: list
        The list of labels for the training dataset
    val_X: list
        A list of the ECG waveforms from the validation dataset
    val_y: list
        The list of labels for the validation dataset
    testing_set_X: list
        A list of the ECG waveforms from the testing dataset
    testing_set_y: list
        The list of labels for the testing dataset
    get_model_func: function
        A function that return compiled model that should be trained and validated
    epochs: number
        A number of epochs for training the model
    image_shape: tuple
        The shape of the input into the model
    num_classes: number
        The number of classes for classification
        
    Returns
    -------
    list:
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and 
        class 1 (AMI diagnosed patients) for the testing dataset
    list:
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and 
        class 1 (AMI diagnosed patients) for the training dataset
    object:
        An object that keeps track of the accuracy, loss, and other training metrics, for each epoch, in the memory
    """
    test_labels = tf.keras.utils.to_categorical(testing_set_y, num_classes)
    val_labels = tf.keras.utils.to_categorical(val_y, num_classes)
    train_labels = tf.keras.utils.to_categorical(training_set_y, num_classes)
    
    train_images = training_set_X.reshape(training_set_X.shape[0], image_shape[0], image_shape[1], image_shape[2])
    val_images = val_X.reshape(val_X.shape[0], image_shape[0], image_shape[1], image_shape[2])
    test_images = testing_set_X.reshape(testing_set_X.shape[0], image_shape[0], image_shape[1], image_shape[2])
    
    model = get_model_func(image_shape, num_classes)
    model.summary()
    
    train_data_size = train_images.shape[0]
    test_data_size = test_images.shape[0]
    val_data_size = val_images.shape[0]
    
    print("model will be trained with {}, validated with {} and be tested with {} samples".format(train_data_size, val_data_size, test_data_size))
    print("Fitting model to the training data...")
    history = model.fit(train_images, train_labels, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(val_images, val_labels))
    
    predictions_test = model.predict(test_images, batch_size=batch_size, verbose=1)
    predictions_train = model.predict(train_images, batch_size=batch_size, verbose=1)
    print(model.metrics_names)
    print('Test metrics values')
    print(model.evaluate(test_images, test_labels, batch_size=batch_size, verbose=1))
    print('Train metrics values')
    print(model.evaluate(train_images, train_labels, batch_size=batch_size, verbose=1))
    return predictions_test, predictions_train, history

def learn_and_test(training_set_X, training_set_y, val_X, val_y, testing_set_X, testing_set_y, get_model_func, epochs, image_shape = (24, 300, 1), num_classes=2, save_confision_matrix=False):
    """
    Reshuffles the training dataset, trains, validates the model, makes the predicitons for the testing dataset and
    print the confusion matrix and classification report
    
    Parameters
    ----------
    training_set_X: list
        A list of the ECG waveforms from the training dataset
    training_set_y: list
        The list of labels for the training dataset
    val_X: list
        A list of the ECG waveforms from the validation dataset
    val_y: list
        The list of labels for the validation dataset
    testing_set_X: list
        A list of the ECG waveforms from the testing dataset
    testing_set_y: list
        The list of labels for the testing dataset
    get_model_func: function
        A function that return compiled model that should be trained and validated
    epochs: number
        A number of epochs for training the model
    image_shape: tuple
        The shape of the input into the model
    num_classes: number
        The number of classes for classification
    save_confision_matrix: bool
        Indicates if the confusion matrix should be saved as an image
        
    Returns
    -------
    list:
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and 
        class 1 (AMI diagnosed patients) for the testing dataset
    list:
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and 
        class 1 (AMI diagnosed patients) for the training dataset
    object:
        An object that keeps track of the accuracy, loss, and other training metrics, for each epoch, in the memory
    """
    training_set_X, training_set_y = shuffle(training_set_X, training_set_y)

    st = time.time()
    predictions_full_CC_test3, predictions_full_CC_train3, history = learn_the_model_experiment(training_set_X, training_set_y, val_X, val_y, testing_set_X, testing_set_y, get_model_func, epochs, image_shape=image_shape, num_classes=num_classes)

    elapsed_time = time.time() - st
    print('Training model time (full):', elapsed_time/60, 'minutes')

    print("Evaluation accuracy score (full, test) = ", accuracy_score(testing_set_y, reverse_one_hot(predictions_full_CC_test3)))
    print("Evaluation accuracy score (full, train) = ", accuracy_score(training_set_y, reverse_one_hot(predictions_full_CC_train3)))

    print("Confusion matrix for the testing dataset")
    print(confusion_matrix(testing_set_y, reverse_one_hot(predictions_full_CC_test3)))

    print("Confusion matrix for the training dataset")
    print(confusion_matrix(training_set_y, reverse_one_hot(predictions_full_CC_train3)))
    
    print("Classification report for the testing dataset")
    print(classification_report(testing_set_y, reverse_one_hot(predictions_full_CC_test3)))
    print("Classification report for the training dataset")
    print(classification_report(training_set_y, reverse_one_hot(predictions_full_CC_train3)))
        
    cf_matrix3 = confusion_matrix(testing_set_y, reverse_one_hot(predictions_full_CC_test3))
    if num_classes == 2:
        depict_confusion_matrix(cf_matrix3, 'Testing set (full)', save=save_confision_matrix, filename='/home/umcg-asorova/project/images/conf_mat_test.png')
    else:
        print(cf_matrix3)
    
    cf_matrix4 = confusion_matrix(training_set_y, reverse_one_hot(predictions_full_CC_train3))
    if num_classes == 2:
        depict_confusion_matrix(cf_matrix4, 'Training set (full)', save=save_confision_matrix, filename='/home/umcg-asorova/project/images/conf_mat_train.png')
    else:
        print(cf_matrix4)
    return predictions_full_CC_test3, predictions_full_CC_train3, history

def depict_confusion_matrix(cf_matrix, title, save=False, filename='confusion_matrix.png'):
    """
    Creates a confusion matrix based on the provided data and saves the plot if needed
    
    Parameters
    ----------
    cf_matrix: list
        The confusion matrix
    title: string
        The title of the plot
    save: bool
        Indicates if the plot should be saved as a picture
    filename: string
        The file name to save the plot to if case the 'save' parameter is set to True
        
    Returns
    -------
    None
    """
    group_names = ['True Neg','False Pos','False Neg','True Pos']
    group_counts = ['{0:0.0f}'.format(value) for value in
                    cf_matrix.flatten()]
    group_percentages = ['{0:.2%}'.format(value) for value in
                         cf_matrix.flatten()/np.sum(cf_matrix)]
    labels = [f'{v1}\n{v2}\n{v3}' for v1, v2, v3 in
              zip(group_names,group_counts,group_percentages)]
    labels = np.asarray(labels).reshape(2,2)
    sns.heatmap(cf_matrix, annot=labels, fmt='', cmap='Blues')
    plt.title(title)
    if save:
        plt.savefig(filename, dpi=200)
    else:
        plt.show()
    plt.clf()

In [None]:
# Split the male data into the training set constituting 60%, testing set making up 20%, and the valudation set accounting for the rest
X_train, X_test, y_train, y_test = train_test_split(waves_males, labels_males, train_size=0.6, random_state=0, stratify=labels_males)
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, train_size=0.5, random_state=0, stratify=y_test)

print('Original dataset shape (full):', Counter(labels_males))
print('Resampled dataset shape (train):', Counter(y_train))
print('Resampled dataset shape (train):', Counter(y_val))
print('Resampled dataset shape (test):', Counter(y_test))

In [None]:
# Split the female data into the training set constituting 60%, testing set making up 20%, and the valudation set accounting for the rest
X_train_females, X_test_females, y_train_females, y_test_females = train_test_split(waves_females, labels_females, train_size=0.6, random_state=0, stratify=labels_females)
X_val_females, X_test_females, y_val_females, y_test_females = train_test_split(X_test_females, y_test_females, train_size=0.5, random_state=0, stratify=y_test_females)

print('Original dataset shape (full):', Counter(labels_females))
print('Resampled dataset shape (train):', Counter(y_train_females))
print('Resampled dataset shape (train):', Counter(y_val_females))
print('Resampled dataset shape (test):', Counter(y_test_females))

In [None]:
# creating a bar chart to visualize the proportions of the two classes (healthy individuals and AMI-diagnosed patients)
N = 6
healthy = (1565, 463, 2578, 114, 522, 860)
ami_diagnosed = (463, 463, 114, 114, 154, 38)

ind = np.arange(N)   
width = 0.4
 
fig = plt.subplots(figsize =(13, 7))
p1 = plt.bar(ind, ami_diagnosed, width)
p2 = plt.bar(ind, healthy, width, bottom = ami_diagnosed)

def addlabels(healthy, ami_diagnosed):
    """
    Adding labels to the bars on the plot
    
    Parameters
    ----------
    healthy: number
        Number of healthy individuals
    ami_diagnosed: string
        Number of AMI-diagnosed patients
        
    Returns
    -------
    None
    """
    for i in range(N):
        plt.text(i, healthy[i] + ami_diagnosed[i] + 100, "{:.1f} : 1.0".format(healthy[i]/ami_diagnosed[i]), ha = 'center', fontsize=15)

addlabels(healthy, ami_diagnosed)
 
plt.ylabel('Number of patients', fontsize=13)
plt.title('Datasets: healthy to the AMI-diagnosed patient ratio', fontsize=13)
plt.xticks(ind, ('Training set M', 'Training set M \nafter undersampling', 'Training set F', 'Training set F \nafter undersampling', 'Testing set M', 'Testing set F'), fontsize=13)
plt.yticks(np.arange(0, 3600, 200), fontsize=13)
plt.legend((p2[0], p1[0]), ('No heart diseases diagnosed', 'AMI diagnosed'), fontsize=13)
plt.savefig('../../images/datasets_after_undersampling', dpi = 800, bbox_inches = 'tight')
 
plt.show()

In [None]:
# solving the class imbalance problem with the ClusterCentroids undersampling algorithm for the male training dataset
cc = ClusterCentroids(
    estimator=MiniBatchKMeans(n_init=2), sampling_strategy='not minority'#, random_state=0, 
)
st = time.time()
X_resampled_train, y_resampled_train = cc.fit_resample(X_train, y_train)
elapsed_time = time.time() - st
print('Undersampling time (full):', elapsed_time/60, 'minutes')

print('Original dataset shape (train):', Counter(y_train))
print('Resampled dataset shape (train):', Counter(y_resampled_train))

In [None]:
# solving the class imbalance problem with the ClusterCentroids undersampling algorithm for the female training dataset
cc = ClusterCentroids(
    estimator=MiniBatchKMeans(n_init=2), sampling_strategy='not minority'#, random_state=0, 
)
st = time.time()
X_resampled_train_females, y_resampled_train_females = cc.fit_resample(X_train_females, y_train_females)
elapsed_time = time.time() - st
print('Undersampling time (full):', elapsed_time/60, 'minutes')

print('Original dataset shape (train):', Counter(y_train_females))
print('Resampled dataset shape (train):', Counter(y_resampled_train_females))

In [None]:
def check_with_diff_threshold(y_test, pred_test, threshold):
    """
    Creates the confusion matrix for different classification thresholds
    
    Parameters
    ----------
    y_test: list
        The original labels of the testing dataset
    pred_test: list
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and 
        class 1 (AMI diagnosed patients) predicted by the model
    threshold: number
        The threshold for belonging to the class of AMI diagnosed patients
        
    Returns
    -------
    None
    """
    predictions_thresholded = moved_threshold(pred_test, threshold)
    print("Evaluation accuracy score (full, test) = ", accuracy_score(y_test, predictions_thresholded))
    
    print("Confusion matrix for the testing dataset")
    print(confusion_matrix(y_test, predictions_thresholded))
     
    print("Classification report for the testing dataset")
    print(classification_report(y_test, predictions_thresholded))

In [None]:
# 250 epochs, AMI, males
pred_test, pred_train, history = learn_and_test(X_resampled_train, y_resampled_train, X_val, y_val, X_test, y_test, get_model, 250, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# 250 epochs, AMI, females
pred_test_females, pred_train_females, history_females = learn_and_test(X_resampled_train_females, y_resampled_train_females, X_val_females, y_val_females, X_test_females, y_test_females, get_model, 250, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# checks the accuracy of prediction with the default threshold (males)
check_with_diff_threshold(y_test, pred_test, 0.5)

In [None]:
# checks the accuracy of prediction with threshold set to 0.4 (males)
check_with_diff_threshold(y_test, pred_test, 0.4)

In [None]:
# checks the accuracy of prediction with threshold set to 0.6 (males)
check_with_diff_threshold(y_test, pred_test, 0.6)

In [None]:
# checks the accuracy of prediction with threshold set to 0.6 (females)
check_with_diff_threshold(y_test_females, pred_test_females, 0.6)

In [None]:
def plot_both_loss_and_accuracy_curves(history, history_females, model_type = 'simple'):
    """
    Plots and saves the loss and accuracy curves for both genders
    
    Parameters
    ----------
    history: object
        An object that keeps track of the accuracy, loss, and other training metrics, for each epoch, in the memory for the males
    history_females: list
        An object that keeps track of the accuracy, loss, and other training metrics, for each epoch, in the memory for the females
    model_type: number
        The type of the model - simple or complex - for clear description on the plot
        
    Returns
    -------
    None
    """
    training_loss = history.history['loss']
    validation_loss = history.history['val_loss']
    training_loss_females = history_females.history['loss']
    validation_loss_females = history_females.history['val_loss']
    training_accuracy = history.history['accuracy']
    validation_accuracy = history.history['val_accuracy']
    training_accuracy_females = history_females.history['accuracy']
    validation_accuracy_females = history_females.history['val_accuracy']
    
    # Plot learning curves
    plt.figure(figsize=(12, 6))
    
    # Plot training and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(training_loss, label='Training Loss: Males', color='green')
    plt.plot(validation_loss, label='Validation Loss: Males', color='green', linestyle = '--')

    plt.plot(training_loss_females, label='Training Loss: Females', color='red')
    plt.plot(validation_loss_females, label='Validation Loss: Females', color='red', linestyle = '--')
    
    plt.title('Training and Validation Loss (CWT not applied)', fontsize=13)
    plt.xlabel('Epoch', fontsize=13)
    plt.ylabel('Loss', fontsize=13)
    plt.legend(fontsize=11)
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    
    # Plot training and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(training_accuracy, label='Training Accuracy: Males', color='green')
    plt.plot(validation_accuracy, label='Validation Accuracy: Males', color='green', linestyle = '--')

    plt.plot(training_accuracy_females, label='Training Accuracy: Females', color='red')
    plt.plot(validation_accuracy_females, label='Validation Accuracy: Females', color='red', linestyle = '--')

    plt.title('Training and Validation Accuracy (CWT not applied)', fontsize=13)
    plt.xlabel('Epoch', fontsize=13)
    plt.ylabel('Accuracy', fontsize=13)
    plt.legend(fontsize=11)
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    
    plt.tight_layout()
    plt.savefig('../../images/loss_curves_' + model_type, dpi = 800, bbox_inches = 'tight')
    plt.show()

In [None]:
# plot the learning curves for the model trained on the raw ECGs (simple model)
plot_both_loss_and_accuracy_curves(history, history_females, 'simple')

In [None]:
# plot the learning curves for the model trained on the scalograms (complex model)
plot_both_loss_and_accuracy_curves(history_cwt, history_cwt_females, 'complex')

In [None]:
def plot_two_rocs(y_test, pred_test, y_test_females, pred_test_females, model_type='simple'):
    """
    Plots and saves the ROC curves for both genders
    
    Parameters
    ----------
    y_test: list
        The original labels of the male testing dataset
    pred_test: list
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and 
        class 1 (AMI diagnosed patients) predicted by the model of the male testing dataset
    y_test_females: list
        The original labels of the female testing dataset
    pred_test_females: list
        A list of pairs containing the probabilities belonging to the class 0 (healthy patients) and 
        class 1 (AMI diagnosed patients) predicted by the model of the female testing dataset
    model_type: number
        The type of the model - simple or complex - for clear description on the plot
        
    Returns
    -------
    None
    """
    
    # compute for males
    lr_probs = pred_test[:, 1]
    fpr, tpr, thresholds = roc_curve(y_test, lr_probs)
    roc_auc = auc(fpr, tpr)
    gmeans = sqrt(tpr * (1-fpr))
    ix = np.argmax(gmeans)

    # compute for females
    lr_probs_females = pred_test_females[:, 1]
    fpr_females, tpr_females, thresholds_females = roc_curve(y_test_females, lr_probs_females)
    roc_auc_females = auc(fpr_females, tpr_females)
    gmeans_females = sqrt(tpr_females * (1-fpr_females))
    ix_females = np.argmax(gmeans_females)
    
    plt.figure(figsize=(8, 8))
    
    # Plot ROC curve males
    plt.plot(fpr, tpr, color='green', lw=2, label=f'ROC curve: Males (AUC = {roc_auc:.2f})')
    plt.scatter(fpr[ix], tpr[ix], marker='o', color='green', label=f'Optimal threshold (M): {thresholds[ix]:.2f}, \n(FPR, TPR) = ({fpr[ix]:.2f}, {tpr[ix]:.2f})')
    plt.axvline(x = fpr[ix], color = 'green', linestyle='-.')
    plt.axhline(y = tpr[ix], color = 'green', linestyle='-.')

    # Plot ROC curve females
    plt.plot(fpr_females, tpr_females, color='red', lw=2, label=f'ROC curve: Females (AUC = {roc_auc_females:.2f})')
    plt.scatter(fpr_females[ix_females], tpr_females[ix_females], marker='o', color='red', label=f'Optimal threshold (F): {thresholds_females[ix_females]:.2f}, \n(FPR, TPR) = ({fpr_females[ix_females]:.2f}, {tpr_females[ix_females]:.2f})')
    plt.axvline(x = fpr_females[ix_females], color = 'red', linestyle='-.')
    plt.axhline(y = tpr_females[ix_females], color = 'red', linestyle='-.')

    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Guess')
    plt.xlabel('False Positive Rate', fontsize=13)
    plt.ylabel('True Positive Rate', fontsize=13)
    plt.title('ROC Curve (CWT not applied)', fontsize=13)
    plt.legend(loc='lower right', fontsize=13)
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    
    plt.savefig('../../images/aucs_' + model_type, dpi = 800, bbox_inches = 'tight')
    plt.show()

In [None]:
# Plotting and saving the ROC curves for both genders for the simple model trained on the raw ECGs
plot_two_rocs(y_test, pred_test, y_test_females, pred_test_females, 'simple')

In [None]:
# Plotting and saving the ROC curves for both genders for the complex model trained on the scalograms
plot_two_rocs(y_test, pred_test_males, y_test_females, pred_test_females, 'complex')

In [None]:
# 150 epochs, AMI, males
pred_test, pred_train = learn_and_test(X_resampled_train, y_resampled_train, X_test, y_test, get_model, 150, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# 150 epochs, AMI, females
pred_test, pred_train = learn_and_test(X_resampled_train_females, y_resampled_train_females, X_test_females, y_test_females, get_model, 150, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# 200 epochs, AMI, males
pred_test, pred_train = learn_and_test(X_resampled_train, y_resampled_train, X_test, y_test, get_model, 200, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# 200 epochs, AMI, females
pred_test, pred_train = learn_and_test(X_resampled_train_females, y_resampled_train_females, X_test_females, y_test_females, get_model, 200, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# 500 epochs, AMI, males
pred_test, pred_train = learn_and_test(X_resampled_train, y_resampled_train, X_test, y_test, get_model, 500, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# 500 epochs, AMI, females
pred_test, pred_train = learn_and_test(X_resampled_train_females, y_resampled_train_females, X_test_females, y_test_females, get_model, 500, image_shape = (12, 600, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# the scales for the mother wavelet to perform the CWT
scales = np.arange(1, 200)

# the name of the mother wavelet - Morlet
waveletname = 'morl'

def create_time_frequency_images(X_resampled_train):
     """
    Performs the CWT on the one ECG lead (the first one) and returns the CWT coefficients
    
    Parameters
    ----------
    X_resampled_train: list
        A list of the ECG waveforms from the training dataset
        
    Returns
    -------
    list:
        CWT coefficients of the first ECG lead
    """
    origin_length = len(X_resampled_train)
    X_resampled_train = np.array(X_resampled_train).reshape(len(X_resampled_train), 12, 600) 
    # get only first lead
    X_resampled_train = np.array([x[0] for x in X_resampled_train]).ravel()
    
    X_resampled_train = np.array(X_resampled_train).reshape(origin_length, 600, 1)        
    train_data_cwt = np.ndarray(shape=(len(X_resampled_train), 199, 599, 1), dtype=np.float32)
    
    for ii in range(0, len(X_resampled_train)):
        signal = X_resampled_train[ii, :, 0]
        coeff, freq = pywt.cwt(signal, scales, waveletname, 1)
        train_data_cwt[ii, :, :, 0] = coeff[:,:599]
    return train_data_cwt

# performing the CWT on the training, validation and testing male datasets
train_data_cwt = create_time_frequency_images(X_resampled_train)
test_data_cwt = create_time_frequency_images(X_test)
val_data_cwt = create_time_frequency_images(X_val)


In [None]:
# performing the CWT on the training, validation and testing female datasets
train_data_cwt_females = create_time_frequency_images(X_resampled_train_females)
test_data_cwt_females = create_time_frequency_images(X_test_females)
val_data_cwt_females = create_time_frequency_images(X_val_females)

In [None]:
def get_signal(X_resampled_train, ii=0):
    """
    Extracts the first lead out of the 12-lead ECG recording for a certain patient
    
    Parameters
    ----------
    X_resampled_train: list
        A list of the ECG waveforms from the training dataset
    ii: integer
        The index of the element in the X_resampled_train corresponding to a certain patient
        
    Returns
    -------
    list:
        The first lead of the ECG recording corresponding to a certain patient
    """
    origin_length = len(X_resampled_train)
    X_resampled_train = np.array(X_resampled_train).reshape(len(X_resampled_train), 12, 600) 
        
     # get only first lead
    X_resampled_train = np.array([x[0] for x in X_resampled_train]).ravel()
    X_resampled_train = np.array(X_resampled_train).reshape(origin_length, 600, 1)            
    signal = X_resampled_train[ii, :, 0]
    return np.array(signal)
    
def perform_CWT_on_one_lead(X_resampled_train, ii=0):
    """
    Performs a CWT on a first lead of the ECG recording corresponding to a certain patient
    
    Parameters
    ----------
    X_resampled_train: list
        A list of the ECG waveforms from the training dataset
    ii: integer
        The index of the element in the X_resampled_train corresponding to a certain patient
        
    Returns
    -------
    list:
        The CWT coefficients for the first lead of the ECG recording corresponding to a certain patient
    number:
        The starting time
    number:
        The number of measurements in one lead
    number:
        The time step
    list:
        The scales for the mother wavelet that were used to perform the CWT
    """
    signal = get_signal(X_resampled_train, ii=ii)

    # 600 measurements in one lead
    N = signal.shape[0]
    
    # starting time
    t0=0
    
    # 1200 milliseconds/600 measurements = 2 milliseconds correspond to one measurement -> 0.002 seconds
    dt=0.002
    
    time = np.arange(0, N) * dt + t0
    scales = np.arange(1, 200)
    coefs, freq = pywt.cwt(signal, scales, 'morl', dt)
    return coefs, t0, N, dt, scales    

In [None]:
def plot_pic(coefs, ax, t0, N, dt, scales, max_c, min_c):
    """
    Creates a scalogram of one ECG lead and returns it
    
    Parameters
    ----------
    coefs: list
        The CWT coefficients for the first lead of the ECG lead corresponding to a certain patient
    ax: object
        The axis on which to put the created image
    t0: number
        The starting time
    N: number
        The number of measurements in one lead
    dt: number
        The time step
    scales: list
        The scales for the mother wavelet that were used to perform the CWT
    max_c: number
        The largest value of the colorbar
    min_c: number
        The lowest value of the colorbar
        
    Returns
    -------
    object:
        Returns the scalogram created on one ECG lead corresponding to a certain patient
    """
    im = ax.imshow(
        coefs,
        extent=[t0, (N-1)*dt + t0, scales[0], scales[-1]],
        cmap=plt.cm.seismic,
        aspect="auto",
        vmax=max_c,
        vmin=min_c,
    );
    return im

# plot two scalograms next to each other to see the differences
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
titles = ['The spectogram of the healthy patient\n', 'The spectogram of the STEMI-diagnosed patient\n']

# person with the index 0 is a healthy individual, with the index 605 - AMI-diagnosed patient
patients = [0, 605]
coeffs = []
t0, N, dt = 0, 0, 0

# the min anad max values of the coefficients to use the same gradient range on both plots 
min_c, max_c = 100000000, 0.00000000000000001

# Perform CWT on the healthy individual and on the AMI-diagnosed patient
for i in range(2):
        coefs, t0, N, dt, scales = perform_CWT_on_one_lead(X_resampled_train, ii=patients[i])
        coeffs.append(coefs)
        if abs(coefs).min() < min_c:
            min_c = abs(coefs).min()
        if abs(coefs).max() > max_c:
            max_c = abs(coefs).max()

images = []
# creating picures for both CWT-preprocessed leads
for i in range(2):
        im = plot_pic(coeffs[i], axs[i], t0, N, dt, scales, max_c, min_c)
        images.append(im)
        axs[i].set_title(titles[i], fontsize=18)
        axs[i].set_ylabel('Scales', fontsize=18)
        axs[i].set_xlabel('Time (seconds)', fontsize=18)
        axs[i].tick_params(axis='x', labelsize=14 )
        axs[i].tick_params(axis='y', labelsize=14 )
        axs[i].axvline(x=0.5, color='white', linestyle=':', linewidth=2)
        axs[i].axvline(x=0.8, color='white', linestyle=':', linewidth=2)

cbar = axs[1].figure.colorbar(images[1], ax=axs[1])
cbar.ax.set_ylabel('Coeffitients', rotation=-90, va="bottom", fontsize=18)
cbar.ax.tick_params(labelsize=14 )

cbar = axs[0].figure.colorbar(images[1], ax=axs[0])
cbar.ax.set_ylabel('Coeffitients', rotation=-90, va="bottom", fontsize=18)
cbar.ax.tick_params(labelsize=14 )

# save the picture to use in the thesis
plt.savefig('../../images/spectograms', dpi = 800, bbox_inches = 'tight')
plt.show()

In [None]:
# AMI, males, 50 epochs, clustercentroids, complex_model, input - time frequency images
pred_test_males, pred_train_males, history_cwt = learn_and_test(train_data_cwt, y_resampled_train, val_data_cwt, y_val, test_data_cwt, y_test, get_complex_model, 50, image_shape = (199, 599, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# AMI, females, 50 epochs, clustercentroids, complex_model, input - time frequency images
pred_test_females, pred_train_females, history_cwt_females = learn_and_test(train_data_cwt_females, y_resampled_train_females, val_data_cwt_females, y_val_females, test_data_cwt_females, y_test_females, get_complex_model, 50, image_shape = (199, 599, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# AMI, females, 10 epochs, clustercentroids, complex_model, input - time frequency images
pred_test, pred_train = learn_and_test(train_data_cwt_females, y_resampled_train_females, test_data_cwt_females, y_test_females, get_complex_model, 10, image_shape = (199, 599, 1), num_classes=2, save_confision_matrix=False)

In [None]:
# AMI, females, 7 epochs, clustercentroids, complex_model, input - time frequency images
pred_test, pred_train = learn_and_test(train_data_cwt_females, y_resampled_train_females, test_data_cwt_females, y_test_females, get_complex_model, 7, image_shape = (199, 599, 1), num_classes=2, save_confision_matrix=False)

In [None]:
def get_avg_roc_10splits(get_specific_model_function, X_train, y_train, X_test, y_test, image_shape, batch_size, epochs, num_classes=2):
    """
    Performs the 10-fold cross-validation of the model. Prints the classification matrix, confusion reports and accuracy scores
    for both training and testing datasets
    
    Parameters
    ----------
    get_specific_model_function: function
        A function that return compiled model that should be trained and validated
    X_train: list
        A list of the ECG waveforms from the training dataset
    y_train: list
        The list of labels for the training dataset
    X_test: list
        A list of the ECG waveforms from the testing dataset
    y_test: list
        The list of labels for the testing dataset
    image_shape: tuple
        The shape of the input into the model
    batch_size: number
        The number of batches for training the model
    epochs: number
        The number of epochs for training the model
    num_classes: number
        The number of classes for classification
        
    Returns
    -------
    number:
        The mean value of ROC AUC across 10 iterations
    number:
        The mean time in seconds required to reshuffle the training and testing datasets, train the model and predict the labels for 
        both training and testing datasets
    number:
        The mean value of accuracy across 10 iterations for the training dataset
    number:
        The mean value of accuracy across 10 iterations for the testing dataset
    """
        
    roc_auc_list = []
    times = []
    train_acc = []
    test_acc = []
    for i in range(10):
        # function to get the model is used since in Tensorflow the .fit() method trains the model without discarding any info pertaining to previous trainings.
        # It retrains the model on the new data. For cross validation I need my model to be retrained from scratch on every iteration.
        st = time.time()
        
        model = get_specific_model_function(image_shape = image_shape, num_classes=num_classes)
        X_resampled_train, y_resampled_train = shuffle(X_train, y_train)
        X_resampled_test, y_resampled_test = shuffle(X_test, y_test)
        
        test_labels = tf.keras.utils.to_categorical(y_resampled_test, num_classes)
        train_labels = tf.keras.utils.to_categorical(y_resampled_train, num_classes)
    
        train_images = X_resampled_train.reshape(X_resampled_train.shape[0], image_shape[0], image_shape[1], image_shape[2])
        test_images = X_resampled_test.reshape(X_resampled_test.shape[0], image_shape[0], image_shape[1], image_shape[2])
    
        model.fit(train_images, train_labels, batch_size=batch_size, epochs=epochs, verbose=0)
        predictions_test = model.predict(test_images, batch_size=batch_size, verbose=0)
        predictions_train = model.predict(train_images, batch_size=batch_size, verbose=0)
        roc_auc_list.append(roc_auc_score(test_labels, np.array(reverse_one_hot(predictions_test)).reshape(-1, 1)))
        elapsed_time = time.time() - st
        times.append(elapsed_time)

        print(model.metrics_names)
        print('Test metrics values')
        print(model.evaluate(test_images, test_labels, batch_size=batch_size, verbose=0))
        print('Train metrics values')
        print(model.evaluate(train_images, train_labels, batch_size=batch_size, verbose=0))
        
        acc_score_test = accuracy_score(y_test, reverse_one_hot(predictions_test))
        print("Evaluation accuracy score (full, test) = ", acc_score_test)
        acc_score_train = accuracy_score(y_train, reverse_one_hot(predictions_train))
        print("Evaluation accuracy score (full, train) = ", acc_score_train)
        train_acc.append(acc_score_train)
        test_acc.append(acc_score_test)
    
        print("Confusion matrix for the testing dataset")
        print(confusion_matrix(y_test, reverse_one_hot(predictions_test)))
    
        print("Confusion matrix for the training dataset")
        print(confusion_matrix(y_train, reverse_one_hot(predictions_train)))
        
        print("Classification report for the testing dataset")
        print(classification_report(y_test, reverse_one_hot(predictions_test)))
        print("Classification report for the training dataset")
        print(classification_report(y_train, reverse_one_hot(predictions_train)))
    
    return np.mean(roc_auc_list), np.mean(times)/60, np.mean(train_acc), np.mean(test_acc)

In [None]:
# Performing cross validation of the model on the males with the batches = 100, and epochs with the defined values
epochs = [100, 150, 200, 250, 300]
for e in epochs:
    X_resampled_train, y_resampled_train, X_test, y_test = train_test_split(waves_males, labels_males, train_size=0.75, random_state=0, stratify=labels_males)
    roc_lr, av_time = get_avg_roc_10splits(get_model, X_resampled_train, y_resampled_train, X_test, y_test, (60, 120, 1), 100, e, num_classes=2)
    print('mean roc auc', roc_lr, 'mean time is sec', av_time, 'epochs', e)


In [None]:
# Performing cross validation of the model on the males with the batches = 64, and epochs with the defined values
epochs = [100, 150, 200, 250, 300]
for e in epochs:
    X_resampled_train, y_resampled_train, X_test, y_test = train_test_split(waves_males, labels_males, train_size=0.75, random_state=0, stratify=labels_males)
    roc_lr, av_time = get_avg_roc_10splits(get_model, X_resampled_train, y_resampled_train, X_test, y_test, (60, 120, 1), 64, e, num_classes=2)
    print('mean roc auc', roc_lr, 'mean time is sec', av_time, 'epochs', e)


In [None]:
# Performing cross validation of the model on the females with the batches = 100, and epochs with the defined values
epochs = [100, 150, 200, 250, 300]
for e in epochs:
    X_resampled_train_females, y_resampled_train_females, X_test_females, y_test_female = train_test_split(waves_females, labels_females, train_size=0.75, random_state=0, stratify=labels_females)
    roc_lr, av_time = get_avg_roc_10splits(get_model, X_resampled_train_females, y_resampled_train_females, X_test_females, y_test_females, (60, 120, 1), 100, e, num_classes=2)
    print('mean roc auc', roc_lr, 'mean time is sec', av_time, 'epochs', e)


In [None]:
# Performing cross validation of the model on the females with the batches = 64, and epochs with the defined values
epochs = [100, 150, 200, 250, 300]
for e in epochs:
    X_resampled_train_females, y_resampled_train_females, X_test_females, y_test_female = train_test_split(waves_females, labels_females, train_size=0.75, random_state=0, stratify=labels_females)
    roc_lr, av_time = get_avg_roc_10splits(get_model, X_resampled_train_females, y_resampled_train_females, X_test_females, y_test_females, (60, 120, 1), 64, e, num_classes=2)
    print('mean roc auc', roc_lr, 'mean time is sec', av_time, 'epochs', e)
