# MPIIGaze Cross Data Generalization

This notebook will guide you through the process of recreating the generalization with MPIIGaze.

Next, run each following code block one by one.

## Data Downloading
Run this code block to download the dataset.

In [None]:
!wget http://datasets.d2.mpi-inf.mpg.de/MPIIGaze/MPIIFaceGaze_normalized.zip
!unzip MPIIFaceGaze_normalized.zip

## Boilerplate
Run this code the initialize the models and data loaders for later use.

In [None]:
'''
Parses frames out of a video file for the use of testing the images.
'''

import numpy as np
import cv2
import av
import mediapipe as mp
import math
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from threading import local
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import h5py
import torch.utils.data as data
import scipy.io as sio
from PIL import Image
import os
import os.path
import torchvision.transforms as transforms

# This code is converted from https://github.com/CSAILVision/GazeCapture/blob/master/code/faceGridFromFaceRect.m

# Given face detection data, generate face grid data.
#
# Input Parameters:
# - frameW/H: The frame in which the detections exist
# - gridW/H: The size of the grid (typically same aspect ratio as the
#     frame, but much smaller)
# - labelFaceX/Y/W/H: The face detection (x and y are 0-based images
#     coordinates)
# - parameterized: Whether to actually output the grid or just the
#     [x y w h] of the 1s square within the gridW x gridH grid.
# https://drive.google.com/drive/folders/1ZcYb4eH2jPndS5nkqQFcLHdGNM9dTF5C?usp=sharing
# https://drive.google.com/file/d/gpip1UdhuJ_bulreFyGa8CziK4tdwXeHC2PIh/view?usp=sharing

def generate_centered_face_grid(gridW, gridH):
    labelFaceGrid = np.zeros(gridW * gridH)
    grid = np.zeros((gridH, gridW))

    xLo = 7
    yLo = 7

    xHi = 18
    yHi = 18

    faceLocation = np.ones((yHi - yLo, xHi - xLo))
    grid[yLo:yHi, xLo:xHi] = faceLocation

    # Flatten the grid.
    grid = np.transpose(grid)
    labelFaceGrid = grid.flatten()

    return labelFaceGrid

def faceGridFromFaceRect(frameW, frameH, gridW, gridH, labelFaceX, labelFaceY, labelFaceW, labelFaceH, parameterized):

    scaleX = gridW / frameW
    scaleY = gridH / frameH

    if parameterized:
      labelFaceGrid = np.zeros(4)
    else:
      labelFaceGrid = np.zeros(gridW * gridH)

    grid = np.zeros((gridH, gridW))

    # Use one-based images coordinates.
    xLo = round(labelFaceX * scaleX)
    yLo = round(labelFaceY * scaleY)
    w = round(labelFaceW * scaleX)
    h = round(labelFaceH * scaleY)

    if parameterized:
        labelFaceGrid = [xLo, yLo, w, h]
    else:
        xHi = xLo + w
        yHi = yLo + h

        # Clamp the values in the range.
        xLo = int(min(gridW, max(0, xLo)))
        xHi = int(min(gridW, max(0, xHi)))
        yLo = int(min(gridH, max(0, yLo)))
        yHi = int(min(gridH, max(0, yHi)))

        faceLocation = np.ones((yHi - yLo, xHi - xLo))
        grid[yLo:yHi, xLo:xHi] = faceLocation

        # Flatten the grid.
        grid = np.transpose(grid)
        labelFaceGrid = grid.flatten()

    return labelFaceGrid

def get_face_grid(face, frameW, frameH, gridSize):
    faceX,faceY,faceW,faceH = face

    return faceGridFromFaceRect(frameW, frameH, gridSize, gridSize, faceX, faceY, faceW, faceH, True)


def crop_to_bounds(img, bounds):
    [x, y, w, h] = bounds
    cropped = img[y:y + h, x:x + w]
    return cropped


gridSize = 25

