# Alzheimer textual explanation, visual explanation and classification
In this notebook there's all the procedure we do for the classification and for the explanation.

For the realization of this project i start from the code of my colleague.

In this notebook we suppose that you have already the dataset and the explanation, if else, 
you will run "Creation of the dataset" before this notebook.

In [2]:
import os, random, glob, cv2
import nltk
 
import pickle

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sys import platform
import re
import html
import string
import unicodedata
from nltk.tokenize import word_tokenize
import os

from sklearn.preprocessing import MinMaxScaler
from torch import nn
from tqdm.notebook import tqdm

import tensorflow as tf
from tensorflow.keras.utils import pad_sequences, to_categorical, plot_model 
from monai.data import CacheDataset

from tensorflow.keras import Model
from tensorflow.keras.layers import (
    Input, Dense,
    LSTM, Embedding,
    Dropout, add,
    MaxPool3D, Conv3D,
    GlobalAveragePooling3D, BatchNormalization
)

import torch
from torch.utils.data import Dataset, DataLoader

import importlib
import Utility
importlib.reload(Utility)
from Utility import get_gradcam
from alzheimer_disease.src.helpers.utils import get_device
from alzheimer_disease.src.modules.training import training_model
from alzheimer_disease.src.helpers.config import get_config
from alzheimer_disease.src.modules.preprocessing import get_transformations
from alzheimer_disease.src.models.densenetmm import DenseNetMM

#nltk.download('punkt')

In [3]:
# Definition of all paths
dataset = 'oasis_aug'

_base_path = '/Volumes/Seagate Bas/Vito/CV'
_config = get_config()
saved_path = os.path.join(_base_path, _config.get('SAVED_FOLDER'))
reports_path = os.path.join(_base_path, _config.get('REPORT_FOLDER'))
logs_path = os.path.join(_base_path, _config.get('LOG_FOLDER'))
_data_path = os.path.join(_base_path, _config.get('LOCAL_DATA'))
data_path, meta_path, explanation_path = [
    os.path.join(_data_path, dataset, 'data/'),
    os.path.join(_data_path, dataset, 'meta/'),
    os.path.join(_data_path, dataset, 'explainability/')
]

device = get_device()

if platform == 'win32':
    saved_path = saved_path.replace('/', '\\')
    reports_path = reports_path.replace('/', '\\')
    logs_path = logs_path.replace('/', '\\')
    data_path = data_path.replace('/', '\\')
    meta_path = meta_path.replace('/', '\\')
    explanation_path = explanation_path.replace('/', '\\')

saved_path, reports_path, logs_path, data_path, meta_path, explanation_path, device

('/Volumes/Seagate Bas/Vito/CV/saved/',
 '/Volumes/Seagate Bas/Vito/CV/reports/',
 '/Volumes/Seagate Bas/Vito/CV/logs/',
 '/Volumes/Seagate Bas/Vito/CV/data/oasis_aug/data/',
 '/Volumes/Seagate Bas/Vito/CV/data/oasis_aug/meta/',
 '/Volumes/Seagate Bas/Vito/CV/data/oasis_aug/explainability/',
 'cpu')

In [4]:
SIZE = 128
output_length = 1024
epochs = 30
name_model = 'DenseNetMM_best'

CHANNELS = ['T2w']

FEATURES = ['sex', 'age', 'bmi', 'education', 'cdr_memory', 'cdr_orientation', 'cdr_judgment', 'cdr_community', 'cdr_hobbies', 'cdr_personalcare', 'boston_naming_test', 'depression', 'sleeping_disorder', 'motor_disturbance']
MULTICLASS = True

