<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/best_models_data_augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
kaggle = False

# Imports

In [28]:
import os, shutil, re
import pandas as pd, numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import scipy
import skimage.transform as transform
import seaborn as sns
if kaggle:
    from kaggle_datasets import KaggleDatasets
    from kaggle_secrets import UserSecretsClient
else:
    from google.colab import drive
import nibabel as nib

# Import the most used layers
from tensorflow.keras.layers import Conv3D, MaxPooling3D, Flatten, Dense, Input, BatchNormalization, Dropout, ReLU

# Hardware config.

In [29]:
DEVICE = 'TPU' # or TPU
tpu = None

if DEVICE == 'TPU':
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        STRATEGY = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        print('Could not connect to TPU, setting default strategy')
        tpu = None
        STRATEGY = tf.distribute.get_strategy()
elif DEVICE == 'GPU':
    STRATEGY = tf.distribute.MirroredStrategy()
    
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = STRATEGY.num_replicas_in_sync

print(f'Number of accelerators: {REPLICAS}')

Could not connect to TPU, setting default strategy
Number of accelerators: 1


# Data augmentation

In [30]:
def augment_image(img):

    img = img.squeeze()
    original_shape = img.shape
    img = random_rotations(img, -5, 5)
    img = random_zoom(img, min=0.9, max=1.1)
    img = random_shift(img, max=0.2)
    # img = random_flip(img)
    img = downscale(img, original_shape)
    img = np.expand_dims(img, axis=3) # Restore channel axis
    return img

def downscale(image, shape):
    'For upscale, anti_aliasing should be false'
    return transform.resize(image, shape, mode='constant', anti_aliasing=True)

@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
def tf_augment_image(input):
    """ Tensorflow can't manage numpy functions, we have to wrap our augmentation function """
    img = tf.numpy_function(augment_image, [input], tf.float32)
    return img

def random_rotations(img, min_angle, max_angle):
    """
    Rotate 3D image randomly
    """
    assert img.ndim == 3, "Image must be 3D"
    rotation_axes = [(1, 0), (1, 2), (0, 2)]
    angle = np.random.randint(low=min_angle, high=max_angle+1)
    axes_random_id = np.random.randint(low=0, high=len(rotation_axes))
    axis = rotation_axes[axes_random_id] # Select a random rotation axis
    return scipy.ndimage.rotate(img, angle, axes=axis)

def random_zoom(img,min=0.7, max=1.2):
    """
    Generate random zoom of a 3D image
    """
    zoom = np.random.sample()*(max - min) + min # Generate random zoom between min and max
    zoom_matrix = np.array([[zoom, 0, 0, 0],
                            [0, zoom, 0, 0],
                            [0, 0, zoom, 0],
                            [0, 0, 0, 1]])
    
    return scipy.ndimage.interpolation.affine_transform(img, zoom_matrix)

def random_flip(img):
    """
    Flip image over a random axis
    """
    axes = [0, 1, 2]
    rand_axis = np.random.randint(len(axes))
    img = img.swapaxes(rand_axis, 0)
    img = img[::-1, ...]
    img = img.swapaxes(0, rand_axis)
    img = np.squeeze(img)
    return img

def random_shift(img, max=0.4):
    """
    Random shift over a random axis
    """
    (x, y, z) = img.shape
    (max_shift_x, max_shift_y, max_shift_z) = int(x*max/2),int(y*max/2), int(z*max/2)
    shift_x = np.random.randint(-max_shift_x, max_shift_x)
    shift_y = np.random.randint(-max_shift_y,max_shift_y)
    shift_z = np.random.randint(-max_shift_z,max_shift_z)

    translation_matrix = np.array([[1, 0, 0, shift_x],
                                   [0, 1, 0, shift_y],
                                   [0, 0, 1, shift_z],
                                   [0, 0, 0, 1]
                                   ])

    return scipy.ndimage.interpolation.affine_transform(img, translation_matrix)

# Tfrecords loading

