# Model Structures
In this file I'll show the structure of each neural network used in this project by running the model summary and printing the output.
I'll use the control t1c 2 classes no clinical data trainings.

In [3]:
import tensorflow as tf
from pathlib import Path
import os
from time import strftime
from functools import partial
import numpy as np
from enum import Enum, auto

In [4]:
class Dataset(Enum):
    NORMAL = auto()
    PRETRAIN_ROUGH = auto() # external brain_tumor_dataset
    PRETRAIN_FINE = auto()

dropout_rate = 0.4
l2_regularization = 0.0001
learning_rate = 0.001
input_shape = (240, 240, 1)
contrast_DA = False
clinical_data = False
use_layer = False
activation_func = "mish"
num_classes = 2
dataset_type = Dataset.NORMAL

In [5]:
class NormalizeToRange(tf.keras.layers.Layer):
    """Layer to normalize input tensor values to [0, 1] or [-1, 1]."""
    def __init__(self, zero_to_one=True, epsilon=1e-7, **kwargs):
        super().__init__(**kwargs) #super(NormalizeToRange, self).__init__()
        self.zero_to_one = zero_to_one
        self.epsilon = epsilon

    def call(self, inputs):
        min_val = tf.reduce_min(inputs)
        max_val = tf.reduce_max(inputs)

        range_val = max_val - min_val
        range_val = tf.maximum(range_val, self.epsilon)

        if self.zero_to_one:
            # Normalize to [0, 1]
            normalized = (inputs - min_val) / range_val
        else:
            # Normalize to [-1, 1]
            normalized = 2 * (inputs - min_val) / range_val - 1
        return normalized
    
class RandomRescale(tf.keras.layers.Layer):
    def __init__(self, scale_range=(0.8, 1.2), **kwargs):
        """
        Custom layer for random rescaling of images.
        Args:
            scale_range (tuple): A tuple specifying the minimum and maximum scaling factors.
                                 Values < 1.0 zoom out, and > 1.0 zoom in.
        """
        super(RandomRescale, self).__init__(**kwargs)
        self.scale_range = scale_range

    def call(self, inputs, training=None):
        if training:
            # Randomly choose a scaling factor
            scale = tf.random.uniform([], self.scale_range[0], self.scale_range[1])
            
            # Get image dimensions
            input_shape = tf.shape(inputs)
            height, width = input_shape[1], input_shape[2]

            # For testing without the batch size
            #height, width = input_shape[0], input_shape[1]
            
            # Compute new dimensions
            new_height = tf.cast(tf.cast(height, tf.float32) * scale, tf.int32)
            new_width = tf.cast(tf.cast(width, tf.float32) * scale, tf.int32)
            
            # Resize image to new dimensions
            scaled_image = tf.image.resize(inputs, [new_height, new_width])
            
            # Crop or pad to original size
            scaled_image = tf.image.resize_with_crop_or_pad(scaled_image, height, width)
            
            return scaled_image
        else:
            return inputs

    def get_config(self):
        config = super(RandomRescale, self).get_config()
        config.update({"scale_range": self.scale_range})
        return config
    
normal_data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip(mode = "horizontal"),
    tf.keras.layers.RandomRotation(factor = (-0.14, 0.14), fill_mode = "nearest"),
    NormalizeToRange(zero_to_one=True),
    tf.keras.layers.RandomTranslation(
        height_factor = 0.05,
        width_factor = 0.05,
        fill_mode = "nearest",
        interpolation = "bilinear"
    ),
    RandomRescale(scale_range=(0.7, 1.2))
], name = "normal_data_augmentation")

## Custom CNN

