Written by Yashil Pudaruth.

*To avoid Out Of Memory (OOM) error, restart kernel and clear output after each run.*

In [None]:
#import libraries

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import keras
import os
from keras import layers
from keras.applications.mobilenet_v2 import MobileNetV2
from keras.applications.mobilenet_v2 import preprocess_input
from keras.preprocessing.image import ImageDataGenerator

In [None]:
#define hyperparameters 

batch_size = 32
img_width = 224
img_height = 224

In [None]:
#set paths to training and validation folders separately

train_path = os.path.join('Dataset2', 'Train')

val_path = os.path.join('Dataset2', 'Validation')

In [None]:
#create ImageDataGenerator object to read, apply preprocessing to, and augment training dataset

train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input, 
                                   rotation_range=30, 
                                   width_shift_range=0.2, 
                                   height_shift_range=0.2, 
                                   shear_range=0.2, 
                                   zoom_range=0.2, 
                                   horizontal_flip=True)

#create ImageDataGenerator object to read and apply preprocessing to validation dataset

val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)

In [None]:
#training data specified as training dataset and input pipeline setup

train_ds = train_datagen.flow_from_directory(directory=train_path, 
                                             target_size=(img_width, img_height), 
                                             color_mode="rgb", 
                                             batch_size=batch_size, 
                                             class_mode="categorical", 
                                             shuffle=True, 
                                             seed=42)

#validation data specified as validation dataset and input pipeline setup
val_ds = val_datagen.flow_from_directory(directory=val_path, 
                                         target_size=(img_width, img_height), 
                                         color_mode="rgb", 
                                         batch_size=batch_size, 
                                         class_mode="categorical", 
                                         shuffle=False)

In [None]:
#instantiate base model with ImageNet weights, input shape but without the top
base_model = MobileNetV2(weights='imagenet', 
                         input_shape=(224, 224, 3), 
                         include_top=False)

#freeze layers in base model
base_model.trainable = False

In [None]:
#create new input tensor
inputs = keras.Input(shape=(224, 224, 3))

#place base model on new input tensor and set batch normalisation layers to inference mode
new_top = base_model(inputs, training=False)

#add the global average pooling, 64 neuron fully connected, 20% dropout and 3 neuron Softmax output layers
new_top = layers.GlobalAveragePooling2D()(new_top)
new_top = layers.Dense(64, activation="relu")(new_top)
new_top = layers.Dropout(0.2)(new_top)

outputs = layers.Dense(3, activation='softmax')(new_top)

#build the new model
model = keras.Model(inputs=inputs, outputs=outputs)

In [None]:
model.summary()

In [None]:
#set number of epochs and initial learning rate

INIT_LR = 1e-3
EPOCHS = 10

#create custom Adam optimiser with initial learning and decay rates
opt = tf.keras.optimizers.Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)


#compile the model
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
#set model checkpoint callback to save best model of all epochs based on min val loss
callback = tf.keras.callbacks.ModelCheckpoint('model_mnv2_3cls.h5', 
                                              monitor='val_loss', 
                                              save_best_only=True, 
                                              mode='min')

#train the model with weighted classes and record history
mymodel = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS, 
                    class_weight = {0:1 , 1:2, 2:1}, callbacks=[callback])

In [None]:
#retrieve learning metrics 
acc = mymodel.history['accuracy']
val_acc = mymodel.history['val_accuracy']

loss = mymodel.history['loss']
val_loss = mymodel.history['val_loss']

#define epochs range for x-axis of learning graphs
epochs_range = range(1, EPOCHS+1)

In [None]:
#plot learning graphs vs epochs 

plt.figure(figsize=(6, 6))

plt.plot(epochs_range, mymodel.history["accuracy"], label="train_acc")
plt.plot(epochs_range, mymodel.history["val_accuracy"], label="val_acc")

plt.plot(epochs_range, mymodel.history["loss"], label="train_loss")
plt.plot(epochs_range, mymodel.history["val_loss"], label="val_loss")

