IMPORTS

In [2]:
import tensorflow as tf
import numpy as np
import os

LOAD AND PREPROCESS  DATASET

In [3]:
from tensorflow.keras.preprocessing import image_dataset_from_directory
from sklearn.model_selection import train_test_split

In [5]:
# Define the path to your dataset directory
dataset_dir = "dataset/train"

# Parameters
batch_size = 32
img_height = 256  # Resize height for MRI images
img_width = 256   # Resize width for MRI images
validation_split = 0.2  # 20% data for validation

In [6]:
# Load dataset and split into training and validation sets
train_ds = image_dataset_from_directory(
    dataset_dir,
    validation_split=validation_split,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

val_ds = image_dataset_from_directory(
    dataset_dir,
    validation_split=validation_split,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

Found 3253 files belonging to 2 classes.
Using 2603 files for training.
Found 3253 files belonging to 2 classes.
Using 650 files for validation.


In [7]:
def preprocess_image(image, label):
    # Normalize pixel values between 0 and 1
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

train_ds = train_ds.map(preprocess_image)
val_ds = val_ds.map(preprocess_image)

In [8]:
# Optional: Improve performance by caching and prefetching
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Check the structure of the datasets
print("Training Dataset:", train_ds)
print("Validation Dataset:", val_ds)

Training Dataset: <_PrefetchDataset element_spec=(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>
Validation Dataset: <_PrefetchDataset element_spec=(TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>


TRAINING

In [15]:
from tensorflow.keras import layers, Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.applications import DenseNet121

In [12]:
# Define Darknet-53 with initial layers
def darknet53(input_shape=(224, 224, 3)):
    inputs = tf.keras.Input(shape=input_shape)
    
    # Initial Conv Layer
    x = layers.Conv2D(32, (3, 3), strides=(1, 1), padding='same', activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    
    # Darknet-53 residual block with a convolution shortcut when increasing filters
    def darknet_residual_block(x, filters):
        # Shortcut connection
        shortcut = x

        # First Conv layer with 1x1 kernel
        x = layers.Conv2D(filters, (1, 1), padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        # Second Conv layer with 3x3 kernel
        x = layers.Conv2D(filters * 2, (3, 3), padding='same', activation='relu')(x)
        x = layers.BatchNormalization()(x)
        
        # Match dimensions for shortcut if necessary
        if shortcut.shape[-1] != x.shape[-1]:  # Check if the number of filters changed
            shortcut = layers.Conv2D(filters * 2, (1, 1), padding='same')(shortcut)
        
        # Add the shortcut
        return layers.Add()([shortcut, x])

    # Stack of residual blocks with increasing filter sizes
    for filters in [32, 64, 128, 256]:
        x = darknet_residual_block(x, filters)
    
    # Output block for feature extraction
    darknet_output = layers.GlobalAveragePooling2D()(x)
    return Model(inputs, darknet_output, name="darknet53")

In [13]:
# Load Darknet-53 model and DenseNet model
darknet_model = darknet53()

# Load pre-trained DenseNet and extract layers up to a certain depth
densenet_base = DenseNet121(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
densenet_output = layers.GlobalAveragePooling2D()(densenet_base.output)
densenet_model = Model(densenet_base.input, densenet_output, name="densenet121")

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/densenet/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5


In [19]:
# Combine the models using Concatenate
input_layer = tf.keras.Input(shape=(224, 224, 3))

# Pass input through Darknet and DenseNet models
darknet_features = darknet_model(input_layer)
densenet_features = densenet_model(input_layer)

In [20]:
combined_features = layers.Concatenate()([darknet_features, densenet_features])

# Add a few fully connected layers after concatenation
x = layers.Dense(128, activation='relu')(combined_features)
x = layers.Dense(64,activation='relu')(x)
x = layers.Dense(32,activation='relu')(x)
x = layers.Dense(8,activation='relu')(x)
x = layers.Dense(2,activation='relu')(x)

output = layers.Dense(2, activation='softmax')(x)  # Adjust number of classes as needed

In [21]:
# Define the final model
combined_model = Model(inputs=input_layer, outputs=output, name="combined_darknet_densenet")

# Compile the model
combined_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Show model summary
combined_model.summary()

Model: "combined_darknet_densenet"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_4 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 darknet53 (Functional)      (None, 512)                  1837024   ['input_4[0][0]']             
                                                                                                  
 densenet121 (Functional)    (None, 1024)                 7037504   ['input_4[0][0]']             
                                                                                                  
 concatenate_1 (Concatenate  (None, 1536)                 0         ['darknet53[0][0]',           
 )                                                                   'dens