In [8]:
def build_conv_model():
  
    DefaultConv2D = partial(
        tf.keras.layers.Conv2D,
        kernel_size = 3,
        padding = "same",
        activation = activation_func,
        kernel_initializer = "he_normal",
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)  # L2 Regularization
    )

    DefaultDenseLayer = partial(
        tf.keras.layers.Dense,
        activation = activation_func,
        kernel_initializer = "he_normal",
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
    )

    optimizer = tf.keras.optimizers.SGD(learning_rate = learning_rate, momentum = 0.9, nesterov = True)

    # Define inputs
    image_input = tf.keras.layers.Input(shape=input_shape, name="image_input")
    sex_input = tf.keras.layers.Input(shape=(1,), name="sex_input") 
    age_input = tf.keras.layers.Input(shape=(1,), name="age_input")
    layer_input = tf.keras.layers.Input(shape=(1,), name="layer_input")

    # Choose Data Augmentation pipeline
    augment_layer = normal_data_augmentation

    # --- Model Architecture ---
    x = augment_layer(image_input) # Apply augmentation first

    x = tf.keras.layers.BatchNormalization(name = "b_norm_1")(x) # BN before first conv
    x = DefaultConv2D(filters = 64, kernel_size = 7, strides = 2, name = "conv_1")(x)
    x = tf.keras.layers.MaxPool2D(pool_size = (2,2), name = "pool_1")(x)

    x = tf.keras.layers.BatchNormalization(name = "b_norm_2")(x)
    x = DefaultConv2D(filters = 128, name = "conv_2a")(x)
    x = DefaultConv2D(filters = 128, name = "conv_2b")(x)
    x = tf.keras.layers.MaxPool2D(pool_size = (2,2), name = "pool_2")(x)

    x = tf.keras.layers.BatchNormalization(name = "b_norm_3")(x)
    x = DefaultConv2D(filters = 256, name = "conv_3a")(x)
    x = DefaultConv2D(filters = 256, name = "conv_3b")(x)
    x = tf.keras.layers.MaxPool2D(pool_size = (2,2), name = "pool_3")(x)

    image_features = tf.keras.layers.Flatten(name = "flatten")(x)

    # --- Feature Concatenation ---
    # use 'clincal_data' and 'use_layer'

    inputs_to_concat = [image_features]

    if clinical_data:
        inputs_to_concat.extend([sex_input, age_input])
        if use_layer:
            inputs_to_concat.append(layer_input)
    elif use_layer:
        inputs_to_concat.append(layer_input)

    if len(inputs_to_concat) > 1:
        concatenated_features = tf.keras.layers.Concatenate(name = "concat_features")(inputs_to_concat)
    else:
        concatenated_features = image_features # No concatenation needed
        

    # --- Dense Layers ---
    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_1")(concatenated_features)
    x = DefaultDenseLayer(units = 512, name = "dense_1")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_1")(x)

    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_2")(x)
    x = DefaultDenseLayer(units = 256, name = "dense_2")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_2")(x)


    # --- Output Layer ---

    if num_classes == 2:
        # Binary Classification
        x = tf.keras.layers.Dense(1, name=f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('sigmoid', dtype='float32', name='predictions')(x)
        loss = "binary_crossentropy"
        metrics = ["accuracy",
                   tf.keras.metrics.AUC(name = "auc"),
                   tf.keras.metrics.Precision(name = "precision"),
                   tf.keras.metrics.Recall(name = "recall")]
    elif num_classes > 2 and num_classes <= 6:
        x = tf.keras.layers.Dense(num_classes, name=f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        loss = "sparse_categorical_crossentropy"
        metrics = ["accuracy"]
    else:
        raise ValueError("num_classes must have a value between 2 and 6")

    # --- Create and compile model ---
    if dataset_type == Dataset.NORMAL:
        if clinical_data == True:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input], outputs = [output])
        else:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input], outputs = [output])
    else:
        model = tf.keras.Model(inputs = [image_input], outputs = [output])

    model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
    )
    
    model.summary()

    return model


In [9]:
build_conv_model()

<Functional name=functional_1, built=True>

## ResNet34