In [5]:
# I started with the train test split of colleague and adapt to my task
def train_test_splitting(
        data_folder,
        meta_folder,
        explanation_folder,
        channels,
        features,
        train_ratio=.8,
        multiclass=False,
        verbose=True
):
    """
    Splitting train/eval/test.
    Args:
        data_folder (str): path of the folder containing images.
        meta_folder (str): path of the folder containing csv files.
        explanation_folder (str): path of the folder containing csv files of the explanation.
        channels (list): image channels to select (values `T1w`, `T2w` or both).
        features (list): features set to select.
        train_ratio (float): ratio of the training set, value between 0 and 1.
        multiclass (bool): `False` for binary classification, `True` for ternary classification.
        verbose (bool): whether or not print information.
    Returns:
        train_data (list): the training data ready to feed monai.data.Dataset
        eval_data (list): the evaluation data ready to feed monai.data.Dataset
        test_data (list): the testing data ready to feed monai.data.Dataset.
        (see https://docs.monai.io/en/latest/data.html#monai.data.Dataset).
    """
    scaler = MinMaxScaler()
    df = pd.read_csv(os.path.join(meta_folder, 'data_num.csv'))
    df1 = df[(df['weight'] != .0) & (df['height'] != .0)]
    df['bmi'] = round(df1['weight'] / (df1['height'] * df1['height']), 0)
    df['bmi'] = df['bmi'].fillna(.0)
    sessions = [s.split('_')[0] for s in os.listdir(data_folder) if os.path.isdir(os.path.join(data_folder, s))]
    subjects = list(set(sessions))

    # uploading of the dataset
    explanation = pd.read_csv(explanation_folder + 'explaination.csv', sep=';')



    # applying splitting on subjects to prevent data leakage
    random.shuffle(subjects)
    split_train = int(len(subjects) * train_ratio)
    train_subjects, test_subjects = subjects[:split_train], subjects[split_train:]
    split_eval = int(len(train_subjects) * .8)
    eval_subjects = train_subjects[split_eval:]
    train_subjects = train_subjects[:split_eval]

    # applying multiclass label correction and splitting
    if multiclass:
        train_subjects, eval_subjects, test_subjects = [], [], []
        df.loc[df['cdr'] == .0, 'final_dx'] = .0
        df.loc[df['cdr'] == .5, 'final_dx'] = 1.
        df.loc[(df['cdr'] != .0) & (df['cdr'] != .5), 'final_dx'] = 2.
        m = np.min(np.unique(df['final_dx'].to_numpy(), return_counts=True)[1])
        df = pd.concat([
            df[df['final_dx'] == .0].sample(m),
            df[df['final_dx'] == 1.].sample(m),
            df[df['final_dx'] == 2.].sample(m)
        ], ignore_index=True)
        n_test = m - int(m * train_ratio)
        n_eval = m - n_test - int(m * train_ratio * train_ratio)
        for i in range(3):
            sub = list(set(df[df['final_dx'] == float(i)]['subject_id'].to_numpy()))
            random.shuffle(sub)
            counter = 0
            for j in range(len(sub)):
                counter += len(df[df['subject_id'] == sub[j]])
                if counter <= n_test:
                    test_subjects.append(sub[j])
                elif counter > n_test and counter <= (n_test + n_eval):
                    eval_subjects.append(sub[j])
                else:
                    train_subjects.append(sub[j])

    # loading sessions paths
    X_train = df[df['subject_id'].isin(train_subjects)]
    X_eval = df[df['subject_id'].isin(eval_subjects)]
    X_test = df[df['subject_id'].isin(test_subjects)]
    train_sessions = [os.path.join(data_folder, s) for s in X_train['session_id'].values]
    eval_sessions = [os.path.join(data_folder, s) for s in X_eval['session_id'].values]
    test_sessions = [os.path.join(data_folder, s) for s in X_test['session_id'].values]

    # loading explanation of subjects
    explanation_train = explanation[explanation['subject_id'].isin(X_train['subject_id'].values)]
    explanation_eval = explanation[explanation['subject_id'].isin(X_eval['subject_id'].values)]
    explanation_test = explanation[explanation['subject_id'].isin(X_test['subject_id'].values)]

    # scaling numerical data in range [0,1]
    X_train.loc[:, features] = scaler.fit_transform(X_train[features])
    X_eval.loc[:, features] = scaler.fit_transform(X_eval[features])
    X_test.loc[:, features] = scaler.fit_transform(X_test[features])

    # arranging data in dictionaries
    # I will also take the reference session of the explanation and the image
    train_data = [dict({
        'image': sorted([os.path.join(s, i) for i in os.listdir(s) if any(c in i for c in channels)]),
        'data': X_train[X_train['session_id'] == s.split('/')[-1]][features].values[0],
        'label': df[df['session_id'] == s.split('/')[-1]]['final_dx'].values[0],
        'explanation': explanation_train[explanation_train['session_id'] == s.split('/')[-1]]['explaination'].values[0],
        'session_id': s.split('/')[-1]
    }) for s in train_sessions]
    eval_data = [dict({
        'image': sorted([os.path.join(s, i) for i in os.listdir(s) if any(c in i for c in channels)]),
        'data': X_eval[X_eval['session_id'] == s.split('/')[-1]][features].values[0],
        'label': df[df['session_id'] == s.split('/')[-1]]['final_dx'].values[0],
        'explanation': explanation_eval[explanation_eval['session_id']==s.split('/')[-1]]['explaination'].values[0],
        'session_id': s.split('/')[-1]
    }) for s in eval_sessions]
    test_data = [dict({
        'image': sorted([os.path.join(s, i) for i in os.listdir(s) if any(c in i for c in channels)]),
        'data': X_test[X_test['session_id'] == s.split('/')[-1]][features].values[0],
        'label': df[df['session_id'] == s.split('/')[-1]]['final_dx'].values[0],
        'explanation': explanation_test[explanation_test['session_id'] == s.split('/')[-1]]['explaination'].values[0],
        'session_id': s.split('/')[-1]
    }) for s in test_sessions]

    # print data splitting information
    if verbose:
        print(''.join(['> ' for _ in range(40)]))
        print(f'\n{"":<20}{"TRAINING":<20}{"EVALUATION":<20}{"TESTING":<20}\n')
        print(''.join(['> ' for _ in range(40)]))
        tsb1 = str(len(train_subjects)) + ' (' + str(round((len(train_subjects) * 100 / len(df['subject_id'].unique())), 0)) + ' %)'
        tsb2 = str(len(eval_subjects)) + ' (' + str(round((len(eval_subjects) * 100 / len(df['subject_id'].unique())), 0)) + ' %)'
        tsb3 = str(len(test_subjects)) + ' (' + str(round((len(test_subjects) * 100 / len(df['subject_id'].unique())), 0)) + ' %)'
        tss1 = str(len(train_sessions)) + ' (' + str(round((len(train_sessions) * 100 / len(df)), 2)) + ' %)'
        tss2 = str(len(eval_sessions)) + ' (' + str(round((len(eval_sessions) * 100 / len(df)), 2)) + ' %)'
        tss3 = str(len(test_sessions)) + ' (' + str(round((len(test_sessions) * 100 / len(df)), 2)) + ' %)'
        print(f'\n{"subjects":<20}{tsb1:<20}{tsb2:<20}{tsb3:<20}\n')
        print(f'{"sessions":<20}{tss1:<20}{tss2:<20}{tss3:<20}\n')

    return train_data, eval_data, test_data

