# This notebook shows how to use the trained CNN models

In [1]:
import tensorflow as tf
from resizeimage import resizeimage
from PIL import Image, ImageOps
import numpy as np
import cv2

In [2]:
CLASS_NAMES_WITHOUT_DISGUST = ["Angry", "Fear", "Happy", "Sad", "Surprise", "Neutral"]

In [3]:
%%time
# Download the models from drive and put them in a folder named "models" for example
CNNModel3 = tf.keras.models.load_model("models/CNNmodel3")
CNNModel4 = tf.keras.models.load_model("models/CNNmodel4")
CNNModel5 = tf.keras.models.load_model("models/CNNmodel5")
CNNModel6 = tf.keras.models.load_model("models/CNNmodel6")

OSError: SavedModel file does not exist at: models/CNNmodel3/{saved_model.pbtxt|saved_model.pb}

In [7]:
def predict(image):
    rgb_image = cv2.cvtColor(np.float32(image), cv2.COLOR_GRAY2RGB)
    rgb_image = rgb_image.reshape(1, 48, 48, 3)
    image = image.reshape(1, 48, 48, 1)
    
    CNNModel3_predictions = CNNModel3.predict(image)
    CNNModel4_predictions = CNNModel4.predict(image)
    CNNModel5_predictions = CNNModel5.predict(image)
    
    CNNModel6_predictions = CNNModel6.predict(rgb_image) # This transfer learning model expects rgb input
    
    return CNNModel3_predictions, CNNModel4_predictions, CNNModel5_predictions, CNNModel6_predictions
    

def preprocess_image(image):
    image = resizeimage.resize_cover(image, [48,48])
    image = ImageOps.grayscale(image)
    image = np.array(image)
    image = image.astype("int32")
    image = image/255.0
    image = image.reshape(48, 48, 1)
    return image

def get_prediction_label(CNNModel3_predictions, CNNModel4_predictions, 
                         CNNModel5_predictions, CNNModel6_predictions):    
    return CLASS_NAMES_WITHOUT_DISGUST[np.argmax(CNNModel3_predictions)], CLASS_NAMES_WITHOUT_DISGUST[np.argmax(CNNModel4_predictions)], CLASS_NAMES_WITHOUT_DISGUST[np.argmax(CNNModel5_predictions)], CLASS_NAMES_WITHOUT_DISGUST[np.argmax(CNNModel6_predictions)]


In [None]:
image = preprocess_image() # Put the image in this function you want to use  

CNNModel3_predictions, CNNModel4_predictions, CNNModel5_predictions, CNNModel6_predictions = predict(image)
            
CNNModel3_predicted_label, CNNModel4_predicted_label, CNNModel5_predicted_label, CNNModel6_predicted_label = get_prediction_label(CNNModel3_predictions, CNNModel4_predictions, 
                                                                              CNNModel5_predictions, CNNModel6_predictions)
            
print(f"Predictions: \n{CNNModel3_predicted_label, CNNModel4_predicted_label, CNNModel5_predicted_label, CNNModel6_predicted_label}")