Libraries

In [2]:
import tensorflow as tf
from   tensorflow.keras.models import Model
from   tensorflow import keras
from   tensorflow.keras import layers, regularizers
from   tensorflow.keras.applications import EfficientNetB3
from   tensorflow.keras.layers import Conv2D, UpSampling2D, Concatenate
from   keras.layers import Lambda
import pathlib
from   tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

#!pip install tabulate

Parameters and Hyperparameters settings

In [3]:
# Load and preprocess the data
batch_size = 4
img_height = 300#256
img_width = 300#256
IMG_SIZE = img_height
EPOCHS = 6
base_model = EfficientNetB3(weights = 'imagenet', include_top = False, input_shape = (img_height , img_width, 3))

Datasets' addresses

In [4]:
MDD_training_dir = "C:/Users/39351/Documents/Most_Killer_Cancers_Train_Test_Sets_100Tiles/train/"
MDD_val_dir =      "C:/Users/39351/Documents/Most_Killer_Cancers_Train_Test_Sets_100Tiles/val/"
#MDD_training_dir = MDD_val_dir
MDD_test_dir =     "C:/Users/39351/Documents/Most_Killer_Cancers_Train_Test_Sets_100Tiles/test/"

Setting the Keras Dataloaders

In [5]:
MDD_training_dir = pathlib.Path(MDD_training_dir)
MDD_val_dir = pathlib.Path(MDD_val_dir)
#label_mode ='binary'
label_mode ='categorical'

train_ds = tf.keras.utils.image_dataset_from_directory(
  MDD_training_dir,
  labels = 'inferred',
  label_mode =label_mode,
  #validation_split = 0.2,
  #subset = 'training',
  shuffle = True,
  seed = 42,
  image_size = (img_height, img_width),
  batch_size = batch_size)

val_ds = tf.keras.utils.image_dataset_from_directory(
  MDD_val_dir,
  labels = 'inferred',
  label_mode =label_mode,
  #validation_split = 0.2,
  #subset = 'validation',
  shuffle = True,
  seed = 42,
  image_size = (img_height, img_width),
  batch_size = batch_size)

test_ds = tf.keras.utils.image_dataset_from_directory(
  MDD_test_dir,
  labels = 'inferred',
  label_mode =label_mode,
  #validation_split = 0.2,
  #subset = 'validation',
  shuffle = False,
  image_size = (img_height, img_width),
  batch_size = batch_size)

# Extract class names
class_names = val_ds.class_names

# Calculate the number of classes
num_classes = len(class_names)

print("Class Names:", class_names)
print("Number of Classes:", num_classes)

Found 86900 files belonging to 11 classes.
Found 13900 files belonging to 11 classes.
Found 27400 files belonging to 11 classes.
Class Names: ['Astrocytoma', 'Breast_Carcinoma', 'Colon_Adenocarcinoma', 'Cutaneous_Melanoma', 'Gastric_Adenocarcinoma', 'Glioblastoma', 'Hepatocarcinoma', 'Non_Tumor', 'Nsclc_Adenocarcinoma', 'Nsclc_Squamous_Cell_Carcinoma', 'Oligodendroglioma']
Number of Classes: 11


In [6]:
class TransformerBlock(layers.Layer):
    def __init__(self, num_heads, embed_dim, ff_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.rate = rate

        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
        
    def get_config(self):
        config = super().get_config()
        config.update({
            'num_heads': self.num_heads,
            'embed_dim': self.embed_dim,
            'ff_dim': self.ff_dim,
            'rate': self.rate
        })
        return config

In [7]:
# Define pyramid module layersfrom keras.layers import Lambda
c5 = base_model.get_layer('block7a_project_bn').output
p4 = Conv2D(filters=64, kernel_size=(2, 2), activation='relu')(c5)
p4 = UpSampling2D(size=(2, 2), interpolation='bilinear')(p4)

c4 = base_model.get_layer('block4a_expand_activation').output
c4 = Conv2D(filters=64, kernel_size=(1, 1), activation='relu')(c4)
        
c4 = Lambda(lambda x: tf.image.resize(x, (18, 18)))(c4)  # Resize spatial dimensions
p4 = Concatenate()([p4, c4])
        
p3 = Conv2D(filters=64, kernel_size=(2, 2), dilation_rate=(2, 2), activation='relu')(p4)
p3 = UpSampling2D(size=(2, 2), interpolation='bilinear')(p3)
c3 = base_model.get_layer('block3a_expand_activation').output
c3 = Conv2D(filters=64, kernel_size=(1, 1), activation='relu')(c3)
        
c3 = Lambda(lambda x: tf.image.resize(x, (32, 32)))(c3)  # Resize spatial dimensions
p3 = Concatenate()([p3, c3])

# Create the new model with the pyramid module connected to CNN
net = tf.keras.Model(inputs=base_model.input, outputs=p3)

In [8]:
'''for layer in net.layers:
    layer.trainable = False'''
# Add a layer to convert the input to the HSV color space
x = TransformerBlock(num_heads = 4, embed_dim = 128, ff_dim = 96, rate=0.3)(net.output)

#x = keras.layers.Permute((3, 1, 2))(x)#
x = keras.layers.Reshape((int(x.shape[1]), int(x.shape[2]), int(x.shape[3])))(x)
x_c1 = keras.layers.Conv2D(64, (3, 3), padding='same', dilation_rate=(2, 2), activation='relu')(x)
x_c01 = TransformerBlock(num_heads = 4, embed_dim = 64, ff_dim = 64, rate=0.3)(x_c1)
x_c001 = keras.layers.Reshape((int(x_c01.shape[1]), int(x_c01.shape[2]), int(x_c01.shape[3])))(x_c01)
x_c11 = keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x_c001)
x_c11 = keras.layers.BatchNormalization()(x_c11)
x_c2 = keras.layers.Conv2D(32, (3, 3), padding='same', activation='relu')(x_c11)
combined1 = keras.layers.concatenate([x_c1, x_c2])
x = keras.layers.MaxPooling2D()(combined1)
x = keras.layers.Flatten()(x)

