<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/petrain_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
kaggle = False

# Imports

In [None]:
import shutil
import sys
import re
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow.keras as keras
if not kaggle:
    from google.colab import drive
else:
    from kaggle_datasets import KaggleDatasets
    from kaggle_secrets import UserSecretsClient

if tf.io.gfile.exists('Alzheimer-disease-classification'):
    shutil.rmtree('Alzheimer-disease-classification')
! git clone https://github.com/Angelvj/Alzheimer-disease-classification.git

sys.path.insert(0,'/content/Alzheimer-disease-classification/code')

In [6]:
import functions.data_augmentation as augmentation
import functions.io_utils as io
import functions.lr_schedules as lr_schedules
from models import ResNet
from functions.model_evaluation import plot_epochs_history, get_dataset, count_data_items
from functions.data_augmentation import random_rotations, random_zoom, random_shift, downscale

# Hardware config.

In [4]:
DEVICE = 'TPU' # or TPU
tpu = None

if DEVICE == 'TPU':
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        STRATEGY = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        print('Could not connect to TPU, setting default strategy')
        tpu = None
        STRATEGY = tf.distribute.get_strategy()
elif DEVICE == 'GPU':
    STRATEGY = tf.distribute.MirroredStrategy()
    
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = STRATEGY.num_replicas_in_sync

print(f'Number of accelerators: {REPLICAS}')

Could not connect to TPU, setting default strategy
Number of accelerators: 1


# Pretrain resnet18 with COVID-19 data

In [None]:
def augment_image(img):
    img = img.squeeze()
    original_shape = img.shape
    img = augmentation.random_rotations(img, -20, 20)
    img = augmentation.random_zoom(img, min=0.9, max=1.1)
    img = augmentation.random_shift(img, max=0.4)
    img = augmentation.downscale(img, original_shape)
    img = np.expand_dims(img, axis=3) # Restore channel's axis
    return img

@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
def tf_augment_image(input):
    """ Tensorflow can't manage numpy functions, we have to wrap our augmentation function """
    img = tf.numpy_function(augment_image, [input], tf.float32)
    return img

TFREC_DATASET = 'tfrec-covid19/'
SHAPE = (128, 128, 64, 1)
CLASSES = ['normal', 'covid']

if kaggle:
    INPUT_DATAPATH = '/kaggle/input/' if tpu is None else None
    METADATA_PATH = '/kaggle/input'
else:
    drive.mount('/content/drive')
    INPUT_DATAPATH = '/content/drive/MyDrive/data/'
    METADATA_PATH = '/content/drive/MyDrive/data/'

if INPUT_DATAPATH == None:
    user_secrets = UserSecretsClient()
    user_credential = user_secrets.get_gcloud_credential()
    user_secrets.set_tensorflow_credential(user_credential)
    DS_PATH = KaggleDatasets().get_gcs_path(TFREC_DATASET)
else:
    DS_PATH = INPUT_DATAPATH + TFREC_DATASET

metadata = pd.read_csv(METADATA_PATH + TFREC_DATASET + '/covid_dataset_summary.csv', encoding='utf')

IMG_SHAPE = (128, 128, 64, 1)
NUM_CLASSES = 2
CLASSES = ['normal', 'covid']

X = DS_PATH + metadata.iloc[:, 0].to_numpy()
y = np.argmax(metadata.iloc[:, -len(CLASSES):].to_numpy(), axis=1)

X_train, X_val, _, _ = train_test_split(X, y, test_size = 0.2, stratify = y)

LR = 0.00001
BATCH_SIZE = 4
EPOCHS = 100

with STRATEGY.scope():
    OPT = tf.keras.optimizers.Adam(learning_rate=LR)
    LOSS = tf.keras.losses.BinaryCrossentropy()
    model = ResNet.ResnetBuilder.build_resnet(18, (128, 128, 64, 1), 2)
    model.compile(optimizer = OPT, loss=LOSS, metrics= ['accuracy'])

# Save best model in best epoch based on validation accuracy
cbks = [keras.callbacks.ModelCheckpoint('pretrained_3D_resnet18.h5', save_best_only=True),
        keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=15),
    ]

history = model.fit(
    get_dataset(X_train, IMG_SHAPE, NUM_CLASSES, AUTO, batch_size = BATCH_SIZE, train=True, augment=tf_augment_image, cache=True), 
    epochs = EPOCHS, callbacks = cbks,
    steps_per_epoch = max(1, int(np.rint(count_data_items(X_train)/BATCH_SIZE))),
    validation_data = get_dataset(X_val, IMG_SHAPE, NUM_CLASSES, AUTO, batch_size = BATCH_SIZE, train=False), 
    validation_steps= max(1, int(np.rint(count_data_items(X_val)/BATCH_SIZE))))

tf.keras.models.save_model(model, 'pretrained_3D_resnet18.h5')