<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/petrain_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import shutil
import sys
import tensorflow as tf

if tf.io.gfile.exists('Alzheimer-disease-classification'):
    shutil.rmtree('Alzheimer-disease-classification')
! git clone https://github.com/Angelvj/Alzheimer-disease-classification.git

sys.path.insert(0,'/content/Alzheimer-disease-classification/code')

In [2]:
from models import ResNet
from functions.tfrec_loading import read_tfrecord, load_dataset, count_data_items
from functions.data_augmentation import random_rotations, random_zoom, random_shift, downscale
import re

# Data augmentation

In [3]:
def augment_image(img):

    img = img.squeeze()
    original_shape = img.shape
    img = random_rotations(img, -20, 20)
    # img = random_zoom(img, min=0.9, max=1.1)
    # img = random_shift(img, max=0.2)
    # img = random_flip(img)
    img = downscale(img, original_shape)
    img = np.expand_dims(img, axis=3) # Restore channel axis
    return img

@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
def tf_augment_image(input):
    """ Tensorflow can't manage numpy functions, we have to wrap our augmentation function """
    img = tf.numpy_function(augment_image, [input], tf.float32)
    return img

# Load tfrecords

In [4]:
def get_dataset(filenames, img_shape, num_classes, autotune, batch_size = 4, 
                train=False, augment=False, cache=False, no_order=True):

    dataset =  load_dataset(filenames, img_shape, num_classes, autotune, no_order)
    if cache:
        dataset = dataset.cache() # Do it only if dataset fits in ram
    if train:
        dataset = dataset.repeat()
        if augment:
            dataset = dataset.map(lambda img, label: (tf_augment_image(img), label), num_parallel_calls=autotune)

        dataset = dataset.shuffle(count_data_items(filenames))

    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)
    return dataset

# Hardware config.

In [5]:
DEVICE = 'TPU' # or TPU
tpu = None

if DEVICE == 'TPU':
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        STRATEGY = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        print('Could not connect to TPU, setting default strategy')
        tpu = None
        STRATEGY = tf.distribute.get_strategy()
elif DEVICE == 'GPU':
    STRATEGY = tf.distribute.MirroredStrategy()
    
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = STRATEGY.num_replicas_in_sync

print(f'Number of accelerators: {REPLICAS}')

Could not connect to TPU, setting default strategy
Number of accelerators: 1


In [None]:
from google.colab import drive
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow.keras as keras

SEED = 34
NUM_CLASSES = 2
IMG_SHAPE = (128, 128, 64, 1)
LR = 0.00001
METRICS = ['accuracy']
BATCH_SIZE = 4
USE_TFREC = True
EPOCHS = 100
CLASSES = ['normal', 'covid']

drive.mount('/content/drive')

DS_PATH = '/content/drive/My Drive/data/tfrec-covid19/' # or GCS path

metadata = pd.read_csv(DS_PATH + '/covid_dataset_summary.csv', encoding='utf-8')

X = DS_PATH + metadata.iloc[:, 0].to_numpy()
y = np.argmax(metadata.iloc[:, -len(CLASSES):].to_numpy(), axis=1)


X_train, X_val, y_train, y_val = train_test_split(X, y, test_size = 0.2, random_state = SEED, stratify = y)
y_train, y_val = None, None


initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)

with STRATEGY.scope():
    OPT = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    LOSS = tf.keras.losses.BinaryCrossentropy()
    model = ResNet.ResnetBuilder.build_resnet(18, (128, 128, 64, 1), 2)
    model.compile(optimizer = OPT, loss=LOSS, metrics= METRICS)


cbks = [keras.callbacks.ModelCheckpoint(
    "pretrained_3D_resnet.h5", save_best_only=True),
        keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=15)
    ]
    

history = model.fit(
    get_dataset(X_train, IMG_SHAPE, NUM_CLASSES, AUTO, batch_size = BATCH_SIZE, train=True, augment=True, cache=True), 
    epochs = EPOCHS, callbacks = cbks,
    steps_per_epoch = max(1, int(np.rint(count_data_items(X_train)/BATCH_SIZE))),
    validation_data = get_dataset(X_val, IMG_SHAPE, NUM_CLASSES, AUTO, batch_size = BATCH_SIZE, train=False), 
    validation_steps= max(1, int(np.rint(count_data_items(X_val)/BATCH_SIZE))))


if tf.__version__ == "2.4.1": # TODO: delete when tensorflow fixes the bug
    scores = model.evaluate(get_dataset(X_train, IMG_SHAPE, NUM_CLASSES, AUTO, batch_size = BATCH_SIZE, train=False), 
                            batch_size = BATCH_SIZE, steps = max(1, int(np.rint(count_data_items(X_train, USE_TFREC)/BATCH_SIZE))))
    for i in range(len(model.metrics_names)):
        history.history[model.metrics_names[i]][-1] = scores[i]