<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/petrain_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Define ResNet model

In [1]:
from __future__ import (
    absolute_import,
    division,
    print_function,
    unicode_literals
)
import six
from math import ceil
from keras.models import Model
from keras.layers import (
    Input,
    Activation,
    Dense,
    Flatten
)
from keras.layers.convolutional import (
    Conv3D,
    AveragePooling3D,
    MaxPooling3D
)
from keras.layers.merge import add
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras import backend as K

# This code is an adaptation of https://github.com/raghakot/keras-resnet/blob/master/resnet.py 
# to 3D

def _bn_relu(input):
    """Helper to build a BN -> relu block
    """
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
    return Activation("relu")(norm)


def _conv_bn_relu(**conv_params):
    """Helper to build a conv -> BN -> relu block
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1, 1))
    kernel_initializer = conv_params.setdefault(
        "kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1e-4))

    def f(input):
        conv = Conv3D(filters=filters, kernel_size=kernel_size,
                      strides=strides, kernel_initializer=kernel_initializer,
                      padding=padding,
                      kernel_regularizer=kernel_regularizer)(input)
        return _bn_relu(conv)

    return f


def _bn_relu_conv(**conv_params):
    """Helper to build a BN -> relu -> conv block.
    This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1, 1))
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1e-4))

    def f(input):
        activation = _bn_relu(input)
        return Conv3D(filters=filters, kernel_size=kernel_size,
                      strides=strides, kernel_initializer=kernel_initializer,
                      padding=padding,
                      kernel_regularizer=kernel_regularizer)(activation)
    return f


def _shortcut(input, residual):
    """Adds a shortcut between input and residual block and merges them with "sum"
    """
    # Expand channels of shortcut to match residual.
    # Stride appropriately to match residual (width, height)
    # Should be int if network architecture is correctly configured.
    input_shape = K.int_shape(input)
    residual_shape = K.int_shape(residual)
    stride_width = ceil(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS])
    stride_height = ceil(input_shape[COL_AXIS] / residual_shape[COL_AXIS])
    stride_depth = ceil(input_shape[DEPTH_AXIS] / residual_shape[DEPTH_AXIS])
    equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]

    shortcut = input
    # 1 X 1 conv if shape is different. Else identity.
    if stride_width > 1 or stride_height > 1 or stride_depth > 1 \
            or not equal_channels:
        shortcut = Conv3D(
            filters=residual_shape[CHANNEL_AXIS],
            kernel_size=(1, 1, 1),
            strides=(stride_width, stride_height, stride_depth),
            kernel_initializer="he_normal", padding="valid",
            kernel_regularizer=l2(1e-4)
            )(input)
    return add([shortcut, residual])


def _residual_block(block_function, filters, kernel_regularizer, repetitions,
                      is_first_layer=False):
    """Builds a residual block with repeating bottleneck blocks.
    """
    def f(input):
        for i in range(repetitions):
            strides = (1, 1, 1)
            if i == 0 and not is_first_layer:
                strides = (2, 2, 2)
            input = block_function(filters=filters, strides=strides,
                                   kernel_regularizer=kernel_regularizer,
                                   is_first_block_of_first_layer=(
                                       is_first_layer and i == 0)
                                   )(input)
        return input

    return f