In [6]:
densenet = DenseNetMM(
    in_channels = len(CHANNELS),
    in_size = SIZE,
    in_features_size= len(FEATURES),
    out_channels = 3 if MULTICLASS else 2,
    append_features = True,
    name=name_model
)

In [7]:
train_transform, eval_transform = get_transformations(size=SIZE)

train, val, test = train_test_splitting(
    data_folder=data_path,
    meta_folder=meta_path,
    explanation_folder=explanation_path,
    channels=CHANNELS,
    features=FEATURES,
    multiclass=MULTICLASS,
    verbose=True
)

> > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > 

                    TRAINING            EVALUATION          TESTING             

> > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > > 

subjects            380 (65.0 %)        98 (17.0 %)         111 (19.0 %)        

sessions            435 (63.6 %)        111 (16.23 %)       138 (20.18 %)       



In [8]:
train[0]

{'image': ['/Volumes/Seagate Bas/Vito/CV/data/oasis_aug/data/OAS30788_MR_d4108/sub-OAS30788_sess-d4108_acq-TSE_T2w.nii.gz'],
 'data': array([1.        , 0.8       , 0.47826087, 0.66666667, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.96666667, 0.5       , 0.5       , 0.5       ]),
 'label': 0.0,
 'explanation': "**Summary of Heatmap Analysis**\n\nThe heatmap analysis reveals that the machine learning model focused on specific regions of the brain to make its classification decision. The highlighted regions are not areas affected by Alzheimer's Disease, but rather areas that the model considered crucial for its prediction.\n\n1. **Frontal-to-Occipital (GapMap) left**: This region accounts for 19.9% of the heatmap. According to the Julich-Brain Atlas, this area is responsible for processing visual information and is involved in attention and executive functions. The model's focus on this region may indicate that it is considering the patient's

In [9]:
if glob.glob(saved_path+ f'{name_model}.pth'):
	print(f'Loading {name_model}.pth')
	densenet.load_state_dict(torch.load(saved_path + f'{name_model}.pth'))
else:
	print('Train of the model')
	train_metrics = training_model(
		model = densenet,
		data = [train, val],
		transforms = [train_transform, eval_transform],
		epochs = epochs,
		device = get_device(),
		paths = [saved_path, reports_path, logs_path],
		num_workers=0,
		verbose=True
	)

Loading DenseNetMM_best.pth


## Image Captioning

In [10]:
name_fextractor = 'DenseNetMMFeatureExtractor'

In [11]:
VOCABULARY_SIZE = 1179
#dimensions of the word embedding vector
EMBEDDING_DIM = 512
# number of units in the recurrent layers
UNITS = 512
#number of samples that will propagated through the network at once. 
BATCH_SIZE = 32
#shuffling the dataset
BUFFER_SIZE = 1000

def preprocess(text):
    #conver all text into lower
    text = text.lower()
    #remove all character from text that are not words and whitespace
    text = re.sub(r'[^\w\s]', '', text) 
    #replace multiple whitespace with a single space
    text = re.sub('\s+', ' ', text)
    #remove any leading or trailing whitespace from the text
    text = text.strip()
    #Add start and end token to the text at begining and end of the text respectively
    text = '[start] ' + text + ' [end]'
    return text

# concat for get all df
entire_df = train + val + test

all_text = [explaination['explanation'] for explaination in entire_df]

# tokenize the text
#Keras preprocessing layer that transforms text into sequences of integers.
tokenizer = tf.keras.layers.TextVectorization(
    #set maximum number of tokens (words) that the tokenizer will keep
    max_tokens=VOCABULARY_SIZE, 
    standardize=None,
    #specifies the length of the output sequences
    output_sequence_length=output_length
)

# Adapting the Tokenizer to all caption
tokenizer.adapt(all_text)

vocab_size = len(tokenizer.word_index) + 1
print('Vocabulary Size: {}'.format(vocab_size))

Vocabulary Size: 1179


In [None]:
#layer that maps strings to integer indices.
word2idx = tf.keras.layers.StringLookup(
    #specifies a token that will be treated as a mask
    mask_token="",
    vocabulary=tokenizer.get_vocabulary())
#The vocabulary is obtained from the tokenizer using the get_vocabulary() method, which returns a list of strings
#representing the vocabulary in order of frequency (most frequent first).

idx2word = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary(),
    invert=True
)