In [31]:
def read_tfrecord(example):
    
    tfrec_format = {
        "image": tf.io.VarLenFeature(tf.float32),
        "one_hot_label": tf.io.VarLenFeature(tf.float32)
    }

    example = tf.io.parse_single_example(example, tfrec_format)
    one_hot_label = tf.sparse.to_dense(example['one_hot_label'])
    one_hot_label = tf.reshape(one_hot_label, [NUM_CLASSES])
    image = tf.reshape(tf.sparse.to_dense(example['image']), IMG_SHAPE)
    # TPU needs size to be known statically, so this doesn't work
    #     image  = tf.reshape(example['image'], example['shape']) 
    return image, one_hot_label


def load_dataset(filenames, labels, use_tfrec, no_order=True):
    
    if use_tfrec:
        # Allow order-altering optimizations
        option_no_order = tf.data.Options()
        option_no_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
        if no_order:
            dataset = dataset.with_options(option_no_order)
        dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO)

    else:
        dataset = tf.data.Dataset.from_generator(generator_fn(filenames, labels),
            output_signature=(
                 tf.TensorSpec(shape=IMG_SHAPE, dtype=tf.float32),
                 tf.TensorSpec(shape=(NUM_CLASSES,), dtype=tf.float32)))

    return dataset

def count_data_items(filenames, use_tfrec):
    
    if use_tfrec:
        n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) 
            for filename in filenames]
        return np.sum(n)
    else:
        return len(filenames)

def get_dataset(filenames, labels=None, use_tfrec=True, batch_size = 4, train=False, augment=False, cache=True, no_order=True):
    
    dataset =  load_dataset(filenames, labels, use_tfrec, no_order)
    
    if cache:
        dataset = dataset.cache() # Do it only if dataset fits in ram
    if train:
        dataset = dataset.repeat()
        if augment:
            # raise NotImplementedError

            dataset = dataset.map(lambda img, label: (tf_augment_image(img), label), num_parallel_calls=AUTO)

            # dataset = dataset.map(tf_augment_image, num_parallel_calls=AUTO)
        dataset = dataset.shuffle(count_data_items(filenames, use_tfrec))

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)
    return dataset

# Train schedule

In [18]:
def get_lr_decay_callback(batch_size=4, verbose=False):
    lr_max = 0.00001
    lr_exp_decay = 0.9
    lr_min = 0.000001
    
    def lrfn(epoch):
        lr = (lr_max - lr_min) * lr_exp_decay**(epoch) + lr_min
        return lr
        
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = verbose)
    return lr_callback

# Build models

In [19]:
def build_model_4(input_shape):
    inputs = tf.keras.layers.Input(input_shape)

    x = Conv3D(filters=16, kernel_size=3, activation="relu")(inputs)
    x = Conv3D(filters=16, kernel_size=3, activation='relu')(x)
    x = MaxPooling3D(pool_size=2)(x)

    x = Conv3D(filters=64, kernel_size=3, activation="relu")(x)
    x = Conv3D(filters=64, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=64, kernel_size=3, activation='relu')(x)
    x = MaxPooling3D(pool_size=2)(x)

    x = BatchNormalization()(x)
    x = Conv3D(filters=128, kernel_size=3, activation="relu")(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu')(x)
    x = MaxPooling3D(pool_size=2)(x)

    x = Flatten()(x)
    # x = Dropout(rate=0.1)(x)
    x = Dense(units=256, activation="relu")(x)

    outputs = Dense(units=3, activation="softmax")(x)

    model = tf.keras.Model(inputs, outputs, name="model_4")
    return model

# Visualization

In [20]:
def plot_epochs_history(num_epochs, history):

    plt.figure(figsize=(15, 5))
    plt.plot(np.arange(num_epochs), history['accuracy'], '-o', label='Train acc',
            color = '#ff7f0e')
    plt.plot(np.arange(num_epochs), history['val_accuracy'], '-o', label='Val acc',
            color = '#1f77b4')
    x = np.argmax(history['val_accuracy']); y = np.max(history['val_accuracy'])
    xdist = plt.xlim()[1] - plt.xlim()[0]; ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x, y, s=150, color='#1f77b4'); plt.text(x-0.03*xdist,y-0.13*ydist,'max acc\n%.2f'%y,size=14)
    plt.ylabel('ACC', size=14); plt.xlabel('Epoch', size=14)
    plt.legend(loc=2)
    plt2 = plt.gca().twinx()
    plt2.plot(np.arange(num_epochs),history['loss'],'-o',label='Train Loss',color='#2ca02c')
    plt2.plot(np.arange(num_epochs),history['val_loss'],'-o',label='Val Loss',color='#d62728')
    x = np.argmin(history['val_loss'] ); y = np.min(history['val_loss'])
    ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x,y,s=150,color='#d62728'); plt.text(x-0.03*xdist,y+0.05*ydist,'min loss',size=14)
    plt.ylabel('Loss',size=14)
    plt.legend(loc=3)
    plt.show()  
    