def basic_block(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4),
                is_first_block_of_first_layer=False):
    """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    def f(input):
        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv1 = Conv3D(filters=filters, kernel_size=(3, 3, 3),
                           strides=strides, padding="same",
                           kernel_initializer="he_normal",
                           kernel_regularizer=kernel_regularizer
                           )(input)
        else:
            conv1 = _bn_relu_conv(filters=filters,
                                    kernel_size=(3, 3, 3),
                                    strides=strides,
                                    kernel_regularizer=kernel_regularizer
                                    )(input)

        residual = _bn_relu_conv(filters=filters, kernel_size=(3, 3, 3),
                                   kernel_regularizer=kernel_regularizer
                                   )(conv1)
        return _shortcut(input, residual)

    return f


def bottleneck(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4),
               is_first_block_of_first_layer=False):
    """Bottleneck architecture for > 34 layer resnet.
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
    Returns:
        A final conv layer of filters * 4
    """
    def f(input):
        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv_1_1 = Conv3D(filters=filters, kernel_size=(1, 1, 1),
                              strides=strides, padding="same",
                              kernel_initializer="he_normal",
                              kernel_regularizer=kernel_regularizer
                              )(input)
        else:
            conv_1_1 = _bn_relu_conv(filters=filters, kernel_size=(1, 1, 1),
                                       strides=strides,
                                       kernel_regularizer=kernel_regularizer
                                       )(input)

        conv_3_3 = _bn_relu_conv(filters=filters, kernel_size=(3, 3, 3),
                                   kernel_regularizer=kernel_regularizer
                                   )(conv_1_1)
        residual = _bn_relu_conv(filters=filters * 4, kernel_size=(1, 1, 1),
                                   kernel_regularizer=kernel_regularizer
                                   )(conv_3_3)

        return _shortcut(input, residual)

    return f

def _handle_data_format():
    global ROW_AXIS
    global COL_AXIS
    global DEPTH_AXIS
    global CHANNEL_AXIS
    if K.image_data_format() == 'channels_last':
        ROW_AXIS = 1
        COL_AXIS = 2
        DEPTH_AXIS = 3
        CHANNEL_AXIS = 4
    else:
        CHANNEL_AXIS = 1
        ROW_AXIS = 2
        COL_AXIS = 3
        DEPTH_AXIS = 4


def _get_block(identifier):
    if isinstance(identifier, six.string_types):
        res = globals().get(identifier)
        if not res:
            raise ValueError('Invalid {}'.format(identifier))
        return res
    return identifier


class ResnetBuilder(object):
    """ResNet."""

    @staticmethod
    def build(input_shape, num_outputs, block_fn, repetitions, reg_factor):
        """Builds a custom ResNet like architecture.
        Args:
            input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
            num_outputs: The number of outputs at final softmax layer
            block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
                The original paper used basic_block for layers < 50
            repetitions: Number of repetitions of various block units.
                At each block unit, the number of filters are doubled and the input size is halved
        Returns:
            The keras `Model`.
        """
        _handle_data_format()
        if len(input_shape) != 4:
            raise ValueError("Input should have 4 dimensions")

        # Load function from str if needed.
        block_fn = _get_block(block_fn)

        input = Input(shape=input_shape)
        conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7, 7), strides=(2, 2, 2), kernel_regularizer=l2(reg_factor))(input)
        pool1 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="same")(conv1)

        block = pool1
        filters = 64
        for i, r in enumerate(repetitions):
            block = _residual_block(block_fn, filters=filters, kernel_regularizer=l2(reg_factor), repetitions=r, is_first_layer=(i == 0))(block)
            filters *= 2

        # last activation
        block = _bn_relu(block)
        block_shape = K.int_shape(block)

        # Classifier block
        pool2 = AveragePooling3D(pool_size=(block_shape[ROW_AXIS], block_shape[COL_AXIS], block_shape[DEPTH_AXIS]), strides=(1, 1, 1))(block)
        flatten1 = Flatten()(pool2)
        if num_outputs > 1:
            dense = Dense(units=num_outputs, kernel_initializer="he_normal", activation="softmax", kernel_regularizer=l2(reg_factor))(flatten1)
        else:
            dense = Dense(units=num_outputs, kernel_initializer="he_normal", activation="sigmoid", kernel_regularizer=l2(reg_factor))(flatten1)

        model = Model(inputs=input, outputs=dense)
        return model


    @staticmethod
    def build_resnet(num_layers, input_shape, num_outputs, reg_factor=1e-4):
        """Build resnet 18, 34, 50, 101 or 152"""

        repetitions = {18: [2, 2, 2, 2], 
                       34: [3, 4, 6, 3],
                       50: [3, 4, 6, 3],
                       101: [2, 4, 23, 3],
                       152: [3, 8, 36, 3]}

        block_fn = {18: basic_block, 34: basic_block, 50: bottleneck,
                    101: bottleneck, 152:bottleneck}

        return ResnetBuilder.build(input_shape, num_outputs, block_fn[num_layers],
                                     repetitions[num_layers], reg_factor=reg_factor)

# Data augmentation

In [2]:
import scipy
import skimage.transform as transform
import tensorflow as tf
import re

def augment_image(img):

    img = img.squeeze()
    original_shape = img.shape
    img = random_rotations(img, -20, 20)
    # img = random_zoom(img, min=0.9, max=1.1)
    # img = random_shift(img, max=0.2)
    # img = random_flip(img)
    img = downscale(img, original_shape)
    img = np.expand_dims(img, axis=3) # Restore channel axis
    return img

def downscale(image, shape):
    'For upscale, anti_aliasing should be false'
    return transform.resize(image, shape, mode='constant', anti_aliasing=True)

@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
def tf_augment_image(input):
    """ Tensorflow can't manage numpy functions, we have to wrap our augmentation function """
    img = tf.numpy_function(augment_image, [input], tf.float32)
    return img

