In [None]:
#Import necesary libraries
import numpy as np
import time
import pandas as pd
import os
import cv2
import PIL.Image as Image
import matplotlib.pylab as plt

import tensorflow as tf
import tensorflow_hub as hub

In [None]:
#Mount Drive 
from google.colab import drive
drive.mount('/googledrive')

# Create a symbolic link to our Google Drive
! mkdir -p /googledrive/MyDrive/colabdrive
! ln -snf /googledrive/MyDrive/colabdrive/ /colabdrive

In [None]:
!unzip '/googledrive/MyDrive/colabdrive/output'  -d  '/googledrive/MyDrive/colabdrive/'

In [None]:
TRAINING_DATASET_PATH="/colabdrive/output/train"
TEST_DATASET_PATH="/colabdrive/output/test"
VALIDATION_DATASET_PATH="/colabdrive/output/val"
EXPORT_PATH='/colabdrive/Saved Model'
TFLITE_EXPORT_PATH=EXPORT_PATH+"/model.tflite"

In [None]:
batch_size = 32 # Set batch size
img_height = 224 #Set image height 
img_width = 224 #Set image width

#Creatin test, train and validation set, set the image dimensions and the batch size
trainSet = tf.keras.preprocessing.image_dataset_from_directory(
  str(TRAINING_DATASET_PATH),
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

testSet = tf.keras.preprocessing.image_dataset_from_directory(
  str(TEST_DATASET_PATH),
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

validationSet = tf.keras.preprocessing.image_dataset_from_directory(
  str(VALIDATION_DATASET_PATH),
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

In [None]:
#Retrieve class names and the number of classes
classNames = np.array(validationSet.class_names)
numberOfClasses = len(classNames)



In [None]:
# Function for Rescaling and setiing the buffer size for datasets
def dataProcessing(dataset):
  normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
  dataset = dataset.map(lambda x, y: (normalization_layer(x), y))
  AUTOTUNE = tf.data.AUTOTUNE
  dataset = dataset.cache().prefetch(buffer_size=AUTOTUNE)
  return dataset

In [None]:
#Rescaling of images from datasets
trainSet = dataProcessing(trainSet)
testSet = dataProcessing(testSet)
validationSet = dataProcessing(validationSet)

In [None]:
#IF YOU RUN THE FOLLOWING CELLS, YOU WILL TRAIN A MODEL. YOU HAVE THE OPTION TO LOAD AN ALREADY TRAINED MODEL AND SAVE TIME BY RUNNING THE CELL (LOAD THE TRAINED MODEL)
# Seting the feature extractor layer using a pre-treined headless layer from TensorFlow Hub
feature_extractor_layer = hub.KerasLayer(
    "https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/feature_vector/5", input_shape=(224, 224, 3), trainable=False)

#Creating a Sequential model using the feature extractor layer and a top layer for classification of traffic signs
model = tf.keras.Sequential([
  feature_extractor_layer,
  tf.keras.layers.Dense(numberOfClasses)
])

In [None]:
#Summary of the model
model.summary()

In [None]:
model.compile(
  optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['acc'])

In [None]:
#Custom callback to log the loss and accuracy of each batch individually, instead of the epoch average.
class CollectBatchStats(tf.keras.callbacks.Callback):
  def __init__(self):
    self.batch_losses = []
    self.batch_acc = []

  def on_train_batch_end(self, batch, logs=None):
    self.batch_losses.append(logs['loss'])
    self.batch_acc.append(logs['acc'])
    self.model.reset_metrics()



In [None]:
#Model training with the evaluation of each batch
batchStatsCallback = CollectBatchStats()

history = model.fit(trainSet, epochs=4,validation_data = validationSet,
                    callbacks=[batchStatsCallback])

In [None]:
#Evaluation of the model
score = model.evaluate(testSet, verbose=0)
print(f'Test loss: {score[0]} / Test accuracy: {score[1]}')


In [None]:
t = time.time()

export_path = EXPORT_PATH+"/{}".format(int(t))
model.save(export_path)

export_path

In [None]:
#LOAD THE TRAINED MODEL

#Reload the saved model

reloadedModel = tf.keras.models.load_model(export_path)

In [None]:
#Visualisation of model predictions for the first batch of images
for imageBatch,_ in testSet:
  break

reloadedResultBatch= reloadedModel.predict(imageBatch)
predictedId = np.argmax(reloadedResultBatch, axis=-1)
predicteLlabelBatch = classNames[predictedId]
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(imageBatch[n])
  plt.title(predicteLlabelBatch[n].title())
  plt.axis('off')
_ = plt.suptitle("Model predictions")

In [None]:

PredictionsOnTestSet = reloadedModel.predict(testSet)

In [None]:
#Convert the model into a TF Lite model and saving it on drive
converter = tf.lite.TFLiteConverter.from_saved_model(export_path) # path to the SavedModel directory
tfliteModel = converter.convert()
with open(EXPORT_PATH+"/model.tflite", 'wb') as f:
  f.write(tfliteModel)

In [None]:
#Load the TF Lite model 
tfliteInterpreter = tf.lite.Interpreter(TFLITE_EXPORT_PATH)
inputDetails = tfliteInterpreter.get_input_details() #Get input details
outputDetails = tfliteInterpreter.get_output_details() #Get output details
tfliteInterpreter.resize_tensor_input(inputDetails[0]['index'], (1, 224, 224, 3)) #Resize the input for making prediction on an image
tfliteInterpreter.resize_tensor_input(outputDetails[0]['index'], (1, len(classNames)))
tfliteInterpreter.allocate_tensors()

In [None]:

testImage=Image.open(TEST_DATASET_PATH+'/20/00000_00000_00027.png').resize((224,224)) #Lite model predictions on an image
testImage=np.array(testImage,  dtype=np.float32)/255.0
testImage.shape
Img = testImage[np.newaxis, ...]
startTime = time.time()
result = tfliteInterpreter.set_tensor(inputDetails[0]['index'],Img)
tfliteInterpreter.invoke()
tfliteModelPredictions = tfliteInterpreter.get_tensor(outputDetails[0]['index'])
endTime = time.time()
predictedClass = np.argmax(tfliteModelPredictions[0], axis=-1)
predictedClassName = classNames[predictedClass]
duration = endTime - startTime
print(f"Prediction time for an image: {duration} seconds")

In [None]:
#TF Lite  model predictions on an image

plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
plt.imshow(testImage)
plt.title(predictedClassName.title())
plt.axis('off')
_ = plt.suptitle("Model predictions for an image")

In [None]:
 # Accuracy of the lite model on test dataset

def load_dataset(directory): 
    """
    Returns: 
    X_orig -- np.array containing all images
    y_orig -- np.array containing all image labels
    """

    y_orig = [] # store class number
    X_orig = []

    for category in os.listdir(directory):
        flower_path = os.path.join(directory, category)
        for file_name in os.listdir(flower_path):
            img = cv2.cvtColor(cv2.imread(os.path.join(flower_path, file_name)), cv2.COLOR_BGR2RGB)
            if img is not None :
                resized=cv2.resize(img,(224,224))/255.0
                X_orig.append(resized)            
                y_orig.append(category)
    
    y_orig = np.array(y_orig)
    # y_orig = y_orig.reshape((1, y_orig.shape[0]))
        
    return X_orig, y_orig



In [None]:
X, y = load_dataset(TEST_DATASET_PATH)

In [None]:
predictedLabels=list()

for i in range(len(X)):
  startTimeAverage=time.time()
  Img = X[i][np.newaxis, ...]
  Img=np.array(Img,  dtype=np.float32)
  result = tfliteInterpreter.set_tensor(inputDetails[0]['index'],Img)
  tfliteInterpreter.invoke()
  tfliteModelPredictions = tfliteInterpreter.get_tensor(outputDetails[0]['index'])
  endTimeAverage = time.time()
  predictedClass = np.argmax(tfliteModelPredictions[0], axis=-1)
  predictedClassName = classNames[predictedClass]
  predictedLabels.append(predictedClassName)


In [None]:
averageTimePerImage=(endTimeAverage-startTimeAverage)/len(predictedLabels)
print("Average time for making one prediction: ",averageTimePerImage )

In [None]:
def compute_accuracy(realLabels, predictedLabels):
  correctPredictions=0
  for i in range(len(realLabels)):
    if realLabels[i]==predictedLabels[i]:
      correctPredictions+=1
  accuracy=correctPredictions/len(realLabels)*100
  return accuracy


In [None]:
print("Lite model accuracy: ",compute_accuracy(y,predictedLabels))