def plot_cm(labels, predictions):
    
    cm = confusion_matrix(labels, predictions)
    plt.figure(figsize=(5,5))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.title('Confusion matrix')
    plt.ylabel('Actual label')
    plt.xlabel('Predicted label')

# Functions for training and evaluating models

In [21]:
def evaluate_model_kfold(model_builder, train_filenames, n_folds, batch_size, epochs, 
                         plot_fold_results = True, plot_avg_results = True, train_labels=None, 
                         stratify=False, shuffle=True, random_state=None, use_tfrec=True, augment=True):
    
    # np_rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(random_state)))
    folds_histories = []

    if stratify:
        skf = StratifiedKFold(n_splits=n_folds, shuffle=shuffle, random_state=random_state)
    else:
        skf = KFold(n_splits=n_folds, shuffle=shuffle, random_state=random_state)

    for fold, (idx_train, idx_val) in enumerate(skf.split(train_filenames, train_labels)):
        if tpu != None:
            tf.tpu.experimental.initialize_tpu_system(tpu)

        # np_rs.shuffle(idx_train)
        X_train = train_filenames[idx_train]
        X_val = train_filenames[idx_val]
        y_train = None if use_tfrec is None else train_labels[idx_train]
        y_val = None if use_tfrec is None else train_labels[idx_val]

        # Build model
        tf.keras.backend.clear_session()
        with STRATEGY.scope():
            model = model_builder(input_shape=IMG_SHAPE)
            # Optimizers and Losses create TF variables --> should always be initialized in the scope
            OPT = tf.keras.optimizers.Adam(learning_rate=LR)
            LOSS = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.00)
            model.compile(optimizer=OPT, loss=LOSS, metrics=METRICS, steps_per_execution=8)

        # Train
        print(f'Training for fold {fold + 1} of {n_folds}...')
        cbks = [get_lr_decay_callback(batch_size)] # TODO: poner bien
        history = model.fit(
            get_dataset(X_train, y_train,  use_tfrec, train=True, augment=augment, batch_size=batch_size), 
            epochs = EPOCHS, callbacks = cbks,
            steps_per_epoch = max(1, int(np.rint(count_data_items(X_train, use_tfrec)/batch_size))),
            validation_data = get_dataset(X_val, y_val, use_tfrec, batch_size = batch_size, train=False), 
            validation_steps= max(1, int(np.rint(count_data_items(X_val,use_tfrec)/batch_size))))
    
        if tf.__version__ == "2.4.1": # TODO: delete when tensorflow fixes the bug
            scores = model.evaluate(get_dataset(X_train, y_train, use_tfrec, batch_size = batch_size, train=False), 
                                    batch_size = batch_size, steps = max(1, int(np.rint(count_data_items(X_train, use_tfrec)/batch_size))))
            for i in range(len(model.metrics_names)):
                history.history[model.metrics_names[i]][-1] = scores[i]
            
        folds_histories.append(history.history)
        
        if plot_fold_results:
            plot_epochs_history(epochs, history.history)
        
    avg_history = avg_results_per_epoch(folds_histories)
            
    if plot_avg_results:
        
        plot_epochs_history(epochs, avg_history)

        print('-'*80)
        print('Results per fold')
        for i in range(n_folds):
            print('-'*80)
            out = f"> Fold {i + 1} - loss: {folds_histories[i]['loss'][-1]} - accuracy: {folds_histories[i]['accuracy'][-1]}"
            out += f" - val_loss.: {folds_histories[i]['val_loss'][-1]} - val_accuracy: {folds_histories[i]['val_accuracy'][-1]}"
            print(out)

        print('-'*80)
        print('Average results over folds (on last epoch):')
        print(f"> loss: {avg_history['loss'][-1]}")
        print(f"> accuracy: {avg_history['accuracy'][-1]}")
        print(f"> cval_loss: {avg_history['val_loss'][-1]}")
        print(f"> cval_accuracy: {avg_history['val_accuracy'][-1]}")
        print('-'*80)

    return folds_histories