def random_rotations(img, min_angle, max_angle):
    """
    Rotate 3D image randomly
    """
    assert img.ndim == 3, "Image must be 3D"
    rotation_axes = [(1, 0), (1, 2), (0, 2)]
    # angle = np.random.randint(low=min_angle, high=max_angle+1)
    angle= max_angle
    axes_random_id = np.random.randint(low=0, high=len(rotation_axes))
    axis = rotation_axes[axes_random_id] # Select a random rotation axis
    return scipy.ndimage.rotate(img, angle, axes=axis)

def random_zoom(img,min=0.7, max=1.2):
    """
    Generate random zoom of a 3D image
    """
    zoom = np.random.sample()*(max - min) + min # Generate random zoom between min and max
    zoom_matrix = np.array([[zoom, 0, 0, 0],
                            [0, zoom, 0, 0],
                            [0, 0, zoom, 0],
                            [0, 0, 0, 1]])
    
    return scipy.ndimage.interpolation.affine_transform(img, zoom_matrix)

def random_flip(img):
    """
    Flip image over a random axis
    """
    axes = [0, 1, 2]
    rand_axis = np.random.randint(len(axes))
    img = img.swapaxes(rand_axis, 0)
    img = img[::-1, ...]
    img = img.swapaxes(0, rand_axis)
    img = np.squeeze(img)
    return img

def random_shift(img, max=0.4):
    """
    Random shift over a random axis
    """
    (x, y, z) = img.shape
    (max_shift_x, max_shift_y, max_shift_z) = int(x*max/2),int(y*max/2), int(z*max/2)
    shift_x = np.random.randint(-max_shift_x, max_shift_x)
    shift_y = np.random.randint(-max_shift_y,max_shift_y)
    shift_z = np.random.randint(-max_shift_z,max_shift_z)

    translation_matrix = np.array([[1, 0, 0, shift_x],
                                   [0, 1, 0, shift_y],
                                   [0, 0, 1, shift_z],
                                   [0, 0, 0, 1]
                                   ])

    return scipy.ndimage.interpolation.affine_transform(img, translation_matrix)

# Load tfrecords

In [3]:
def read_tfrecord(example):
    
    tfrec_format = {
        "image": tf.io.VarLenFeature(tf.float32),
        "one_hot_label": tf.io.VarLenFeature(tf.float32)
    }

    example = tf.io.parse_single_example(example, tfrec_format)
    one_hot_label = tf.sparse.to_dense(example['one_hot_label'])
    one_hot_label = tf.reshape(one_hot_label, [NUM_CLASSES])
    image = tf.reshape(tf.sparse.to_dense(example['image']), IMG_SHAPE)

    return image, one_hot_label


def load_dataset(filenames, labels, use_tfrec, no_order=True):
    
    if use_tfrec:
        # Allow order-altering optimizations
        option_no_order = tf.data.Options()
        option_no_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO)
        if no_order:
            dataset = dataset.with_options(option_no_order)
        dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO)

    else:
        dataset = tf.data.Dataset.from_generator(generator_fn(filenames, labels),
            output_signature=(
                 tf.TensorSpec(shape=IMG_SHAPE, dtype=tf.float32),
                 tf.TensorSpec(shape=(NUM_CLASSES,), dtype=tf.float32)))

    return dataset