In [12]:
# CNN encoder
encoder = DenseNetMM(
    in_channels = len(CHANNELS),
    in_size = SIZE,
    in_features_size= len(FEATURES),
    out_channels = 3 if MULTICLASS else 2,
    append_features = True,
    name=name_fextractor
)

# Upload the previous model for the feature extraction
if glob.glob(saved_path+ f'{name_model}.pth'):
	print(f'Loading {name_model}.pth')
	encoder.load_state_dict(torch.load(saved_path + f'{name_model}.pth'))
else:
	print('Train of the model')
	train_metrics = training_model(
		model = encoder,
		data = [train, val],
		transforms = [train_transform, eval_transform],
		epochs = epochs,
		device = get_device(),
		paths = [saved_path, reports_path, logs_path],
		num_workers=0,
		verbose=True
	)
    
# get just the feature extractor from image
encoder = torch.nn.Sequential(
    encoder.features_img,
    encoder.output_layers,
)

encoder

Loading DenseNetMM_best.pth


In [None]:
class TransformerEncoderLayer(tf.keras.layers.Layer):

    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.layer_norm_1 = tf.keras.layers.LayerNormalization()
        self.layer_norm_2 = tf.keras.layers.LayerNormalization()
        self.attention = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim)
        self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")

    #Forward Pass (call method):
    def call(self, x, training):
        x = self.layer_norm_1(x)
        x = self.dense(x)

        attn_output = self.attention(
            query=x,
            value=x,
            key=x,
            attention_mask=None,
            training=training
        )

        x = self.layer_norm_2(x + attn_output)

        return x