In [12]:
def build_resnet34_model():

    DefaultConv2D = partial(
        tf.keras.layers.Conv2D,
        kernel_size = 3,
        strides = 1,
        padding="same",
        activation = activation_func,
        kernel_initializer = "he_normal",
        use_bias = False,
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
    )
    
    DefaultDenseLayer = partial(
        tf.keras.layers.Dense,
        activation = activation_func,
        kernel_initializer = "he_normal",
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
    )
    
    class ResidualUnit(tf.keras.layers.Layer):
        def __init__(self, filters, strides=1, activation="relu", **kwargs):
            super().__init__(**kwargs)
            self.activation = tf.keras.activations.get(activation)
            self.main_layers = [
                DefaultConv2D(filters, strides = strides),
                tf.keras.layers.BatchNormalization(),
                self.activation,
                DefaultConv2D(filters),
                tf.keras.layers.BatchNormalization()
            ]
            self.skip_layers = []
            if strides > 1:
                self.skip_layers = [
                    DefaultConv2D(filters, kernel_size=1, strides=strides),
                    tf.keras.layers.BatchNormalization()
                ]
            
        def call(self, inputs):
            Z = inputs
            for layer in self.main_layers:
                Z = layer(Z)
            skip_Z = inputs
            for layer in self.skip_layers:
                skip_Z = layer(skip_Z)
            return self.activation(Z + skip_Z)
    

    optimizer = tf.keras.optimizers.SGD(learning_rate = learning_rate, momentum = 0.9, nesterov = True)

    # Define inputs
    image_input = tf.keras.layers.Input(shape=input_shape, name = "image_input")
    sex_input = tf.keras.layers.Input(shape=(1,), name = "sex_input")
    age_input = tf.keras.layers.Input(shape=(1,), name = "age_input")
    layer_input = tf.keras.layers.Input(shape=(1,), name = "layer_input")

    # Choose data augmentation pipeline
    augment_layer = normal_data_augmentation

    # --- Model Architecture ---
    x = augment_layer(image_input)
    x = tf.keras.layers.BatchNormalization(name = "b_norm_1")(x)
    x = DefaultConv2D(filters = 64, kernel_size = 7, strides = 2, name = "conv_1")(x)
    x = tf.keras.layers.BatchNormalization(name = "b_norm_2")(x)
    x = tf.keras.layers.Activation(activation_func)(x)
    x = tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = "same", name = "pool_1")(x)

    prev_filters = 64
    residual_counter = 0
    for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
        strides = 1 if filters == prev_filters else 2
        x = ResidualUnit(filters, strides = strides, name = f"residual_unit_layer_{residual_counter}_filters_{filters}")(x)
        prev_filters = filters
        residual_counter += 1
    
    x = tf.keras.layers.GlobalAveragePooling2D(name = "gap")(x)
    resnet_image_features = tf.keras.layers.Flatten(name = "flatten")(x)

    # --- Feature Concatenation ---
    inputs_to_concat = [resnet_image_features]

    if clinical_data:
        inputs_to_concat.extend([sex_input, age_input])
        if use_layer:
            inputs_to_concat.append(layer_input)
    elif use_layer:
        inputs_to_concat.append(layer_input)

    if len(inputs_to_concat) > 1:
        concatenated_features = tf.keras.layers.Concatenate(name = "concat_features")(inputs_to_concat)
    else:
        concatenated_features = resnet_image_features # No concatenation needed

    # --- Dense Layers ---
    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_1")(concatenated_features)
    x = DefaultDenseLayer(units = 512, name = "dense_1")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_1")(x)

    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_2")(x)
    x = DefaultDenseLayer(units = 256, name = "dense_2")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_2")(x)

    # --- Output Layer ---

    if num_classes == 2:
        # Binary Classification
        x = tf.keras.layers.Dense(1, name = f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('sigmoid', dtype='float32', name='predictions')(x)
        loss = "binary_crossentropy"
        metrics = ["accuracy",
                   tf.keras.metrics.AUC(name = "auc"),
                   tf.keras.metrics.Precision(name = "precision", thresholds = 0.5),
                   tf.keras.metrics.Recall(name = "recall", thresholds = 0.5),
                   tf.keras.metrics.F1Score(name = "f1_score", threshold = 0.5, average="micro")]
    elif num_classes > 2 and num_classes <= 6:
        x = tf.keras.layers.Dense(num_classes, name = f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        loss = "sparse_categorical_crossentropy"
        metrics = ["accuracy"]
    else:
        raise ValueError("num_classes must have a value between 2 and 6")

    # --- Create and compile model ---
    if dataset_type == Dataset.NORMAL:
        if clinical_data == True:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input], outputs = [output])
        else:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input], outputs = [output])
    else:
        model = tf.keras.Model(inputs = [image_input], outputs = [output])

    model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
    )
    
    model.summary()

    return model

In [14]:
conv_model = build_resnet34_model()

Because long layer names get truncated, I'll print out all the layer names

In [15]:
for layer in conv_model.layers:
    print(layer.name)

image_input
normal_data_augmentation
b_norm_1
conv_1
b_norm_2
activation_1
pool_1
residual_unit_layer_0_filters_64
residual_unit_layer_1_filters_64
residual_unit_layer_2_filters_64
residual_unit_layer_3_filters_128
residual_unit_layer_4_filters_128
residual_unit_layer_5_filters_128
residual_unit_layer_6_filters_128
residual_unit_layer_7_filters_256
residual_unit_layer_8_filters_256
residual_unit_layer_9_filters_256
residual_unit_layer_10_filters_256
residual_unit_layer_11_filters_256
residual_unit_layer_12_filters_256
residual_unit_layer_13_filters_512
residual_unit_layer_14_filters_512
residual_unit_layer_15_filters_512
gap
flatten
b_norm_dense_1
dense_1
dropout_1
b_norm_dense_2
dense_2
dropout_2
dense_output_2cls
predictions