def count_data_items(filenames, use_tfrec):
    
    if use_tfrec:
        n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) 
            for filename in filenames]
        return np.sum(n)
    else:
        return len(filenames)

def get_dataset(filenames, labels=None, use_tfrec=True, batch_size = 4, train=False, augment=False, cache=False, no_order=True):

    dataset =  load_dataset(filenames, labels, use_tfrec, no_order)
    
    if cache:
        dataset = dataset.cache() # Do it only if dataset fits in ram
    if train:
        dataset = dataset.repeat()

        if augment:
            dataset = dataset.map(lambda img, label: (tf_augment_image(img), label), num_parallel_calls=AUTO)

        dataset = dataset.shuffle(count_data_items(filenames, use_tfrec))

    dataset = dataset.map(lambda img, label: (tf_augment_image(img), label), num_parallel_calls=AUTO)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(AUTO)
    return dataset

In [4]:
model = ResnetBuilder.build_resnet(18, (128, 128, 64, 1), 2)

# Check augmentation

In [5]:
DEVICE = 'TPU' # or TPU
tpu = None

if DEVICE == 'TPU':
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        STRATEGY = tf.distribute.experimental.TPUStrategy(tpu)
    except ValueError:
        print('Could not connect to TPU, setting default strategy')
        tpu = None
        STRATEGY = tf.distribute.get_strategy()
elif DEVICE == 'GPU':
    STRATEGY = tf.distribute.MirroredStrategy()
    
AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = STRATEGY.num_replicas_in_sync

print(f'Number of accelerators: {REPLICAS}')

Could not connect to TPU, setting default strategy
Number of accelerators: 1


In [None]:
from google.colab import drive
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import keras


SEED = 34
NUM_CLASSES = 2
IMG_SHAPE = (128, 128, 64, 1)
LR = 0.00001
METRICS = ['accuracy']
BATCH_SIZE = 4
USE_TFREC = True
EPOCHS = 100
CLASSES = ['normal', 'covid']

drive.mount('/content/drive')

DS_PATH = '/content/drive/My Drive/data/tfrec-covid19/' # or GCS path

metadata = pd.read_csv(DS_PATH + '/covid_dataset_summary.csv', encoding='utf-8')

X = DS_PATH + metadata.iloc[:, 0].to_numpy()
y = np.argmax(metadata.iloc[:, -len(CLASSES):].to_numpy(), axis=1)


X_train, X_val, y_train, y_val = train_test_split(X, y, test_size = 0.2, random_state = SEED, stratify = y)
y_train, y_val = None, None


initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)

with STRATEGY.scope():
    OPT = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
    LOSS = tf.keras.losses.BinaryCrossentropy()
    model = ResnetBuilder.build_resnet(18, (128, 128, 64, 1), 2)
    model.compile(optimizer = OPT, loss=LOSS, metrics= METRICS)


cbks = [keras.callbacks.ModelCheckpoint(
    "pretrained_3D_resnet.h5", save_best_only=True),
        keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=15)
    ]


history = model.fit(
    get_dataset(X_train, None, USE_TFREC, train=True, augment=True, batch_size=BATCH_SIZE), 
    epochs = EPOCHS, callbacks = cbks,
    steps_per_epoch = max(1, int(np.rint(count_data_items(X_train, USE_TFREC)/BATCH_SIZE))),
    validation_data = get_dataset(X_val, None, USE_TFREC, batch_size = BATCH_SIZE, train=False), 
    validation_steps= max(1, int(np.rint(count_data_items(X_val, USE_TFREC)/BATCH_SIZE))))


if tf.__version__ == "2.4.1": # TODO: delete when tensorflow fixes the bug
    scores = model.evaluate(get_dataset(X_train, None, USE_TFREC, batch_size = BATCH_SIZE, train=False), 
                            batch_size = BATCH_SIZE, steps = max(1, int(np.rint(count_data_items(X_train, USE_TFREC)/BATCH_SIZE))))
    for i in range(len(model.metrics_names)):
        history.history[model.metrics_names[i]][-1] = scores[i]