In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from keras import Sequential
from keras.layers import Dense, BatchNormalization, Dropout, Flatten, Resizing, Rescaling, RandomFlip, RandomRotation, RandomBrightness, RandomContrast, RandomZoom
from keras.utils import image_dataset_from_directory
from keras import Sequential, Input
from keras.optimizers import Adam
from keras.applications import EfficientNetV2L
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from keras import callbacks
from torchvision import transforms
from keras.preprocessing import image
from PIL import Image
from keras.losses import SparseCategoricalCrossentropy as scc
import cv2
from keras_cv.layers import Grayscale
f'Tensorflow version: {tf.__version__}'

In [None]:
imagesize = 256
batchsize = 32
rand_seed = 10

train_ds = image_dataset_from_directory(
    '/kaggle/input/medical-scan-classification-dataset/Covid/Covid/CT Scan',
    image_size = (imagesize, imagesize),
    color_mode = 'rgb',
    batch_size = batchsize,
    validation_split = 0.2,
    subset = 'training',
    labels = 'inferred',
    shuffle = True,
    seed = rand_seed
)

valid_ds = image_dataset_from_directory(
    '/kaggle/input/medical-scan-classification-dataset/Covid/Covid/CT Scan',
    image_size = (imagesize, imagesize),
    color_mode = 'rgb',
    batch_size = batchsize,
    validation_split = 0.2,
    subset = 'validation',
    labels = 'inferred',
    shuffle = True,
    seed = rand_seed
)

In [None]:
#x = valid_ds.take(1)

In [None]:
#for i, l in x:
#    print(i.shape)
#    break

In [None]:
#for img, label in valid_ds:
#    img = np.repeat(img, 3, 2)

In [None]:
class_names = train_ds.class_names

In [None]:
#AUTOTUNE = tf.data.AUTOTUNE

#train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
#valid_ds = valid_ds.prefetch(buffer_size=AUTOTUNE)

In [None]:
base_model = EfficientNetV2L(include_top=False,
    weights='imagenet',
    input_shape=(imagesize, imagesize, 3))
base_model.trainable = False

In [None]:
base_model.trainable = True

set_trainable = False

for layer in base_model.layers:
  if layer.name == 'block7a_expand_conv':
    set_trainable = True
  if set_trainable:
    layer.trainable = True
  else:
    layer.trainable = False

#for layer in base_model.layers:
#  print(layer.name,layer.trainable)

In [None]:
data_augmentation = Sequential([
    RandomFlip('horizontal_and_vertical'),
    RandomRotation(0.2),
    RandomBrightness(0.2),
    RandomContrast(0.2),
    Rescaling(1./255),
    RandomZoom(height_factor=(-0.1, 0.1), width_factor = (-0.1, 0.1))
])

In [None]:
grayscale = transforms.Grayscale(num_output_channels = 3)

In [None]:
lr_init = 0.00015

In [None]:
lr_const = 0.00015

In [None]:
model = Sequential([
    RandomZoom(0.1,0.1),
    RandomContrast(0.1),
    RandomBrightness(0.1),
    RandomFlip("horizontal"),
    Grayscale(output_channels=3),
    #Rescaling(1./255),
    base_model,
    Flatten(),
    Dense(1024, activation = 'relu'),
    BatchNormalization(),
    Dense(512, activation = 'relu'),
    BatchNormalization(),
    Dense(512, activation = 'relu'),
    BatchNormalization(),
    Dense(256, activation = 'relu'),
    BatchNormalization(),
    Dense(128, activation = 'relu'),
    BatchNormalization(),
    Dropout(0.2),
    Dense(4, activation = 'softmax')
])

In [None]:
#model = create_model(base_model)
#model.summary()

In [None]:
early_stopping = callbacks.EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=10,
    restore_best_weights=True,
)

In [None]:
def lr_scheduler(epoch, lr):
    if epoch < 8:
        return lr + lr_init
    elif epoch < 25:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

In [None]:
lr_callback = callbacks.LearningRateScheduler(lr_scheduler, verbose = 1)

In [None]:
def plot_history(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']

    loss=history.history['loss']
    val_loss=history.history['val_loss']
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()

In [None]:
model.compile(optimizer=Adam(lr_init),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
history = model.fit(train_ds,
                    epochs=100,
                    validation_data=valid_ds,
                   callbacks=[early_stopping,lr_callback])

In [None]:
model.save('covid_ct_scan.h5')

In [None]:
plot_history(history)