## ResNet152

In [1]:
def build_resnet152_model():

    DefaultConv2D = partial(
        tf.keras.layers.Conv2D,
        kernel_size = 3,
        strides = 1,
        padding="same",
        activation = activation_func,
        kernel_initializer = "he_normal",
        use_bias = False,
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
    )

    DefaultDenseLayer = partial(
        tf.keras.layers.Dense,
        activation = activation_func,
        kernel_initializer = "he_normal",
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
    )
    
    class BottleneckResidualUnit(tf.keras.layers.Layer):
        def __init__(self, filters, strides=1, activation="relu", **kwargs):
            super().__init__(**kwargs)
            self.activation = tf.keras.activations.get(activation)
            self.strides = strides
            self.filters = filters
            self.filters_out = filters * 4 # Output filters for bottleneck

            # --- Main Path Layers ---
            self.conv1 = tf.keras.layers.Conv2D(
                    filters,
                    kernel_size = 1,
                    strides = strides,
                    padding = "same",
                    kernel_initializer = "he_normal",
                    use_bias = False,
                    kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
                )
            self.bn1 = tf.keras.layers.BatchNormalization()
            # self.act1 = self.activation # Included below for clarity

            self.conv2 = tf.keras.layers.Conv2D(
                    filters,
                    kernel_size = 3,
                    strides = 1,
                    padding = "same",
                    kernel_initializer = "he_normal",
                    use_bias = False,
                    kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
                )
            self.bn2 = tf.keras.layers.BatchNormalization()
            # self.act2 = self.activation # Included below for clarity

            self.conv3 = tf.keras.layers.Conv2D(
                    self.filters_out,
                    kernel_size = 1,
                    strides = 1,
                    padding = "same",
                    kernel_initializer = "he_normal",
                    use_bias = False,
                    kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
                )
            self.bn3 = tf.keras.layers.BatchNormalization()

            # --- Shortcut Path Layers (initialized but may not be used) ---
            # We define them here and decice in "call" if they are needed.
            self.skip_conv = tf.keras.layers.Conv2D(
                self.filters_out,
                kernel_size = 1,
                strides = strides,
                padding = "same",
                kernel_initializer = "he_normal",
                use_bias = False,
                kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
            )
            self.skip_bn = tf.keras.layers.BatchNormalization()
            self.needs_projection = False


        def build(self, input_shape):
            if self.strides > 1 or input_shape[-1] != self.filters * 4:
                self.needs_projection = True
            super().build(input_shape)

        def call(self, inputs):
            # Main path
            Z = self.conv1(inputs)
            Z = self.bn1(Z)
            Z = self.activation(Z)

            Z = self.conv2(Z)
            Z = self.bn2(Z)
            Z = self.activation(Z)

            Z = self.conv3(Z)
            Z = self.bn3(Z)
            
            # Shortcut Path
            if self.needs_projection:
                skip_Z = self.skip_conv(inputs)
                skip_Z = self.skip_bn(skip_Z)
            else:
                skip_Z = inputs
            
            output = self.activation(Z + skip_Z)
            return output

    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True)

    # Define inputs
    image_input = tf.keras.layers.Input(shape=input_shape, name = "image_input")
    sex_input = tf.keras.layers.Input(shape=(1,), name = "sex_input")
    age_input = tf.keras.layers.Input(shape=(1,), name = "age_input")
    layer_input = tf.keras.layers.Input(shape=(1,), name = "layer_input")

    # Choose data augmentation pipeline
    augment_layer = normal_data_augmentation

    # --- Model Architecture ---
    x = augment_layer(image_input)
    x = tf.keras.layers.BatchNormalization(name = "b_norm_1")(x)
    x = DefaultConv2D(filters = 64, kernel_size = 7, strides = 2, name = "conv_1")(x)
    x = tf.keras.layers.BatchNormalization(name = "b_norm_2")(x)
    x = tf.keras.layers.Activation(activation_func)(x)
    x = tf.keras.layers.MaxPool2D(pool_size = 3, strides = 2, padding = "same", name = "pool_1")(x)

    block_config = [
        (64, 3, 1),
        (128, 8, 2),
        (256, 36, 2),
        (512, 3, 2)
    ]

    for filters, blocks, stride in block_config:
        for block in range(blocks):
            if block == 0:
                x = BottleneckResidualUnit(filters, strides=stride)(x)
            else:
                x = BottleneckResidualUnit(filters, strides=1)(x)

    x = tf.keras.layers.GlobalAveragePooling2D(name = "gap")(x)
    resnet_image_features = tf.keras.layers.Flatten(name = "flatten")(x)

    # --- Feature Concatenation ---
    inputs_to_concat = [resnet_image_features]

    if clinical_data:
        inputs_to_concat.extend([sex_input, age_input])
        if use_layer:
            inputs_to_concat.append(layer_input)
    elif use_layer:
        inputs_to_concat.append(layer_input)

    if len(inputs_to_concat) > 1:
        concatenated_features = tf.keras.layers.Concatenate(name = "concat_features")(inputs_to_concat)
    else:
        concatenated_features = resnet_image_features # No concatenation needed

    # --- Dense Layers ---
    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_1")(concatenated_features)
    x = DefaultDenseLayer(units = 512, name = "dense_1")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_1")(x)

    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_2")(x)
    x = DefaultDenseLayer(units = 256, name = "dense_2")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_2")(x)

    # --- Output Layer ---
    if num_classes == 2:
        # Binary Classification
        x = tf.keras.layers.Dense(1, name = f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('sigmoid', dtype='float32', name='predictions')(x)
        loss = "binary_crossentropy"
        metrics = ["accuracy",
                   tf.keras.metrics.AUC(name = "auc"),
                   tf.keras.metrics.Precision(name = "precision", thresholds = 0.5),
                   tf.keras.metrics.Recall(name = "recall", thresholds = 0.5),
                   tf.keras.metrics.F1Score(name = "f1_score", threshold = 0.5, average="micro")]
    elif num_classes > 2 and num_classes <= 6:
        x = tf.keras.layers.Dense(num_classes, name = f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        loss = "sparse_categorical_crossentropy"
        metrics = ["accuracy"]
    else:
        raise ValueError("num_classes must have a value between 2 and 6")

    # --- Create and compile model ---
    if dataset_type == Dataset.NORMAL:
        if clinical_data == True:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input], outputs = [output])
        else:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input], outputs = [output])
    else:
        model = tf.keras.Model(inputs = [image_input], outputs = [output])

    model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
    )
    
    model.summary()

    return model

