Action preservation measure using gaze estimation from
https://github.com/glefundes/mobile-face-gaze

In [None]:
# pip install scikit-video

In [None]:
import cv2
import torch
import utils
import numpy as np

from PIL import Image
from models import gazenet
from mtcnn import FaceDetector
from mtcnn.visualization_utils import show_bboxes

import matplotlib.pyplot as plt
import skvideo.io

In [None]:
# Insert file path to unfiltered video
unfiltered = skvideo.io.vread("T002_ActionsShorter_mini_3239_3347_Use-Radio-or-Gadget.mp4")  
plt.imshow(unfiltered[0])

In [None]:
# Insert file path to filtered video
filtered = skvideo.io.vread("submission_example.mp4")  
plt.imshow(filtered[0])

In [None]:
# Ensure that the dimensions are the same
print(unfiltered.shape)
print(filtered.shape)

In [None]:
#Preprocess the videos if necessary to ensure dimensions match
diff_frame = unfiltered.shape[1] - filtered.shape[1]
offset = int(diff_frame/2)
unfiltered = unfiltered[:, offset:-offset, offset:-offset, :]
print(unfiltered.shape)
print(filtered.shape)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = gazenet.GazeNet(device)
state_dict = torch.load('models/weights/gazenet.pth', map_location=device)
model.load_state_dict(state_dict)
print('Loaded model on {}'.format(device))

In [None]:
image_u = Image.fromarray(unfiltered[0])
face_detector = FaceDetector(device=device)
bboxes, landmarks = face_detector.detect(image_u)
image_bboxes = show_bboxes(image_u, bboxes, landmarks, width=1, eyeline=True)

plt.figure(figsize=(10,10))
plt.imshow(image_bboxes)

In [None]:
image_f = Image.fromarray(filtered[0])
image_bboxes = show_bboxes(image_f, bboxes, landmarks, width=1, eyeline=True)

plt.figure(figsize=(10,10))
plt.imshow(image_bboxes)

In [None]:
fig=plt.figure(figsize=(10, 10))
rows = np.ceil(np.sqrt(len(bboxes)))
columns = rows + 1
plot_idx = 1
faces_u, origins_u = [],[]
for bbox, lm in zip(bboxes, landmarks):
    # Detection confidence check
    if(bbox[-1] > 0.98):
        # Crop and normalize face
        face_u, gaze_origin_u, M_u  = utils.normalize_face(lm, np.asarray(image_u))
        faces_u.append(face_u)
        origins_u.append(gaze_origin_u)
        fig.add_subplot(rows, columns, plot_idx)
        plt.imshow(face_u)
        plot_idx += 1

In [None]:
fig=plt.figure(figsize=(10, 10))
rows = np.ceil(np.sqrt(len(bboxes)))
columns = rows + 1
plot_idx = 1
faces_f, origins_f = [],[]
for bbox, lm in zip(bboxes, landmarks):
    # Detection confidence check
    if(bbox[-1] > 0.98):
        # Crop and normalize face
        face_f, gaze_origin_f, M_f  = utils.normalize_face(lm, np.asarray(image_f))
        faces_f.append(face_f)
        origins_f.append(gaze_origin_f)
        fig.add_subplot(rows, columns, plot_idx)
        plt.imshow(face_f)
        plot_idx += 1

In [None]:
display = np.asarray(image_u)
for face_u, gaze_origin_u in zip(faces_u, origins_u):
    # Predict gaze
    with torch.no_grad():
        gaze_u = model.get_gaze(face_u)
        gaze_u = gaze_u[0].data.cpu()
        display = cv2.circle(display, gaze_origin_u, 1, (0, 255, 0), -1)            
        display = utils.draw_gaze(display, gaze_origin_u, gaze_u, length=100, color=(255,0,0), thickness=1)
        print(gaze_origin_u)
        print(gaze_u)
fig=plt.figure(figsize=(10, 10))
plt.imshow(display)

In [None]:
display = np.asarray(image_f)
for face_f, gaze_origin_f in zip(faces_f, origins_f):
    # Predict gaze
    with torch.no_grad():
        gaze_f = model.get_gaze(face_f)
        gaze_f = gaze_f[0].data.cpu()
        display = cv2.circle(display, gaze_origin_f, 1, (0, 255, 0), -1)            
        display = utils.draw_gaze(display, gaze_origin_f, gaze_f, length=100, color=(255,0,0), thickness=1)
        print(gaze_origin_f)
        print(gaze_f)
fig=plt.figure(figsize=(10, 10))
plt.imshow(display)

In [None]:
def rms_error(x, y):
    return np.sqrt(sum(np.square(x - y))/2)

In [None]:
# Measures difference in gaze
print(rms_error(gaze_u.numpy(), gaze_f.numpy()))

In [None]:
# TODO
# Calculate average difference in gaze over all frames of video