In [1]:
import tensorflow as tf
import numpy as np
import cv2
import pickle
import dlib

### Initializing the Sequential Object of our model with all the architecture we used.

In [2]:
def get_model():
    
    model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters = 64,kernel_size = (5, 5),input_shape = (96, 96, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(filters = 64,kernel_size = (5, 5),input_shape = (96, 96, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(pool_size = (2, 2)),
    tf.keras.layers.Dropout(0.3),

    tf.keras.layers.Conv2D(filters = 128,kernel_size = (3, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(filters = 128,kernel_size = (3, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(pool_size = (2, 2)),
    tf.keras.layers.Dropout(0.3),

    tf.keras.layers.Conv2D(filters = 256,kernel_size = (3, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(filters = 256,kernel_size = (3, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(pool_size = (2, 2)),
    tf.keras.layers.Dropout(0.3),

    tf.keras.layers.Conv2D(filters = 512,kernel_size = (3, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(filters = 512,kernel_size = (3, 3),activation = 'relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(pool_size = (2, 2)),
    tf.keras.layers.Dropout(0.15),

    tf.keras.layers.Flatten(name='flatten'),
    tf.keras.layers.Dense(512,activation='relu',),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dropout(0.15),
    tf.keras.layers.Dense(7, activation='softmax')
    ])
    return model

### Loading our saved model weights and label encoder object to use them for real-time testing

In [3]:
# Load the saved weights
model = get_model()
model.load_weights("best_weights.h5") 

# Load LabelEncoder 
def load_object(name):
    pickle_obj = open(f"{name}.pck","rb")
    obj = pickle.load(pickle_obj)
    return obj

Le = load_object("LabelEncoder")

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


### Defining various utility functions to be used for prediction.

In [4]:
# For preprocessing of our test images while realtime prediction.
def ProcessImage(image):
    image = tf.convert_to_tensor(image)
    image = tf.image.resize(image , [96 , 96] , method="bilinear")
    image = tf.expand_dims(image , 0)
    return image

# For realtime prediction
def RealtimePrediction(image , model, encoder_):
    prediction = model.predict(image)
    prediction = np.argmax(prediction , axis = 1)
    return encoder_.inverse_transform(prediction)[0]

# Taking the normal dlib frame to our opencv (x,y,w,h) 4-tuple bounding box.
def rect_to_bb(rect):
    x = rect.left()
    y = rect.top()
    w = rect.right() - x
    h = rect.bottom() - y
    return (x, y, w, h)

### Video capturing using Open-CV and making predictions.

In [9]:
cam = cv2.VideoCapture(0)
detector = dlib.get_frontal_face_detector()

while True :
    ret , frame = cam.read() 
    if not ret :
        break
    gray = cv2.cvtColor( frame , cv2.COLOR_BGR2GRAY)

    rects = detector(gray , 0)

    if len(rects) >= 1 :
        for rect in rects :
            (x , y , w , h) = rect_to_bb(rect)
            img = gray[y-10 : y+h+10 , x-10 : x+w+10]
            
            if img.shape[0] == 0 or img.shape[1] == 0 :
                cv2.imshow("Frame", frame)
                
            else :
                img = cv2.cvtColor(img , cv2.COLOR_GRAY2RGB)
                img = ProcessImage(img)
                output = RealtimePrediction(img , model , Le)
                cv2.rectangle(frame, (x, y), (x+w, y+h),(0, 255, 0), 2)
                z = y - 15 if y - 15 > 15 else y + 15
                cv2.putText(frame, str(output), (x, z), cv2.FONT_HERSHEY_SIMPLEX,0.75, (0, 255, 0), 2)
                
        cv2.imshow("Frame", frame)
            
    else :
        cv2.imshow("Frame", frame)
        
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
        
cam.release()
cv2.destroyAllWindows()