In [6]:
resnet152_model = build_resnet152_model()

## ResNeXt50

In [8]:
def build_resnext_model(architecture = "ResNeXt50"):

    architectures = {
        "ResNeXt50": [3, 4, 6, 3],
        "ResNeXt101": [3, 4, 23, 3],
    }

    if architecture not in architectures:
        raise ValueError(f"Architecture {architecture} not recognized. Available architectures: {list(architectures.keys())}")

    repetitions = architectures[architecture]

    DefaultConv2D = partial(
        tf.keras.layers.Conv2D,
        kernel_size = 3,
        strides = 1,
        padding = "same",
        activation = None,
        kernel_initializer = "he_normal",
        use_bias = False,
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
    )

    DefaultDenseLayer = partial(
        tf.keras.layers.Dense,
        activation = activation_func,
        kernel_initializer = "he_normal",
        kernel_regularizer = tf.keras.regularizers.l2(l2_regularization)
    )
    
    class ResNeXtBlock(tf.keras.layers.Layer):
        def __init__(self, filters, cardinality, strides=1, input_filters = None, activation="relu", **kwargs):
            super().__init__(**kwargs)
            self.activation = tf.keras.activations.get(activation)
            self.main_layers = [
                DefaultConv2D(filters // 2, kernel_size=1, strides=1),
                tf.keras.layers.BatchNormalization(),
                self.activation,
                DefaultConv2D(filters // 2, kernel_size=3, strides=strides, groups=cardinality),
                tf.keras.layers.BatchNormalization(),
                self.activation,
                DefaultConv2D(filters, kernel_size=1, strides=1),
                tf.keras.layers.BatchNormalization()
            ]
            self.skip_layers = []
            if strides > 1 or filters != input_filters:
                self.skip_layers = [
                    DefaultConv2D(filters, kernel_size=1, strides=strides),
                    tf.keras.layers.BatchNormalization()
                ]

        def call(self, inputs):
            Z = inputs
            for layer in self.main_layers:
                Z = layer(Z)
            skip_Z = inputs
            for layer in self.skip_layers:
                skip_Z = layer(skip_Z)
            return self.activation(Z + skip_Z)

    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9, nesterov=True)

    # Define inputs
    image_input = tf.keras.layers.Input(shape=input_shape, name = "image_input")
    sex_input = tf.keras.layers.Input(shape=(1,), name = "sex_input")
    age_input = tf.keras.layers.Input(shape=(1,), name = "age_input")
    layer_input = tf.keras.layers.Input(shape=(1,), name = "layer_input")

    # Choose Data Augmentation pipeline
    augment_layer = normal_data_augmentation

    # --- Model Architecture ---
    x = augment_layer(image_input) # Apply augmentation first

    x = tf.keras.layers.BatchNormalization(name = "b_norm_1")(x)
    x = DefaultConv2D(filters=64, kernel_size=7, strides=2, name = "conv_1")(x)
    x = tf.keras.layers.BatchNormalization(name = "b_norm_2")(x)
    x = tf.keras.layers.Activation(activation_func)(x)
    x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="same", name = "pool_1")(x)

    cardinality = 32
    filters = 256 #128
    #repetitions = [3, 4, 6, 3]
    input_filters = x.shape[-1]
    for i, reps in enumerate(repetitions):
        for j in range(reps):
            strides = 2 if i > 0 and j == 0 else 1
            x = ResNeXtBlock(filters, cardinality, strides=strides, input_filters=input_filters, name = f"resnext_stage_{i}_block_{j}")(x)
            input_filters = x.shape[-1] #filters
        filters *= 2

    x = tf.keras.layers.GlobalAveragePooling2D(name = "gap")(x)
    resnext_image_features = tf.keras.layers.Flatten(name = "flatten")(x)

    # --- Feature Concatenation ---
    # use 'clinical_data' and 'use_layer'

    inputs_to_concat = [resnext_image_features]

    if clinical_data:
        inputs_to_concat.extend([sex_input, age_input])
        if use_layer:
            inputs_to_concat.append(layer_input)
    elif use_layer:
        inputs_to_concat.append(layer_input)

    if len(inputs_to_concat) > 1:
        concatenated_features = tf.keras.layers.Concatenate(name = "concat_features")(inputs_to_concat)
    else:
        concatenated_features = resnext_image_features # No concatenation needed


    # --- Dense Layers ---
    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_1")(concatenated_features)
    x = DefaultDenseLayer(units=512, name = "dense_1")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_1")(x)

    x = tf.keras.layers.BatchNormalization(name = "b_norm_dense_2")(x)
    x = DefaultDenseLayer(units=256, name = "dense_2")(x)
    x = tf.keras.layers.Dropout(dropout_rate, name = "dropout_2")(x)

    # --- Output Layer ---

    if num_classes == 2:
        # Binary Classification
        x = tf.keras.layers.Dense(1, name = f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('sigmoid', dtype='float32', name='predictions')(x)
        loss = "binary_crossentropy"
        metrics = ["accuracy",
                   tf.keras.metrics.AUC(name = "auc"),
                   tf.keras.metrics.Precision(name = "precision", thresholds = 0.5),
                   tf.keras.metrics.Recall(name = "recall", thresholds = 0.5),
                   tf.keras.metrics.F1Score(name = "f1_score", threshold = 0.5, average="micro")]
    elif num_classes > 2 and num_classes <= 6:
        x = tf.keras.layers.Dense(num_classes, name = f"dense_output_{num_classes}cls")(x)
        output = tf.keras.layers.Activation('softmax', dtype='float32', name='predictions')(x)
        loss = "sparse_categorical_crossentropy"
        metrics = ["accuracy"]
    else:
        raise ValueError("num_classes must have a value between 2 and 6")

    # --- Create and compile model ---
    if dataset_type == Dataset.NORMAL:
        if clinical_data == True:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input, sex_input, age_input], outputs = [output])
        else:
            if use_layer == True:
                model = tf.keras.Model(inputs = [image_input, layer_input], outputs = [output])
            else:
                model = tf.keras.Model(inputs = [image_input], outputs = [output])
    else:
        model = tf.keras.Model(inputs = [image_input], outputs = [output])

    model.compile(
        loss = loss,
        optimizer = optimizer,
        metrics = metrics
    )
    
    model.summary()

    return model

In [9]:
resnext50_model = build_resnext_model(architecture="ResNeXt50")

## ResNeXt101

In [None]:
resnext101_model = build_resnext_model(architecture="ResNeXt101")

## Transfer InceptionV3

## Transfer ResNet50V2