In [None]:
import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Input, Dropout, Flatten, BatchNormalization, Dense, Activation
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.applications import VGG16
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint


In [None]:
from google.colab import drive
drive.mount('/content/drive2')

# Dataset path and class names
dataset_path = '/content/drive2/MyDrive/MedicalWasteSplit'
train_directory = os.path.join(dataset_path, 'train')
class_names = [item for item in os.listdir(train_directory) if os.path.isdir(os.path.join(train_directory, item))]
print("Class names:", class_names)

# Image dimensions
img_height, img_width = 224, 224

Mounted at /content/drive2
Class names: ['glove_single_latex', 'glove_pair_surgery', 'glove_single_nitrile', 'glove_single_surgery', 'test_tube', 'shoe_cover_single', 'medical_glasses', 'shoe_cover_pair', 'urine_bag', 'medical_cap', 'glove_pair_nitrile', 'gauze', 'glove_pair_latex']


In [None]:
# Create generators
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

In [None]:
valid_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

In [None]:
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'

)

In [None]:
valid_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# Batch size
batch_size = 32
num_classes = 13


In [None]:
# Create generators
train_generator = train_datagen.flow_from_directory(
    os.path.join(dataset_path, 'train'),
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)
validation_generator = valid_datagen.flow_from_directory(
    os.path.join(dataset_path, 'validation'),
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)
test_generator = test_datagen.flow_from_directory(
    os.path.join(dataset_path, 'test'),
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

Found 2967 images belonging to 13 classes.
Found 629 images belonging to 13 classes.
Found 649 images belonging to 13 classes.


In [None]:
#define vgg16 model
base_model = VGG16(input_shape=(224, 224, 3), include_top=False, weights="imagenet")

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
# Fine-tune the last few layers
for layer in base_model.layers[:-4]:
    layer.trainable = False

In [None]:
# Freeze the layers of the base model
#for layer in base_model.layers:
#    layer.trainable = False

In [None]:
# Define the custom head of the model
model = Sequential()
model.add(base_model)
model.add(Dropout(0.2))
model.add(Flatten())
model.add(BatchNormalization())
model.add(Dense(1024, kernel_initializer='glorot_uniform'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(1024, kernel_initializer='glorot_uniform'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(13, activation='softmax'))

In [None]:
# Compile the model
OPT = tf.keras.optimizers.Adam(learning_rate=0.0001)
model.compile(loss='categorical_crossentropy',
              metrics=['accuracy'],
              optimizer=OPT)


In [None]:
# Define Callbacks
filepath = './best_weights.hdf5'
earlystopping = EarlyStopping(monitor='val_auc', mode='max', patience=5, verbose=1)
checkpoint = ModelCheckpoint(filepath, monitor='val_auc', mode='max', save_best_only=True, verbose=1)
callback_list = [earlystopping, checkpoint]

In [None]:
# Display the model summary
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 vgg16 (Functional)          (None, 7, 7, 512)         14714688  
                                                                 
 dropout (Dropout)           (None, 7, 7, 512)         0         
                                                                 
 flatten (Flatten)           (None, 25088)             0         
                                                                 
 batch_normalization (Batch  (None, 25088)             100352    
 Normalization)                                                  
                                                                 
 dense (Dense)               (None, 1024)              25691136  
                                                                 
 batch_normalization_1 (Bat  (None, 1024)              4096      
 chNormalization)                                       

In [None]:
# Model fitting
model_history = model.fit(train_generator,
                          validation_data=validation_generator,
                          epochs=15,
                          callbacks=callback_list,
                          verbose=1)

Epoch 1/15



Epoch 2/15



Epoch 3/15



Epoch 4/15

In [None]:
# Model fitting
model_history = model.fit(train_generator,
                          validation_data=validation_generator,
                          epochs=10,
                          callbacks=callback_list,
                          verbose=1)

Epoch 1/10



Epoch 2/10



Epoch 3/10



Epoch 4/10



Epoch 5/10



Epoch 6/10



Epoch 7/10



Epoch 8/10



Epoch 9/10

In [None]:
import matplotlib.pyplot as plt

In [None]:
def plot_history(history):
    # Plot training & validation accuracy values
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(['Train', 'Validation'], loc='upper left')