<a href="https://colab.research.google.com/github/Angelvj/Alzheimer-disease-classification/blob/main/code/models/models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import six
import tensorflow.keras.backend as K
from math import ceil
from tensorflow.keras.layers import MaxPooling3D, Flatten, Dense, Conv3D, BatchNormalization, Input, Dropout, GlobalAveragePooling3D, add, Activation, concatenate
from tensorflow.keras import Model
from tensorflow.keras.regularizers import l2
from google.colab import drive

In [None]:
pet_shape = (79, 95, 68, 1)
mri_shape = (121, 145, 121, 1)

# Useful functions

In [None]:
def change_input_shape(model, new_shape, name=None):
    new_shape = [None] + list(new_shape)
    new_shape = tuple(new_shape)
    # Extract model's configuration
    model_config = model.get_config()
    # Change config
    if name is not None:
        input_layer_name = name
    else:
        input_layer_name = model_config['layers'][0]['name']
    model_config['layers'][0] = {
                        'name': input_layer_name,
                        'class_name': 'InputLayer',
                        'config': {
                            'batch_input_shape': new_shape,
                            'dtype': 'float32',
                            'sparse': False,
                            'name': input_layer_name
                        },
                        'inbound_nodes': []
                    }
    model_config['layers'][1]['inbound_nodes'] = [[[input_layer_name, 0, 0, {}]]]
    model_config['input_layers'] = [[input_layer_name, 0, 0]] 
    # Create new model
    new_model = model.__class__.from_config(model_config, custom_objects={})
    # Copy weights
    weights = [layer.get_weights() for layer in model.layers[1:]]
    for layer, weight in zip(new_model.layers[1:], weights):
        layer.set_weights(weight)

    return new_model

#  Feed forward networks

In [None]:
def model_0_pet(input_shape = pet_shape):

    inputs = tf.keras.layers.Input(input_shape)
    
    x = tf.keras.layers.Conv3D(filters=32, kernel_size=5, activation="relu")(inputs)
    x = MaxPooling3D(pool_size=2)(x)
    
    x = Flatten()(x)
    x = Dense(units=256, activation="relu")(x)
    
    outputs = Dense(units=3, activation="softmax")(x)
   
    model = tf.keras.Model(inputs, outputs, name="model_0_pet")
    return model


def model_1_pet(input_shape = pet_shape):
    inputs = tf.keras.layers.Input(input_shape)
    
    x = Conv3D(filters=32, kernel_size=5, activation="relu")(inputs)
    x = MaxPooling3D(pool_size=2)(x)
    
    x = Conv3D(filters=32, kernel_size=5, activation="relu")(x)
    x = MaxPooling3D(pool_size=2)(x)
    
    x = Flatten()(x)
    x = Dense(units=256, activation="relu")(x)
    x = Dense(units=256, activation="relu")(x)
    
    outputs = Dense(units=3, activation="softmax")(x)
    
    model = tf.keras.Model(inputs, outputs, name="model_1_pet")
    return model

