## LAD-FFR-Estimation
### Non-Invasive Fractional Flow Reserve Estimation using Deep Learning on Intermediate Left Anterior Descending Coronary Artery Lesion Angiography Images.
#### _ M. Aria

In [None]:
import datetime
total_start = datetime.datetime.now()

import os, random, math, itertools
import numpy as np, pandas as pd

SEED = 42
os.environ['PYTHONHASHSEED']=str(SEED)
np.random.seed(SEED)
random.seed(SEED)

import tensorflow as tf
tf.random.set_seed(SEED)
import tensorflow.keras
import tensorflow_addons as tfa
from tensorflow.keras.layers import Dense, BatchNormalization, Dropout, Input
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras import regularizers
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.layers.experimental import preprocessing
import tensorflow.keras.backend as K

from tensorflow.keras.applications import DenseNet169
NETWORK = DenseNet169

from kaggle_datasets import KaggleDatasets

from sklearn.model_selection import train_test_split
from sklearn import metrics

import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
%matplotlib inline

try:
    tpu = None
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)

    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)
    
print("Tensorflow version ", tf.__version__)

In [None]:
from tensorflow.keras import mixed_precision
print('Mixed precision enabled')

tf.config.optimizer.set_jit(True)
print('Accelerated Linear Algebra enabled')

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()

user_secrets.set_tensorflow_credential(user_credential)

In [None]:
from tensorflow.keras.callbacks import Callback
class CosineAnnealingScheduler(Callback):

    def __init__(self, T_max, eta_max, eta_min=0, verbose=0):
        super(CosineAnnealingScheduler, self).__init__()
        self.T_max = T_max
        self.eta_max = eta_max
        self.eta_min = eta_min
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        lr = self.eta_min + (self.eta_max - self.eta_min) * (1 + math.cos(math.pi * epoch / self.T_max)) / 2
        K.set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\nEpoch %05d: CosineAnnealingScheduler setting learning '
                  'rate to %s.' % (epoch + 1, lr))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['lr'] = K.get_value(self.model.optimizer.lr)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

dataset_id = 'lad-ffr'
GCS_PATH = KaggleDatasets().get_gcs_path(dataset_id)
BATCH_SIZE = 128 * strategy.num_replicas_in_sync

CLASSES = ['FFRH', 'FFRL']
NUM_CLASSES = len(CLASSES)
IMAGE_SIZE = [380, 380]
input_shape = (380, 380, 3)

LOSS = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.2)
METRICS = ['accuracy']

Epochs = 120
Early_Stop = 10
OPTIMIZER = tensorflow.keras.optimizers.Adam(lr = 1e-2, decay = 1e-5)

Fine_Tune_Epochs = 600
Fine_Tune_Early_Stop = 100
Fine_Tune_OPTIMIZER = tensorflow.keras.optimizers.Adam(lr = 1e-4, decay = 1e-6)
Fine_Tune_filepath = "Best-Model-FT.h5"

Callbacks = [
    CosineAnnealingScheduler(Epochs, 1e-3, 1e-5),
    EarlyStopping(monitor='val_loss', patience=Early_Stop, mode='auto', min_delta=0.00001, verbose=2, restore_best_weights=True)]

FT_Callbacks = [
    ReduceLROnPlateau(monitor='val_loss', factor=0.8, patience=10, verbose=2, mode='min', min_delta=0.0001, cooldown=1, min_lr=1e-6),
    ModelCheckpoint(Fine_Tune_filepath, monitor='val_accuracy', verbose=2, save_best_only=True, save_weights_only=False, mode='max'),
    EarlyStopping(monitor='val_accuracy', patience=Fine_Tune_Early_Stop, mode='auto', min_delta=0.00001, verbose=2, restore_best_weights=True)]

In [None]:
filenames = tf.io.gfile.glob(str(GCS_PATH + '/Train/*/*'))
random.shuffle(filenames)

test_filenames = tf.io.gfile.glob(str(GCS_PATH + '/Test/*/*'))
random.shuffle(test_filenames)

In [None]:
CLASS_COUNT = {}
for id, subfolder in enumerate(CLASSES):
    count_label = len([filename for filename in filenames if subfolder in filename])
    CLASS_COUNT[subfolder] = count_label

print("High FFR sample count : " + str(CLASS_COUNT['FFRH']))
print("Low FFR Sample count : " + str(CLASS_COUNT['FFRL']))