def repeated_kfold(model_builder, train_filenames, n_folds, batch_size, epochs, reps=5, train_labels=None,
                   stratify=True, shuffle=True, random_state=None, use_tfrec=True):
    
    reps_histories = []
    
    for i in range(reps):
        print(f'Repetition {i + 1}')
        folds_histories = evaluate_model_kfold(model_builder, train_filenames, n_folds,
                                             batch_size, epochs, train_labels=train_labels, stratify=stratify,
                                             shuffle=shuffle, random_state=random_state, use_tfrec=use_tfrec)

        reps_histories.append(folds_histories)

    return reps_histories

def test_model_rkfold(model_builder, results_filename):
    # Evaluate model with repeated k-fold (because of the high variance)
    reps_results = repeated_kfold(model_builder, X_train, FOLDS, BATCH_SIZE, EPOCHS, reps=REPS, train_labels=y_train,
                   random_state=SEED)
    
    # Save results to disk
    f = open(results_filename, 'w' )
    f.write(repr(reps_results))
    f.close()
    
def train_model(model_builder):
    # Test model
    with STRATEGY.scope():
        OPT = tf.keras.optimizers.Adam(learning_rate=LR)
        LOSS = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.00)
        model = model_builder(IMG_SHAPE)
        model.compile(optimizer=OPT, loss=LOSS, metrics=METRICS)


    cbks = [get_lr_decay_callback(BATCH_SIZE)] # TODO: poner bien

    history = model.fit(
        get_dataset(X_train, train=True, batch_size=BATCH_SIZE), 
        epochs = EPOCHS, callbacks = cbks,
        steps_per_epoch = int(np.rint(count_data_items(X_train, use_tfrec=True)/BATCH_SIZE))
        )
    
    return model
    
def avg_results_per_epoch(histories):
    
    keys = list(histories[0].keys())
    epochs = len(histories[0][keys[0]])
    
    avg_histories = dict()
    for key in keys:
        avg_histories[key] = [np.mean([x[key][i] for x in histories]) for i in range(epochs)]
        
    return avg_histories

def avg_reps_results(reps_histories):
    return avg_results_per_epoch([avg_results_per_epoch(history) for history in reps_histories])
    
def show_rkfold_results(results_file):
    # Load results from disk
    f = open(results_file, 'r')
    reps_results = eval(f.read())
    
    reps_avgd_per_kfold = [avg_results_per_epoch(history) for history in reps_results]
    reps_avg = avg_results_per_epoch(reps_avgd_per_kfold)
    
    # Plot final result over epochs
    plot_epochs_history(EPOCHS, reps_avg)
    
    print('-'*80)
    print('Results per repetition (on last epoch)')
    for i in range(REPS):
        print('-'*80)
        print(f"> Repetition {i + 1} - Loss: {reps_avgd_per_kfold[i]['val_loss'][-1]} - Accuracy : {reps_avgd_per_kfold[i]['val_accuracy'][-1]}")

    print('-'*80)
    print('Average results over repetitions (on last epoch):')
    print(f"> Train Accuracy: {reps_avg['accuracy'][-1]}")
    print(f"> Train Loss: {reps_avg['loss'][-1]}")
    print(f"> CV accuracy: {reps_avg['val_accuracy'][-1]}")
    print(f"> CV Loss: {reps_avg['val_loss'][-1]}")
    print('-'*80)
    