In [None]:
class Embeddings(tf.keras.layers.Layer):

    def __init__(self, vocab_size, embed_dim, max_len):
        super().__init__()
        self.token_embeddings = tf.keras.layers.Embedding(
            vocab_size, embed_dim
        )
        
        self.position_embeddings = tf.keras.layers.Embedding(
            max_len, embed_dim, input_shape=(None, max_len)
        )


    def call(self, input_ids):
        #input_ids: A tensor of token IDs representing the input sequences.
        length = tf.shape(input_ids)[-1]
        #A range of position IDs from 0 to length - 1 is created
        position_ids = tf.range(start=0, limit=length, delta=1)
        #adds a new axis to make position_ids a batch-compatible tensor of shape
        position_ids = tf.expand_dims(position_ids, axis=0)

        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        return token_embeddings + position_embeddings

In [None]:
Embeddings(tokenizer.vocabulary_size(), EMBEDDING_DIM, output_length)(next(iter(train))[1]).shape

In [None]:
class TransformerDecoderLayer(tf.keras.layers.Layer):

    def __init__(self, embed_dim, units, num_heads):
        super().__init__()
# embedding layer to create token and positional embeddings.
        self.embedding = Embeddings(
            tokenizer.vocabulary_size(), embed_dim, output_length)
# for self attention
        self.attention_1 = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
#for attending to the encoder's output
        self.attention_2 = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim, dropout=0.1
        )
        #three layer normalization layers

        self.layernorm_1 = tf.keras.layers.LayerNormalization()
        self.layernorm_2 = tf.keras.layers.LayerNormalization()
        self.layernorm_3 = tf.keras.layers.LayerNormalization()
        #Dense layers for FF network and output layer
        self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
        self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)

        self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
        #two dropout layers
        self.dropout_1 = tf.keras.layers.Dropout(0.3)
        self.dropout_2 = tf.keras.layers.Dropout(0.5)


    def call(self, input_ids, encoder_output, training, mask=None):
        embeddings = self.embedding(input_ids)

        combined_mask = None
        padding_mask = None
        #Prepares the masks for attention mechanisms
        if mask is not None:
            causal_mask = self.get_causal_attention_mask(embeddings)
            padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
            combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
            combined_mask = tf.minimum(combined_mask, causal_mask)
        #Applies self-attention on the embeddings
        attn_output_1 = self.attention_1(
            query=embeddings,
            value=embeddings,
            key=embeddings,
            attention_mask=combined_mask,
            training=training
        )
        #Adds the input embeddings to the attention output and normalizes
        out_1 = self.layernorm_1(embeddings + attn_output_1)
        #Applies attention on the encoder output (cross-attention).
        attn_output_2 = self.attention_2(
            query=out_1,
            value=encoder_output,
            key=encoder_output,
            attention_mask=padding_mask,
            training=training
        )
        #Adds the previous output to the cross-attention output and normalizes

        out_2 = self.layernorm_2(out_1 + attn_output_2)
        #Feedforward network and dropout
        ffn_out = self.ffn_layer_1(out_2)
        ffn_out = self.dropout_1(ffn_out, training=training)
        ffn_out = self.ffn_layer_2(ffn_out)

        ffn_out = self.layernorm_3(ffn_out + out_2)
        ffn_out = self.dropout_2(ffn_out, training=training)
        preds = self.out(ffn_out)
        return preds

#creates a causal mask to ensure that each position can only attend to earlier positions and itself, preventing information leakage from future tokens
    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
            axis=0
        )
        return tf.tile(mask, mult)
    