In [None]:
data = {'Cases':['High FFR', 'Low FFR'],
        'Cases_count':[CLASS_COUNT['FFRH'], CLASS_COUNT['FFRH']]
       }

df = pd.DataFrame(data)

sns.set(style="darkgrid")
plt.figure(figsize=(10,8))
sns.barplot(x=df.index, y= df['Cases_count'].values)
plt.title('Number of samples', fontsize=14)
plt.xlabel('Case type', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.xticks(range(len(df.index)), ['High FFR', 'Low FFR'])
plt.show()

print(df)

In [None]:
train_list_ds = tf.data.Dataset.from_tensor_slices(filenames)
test_list_ds = tf.data.Dataset.from_tensor_slices(test_filenames)

In [None]:
TRAIN_IMG_COUNT = tf.data.experimental.cardinality(train_list_ds).numpy()
print("Training images count: " + str(TRAIN_IMG_COUNT))

Test_IMG_COUNT = tf.data.experimental.cardinality(test_list_ds).numpy()
print("Testing images count: " + str(Test_IMG_COUNT))

In [None]:
def get_label(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    return int(parts[-2] == CLASSES)

In [None]:
def decode_img(img):
    img = tf.image.decode_png(img, channels=3)
    img = tfa.image.equalize(img)
    img = tf.image.convert_image_dtype(img, tf.float32)
    return tf.image.resize(img, IMAGE_SIZE)

In [None]:
def process_path(file_path):
    label = get_label(file_path)
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img, label

In [None]:
train_ds = train_list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
test_ds = test_list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

In [None]:
def prepare_for_training(ds, cache=True):
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
        else:
            ds = ds.cache()

    ds = ds.shuffle(buffer_size=1000)
    ds = ds.batch(BATCH_SIZE)

    if cache:
        ds = ds.prefetch(buffer_size=AUTOTUNE)

    return ds

In [None]:
train_ds = prepare_for_training(train_ds)
test_ds = prepare_for_training(test_ds, False)

In [None]:
# tf.config.set_soft_device_placement(True)
img_augmentation = Sequential([
    preprocessing.RandomRotation(factor=0.3, fill_mode='nearest'),
    preprocessing.RandomTranslation(height_factor=0.15, width_factor=0.15, fill_mode='reflect'),
    preprocessing.RandomZoom(0.15),
    preprocessing.RandomContrast(factor=0.15)
    ],name="Augmentation")

train_ds = train_ds.map(lambda x, y: (img_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)

In [None]:
def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(15):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASSES[np.argmax(label_batch[n])])
        plt.axis("off")

In [None]:
image_batch, label_batch = next(iter(train_ds))
show_batch(image_batch.numpy(), label_batch.numpy())

In [None]:
for image, label in test_ds.take(1):
    print(label[0])

In [None]:
class_weight = {C: ((1/N)*len(filenames)/2.0) for C,N in enumerate(list(CLASS_COUNT.values()))}
print(class_weight)

In [None]:
def build_model(OPTIMIZER, LOSS, METRICS):
    model = None
    baseModel = NETWORK(include_top=False, input_tensor=Input(shape=input_shape), weights="imagenet", pooling ='avg')

    baseModel.trainable = False

    x = BatchNormalization(axis = -1, name="Batch-Normalization-1")(baseModel.output)

    x = Dense(512, activation='relu', kernel_regularizer=regularizers.L1L2(l1=1e-5, l2=1e-4))(x)
    x = BatchNormalization(axis = -1, name="Batch-Normalization-3")(x)
    x = Dropout(.2, name="Dropout-2")(x)

    x = Dense(256, activation='relu')(x)
    x = BatchNormalization(axis = -1, name="Batch-Normalization-4")(x)
    
    outputs = Dense(NUM_CLASSES, activation="softmax", name="Classifier")(x)
    model = tf.keras.Model(inputs=baseModel.input, outputs=outputs, name="LAD-FFR-Classifier")
    
    model.compile(optimizer = OPTIMIZER, loss = LOSS, metrics = METRICS)
        
    return model

In [None]:
with strategy.scope():
    model = build_model(OPTIMIZER, LOSS, METRICS)

In [None]:
def fit_model(Epochs, Callbacks, class_weight=None):
    history = model.fit(
        train_ds,
        validation_split=0.2,
        epochs=Epochs,
        callbacks=Callbacks,
        verbose=2,
        class_weight=class_weight
    )
    return history
    
history = fit_model(Epochs, Callbacks, class_weight=class_weight)

In [None]:
def Plot_Learning_Curves():
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    sns.set(style="dark")
    plt.rcParams['figure.figsize'] = (14, 5)

    plt.subplot(1,2,1)
    plt.plot(loss, label='Training loss')
    plt.plot(val_loss, linestyle="--", label='Validation loss')
    plt.title('Training and validation loss')
    plt.ylabel('Loss') 
    plt.xlabel('Epoch')
    plt.legend()

    plt.subplot(1,2,2)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, linestyle="--", label='Validation Accuracy')
    plt.title('Training and validation accuracy')
    plt.ylabel('Accuracy') 
    plt.xlabel('Epoch')
    plt.legend()

    plt.show()
    
Plot_Learning_Curves()

In [None]:
results = model.evaluate(test_ds, return_dict=True)
print ('\nModel Evaluation:')
print(results['accuracy']*100)

In [None]:
def fine_tune(OPTIMIZER, LOSS, METRICS):

    for layer in model.layers[8:]:
        if not 'block1' in layer.name:
            if not isinstance(layer, BatchNormalization):
                layer.trainable = True  
            
    model.compile(optimizer = Fine_Tune_OPTIMIZER, loss = LOSS, metrics = METRICS)
    return model

In [None]:
with strategy.scope():
    model = fine_tune(Fine_Tune_OPTIMIZER, LOSS, METRICS)

In [None]:
history = fit_model(Fine_Tune_Epochs, FT_Callbacks, class_weight=class_weight)

In [None]:
Plot_Learning_Curves()

In [None]:
model = None
model = load_model(Fine_Tune_filepath)
results = model.evaluate(test_ds, return_dict=True)
print ('\nModel Evaluation:')
print(results['accuracy']*100)

In [None]:
def dataset_to_numpy_util(dataset, N):
    dataset = dataset.unbatch().batch(N)
    for images, labels in dataset:
        numpy_images = images.numpy()
        numpy_labels = labels.numpy()
        break
    return numpy_images, numpy_labels

In [None]:
x_test, y_test = dataset_to_numpy_util(test_ds, Test_IMG_COUNT)

print("Evaluation Dataset:")
print('X shape: ', x_test.shape,' Y shape: ', y_test.shape)

In [None]:
preds = model.predict(x_test)
print('Shape of preds: ', preds.shape)

plt.figure(figsize = (12, 12))

R = np.random.choice(preds.shape[0])

for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    R = np.random.choice(preds.shape[0])
    pred = np.argmax(preds[R])
    actual = np.argmax(y_test[R])
    col = 'g'
    if pred != actual:
        col = 'r'
    plt.xlabel('I={} | P={} | L={}'.format(R, pred, actual), color = col)
    plt.imshow(((x_test[R]* 255).astype(np.uint8)), cmap='binary')
plt.show()

In [None]:
index = 0
plt.rcParams['figure.figsize'] = (6, 4)
plt.plot(preds[index])
sns.set(style="darkgrid")
plt.show()

In [None]:
preds = np.round(preds,0)
class_metrics = metrics.classification_report(y_test, preds, target_names = CLASSES, zero_division = 0)
print (class_metrics)

In [None]:
matrix = metrics.confusion_matrix(y_test.argmax(axis=1), preds.argmax(axis=1))

def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True):

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy
    sns.set(style="dark")
    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\n Accuracy={:0.4f}; Misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()
    