def get_frames(video_file, stream=None):
    """
    Parses all frames out of the given video file and returns an array of PIL images.
    """

    container = av.open(video_file)
    video = container.streams.video[0]

    to_return = []
    for idx, frame in enumerate(container.decode(video)):
        image = cv2.cvtColor(frame.to_rgb().to_ndarray(), cv2.COLOR_RGB2BGR)
        frame_time = float(frame.pts * video.time_base)
        if stream is None or not callable(stream):
            to_return.append([image, frame_time])
        else:
            stream(image, frame_time, idx)
    container.close()

    return to_return

detector_storage = local()

def detect_features(np_img):
    if not hasattr(detector_storage, "detector"):
        detector_storage.base_options = python.BaseOptions(model_asset_path='detector.tflite')
        detector_storage.options = vision.FaceDetectorOptions(base_options=detector_storage.base_options)
        detector_storage.detector = vision.FaceDetector.create_from_options(detector_storage.options)
    
    mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np_img)
    dresult = detector_storage.detector.detect(mp_image)
    if len(dresult.detections) == 0:
        return [np_img, [], []]

    im_width = np_img.shape[1]
    im_height = np_img.shape[0]
    detection = dresult.detections[0]
    right_eye = detection.keypoints[0]
    left_eye = detection.keypoints[1]

    face_bbx = detection.bounding_box
    built_face_bbx = [face_bbx.origin_x, face_bbx.origin_y, face_bbx.width, face_bbx.height]
    face_size = face_bbx.width

    right_eye_px = [math.floor(right_eye.x * im_width), math.floor(right_eye.y * im_height)]
    left_eye_px = [math.floor(left_eye.x * im_width), math.floor(left_eye.y * im_height)]

    eye_ratio = math.ceil(face_size / 8)
    right_eye_bbx = [right_eye_px[0] - eye_ratio, right_eye_px[1] - eye_ratio, eye_ratio * 2, eye_ratio * 2]
    left_eye_bbx = [left_eye_px[0] - eye_ratio, left_eye_px[1] - eye_ratio, eye_ratio * 2, eye_ratio * 2]

    return np_img, [built_face_bbx], [[right_eye_bbx, left_eye_bbx],
                                      get_face_grid(built_face_bbx, im_width, im_height, 25)]

'''
Pytorch model for the iTracker.

Author: Petr Kellnhofer ( pkel_lnho (at) gmai_l.com // remove underscores and spaces), 2018. 

Website: http://gazecapture.csail.mit.edu/

Cite:

Eye Tracking for Everyone
K.Krafka*, A. Khosla*, P. Kellnhofer, H. Kannan, S. Bhandarkar, W. Matusik and A. Torralba
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016

@inproceedings{cvpr2016_gazecapture,
Author = {Kyle Krafka and Aditya Khosla and Petr Kellnhofer and Harini Kannan and Suchendra Bhandarkar and Wojciech Matusik and Antonio Torralba},
Title = {Eye Tracking for Everyone},
Year = {2016},
Booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}
}

'''


