# Import Libs

In [12]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import cv2

# Configuration

In [13]:
classes = [
    "angry", 
    "disgust", 
    "fear", 
    "happy", 
    "sad", 
    "surprise", 
    "neutral"
]

In [14]:
INFER_ON_GPU = True

# Loading Model

In [15]:
MODEL_PATH = "./model/model_epoch50_lr0.01_batch8.pt"

In [16]:
# load model skeleton
model = models.mobilenet.mobilenet_v2()

In [17]:
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, len(classes))

In [18]:
# load model based on device type
device = None
if INFER_ON_GPU:
    device = torch.device("cuda")
    model.load_state_dict(torch.load(MODEL_PATH))
    model = model.to(device)
else:
    device = torch.device("cpu")
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))

print(f'using {device}')

using cuda


# Real Time Video Inference

In [19]:
# tranformations
preprocess = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(3),
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [20]:
# Start inference

cap = cv2.VideoCapture(0)

with torch.no_grad():    
    while cap.isOpened():
        ret, frame = cap.read() # (H, W, C)
        if ret:
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
            faces = face_cascade.detectMultiScale(gray, 1.3, 5)
            for (x, y, w, h) in faces:
                roi_face_frame = frame[y:y+h, x:x+w]
                

                # preprocessing frame
                img = torch.from_numpy(roi_face_frame)
                img = img.permute(2, 0, 1) # (H, W, C) -> (C, H, W), channel last -> channel first
                img = preprocess(img).to(device)
                img.unsqueeze_(0)

                # make detection
                output = model(img)
                _, preds = torch.max(output, 1)        
                detected_emotion = classes[preds]


                cv2.rectangle(frame, (x, y), (x+w, y+h), (255,0,0), 2)
#                 print(detected_emotion)
                cv2.putText(frame, detected_emotion, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.65, (0, 0, 255), 2)
                
            cv2.imshow('Emotion Recognition', frame)

            if cv2.waitKey(10) & 0xFF == ord('q'):
                break
            
cap.release()
cv2.destroyAllWindows()