def show_test_results(model):
    results = model.evaluate(get_dataset(X_test))

    predictions = model.predict(get_dataset(X_test, no_order=False))
    y_pred = np.argmax(predictions, axis=1)

    plot_cm(y_test, y_pred)
    print(classification_report(y_test, y_pred))

# Prepare datasets and evaluation constants

In [22]:
SEED = 268

TFREC_DATASETS = ['tfrec-pet-spatialnorm-elastic-standarized', 
                  'tfrec-mri-grey-minmax']
SHAPES = [(79, 95, 68, 1),
          (160, 160, 96, 1)]

REPS = 5
FOLDS = 10

CLASSES = ['NOR', 'AD', 'MCI']
NUM_CLASSES = len(CLASSES)

METRICS = ['accuracy']

if kaggle:
    INPUT_DATAPATH = '/kaggle/input/' if tpu is None else None
    METADATA_PATH = '/kaggle/input/'
else:
    drive.mount('/content/drive') 
    INPUT_DATAPATH = '/content/drive/MyDrive/data/'
    METADATA_PATH = '/content/drive/MyDrive/data/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Training/testing with PET images

In [None]:
def select_dataset(ds_id):
    global DS, IMG_SHAPE, DS_PATH, INPUT_DATAPATH
    DS = TFREC_DATASETS[ds_id]
    IMG_SHAPE = SHAPES[ds_id]
    if INPUT_DATAPATH == None:
        user_secrets = UserSecretsClient()
        user_credential = user_secrets.get_gcloud_credential()
        user_secrets.set_tensorflow_credential(user_credential)
        DS_PATH = KaggleDatasets().get_gcs_path(DS)
    else:
        DS_PATH = INPUT_DATAPATH + DS


select_dataset(0) # Not intensity normalized

metadata_train = pd.read_csv(METADATA_PATH + DS + '/train/train_summary.csv', encoding='utf-8')
metadata_test = pd.read_csv(METADATA_PATH + DS + '/test/test_summary.csv', encoding='utf-8')

X_train = DS_PATH + '/train/' + metadata_train.iloc[:, 0].to_numpy()
y_train = np.argmax(metadata_train.iloc[:,-len(CLASSES):].to_numpy(), axis=1)
X_test = DS_PATH + '/test/' + metadata_test.iloc[:, 0].to_numpy()
y_test = np.argmax(metadata_test.iloc[:,-len(CLASSES):].to_numpy(), axis=1)


LR = 0.00001
BATCH_SIZE = 4
EPOCHS = 50

test_model_rkfold(build_model_4, 'pet-spatialnorm-elastic-standarized_model-4_augmentation_results.txt')
show_rkfold_results('pet-spatialnorm-elastic-standarized_model-4_augmentation_results.txt')
model = train_model(build_model_4)
show_test_results(model)
model.save('pet-spatialnorm-elastic-standarized_augmentation_model-4.h5')

Repetition 1
Training for fold 1 of 10...
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50

# Training/testing with MRI images

In [None]:
select_dataset(1) # Not intensity normalized

metadata_train = pd.read_csv(METADATA_PATH + DS + '/train/train_summary.csv', encoding='utf-8')
metadata_test = pd.read_csv(METADATA_PATH + DS + '/test/test_summary.csv', encoding='utf-8')

X_train = DS_PATH + '/train/' + metadata_train.iloc[:, 0].to_numpy()
y_train = np.argmax(metadata_train.iloc[:,-len(CLASSES):].to_numpy(), axis=1)
X_test = DS_PATH + '/test/' + metadata_test.iloc[:, 0].to_numpy()
y_test = np.argmax(metadata_test.iloc[:,-len(CLASSES):].to_numpy(), axis=1)


LR = 0.00001
BATCH_SIZE = 4
EPOCHS = 50

test_model_rkfold(build_model_4, 'mri-grey-minmax_model-4_augmentation_results.txt')
show_rkfold_results('mri-grey-minmax_model-4_augmentation_results.txt')
model = train_model(build_model_4)
show_test_results(model)
model.save('mri-grey-minmax_augmentation_model-4.h5')