class ItrackerImageModel(nn.Module):
    # Used for both eyes (with shared weights) and the face (with unqiue weights)
    def __init__(self):
        super(ItrackerImageModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.CrossMapLRN2d(size=5, alpha=0.0001, beta=0.75, k=1.0),
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2, groups=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.CrossMapLRN2d(size=5, alpha=0.0001, beta=0.75, k=1.0),
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 64, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return x

class FaceImageModel(nn.Module):
    
    def __init__(self):
        super(FaceImageModel, self).__init__()
        self.conv = ItrackerImageModel()
        self.fc = nn.Sequential(
            nn.Linear(12*12*64, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            )

    def forward(self, x):
        x = self.conv(x)
        x = self.fc(x)
        return x

class FaceGridModel(nn.Module):
    # Model for the face grid pathway
    def __init__(self, gridSize = 25):
        super(FaceGridModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(gridSize * gridSize, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



class ITrackerModel(nn.Module):


    def __init__(self):
        super(ITrackerModel, self).__init__()
        self.eyeModel = ItrackerImageModel()
        self.faceModel = FaceImageModel()
        self.gridModel = FaceGridModel()
        # Joining both eyes
        self.eyesFC = nn.Sequential(
            nn.Linear(2*12*12*64, 128),
            nn.ReLU(inplace=True),
            )
        # Joining everything
        self.fc = nn.Sequential(
            nn.Linear(128+64+128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 2),
            )

    def forward(self, faces, eyesLeft, eyesRight, faceGrids):
        # Eye nets
        xEyeL = self.eyeModel(eyesLeft)
        xEyeR = self.eyeModel(eyesRight)
        # Cat and FC
        xEyes = torch.cat((xEyeL, xEyeR), 1)
        xEyes = self.eyesFC(xEyes)

        # Face net
        xFace = self.faceModel(faces)
        xGrid = self.gridModel(faceGrids)

        # Cat all
        x = torch.cat((xEyes, xFace, xGrid), 1)
        x = self.fc(x)
        
        return x

'''
Data loader for the iTracker.
Use prepareDataset.py to convert the dataset from http://gazecapture.csail.mit.edu/ to proper format.

Author: Petr Kellnhofer ( pkel_lnho (at) gmai_l.com // remove underscores and spaces), 2018. 

Website: http://gazecapture.csail.mit.edu/

Cite:

Eye Tracking for Everyone
K.Krafka*, A. Khosla*, P. Kellnhofer, H. Kannan, S. Bhandarkar, W. Matusik and A. Torralba
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016

@inproceedings{cvpr2016_gazecapture,
Author = {Kyle Krafka and Aditya Khosla and Petr Kellnhofer and Harini Kannan and Suchendra Bhandarkar and Wojciech Matusik and Antonio Torralba},
Title = {Eye Tracking for Everyone},
Year = {2016},
Booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}
}

'''

MEAN_PATH = './'

def clamp_eyes_to_frame(box, width, height):
    if box[0] > 60000:
        box[0] = 0
    if box[1] > 60000:
        box[1] = 0
    
    if box[0] > 1000:
        box[0] = width
    if box[1] > 1000:
        box[1] = height

def loadMetadata(filename, silent = False, struct_as_record = False):
    try:
        # http://stackoverflow.com/questions/6273634/access-array-contents-from-a-mat-file-loaded-using-scipy-io-loadmat-python
        if not silent:
            print('\tReading metadata from %s...' % filename)
        metadata = sio.loadmat(filename, squeeze_me=True, struct_as_record=struct_as_record)
    except:
        print('\tFailed to read the meta file "%s"!' % filename)
        return None
    return metadata

class SubtractMean(object):
    """Normalize an tensor images with mean.
    """

    def __init__(self, meanImg):
        self.meanImg = transforms.ToTensor()(meanImg / 255)

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor images of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized images.
        """       
        return tensor.sub(self.meanImg)

class MPIIGazeData(data.Dataset):
    def __init__(self, dataPath: str, split: str = "all", imSize=(224,224), gridSize=(25, 25)):
        self.dataset_path = dataPath

        self.faceMean = loadMetadata(os.path.join(MEAN_PATH, 'mean_face_224.mat'))['image_mean']
        self.eyeLeftMean = loadMetadata(os.path.join(MEAN_PATH, 'mean_left_224.mat'))['image_mean']
        self.eyeRightMean = loadMetadata(os.path.join(MEAN_PATH, 'mean_right_224.mat'))['image_mean']
        self.imSize = imSize
        self.gridSize = gridSize
        self.screenDistanceCM = 50.8

        self.transformFace = transforms.Compose([
            transforms.Resize(self.imSize),
            transforms.ToTensor(),
            SubtractMean(meanImg=self.faceMean),
        ])
        self.transformEyeL = transforms.Compose([
            transforms.Resize(self.imSize),
            transforms.ToTensor(),
            SubtractMean(meanImg=self.eyeLeftMean),
        ])
        self.transformEyeR = transforms.Compose([
            transforms.Resize(self.imSize),
            transforms.ToTensor(),
            SubtractMean(meanImg=self.eyeRightMean),
        ])

        if split == 'train':
            subject_split = range(0, 14)
        elif split == 'test':
            subject_split = range(14, 15)
        else:
            subject_split = range(0, 15)

        self.indices = []
        with h5py.File(self.dataset_path, 'r') as f:
            for i in subject_split:
                count = f.get(f'p{i:02}/count')
                for n in range(0, count[()]):
                    self.indices.append([i, n])


    def transform_angle(self, angle):
        pitch = angle[0]
        yaw = angle[1]

        y_offset = -math.tan(pitch) * self.screenDistanceCM
        x_offset = math.tan(yaw) * self.screenDistanceCM

        return np.array([x_offset, y_offset], np.float32)


    def __getitem__(
            self,
            index: int):
        subject, idx = self.indices[index]
        with h5py.File(self.dataset_path, 'r') as f:
            image = f.get(f'p{subject:02}/image/{idx:04}')[()]
            eyes = f.get(f'p{subject:02}/eyes/{idx:04}')[()]
            gaze = f.get(f'p{subject:02}/gaze/{idx:04}')[()]

        reye, leye = eyes
        clamp_eyes_to_frame(reye, image.shape[1], image.shape[0])
        clamp_eyes_to_frame(leye, image.shape[1], image.shape[0])

        imFace = image
        imEyeR = crop_to_bounds(image, reye)
        imEyeL = crop_to_bounds(image, leye)

        imFace = Image.fromarray(cv2.cvtColor(imFace, cv2.COLOR_BGR2RGB))
        imEyeR = Image.fromarray(cv2.cvtColor(imEyeR, cv2.COLOR_BGR2RGB))
        imEyeL = Image.fromarray(cv2.cvtColor(imEyeL, cv2.COLOR_BGR2RGB))

        imFace = self.transformFace(imFace)
        imEyeR = self.transformEyeR(imEyeR)
        imEyeL = self.transformEyeL(imEyeL)

        face_grid = generate_centered_face_grid(*self.gridSize)
        face_grid = torch.FloatTensor(face_grid)
        gaze = self.transform_angle(gaze)
        gaze = torch.FloatTensor(gaze)
        return imFace, imEyeL, imEyeR, face_grid, gaze

    def __len__(self) -> int:
        return len(self.indices)

## Preprocessing
Run this code block to preprocess the dataset.

In [None]:
import argparse
import pathlib

import h5py
import numpy as np
import tqdm
import cv2
from concurrent.futures import ThreadPoolExecutor
import threading


def add_mat_data_to_hdf5(person_id: str, dataset_dir: pathlib.Path,
                         output_path: pathlib.Path, sem: threading.Semaphore) -> None:
    with h5py.File(dataset_dir / f'{person_id}.mat', 'r') as f_input:
        images = f_input.get('Data/data')[()]
        labels = f_input.get('Data/label')[()][:, :4]
    assert len(images) == len(labels) == 3000

    images = images.transpose(0, 2, 3, 1).astype(np.uint8)
    
    poses = labels[:, 2:]
    gazes = labels[:, :2]

    filtered_images = []
    filtered_poses = []
    filtered_gazes = []
    eyes = []
    for i, image in tqdm.tqdm(enumerate(images), desc="Eye processing"):
        im = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        img, face_bbx, eye_features = detect_features(im)
        if len(face_bbx) == 0:
            continue

        right_eye = eye_features[0][0]
        left_eye = eye_features[0][1]
        
        eyes.append([np.array(right_eye).astype(np.uint16), np.array(left_eye).astype(np.uint16)])
        filtered_images.append(images[i])
        filtered_poses.append(poses[i])
        filtered_gazes.append(gazes[i])

    sem.acquire()
    with h5py.File(output_path, 'a') as f_output:
        f_output.create_dataset(f'{person_id}/count', data=len(filtered_images))
        for index, (image, gaze, eye_pair,
                    pose) in tqdm.tqdm(enumerate(zip(filtered_images, filtered_gazes, eyes, filtered_poses)),
                                       leave=False):
            f_output.create_dataset(f'{person_id}/image/{index:04}',
                                    data=image)
            f_output.create_dataset(f'{person_id}/eyes/{index:04}',
                                    data=eye_pair)
            f_output.create_dataset(f'{person_id}/pose/{index:04}', data=pose)
            f_output.create_dataset(f'{person_id}/gaze/{index:04}', data=gaze)
    sem.release()


def main():
    output_dir = 'MPIIGaze'
    dataset = 'MPIIFaceGaze_normalizad'

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)
    output_path = output_dir / 'MPIIFaceGaze.h5'
    if output_path.exists():
        raise ValueError(f'{output_path} already exists.')

    dataset_dir = pathlib.Path(dataset)
    sem = threading.Semaphore()
    all_futures = []
    with ThreadPoolExecutor(max_workers=8) as executor:
        for person_id in range(15):
            person_id = f'p{person_id:02}'
            future = executor.submit(add_mat_data_to_hdf5, person_id, dataset_dir, output_path, sem)
            all_futures.append(future)
        
        with tqdm.tqdm(total=15) as pbar:
            def update(thing):
                nonlocal pbar
                pbar.update(1)
                return
            future.add_done_callback(update)


if __name__ == '__main__':
    main()

## Training
Run this code block to re-train the algorithm and get the results.

In [None]:
# Currently achieves 4.5059 (2.12) on the validation set.
# Currently achieves 14.8343 (3.85) on MPIIGaze validation.
# Currently achieves 8.9463 (2.99) on true MPIIGaze cross generalization

import shutil, os, time, argparse

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data

'''
Train/test code for iTracker.

Author: Petr Kellnhofer ( pkel_lnho (at) gmai_l.com // remove underscores and spaces), 2018. 

Website: http://gazecapture.csail.mit.edu/

Cite:

Eye Tracking for Everyone
K.Krafka*, A. Khosla*, P. Kellnhofer, H. Kannan, S. Bhandarkar, W. Matusik and A. Torralba
IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016

@inproceedings{cvpr2016_gazecapture,
Author = {Kyle Krafka and Aditya Khosla and Petr Kellnhofer and Harini Kannan and Suchendra Bhandarkar and Wojciech Matusik and Antonio Torralba},
Title = {Eye Tracking for Everyone},
Year = {2016},
Booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}
}

'''

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

# Change there flags to control what happens.
doLoad = False # Load checkpoint at the beginning
doTest = False
data_path = 'MPIIGaze/MPIIFaceGaze.h5'

device = 'cuda'
cpu_workers = 2
cuda_workers = 8
epochs = 25
batch_size = 64 if device == 'cpu' else torch.cuda.device_count() * 50

base_lr = 0.0001
momentum = 0.9
weight_decay = 1e-4
print_freq = 10
prec1 = 0
best_prec1 = 1e20
lr = base_lr

count_test = 0
count = 0


def main():
    global args, best_prec1, weight_decay, momentum

    model = ITrackerModel()
    model = torch.nn.DataParallel(model)
    model.to(torch.device(device))
    imSize=(224,224)
    cudnn.benchmark = True
    cudnn.deterministic = False   

    epoch = 0
    if doLoad:
        saved = load_checkpoint()
        if saved:
            print('Loading checkpoint for epoch %05d with loss %.5f (which is the mean squared error not the actual linear error)...' % (saved['epoch'], saved['best_prec1']))
            state = saved['state_dict']
            try:
                model.module.load_state_dict(state)
            except:
                model.load_state_dict(state)
            epoch = saved['epoch']
            best_prec1 = saved['best_prec1']
        else:
            print('Warning: Could not read checkpoint!')

    dataTrain = MPIIGazeData(dataPath = data_path, split='train', imSize = imSize)
    dataVal = MPIIGazeData(dataPath = data_path, split='test', imSize = imSize)
   
    train_loader = torch.utils.data.DataLoader(
        dataTrain,
        batch_size=batch_size, shuffle=True,
        num_workers=cpu_workers if device == 'cpu' else cuda_workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        dataVal,
        batch_size=batch_size, shuffle=False,
        num_workers=cpu_workers if device == 'cpu' else cuda_workers, pin_memory=True)

    criterion = nn.MSELoss().to(torch.device(device)) # Mean squared error loss function

    optimizer = torch.optim.SGD(model.parameters(), lr, # Uses stochastic gradient descent
                                momentum=momentum,
                                weight_decay=weight_decay)

    # Quick test
    if doTest:
        validate(val_loader, model, criterion, epoch)
        return

    for epoch in range(0, epoch):
        adjust_learning_rate(optimizer, epoch)
        
    for epoch in range(epoch, epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)


def train(train_loader, model, criterion,optimizer, epoch):
    global count
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    dev = torch.device(device)

    for i, (imFace, imEyeL, imEyeR, faceGrid, gaze) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        imFace = imFace.to(dev)
        imEyeL = imEyeL.to(dev)
        imEyeR = imEyeR.to(dev)
        faceGrid = faceGrid.to(dev)
        gaze = gaze.to(dev)

        # compute output
        output = model(imFace, imEyeL, imEyeR, faceGrid)

        loss = criterion(output, gaze)

        losses.update(loss.data.item(), imFace.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        count = count + 1

        print('Epoch (train): [{0}][{1}/{2}]\t'
            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
            'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
            'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
            epoch, i, len(train_loader), batch_time=batch_time,
            data_time=data_time, loss=losses))


def validate(val_loader, model, criterion, epoch):
    global count_test
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    lossesLin = AverageMeter()

    # switch to evaluate mode
    model.eval()
    end = time.time()
    for i, (imFace, imEyeL, imEyeR, faceGrid, gaze) in enumerate(val_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        imFace = imFace.to(device=device)
        imEyeL = imEyeL.to(device=device)
        imEyeR = imEyeR.to(device=device)
        faceGrid = faceGrid.to(device=device)
        gaze = gaze.to(device=device)

        # compute output
        with torch.no_grad():
            output = model(imFace, imEyeL, imEyeR, faceGrid)

        loss = criterion(output, gaze)

        lossLin = output - gaze
        lossLin = torch.mul(lossLin, lossLin)
        lossLin = torch.sum(lossLin, 1)
        lossLin = torch.mean(torch.sqrt(lossLin))

        losses.update(loss.data.item(), imFace.size(0))
        lossesLin.update(lossLin.item(), imFace.size(0))

        # compute gradient and do SGD step
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        print('Epoch (val): [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Error L2 {lossLin.val:.4f} ({lossLin.avg:.4f})\t'.format(
            epoch, i, len(val_loader), batch_time=batch_time,
            loss=losses, lossLin=lossesLin))

    return lossesLin.avg

CHECKPOINTS_PATH = '.'


def load_checkpoint(filename='checkpoint.pth.tar'):
    filename = os.path.join(CHECKPOINTS_PATH, filename)
    print(filename)
    if not os.path.isfile(filename):
        return None
    state = torch.load(filename)
    return state


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    if not os.path.isdir(CHECKPOINTS_PATH):
        os.makedirs(CHECKPOINTS_PATH, 0o777)
    bestFilename = os.path.join(CHECKPOINTS_PATH, 'best_' + filename)
    filename = os.path.join(CHECKPOINTS_PATH, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, bestFilename)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = base_lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.state_dict()['param_groups']:
        param_group['lr'] = lr


if __name__ == "__main__":
    main()
    print('DONE')