def best_model_pet(input_shape = pet_shape):

    inputs = Input(input_shape)

    x = Conv3D(filters=16, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(inputs)
    x = Conv3D(filters=16, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(x)
    x = MaxPooling3D(pool_size=2)(x)

    x = Conv3D(filters=64, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(x)
    x = Conv3D(filters=64, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(x)
    x = Conv3D(filters=64, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(x)
    x = MaxPooling3D(pool_size=2)(x)

    x = BatchNormalization(momentum=0.9)(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu', kernel_regularizer=l2(0.0005))(x)
    x = MaxPooling3D(pool_size=2, strides=2)(x)

    x = Flatten()(x)
    x = Dropout(rate=0.2)(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=128, activation='relu')(x)

    outputs = Dense(units=3, activation="softmax")(x)

    model = tf.keras.Model(inputs, outputs, name="model_4_pet")
    return model

def best_model_mri(input_shape = mri_shape):
    inputs = Input(input_shape)

    x = Conv3D(filters=32, kernel_size=5, activation='relu')(inputs)
    x = Conv3D(filters=32, kernel_size=5, activation='relu')(x) 
    x = MaxPooling3D(pool_size=2)(x)

    x = Conv3D(filters=64, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=64, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=64, kernel_size=3, activation='relu')(x)
    x = MaxPooling3D(pool_size=2)(x)

    x = BatchNormalization(momentum=0.9)(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=128, kernel_size=3, activation='relu')(x)
    x = MaxPooling3D(pool_size=2)(x)

    x = BatchNormalization(momentum=0.9)(x)
    x = Conv3D(filters=256, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=256, kernel_size=3, activation='relu')(x)
    x = Conv3D(filters=256, kernel_size=3, activation='relu')(x)
    x = GlobalAveragePooling3D()(x)

    x = Dropout(rate=0.2)(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dropout(rate=0.2)(x)
    x = Dense(units=128, activation='relu')(x)

    outputs = Dense(units=3, activation="softmax")(x)

    model = tf.keras.Model(inputs, outputs, name="model_4_mri")
    return model

# ResNets 3D

In [None]:
# This code is an adaptation of https://github.com/raghakot/keras-resnet/blob/master/resnet.py 
# to 3D
def _bn_relu(input):
    """Helper to build a BN -> relu block
    """
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
    return Activation("relu")(norm)

def _conv_bn_relu(**conv_params):
    """Helper to build a conv -> BN -> relu block
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1, 1))
    kernel_initializer = conv_params.setdefault(
        "kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1e-4))

    def f(input):
        conv = Conv3D(filters=filters, kernel_size=kernel_size,
                      strides=strides, kernel_initializer=kernel_initializer,
                      padding=padding,
                      kernel_regularizer=kernel_regularizer)(input)
        return _bn_relu(conv)

    return f


def _bn_relu_conv(**conv_params):
    """Helper to build a BN -> relu -> conv block.
    This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1, 1))
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1e-4))

    def f(input):
        activation = _bn_relu(input)
        return Conv3D(filters=filters, kernel_size=kernel_size,
                      strides=strides, kernel_initializer=kernel_initializer,
                      padding=padding,
                      kernel_regularizer=kernel_regularizer)(activation)
    return f


def _shortcut(input, residual):
    """Adds a shortcut between input and residual block and merges them with "sum"
    """
    # Expand channels of shortcut to match residual.
    # Stride appropriately to match residual (width, height)
    # Should be int if network architecture is correctly configured.
    input_shape = K.int_shape(input)
    residual_shape = K.int_shape(residual)
    stride_width = ceil(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS])
    stride_height = ceil(input_shape[COL_AXIS] / residual_shape[COL_AXIS])
    stride_depth = ceil(input_shape[DEPTH_AXIS] / residual_shape[DEPTH_AXIS])
    equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]

    shortcut = input
    # 1 X 1 conv if shape is different. Else identity.
    if stride_width > 1 or stride_height > 1 or stride_depth > 1 \
            or not equal_channels:
        shortcut = Conv3D(
            filters=residual_shape[CHANNEL_AXIS],
            kernel_size=(1, 1, 1),
            strides=(stride_width, stride_height, stride_depth),
            kernel_initializer="he_normal", padding="valid",
            kernel_regularizer=l2(1e-4)
            )(input)
    return add([shortcut, residual])


def _residual_block(block_function, filters, kernel_regularizer, repetitions,
                      is_first_layer=False):
    """Builds a residual block with repeating bottleneck blocks.
    """
    def f(input):
        for i in range(repetitions):
            strides = (1, 1, 1)
            if i == 0 and not is_first_layer:
                strides = (2, 2, 2)
            input = block_function(filters=filters, strides=strides,
                                   kernel_regularizer=kernel_regularizer,
                                   is_first_block_of_first_layer=(
                                       is_first_layer and i == 0)
                                   )(input)
        return input

    return f


def basic_block(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4),
                is_first_block_of_first_layer=False):
    """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    def f(input):
        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv1 = Conv3D(filters=filters, kernel_size=(3, 3, 3),
                           strides=strides, padding="same",
                           kernel_initializer="he_normal",
                           kernel_regularizer=kernel_regularizer
                           )(input)
        else:
            conv1 = _bn_relu_conv(filters=filters,
                                    kernel_size=(3, 3, 3),
                                    strides=strides,
                                    kernel_regularizer=kernel_regularizer
                                    )(input)

        residual = _bn_relu_conv(filters=filters, kernel_size=(3, 3, 3),
                                   kernel_regularizer=kernel_regularizer
                                   )(conv1)
        return _shortcut(input, residual)

    return f


def bottleneck(filters, strides=(1, 1, 1), kernel_regularizer=l2(1e-4),
               is_first_block_of_first_layer=False):
    """Bottleneck architecture for > 34 layer resnet.
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf
    Returns:
        A final conv layer of filters * 4
    """
    def f(input):
        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv_1_1 = Conv3D(filters=filters, kernel_size=(1, 1, 1),
                              strides=strides, padding="same",
                              kernel_initializer="he_normal",
                              kernel_regularizer=kernel_regularizer
                              )(input)
        else:
            conv_1_1 = _bn_relu_conv(filters=filters, kernel_size=(1, 1, 1),
                                       strides=strides,
                                       kernel_regularizer=kernel_regularizer
                                       )(input)

        conv_3_3 = _bn_relu_conv(filters=filters, kernel_size=(3, 3, 3),
                                   kernel_regularizer=kernel_regularizer
                                   )(conv_1_1)
        residual = _bn_relu_conv(filters=filters * 4, kernel_size=(1, 1, 1),
                                   kernel_regularizer=kernel_regularizer
                                   )(conv_3_3)

        return _shortcut(input, residual)

    return f

def _handle_data_format():
    global ROW_AXIS
    global COL_AXIS
    global DEPTH_AXIS
    global CHANNEL_AXIS
    if K.image_data_format() == 'channels_last':
        ROW_AXIS = 1
        COL_AXIS = 2
        DEPTH_AXIS = 3
        CHANNEL_AXIS = 4
    else:
        CHANNEL_AXIS = 1
        ROW_AXIS = 2
        COL_AXIS = 3
        DEPTH_AXIS = 4


def _get_block(identifier):
    if isinstance(identifier, six.string_types):
        res = globals().get(identifier)
        if not res:
            raise ValueError('Invalid {}'.format(identifier))
        return res
    return identifier


class ResnetBuilder(object):
    """ResNet."""

    @staticmethod
    def build(input_shape, num_outputs, block_fn, repetitions, reg_factor):
        """Builds a custom ResNet like architecture.
        Args:
            input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
            num_outputs: The number of outputs at final softmax layer
            block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
                The original paper used basic_block for layers < 50
            repetitions: Number of repetitions of various block units.
                At each block unit, the number of filters are doubled and the input size is halved
        Returns:
            The keras `Model`.
        """
        _handle_data_format()
        if len(input_shape) != 4:
            raise ValueError("Input should have 4 dimensions")

        # Load function from str if needed.
        block_fn = _get_block(block_fn)

        input = Input(shape=input_shape)
        conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7, 7), strides=(2, 2, 2), kernel_regularizer=l2(reg_factor))(input)
        pool1 = MaxPooling3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="same")(conv1)

        block = pool1
        filters = 64
        for i, r in enumerate(repetitions):
            block = _residual_block(block_fn, filters=filters, kernel_regularizer=l2(reg_factor), repetitions=r, is_first_layer=(i == 0))(block)
            filters *= 2

        # last activation
        block = _bn_relu(block)
        block_shape = K.int_shape(block)

        # Classifier block
        pool2 = GlobalAveragePooling3D()(block)
        # flatten1 = Flatten()(pool2)
        if num_outputs > 1:
            dense = Dense(units=num_outputs, kernel_initializer="he_normal", activation="softmax", kernel_regularizer=l2(reg_factor))(pool2)
        else:
            dense = Dense(units=num_outputs, kernel_initializer="he_normal", activation="sigmoid", kernel_regularizer=l2(reg_factor))(pool2)

        model = tf.keras.Model(inputs=input, outputs=dense)
        return model


    @staticmethod
    def build_resnet(num_layers, input_shape, num_outputs, reg_factor=1e-4):
        """Build resnet 18, 34, 50, 101 or 152"""

        repetitions = {18: [2, 2, 2, 2], 
                       34: [3, 4, 6, 3],
                       50: [3, 4, 6, 3],
                       101: [2, 4, 23, 3],
                       152: [3, 8, 36, 3]}

        block_fn = {18: basic_block, 34: basic_block, 50: bottleneck,
                    101: bottleneck, 152:bottleneck}

        return ResnetBuilder.build(input_shape, num_outputs, block_fn[num_layers],
                                     repetitions[num_layers], reg_factor=reg_factor)

# Pretrained ResNet for transfer learning

In [None]:
# Build pretrained resnet for transfer learning
def build_pretrained_resnet18(input_shape = mri_shape, include_top = False):
    # Load pretrained model
    drive.mount('/content/drive') 
    pretrained_resnet18 = tf.keras.models.load_model('/content/drive/MyDrive/pretrained_models/pretrained_3D_resnet18.h5')
    # Delete classifier
    pretrained_resnet18 = Model(pretrained_resnet18.input, pretrained_resnet18.layers[-2].output)
    # Change input shape
    pretrained_resnet18 = change_input_shape(pretrained_resnet18, input_shape, name='new_input')
    pretrained_resnet18.trainable = False
    inputs = Input(shape=input_shape, name='mri_input')
    features = pretrained_resnet18(inputs, training=False)
    outputs = Dense(3)(features)
    pretrained_resnet18 = Model(inputs, outputs)

    return pretrained_resnet18

# Two inputs model

In [None]:
# Model with two inputs
def build_two_input_model(pet_shape = pet_shape, mri_shape = mri_shape):

    # Base pet model
    pet_model = best_model_pet(pet_shape)
    # Delete classifier
    pet_model = Model(pet_model.input, pet_model.layers[-2].output)
    # Input for pet images
    pet_input = Input(shape=pet_shape, name='pet_input')
    pet_features = pet_model(pet_input)

    # Base mri model
    mri_model = best_model_mri(mri_shape)
    # Delete classifier
    mri_model = Model(mri_model.input, mri_model.layers[-2].output)
    print(mri_model.summary())
    # Input for mri images
    mri_input = Input(shape=mri_shape, name='mri_input')
    mri_features = mri_model(mri_input)

    # Two inputs model
    x = concatenate([pet_features, mri_features])
    pred = Dense(3, name='label')(x)
    model = Model(inputs=[pet_input, mri_input], outputs = [pred])

    return model