x = keras.layers.Dense(256, activation='relu')(x)
x = keras.layers.Dense(64, activation='relu', name = "features_vector")(x)

if num_classes == 2:
    x = keras.layers.Dense(1, activation='sigmoid',name = "pred")(x)
else:
    x = keras.layers.Dense(num_classes, activation='softmax')(x)

model = Model(inputs=net.input, outputs=x)
print(model)

'''from keras.utils.vis_utils import plot_model
plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)'''

<keras.engine.functional.Functional object at 0x000001C7D656D070>


"from keras.utils.vis_utils import plot_model\nplot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True)"

In [9]:
def scheduler(epoch, lr):
   if epoch < 3:
     return lr
   else:
     return lr * tf.math.exp(-0.1)

# Define callbacks options
#early_stopping = EarlyStopping(monitor = 'val_accuracy', patience = 10, mode = 'max')
#early_stopping = EarlyStopping(monitor='val_loss', mode='min', baseline=0.0040)
reduce_lr = ReduceLROnPlateau(monitor = 'val_accuracy', factor = 0.1, patience = 1, verbose = 1, min_delta = 1e-4, mode = 'max')
best_stn_model = ModelCheckpoint('best_model/best_model.hdf5', save_best_only = True, monitor = 'val_accuracy', mode = 'max')
lr_rate_sched = tf.keras.callbacks.LearningRateScheduler(scheduler)

TN = keras.metrics.TrueNegatives()
TP = keras.metrics.TruePositives()
FN = keras.metrics.FalseNegatives()
FP = keras.metrics.FalsePositives()
SEN = keras.metrics.Recall()
AUC = keras.metrics.AUC()
SPE = tf.keras.metrics.SpecificityAtSensitivity(0.5)
SENS= tf.keras.metrics.SensitivityAtSpecificity(0.5)

opt = tf.keras.optimizers.SGD(learning_rate = 0.001)
if num_classes == 2:
  loss_func = 'binary_crossentropy'
else:
    loss_func = 'categorical_crossentropy'

model.compile(loss = loss_func, optimizer =  opt , metrics = ['accuracy', TN, TP, FN, FP, SEN, AUC, SPE, SENS])
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 300, 300, 3) 0                                            
__________________________________________________________________________________________________
rescaling (Rescaling)           (None, 300, 300, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
normalization (Normalization)   (None, 300, 300, 3)  7           rescaling[0][0]                  
__________________________________________________________________________________________________
stem_conv_pad (ZeroPadding2D)   (None, 301, 301, 3)  0           normalization[0][0]              
____________________________________________________________________________________________

In [None]:
print("MAViT training has being strated")
with tf.device('/device:GPU:0'):
   print("tf.keras code in this scope will run on GPU")
   hist = model.fit(train_ds,
              validation_data = val_ds,
              epochs = EPOCHS,
              verbose = 1,
              callbacks = [lr_rate_sched])
print("Model_Average accuracy on the best model: ", max(hist.history["val_accuracy"]))

In [None]:
model.save('best_model/MAViT.hdf5')

In [None]:
def Evaluation(tp, tn, fp, fn):
    accuracy = (tp+ tn) / (tp+tn+fp+fn)
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    F1_score = (2 * tp) / ((2 * tp) + fp + fn)
    return accuracy, sensitivity, specificity, F1_score

print("Validation Set performacne: ")
stn_loss, stn_acc, stn_tn, stn_tp, stn_fn, stn_fp, stn_senv, stn_auc, stn_spec, stn_sensAt = model.evaluate(val_ds)
SAT_accuracy, SAT_sensitivity, SAT_specificity, SAT_F1_score = Evaluation(stn_tp, stn_tn, stn_fp, stn_fn)
print("Accuracy: {:.2f}%".format(SAT_accuracy * 100))
print("Sensitivity: {:.2f}%".format(SAT_sensitivity * 100))
print("Specificity: {:.2f}%".format(SAT_specificity * 100))
print("F1-Score: {:.2f}%".format(SAT_F1_score * 100))
print("AUC: {:.2f}%".format(stn_auc * 100))

print("*********\nTest Set performacne: ")
stn_loss, stn_acc, stn_tn, stn_tp, stn_fn, stn_fp, stn_senv, stn_auc, stn_spec, stn_sensAt = model.evaluate(test_ds)
SAT_accuracy, SAT_sensitivity, SAT_specificity, SAT_F1_score = Evaluation(stn_tp, stn_tn, stn_fp, stn_fn)
print("Accuracy: {:.2f}%".format(SAT_accuracy * 100))
print("Sensitivity: {:.2f}%".format(SAT_sensitivity * 100))
print("Specificity: {:.2f}%".format(SAT_specificity * 100))
print("MF1-Score: {:.2f}%".format(SAT_F1_score * 100))
print("AUC: {:.2f}%".format(stn_auc * 100))