# <center> Age estimation using Convolutional Neural Network on 12-lead ECG </center>


<div style="width:100%;text-align: center;"> <img align=middle src="https://raw.githubusercontent.com/Bsingstad/DL-images/main/ECGage.png" alt="AI pregnancy" style="height:500px;margin-top:3rem;"> </div>

----------

[Bjørn-Jostein Singstad](https://www.kaggle.com/bjoernjostein) - February 2022

Other relevant Notebooks: 
* [EDA Shaoxing and Ningbo ECG data](https://www.kaggle.com/bjoernjostein/eda-shaoxing-and-ningbo-ecg-data/)
* [Physionet Challenge 2020](https://www.kaggle.com/bjoernjostein/physionet-challenge-2020)

----------

### AI can be used to detect your physiological/cardiovascular age only by looking at your ECG

In more than 100 years the ECG has been a widely used tool for heart dignostics, starting with **Willem Einthoven's** string galvanometer

<div style="width:100%;text-align: center;"> <img align=middle src="https://upload.wikimedia.org/wikipedia/commons/1/1c/Willem_Einthoven_ECG.jpg" alt="AI pregnancy" style="height:500px;margin-top:3rem;"> </div>


Since the 1960s we have also been able to store these ECGs in a digital format, allowing the ECGs to be automatically interpreted by algorithms. Despite this, the convential rule-based algorithms has shown limitations in terms of interpretting ECGs, but the new era of AI still holds promise. Not only can AI-based ECG interpretation tools reveal various types of diagnoses from the ECG, studies has alos shown that it can detect other biologic factors such as age and sex ([Attia et al 2019](https://www.ahajournals.org/doi/full/10.1161/CIRCEP.119.007284)).

In this notebook we use a [InceptionTime](https://link.springer.com/article/10.1007/s10618-020-00710-y) type of 1-dimentional Convolutional Neural Network (the AlexNet for time series) to train and validate an age predicting classifier based on 8 datasets:
1. Chapman-Shaoxing Database
2. China Physilogical Signal Challenge Database
3. China Physilogical Signal Challenge Database - extra
4. The Georgia 12-lead ECG Challenge Database
5. Ningbo Database
6. PTB Diagnostic Database
7. PTB-XL Database
8. St.Petersburg INCARD Database

These databases consist of a total of **88253** ECGs

In [None]:
import pandas as pd
from scipy import stats
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow_addons as tfa
import tensorflow as tf
import tqdm
from scipy import signal
from tensorflow import keras
from keras.utils import plot_model
from keras.preprocessing.sequence import pad_sequences
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score
from scipy.io import loadmat
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload
%reload_ext autoreload
sns.set_style("dark")

In [None]:
######################################################
print("Collecting labels, ECG filenames and metadata...")
gender, age, labels, ecg_len, ecg_filenames = import_key_data("../input/")
ecg_filenames = np.asarray(ecg_filenames)
age = np.asarray(age)
gender = np.asarray(gender)
ecg_len = np.asarray(ecg_len)
labels = np.asarray(labels)
print(f"Total number of patients found: {len(age)}")

In [None]:
#####################################################
print("remove all ECGs not equal to 10 seconds")
age, gender,ecg_filenames, labels = only_ten_sec(ecg_len, age, gender, ecg_filenames, labels)

In [None]:
ecg_filenames = np.delete(ecg_filenames,np.where(age == "NaN"))
gender = np.delete(gender,np.where(age == "NaN"))
labels = np.delete(labels,np.where(age == "NaN"))
age = np.delete(age,np.where(age == "NaN"))


age = np.delete(age,np.where(gender == "NaN"))
age = np.delete(age,np.where(gender == "Unknown"))
ecg_filenames = np.delete(ecg_filenames,np.where(gender == "NaN"))
ecg_filenames = np.delete(ecg_filenames,np.where(gender == "Unknown"))
labels = np.delete(labels,np.where(gender == "NaN"))
labels = np.delete(labels,np.where(gender == "Unknown"))
gender = np.delete(gender,np.where(gender == "NaN"))
gender = np.delete(gender,np.where(gender == "Unknown"))

In [None]:
age = clean_up_age_data(age)
gender = clean_up_gender_data(gender)

print(f"Total number of patients left after data cleaning: {len(age)}")

In [None]:
age = np.delete(age,np.where(gender == 0))
ecg_filenames = np.delete(ecg_filenames,np.where(gender == 0))
labels = np.delete(labels,np.where(gender == 0))
gender = np.delete(gender,np.where(gender == 0))

In [None]:
print(f"Total number of patients left after data cleaning: {len(age)}")

In [None]:
folds = split_data(age, gender)

In [None]:
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=2, verbose=1, mode='min',
    min_delta=0.0001, cooldown=2, min_lr=0
)

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=3)

In [None]:
samp_freq = 100
time = 10
num_leads = 12
batchsize = 16
epoch = 20
mae_score_total = []

print("Training model...")
for i in range(len(folds)):
    train_ind = folds[i][0]
    test_ind = folds[i][1]

    
    model = attia_network_age(samp_freq,time,num_leads) # velg modell
    #model = build_model((samp_freq*time,num_leads), 1)
    model.fit(x=shuffle_batch_generator_age(batch_size=batchsize, gen_x=generate_X_age(ecg_filenames[train_ind]), gen_y=generate_y_age(age[train_ind]),num_leads=num_leads), epochs=epoch, 
              steps_per_epoch=(len(train_ind)/batchsize), 
              validation_data=shuffle_batch_generator_age(batch_size=batchsize, gen_x=generate_X_age(ecg_filenames[test_ind]), gen_y=generate_y_age(age[test_ind]),num_leads=num_leads), validation_freq=1, validation_steps = (len(test_ind)/batchsize),
              verbose = 1, 
              callbacks=[reduce_lr]
              )
    mse_score = []
    mae_score = []
    f1_score = []
    acc_score = []
    pred_age = []
    true_age = []
    for j in tqdm.tqdm(test_ind):
        data, header_data = load_challenge_data(ecg_filenames[j])
        if int(header_data[0].split(" ")[2]) != samp_freq:
            data_new = np.ones([num_leads,int((int(header_data[0].split(" ")[3])/int(header_data[0].split(" ")[2]))*samp_freq)])
            for k,l in enumerate(data):
                data_new[k] = signal.resample(l, int((int(header_data[0].split(" ")[3])/int(header_data[0].split(" ")[2]))*samp_freq))
            data = data_new
            data = pad_sequences(data, maxlen=samp_freq*10, truncating='post',padding="post")
        data = np.moveaxis(data, 0, -1)
        #y_true = gender[j]
        y_true = age[j]
        true_age.append(y_true)
        y_hat = model.predict(np.expand_dims(data, axis=0))
        pred_age.append(y_hat)
        mse_score.append(mean_squared_error(np.expand_dims(y_true,axis=0),y_hat.ravel()))
        mae_score.append(mean_absolute_error(np.expand_dims(y_true,axis=0),y_hat.ravel()))
        #acc_score.append(accuracy_score(np.expand_dims(y_true,axis=0),y_hat.ravel()))
        #f1_score.append(f1_sccore(np.expand_dims(y_true,axis=0),y_hat.ravel()))
    mse_score = np.asarray(mse_score).mean()
    mae_score = np.asarray(mae_score).mean()
    #acc_score = np.asarray(acc_score).mean()
    #f1_score = np.asarray(f1_score).mean()
    print("MAE score {}".format(mae_score))
    #print("Accuracy score {}".format(acc_score))
    #print("F1 score {}".format(f1_score))
    print("MSE score {}".format(mse_score))
    mae_score_total.append(mae_score)



In [None]:
print("mean MAE score {}".format(np.asarray(mae_score_total).mean()))
print("std MAE score {}".format(np.asarray(mae_score_total).std()))

In [None]:
d = {"True age": np.asarray(true_age).ravel(), 'Predicted age': np.asarray(pred_age).ravel(),'Error': abs(np.asarray(true_age).ravel() - np.asarray(pred_age).ravel())}
df = pd.DataFrame(data = d)

## Analysis:


### Predicted vs actual age - Numbers of true and predicted ages 

In [None]:
plt.figure(figsize=(30,15))
plt.hist(df["Predicted age"], bins=100, label="Predicted age", alpha=0.6)
plt.hist(df["True age"], bins=100, label="True age", alpha=0.6)
plt.legend()
plt.xlim(0,100)
plt.xlabel("Age")
plt.ylabel("Number of patients")
plt.show()

### Comparing the predicted age vs the acual age. The red line shows the correct age, while the green line shows the result of the best linear fit based on the predicted age

In [None]:
slope, intercept, r_value, p_value, std_err = stats.linregress(df["True age"],df["Predicted age"])
reg_line = (np.arange(101)*slope) + intercept
plt.figure(figsize=(30,15))
sns.set(font_scale=2)
sns.scatterplot(data=df, x="True age", y="Predicted age")
sns.set(font_scale=5)
plt.plot(np.arange(101),'red', label="1x + 0")
plt.plot(reg_line,'green', label = "{}x + {}".format(round(slope,2), round(intercept,2)))
plt.legend(fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.show()

### Comparing the predicted age vs the acual age seen as box plot. From this we can see which ages that are predicted most accurate and uncertain by the model. The red line shows the correct age, while the green line shows the result of the best linear fit based on the predicted age

In [None]:
plt.figure(figsize=(30,15))
sns.set(font_scale=2)
sns.boxplot(data=df, x="True age", y="Predicted age")
plt.plot(np.arange(101),'red', label="1x + 0")
plt.plot(reg_line,'green', label = "{}x + {}".format(round(slope,2), round(intercept,2)))
plt.xticks(fontsize=12)
plt.legend()
plt.xlim(0,100)
plt.show()

### Total absoulte error between real and predicted age for each year

In [None]:
plt.figure(figsize=(30,15))
sns.barplot(x="True age", y="Error",data=df.groupby("True age").sum().reset_index())
plt.ylabel("Total absolute error pr age")
plt.xticks(fontsize=12)
plt.show()

### Mean absoulte error between real and predicted age for each year

In [None]:
df_err = df.groupby("True age").sum().reset_index()
df_err["count"] = df.groupby("True age").count()["Error"]
df_err["count"] = df_err["count"].fillna(0)
df_err["Mean Error"] = df_err["Error"]/df_err["count"]

plt.figure(figsize=(30,15))
sns.barplot(x="True age", y="Mean Error",data=df_err)
plt.ylabel("Mean absoulte error pr age")
plt.xticks(fontsize=12)
plt.show()

In [None]:
snomed_ct_codes = np.array(["164889003","164890007","6374002","426627000","733534002","713427006","270492004","713426002","39732003","445118002","164909002","251146004","698252002","426783006","284470004",
 "10370003","365413008","427172004","164947007","111975006","164917005","47665007","59118001","427393009","426177001","427084000","63593006","164934002","59931005","17338001"])


In [None]:
diagnoses = np.array(["atrial fibrillation","atrial flutter","bundle branch block","bradycardia","complete left bundle branch block","complete right bundle branch block","1st degree av block","incomplete right bundle branch block","left axis deviation","left anterior fascicular block",
"left bundle branch block","low qrs voltages","nonspecific intraventricular conduction disorder","sinus rhythm","premature atrial contraction","pacing rhythm","poor R wave Progression","premature ventricular contractions","prolonged pr interval","prolonged qt interval",
"qwave abnormal","right axis deviation","right bundle branch block","sinus arrhythmia","sinus bradycardia","sinus tachycardia","supraventricular premature beats","t wave abnormal","t wave inversion","ventricular premature beats"])


In [None]:
val_ecg_filenames =np.asarray(ecg_filenames)[test_ind]

In [None]:
classes = set()
for ecg_file in tqdm.tqdm(val_ecg_filenames):
    header_file = ecg_file.replace('.mat','.hea')
    header = load_header(header_file)
    classes |= set(get_labels(header))
if all(is_integer(x) for x in classes):
    classes = sorted(classes, key=lambda x: int(x)) # Sort classes numerically if numbers.
else:
    classes = sorted(classes) # Sort classes alphanumerically if not numbers.
num_classes = len(classes)

In [None]:
scored_classes = []
for i in classes:
    for j in snomed_ct_codes:
        if i == '':
            continue
        if i == j:
            scored_classes.append(i)
scored_classes = sorted(scored_classes)

In [None]:
class_names = []
for j in range(len(scored_classes)):
    for i in range(len(snomed_ct_codes)):
        if (str(snomed_ct_codes[i]) == scored_classes[j]):
            class_names.append(diagnoses[i])

In [None]:
num_recordings = len(val_ecg_filenames)
num_classes = len(scored_classes)
val_labels = np.zeros((num_recordings, num_classes), dtype=np.bool) # One-hot encoding of classes

for i in range(len(val_ecg_filenames)):
    current_labels = get_labels(load_header(val_ecg_filenames[i].replace('.mat','.hea')))
    for lab in current_labels:
        if lab in scored_classes:
            j = scored_classes.index(lab)
        val_labels[i, j] = 1
val_labels = val_labels * 1

### Comparing the predicted age vs the actual age for 30 different cardiovascular diseases. The red line shows the correct age, while the green line shows the result of the best linear fit based on the predicted age

In [None]:
fig, ax = plt.subplots(6, 5, sharex=True, sharey=True)
fig.set_figheight(46)
fig.set_figwidth(46)
cnt = 0
cnt_2 = 0
for i in range(30):
    TE = np.asarray(df.iloc[np.where(val_labels[:,i]==1)[0]]["True age"] - df.iloc[np.where(val_labels[:,i]==1)[0]]["Predicted age"]).sum()
    MAE = np.asarray(abs(df.iloc[np.where(val_labels[:,i]==1)[0]]["True age"] - df.iloc[np.where(val_labels[:,i]==1)[0]]["Predicted age"])).sum()/len(np.asarray(df.iloc[np.where(val_labels[:,i]==1)[0]]["True age"] - df.iloc[np.where(val_labels[:,i]==1)[0]]["Predicted age"]))
    slope, intercept, r_value, p_value, std_err = stats.linregress(df.iloc[np.where(val_labels[:,i]==1)[0]]["True age"],df.iloc[np.where(val_labels[:,i]==1)[0]]["Predicted age"])
    ax[cnt_2,cnt].set_title(class_names[i] + ", MAE= {}".format(round(MAE,2)))
    df_temp = df.iloc[np.where(val_labels[:,i]==1)[0]]
    ax[cnt_2,cnt].scatter(df_temp["True age"],df_temp["Predicted age"])
    #sns.scatterplot(data=df.iloc[np.where(val_labels[:,i]==1)[0]], x="True age", y="Predicted age")
    ax[cnt_2,cnt].plot(np.arange(101),'red', label="1x + 0")
    ax[cnt_2,cnt].plot(reg_line,'green', label = "{}x + {}".format(round(slope,2), round(intercept,2)))
    ax[cnt_2,cnt].set_xlabel("True age")
    ax[cnt_2,cnt].set_ylabel("Predicted age")
    ax[cnt_2,cnt].legend()
    cnt += 1
    if cnt == 5:
        cnt_2 += 1
        cnt = 0