class ImageCaptioningModel(tf.keras.Model):

    def __init__(self, cnn_model, encoder, decoder, image_aug=None):
        super().__init__()
        self.cnn_model = cnn_model
        self.encoder = encoder
        self.decoder = decoder
        self.image_aug = image_aug
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")
        self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")

    
    #Loss Calculation
    def calculate_loss(self, y_true, y_pred, mask):
        loss = self.loss(y_true, y_pred)
        mask = tf.cast(mask, dtype=loss.dtype)
        loss *= mask
        return tf.reduce_sum(loss) / tf.reduce_sum(mask)
    
    #This method calculates the masked loss by applying the mask to the loss values
    #and then computing the average loss per non-padding token
    def calculate_accuracy(self, y_true, y_pred, mask):
        accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
        accuracy = tf.math.logical_and(mask, accuracy)
        accuracy = tf.cast(accuracy, dtype=tf.float32)
        mask = tf.cast(mask, dtype=tf.float32)
        return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
        
        
    #This method calculates the masked accuracy by comparing predicted tokens to 
    # #the ground truth tokens and applying the mask
    def compute_loss_and_acc(self, img_embed, captions, training=True):
        encoder_output = self.encoder(img_embed, training=True)
        y_input = captions[:, :-1]
        y_true = captions[:, 1:]
        mask = (y_true != 0)
        y_pred = self.decoder(
            y_input, encoder_output, training=True, mask=mask
        )
        loss = self.calculate_loss(y_true, y_pred, mask)
        acc = self.calculate_accuracy(y_true, y_pred, mask)
        return loss, acc
        
        #This method computes the loss and accuracy for a given batch by first encoding
        # #the image embeddings, preparing the input and target sequences for the decoder, 
        # and then calculating the loss and accuracy using the decoder's predictions.
    def train_step(self, batch):
        imgs, captions = batch

        if self.image_aug:
            imgs = self.image_aug(imgs)

        img_embed = self.cnn_model(imgs)

        with tf.GradientTape() as tape:
            loss, acc = self.compute_loss_and_acc(
                img_embed, captions
            )

        train_vars = (
            self.encoder.trainable_variables + self.decoder.trainable_variables
        )
        grads = tape.gradient(loss, train_vars)
        self.optimizer.apply_gradients(zip(grads, train_vars))
        self.loss_tracker.update_state(loss)
        self.acc_tracker.update_state(acc)

        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
#This method performs a training step, including optional image augmentation, 
#forward pass, loss and accuracy computation, gradient computation, and model
 # weights update using the optimizer


    def test_step(self, batch):
        imgs, captions = batch

        img_embed = self.cnn_model(imgs)

        loss, acc = self.compute_loss_and_acc(
            img_embed, captions, training=False
        )

        self.loss_tracker.update_state(loss)
        self.acc_tracker.update_state(acc)

        return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
#This method performs an evaluation step, similar to the training step but
#without gradient computation and weight updates.

    @property
    def metrics(self):
        return [self.loss_tracker, self.acc_tracker]

In [None]:
def load_data(img_dict):
    img = img_dict['data'] 
    #tokenizes the caption using the tokenizer created earlier
    caption = tokenizer(img_dict['explanation'])
    return img, caption

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(
    (train))

train_dataset = train_dataset.map(
    load_data, num_parallel_calls=tf.data.AUTOTUNE
    ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

val_dataset = tf.data.Dataset.from_tensor_slices(
    (val))

val_dataset = val_dataset.map(
    load_data, num_parallel_calls=tf.data.AUTOTUNE
    ).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
encoder_transformer = TransformerEncoderLayer(EMBEDDING_DIM, 1)
decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)

caption_model = ImageCaptioningModel(
    cnn_model=encoder, encoder=encoder_transformer, decoder=decoder
)

In [1]:
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=False, reduction="none"
)

early_stopping = tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)

caption_model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=cross_entropy
)

NameError: name 'tf' is not defined

In [None]:
if not glob.glob(saved_path + 'transformer_caption_model.*') == []:
    print('Caption model training...')
    history = caption_model.fit(
        train_dataset,
        epochs=epochs,
        validation_data=val_dataset,
        callbacks=[early_stopping]
    )
    caption_model.save_weights(saved_path + 'transformer_caption_model.h5')