plot_confusion_matrix(cm           = np.array(matrix), 
                      normalize    = False,
                      target_names = CLASSES,
                      title        = "Confusion Matrix")

plot_confusion_matrix(cm           = np.array(matrix), 
                      normalize    = True,
                      target_names = CLASSES,
                      title        = "Normalized Confusion Matrix")

In [None]:
accuracy = metrics.accuracy_score(y_test.argmax(axis=1), preds.argmax(axis=1))
print("Accuracy: ", accuracy)

accurate_predictions = metrics.accuracy_score(y_test.argmax(axis=1), preds.argmax(axis=1), normalize=False)
print("The number of accurate predictions is: ", accurate_predictions)

model_auc = metrics.roc_auc_score(y_test.argmax(axis=1), preds.argmax(axis=1))
print('AUC:', model_auc)

In [None]:
def plot_roc_curve(true_y, y_prob):

    fpr, tpr, thresholds = metrics.roc_curve(true_y, y_prob)
    model_auc = metrics.auc(fpr, tpr)
    
    plt.figure(1)
    plt.plot([0, 1], 'k--')
    plt.plot(fpr, tpr, label=f'AUC = {model_auc}')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curve')
    plt.legend(loc='best')

    plt.figure(2)
    plt.xlim(0, 0.2)
    plt.ylim(0.8, 1)
    plt.plot([0, 1], 'k--')
    plt.plot(fpr, tpr, label=f'AUC = {model_auc}')
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title('ROC curve (zoomed in at top left)')
    plt.legend(loc='best')