plt.title("Training and Validation for MobileNetV2")
plt.ylim([0.0, 1.0])
plt.xticks(epochs_range)
plt.xlabel("Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="right")
plt.grid()
plt.show()

In [None]:
#classify function for performing classification on single images

def classify(img_path):
    class_names = ['Correctly_Masked','Incorrectly_Masked','Not_Masked']
    img = keras.preprocessing.image.load_img(img_path, target_size=(img_width, img_height))
    img_array = keras.preprocessing.image.img_to_array(img)
    img_array = preprocess_input(img_array)
    img_array = tf.expand_dims(img_array, 0)
    predictions = model.predict(img_array)
    print(predictions)
    score = predictions[0]
    print(score)
    print("Class {} with {:.2f}% confidence".format(class_names[np.argmax(score)], 100*np.max(score)))

In [None]:
#test classification on single images

# img_path = "C:\\Users\Public\Dataset\Test\\85.png"
# classify(img_path)

In [None]:
#save last epoch model trained

#model.save("model_mnv2_3cls_full.h5")

In [None]:
from keras.models import load_model

In [None]:
#load best model of all epoch saved by model checkpoint callback

model = load_model("model_mnv2_3cls.h5")

In [None]:
#read and preprocess test datasets

from sklearn.preprocessing import label_binarize

class_names = ['Correctly_Masked','Incorrectly_Masked','Not_Masked']
x_test =[]
y_test=[]
for c in class_names:
    
    path = os.path.join('Dataset2/Test',c)
    
    label = class_names.index(c)
    
    for file in os.listdir(path):
        
        img_path = os.path.join(path, file)
        img = keras.preprocessing.image.load_img(img_path, target_size=(224, 224))
        img_array = keras.preprocessing.image.img_to_array(img)
        img = preprocess_input(img_array)
    

        x_test.append(img)
        y_test.append(int(label))

In [None]:
#convert arrays to NumPy arrays

x_test = np.array(x_test)
y_test = np.array(y_test)

In [None]:
#one-hot encode labels

y_test2 = label_binarize(y_test, classes=[0, 1, 2])

#get number of classes
n_classes = y_test2.shape[1]

In [None]:
#generate classification report and confusion matric based on predictions

from sklearn.metrics import classification_report, confusion_matrix

preds = model.predict(x_test, batch_size=batch_size, verbose=1)

y_pred = np.argmax(preds, axis=1)
y_true = np.argmax(y_test2, axis=1)

print(classification_report(y_true, y_pred))

cm = confusion_matrix(y_true, y_pred)
print(cm)

In [None]:
#function obtained from http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix for MobileNetV2',
                          cmap=None,
                          normalize=True):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """
    import matplotlib.pyplot as plt
    import numpy as np
    import itertools

    accuracy = np.trace(cm) / np.sum(cm).astype('float')
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(6, 4))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\nAccuracy={:0.2f}; Misclassification={:0.2f}'.format(accuracy, misclass))
    plt.show()

In [None]:
#plot confusion matrix

plot_confusion_matrix(cm, [0,1,2], normalize=False)

In [None]:
#function obtained from https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html

from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score

precision = dict()
recall = dict()
average_precision = dict()

for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test2[:, i], preds[:, i])
    
    average_precision[i] = average_precision_score(y_test2[:, i], preds[:, i])

precision["micro"], recall["micro"], _ = precision_recall_curve(y_test2.ravel(),preds.ravel())

average_precision["micro"] = average_precision_score(y_test2, preds, average="micro")

In [None]:
#plot precision-recall graph 

plt.figure()

plt.step(recall['micro'], precision['micro'], where='post', lw=2,
         label='PR curve (AP = %0.2f)' % average_precision['micro'])

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.grid()
plt.legend(loc='lower right')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.05])
plt.title('Precision-Recall Curve for MobileNetV2')

In [None]:
#function obtained from https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py

from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score
from numpy import interp
from itertools import cycle

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test2[:, i], preds[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

fpr["micro"], tpr["micro"], _ = roc_curve(y_test2.ravel(), preds.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

In [None]:
#plot roc graph

plt.figure()

plt.plot(fpr['micro'], tpr['micro'], color='darkorange',
         lw=2, label='ROC curve (AUC = %0.2f)' % roc_auc['micro'])
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.05])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC curve for MobileNetV2')
plt.legend(loc="lower right")
plt.grid()
plt.show()