else:
    print('Loading Caption Model...')
    caption_model.load_weights(saved_path + 'transformer_caption_model.h5')

# Merging of the proposed methods

In [None]:
def plot_grad_cam_explanation(image, label, pred, heatmap, mask, caption, alpha=128):
	"""
	Plots model input image, Grad-CAM heatmap, segmentation mask and the explanation generated
	Args:
		image (numpy.ndarray): the input 3D image.
		label (int): the input image label.
		pred (int): model prediction for input image.
		heatmap (numpy.ndarray): the Grad-CAM 3D heatmap.
		mask (numpy.ndarray): the computed 3D segmentation mask.
		caption (string): the explanation generated caption.
		alpha (int): transparency channel. Between 0 and 255.
	Returns:
		None.
	"""
	if alpha >= 0 and alpha <= 255:
		heatmap_mask = np.zeros((image.shape[0], image.shape[1], image.shape[2], 4), dtype='uint8')
		heatmap_mask[mask == 1] = [255, 0, 0, alpha]
		image = image[:,:,int(image.shape[2] / 2)]
		heatmap = heatmap[:,:,int(heatmap.shape[2] / 2)]
		heatmap_mask = heatmap_mask[:,:,int(heatmap_mask.shape[2] / 2),:]
		fig, axs = plt.subplots(1, 3, figsize=(18, 6))
		norm_img = cv2.normalize(image, np.zeros((image.shape[1], image.shape[0])), 0, 1, cv2.NORM_MINMAX)
		im_shows = [
			axs[0].imshow(norm_img, cmap='gray', interpolation='bilinear', vmin = .0, vmax = 1.),
			axs[1].imshow(heatmap, cmap='jet', interpolation='bilinear', vmin = .0, vmax = 1.),
			axs[2].imshow(norm_img, cmap='gray', interpolation='bilinear', vmin = .0, vmax = 1.)
		]
		axs[2].imshow(heatmap_mask, interpolation='bilinear')
		axs[0].set_title('Label=' + ('NON-AD' if label == 0 else 'AD') + ' | Prediction=' + ('NON-AD' if pred == 0 else 'AD'), fontsize=16)
		axs[1].set_title('Grad-CAM Heatmap', fontsize=16)
		axs[2].set_title('Mask - Threshold ' + str(.8), fontsize=16)
		for i, ax in enumerate(axs):
			ax.axis('off')
			fig.colorbar(im_shows[i], ax=ax, ticks=np.linspace(0,1,6))
            
        # insert of caption generated
        fig.text(0.5, 0.04, caption, ha='center', va='center')
        fig.tight_layout()
        plt.show()
    else:
		print('\n' + ''.join(['> ' for i in range(30)]))
		print('\nERROR: alpha channel \033[95m '+alpha+'\033[0m out of range [0,255].\n')
		print(''.join(['> ' for i in range(30)]) + '\n')

In [None]:
def get_results_and_plot(image_dict, predictor, generator,saved_path, plot=False):
    '''
    In this function we predict the class of the image after that we
    keep the Grad-CAM and the explanation and return them
    :param plot: if True we plot the Grad-CAM and explanation
    :param saved_path: directory where the model weights are stored
    :param image_dict: image dictionary
    :param predictor: model for predict all value
    :param generator: model for the generation of explanation
    :param tokenizer: tokenizer object
    :param max_length: maximum length of explanation
    :return: 
    '''
    
    # Keep the image, the mask and the prediction
    image, mask, pred, label, heatmap = get_gradcam(
        example=image_dict,
        model=predictor,
        saved_path=saved_path,
        threshold=.8,
    )
    
    # Generate the description from the processed image
    explanation = generator.predict(image)
    
    if plot:
        plot_grad_cam_explanation(image, label, pred, heatmap, mask, explanation)
    
    return  image, label, pred, heatmap, mask, explanation

In [None]:
# get a random example from the entire dataset
example = entire_df[random.randint(0, entire_df.shape[0]-1)]
get_results_and_plot(
    image_dict=example,
    predictor=densenet,
    generator=caption_model,
    saved_path=saved_path,
    tokenizer=tokenizer,
    max_length=output_length,
    plot=True
)