plot_roc_curve(y_test.argmax(axis=1), preds.argmax(axis=1))

In [None]:
test_image = x_test[0]
x = np.expand_dims(test_image, axis=0)
x = x/255
images = np.vstack([x])

classes = model.predict(images, batch_size=BATCH_SIZE)
classes = np.argmax(classes,axis=1)

print ('Class:', CLASSES[int(classes)] )

In [None]:
! pip -q install tf-keras-vis

from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.gradcam import GradcamPlusPlus
from tf_keras_vis.scorecam import ScoreCAM
from tf_keras_vis.utils import normalize

os.mkdir('Maps')

def attention_map(map_kind, test_img, img_no, show_only=False):
    image_titles = ['Attention Map']

    test_image = test_img
    x = np.expand_dims(test_image, axis=0)
    x = x/255
    images = np.vstack([x])
    X = images[0]

    subplot_args = { 'nrows': 1, 'ncols': 1, 'figsize': (6, 6),'subplot_kw': {'xticks': [], 'yticks': []} }

    y_pred = model.predict(X[np.newaxis,...])
    class_idxs_sorted = np.argsort(y_pred.flatten())[::-1]
    
    def score(class_idxs_sorted):
        return (class_idxs_sorted[0][0])
    
    def model_modifier(cloned_model):
        cloned_model.layers[-1].activation = tf.keras.activations.linear
        return cloned_model

    if map_kind == 'Gradcam':
        gradcam = Gradcam(model, model_modifier=model_modifier, clone=False)
        cam = gradcam(score, X, penultimate_layer=-1)

    elif map_kind == 'GradcamPlusPlus':
        gradcamplusplus = GradcamPlusPlus(model, model_modifier=model_modifier, clone=False)
        cam = gradcamplusplus(score, X, penultimate_layer=-1)

    else:
        scorecam = ScoreCAM(model)
        cam = scorecam(score, X, penultimate_layer=-1, max_N=10)

    cam = normalize(cam)
    
    if show_only:
        f, ax = plt.subplots(1,2,figsize=(14,5))
        for i, title in enumerate(image_titles):
            heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
            ax[0].imshow(test_image)
            ax[1].imshow(test_image)
            cb = ax[1].imshow(heatmap, cmap='jet', alpha=0.5)
            f.colorbar(cb)
            plt.title(f"Predicted class: {CLASSES[class_idxs_sorted[0]]} ({y_pred[0,class_idxs_sorted[0]]})")
        plt.tight_layout()
        plt.show()
        
    else:
        f, ax = plt.subplots(**subplot_args)
        for i, title in enumerate(image_titles):
            heatmap = np.uint8(cm.jet(cam[0])[..., :3] * 255)
            ax.set_title(title, fontsize=14)
            ax.imshow(test_image)
            ax.imshow(heatmap, cmap='jet', alpha=0.5)
            ax.axis('off')
        plt.tight_layout()
        plt.savefig(f'./Maps/{img_no}-{map_kind}.png')
    plt.close()

In [None]:
for i in range(1):
    attention_map('Gradcam', x_test[i], i, show_only=True)
    attention_map('GradcamPlusPlus', x_test[i], i, show_only=True)
    attention_map('ScoreCAM', x_test[i], i, show_only=True)

In [None]:
# Total time elapsed
total_end = datetime.datetime.now()
elapsed = total_end - total_start
print ('Total time elapsed: ', elapsed)

----

**LAD-FFR-Classifier** V.1.12.00 | Non-Invasive Fractional Flow Reserve Estimation using Deep Learning on Intermediate Left Anterior Descending Coronary Artery Lesion Angiography Images.
<br>{Binary (FFR > 80, FFR ≤ 80) clasification with attention maps.}

© Proposed method implementation by [**Mehrad Aria**](https://www.mehradaria.com/) for paper [[Non-Invasive Fractional Flow Reserve Estimation using Deep Learning on Intermediate Left Anterior Descending Coronary Artery Lesion Angiography Images](https://doi.org/X)].
<br>Jun 2023 / Tabriz, Iran.

----