### 1. Dependencies

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import DataParallel

import time
import os
import numpy as np
import json
import cv2
from PIL import Image, ImageOps
import random
from tqdm import tqdm
import operator
import itertools
from scipy.io import  loadmat
import logging
from scipy import signal

from utils import data_transforms
from utils import get_paste_kernel, kernel_map
from utils_logging import setup_logger

### 2. Choose between Recasens or GazeNet

- Idea is you can just swap 
models.recasens, dataloader.recasens, training.train_recasens, etc...
- with the following
models.gazenet, dataloader.gazenet, training.train_gazenet

In [2]:
from models.recasens import GazeNet
from dataloader.recasens import GooDataset, GazeDataset
from models.__init__ import save_checkpoint, resume_checkpoint
from training.train_recasens import train, test, GazeOptimizer

  from .collection import imread_collection_wrapper


In [3]:
# Logger will save the training and test errors to a .log file 
logger = setup_logger(name='first_logger', 
                      log_dir ='./logs/',
                      log_file='train_recasens_newauc.log',
                      log_format = '%(asctime)s %(levelname)s %(message)s',
                      verbose=True)

### 3. Dataloaders
- Choose between GazeDataset (Gazefollow dataset) or GooDataset (GooSynth/GooReal)
- Set paths to image directories and pickle paths. For Gazefollow, images_dir and test_images_dir should be the same and both lead to the path containing the train and test folders.

In [4]:
# Dataloaders for GazeFollow
batch_size=64
workers=12
testbatchsize=batch_size*2

images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/train_annotations.mat'
test_images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
test_pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/test_annotations.mat'

train_set = GazeDataset(images_dir, pickle_path, 'train')
train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers)

val_set = GazeDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = torch.utils.data.DataLoader(val_set, batch_size=testbatchsize, num_workers=workers, shuffle=False)

In [4]:
# Dataloaders for GOO-Synth
batch_size=32
workers=12
testbatchsize=32

images_dir = '/hdd/HENRI/goosynth/1person/GazeDatasets/'
pickle_path = '/hdd/HENRI/goosynth/picklefiles/trainpickle2to19human.pickle'
test_images_dir = '/hdd/HENRI/goosynth/test/'
test_pickle_path = '/hdd/HENRI/goosynth/picklefiles/testpickle120.pickle'

train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers)

val_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = torch.utils.data.DataLoader(val_set, batch_size=testbatchsize, num_workers=workers, shuffle=False)

Total images in this set: 172800
Total images in this set: 19200


In [4]:
# Dataloaders for GOO-Real

batch_size=4
workers=4
testbatchsize=2


images_dir = '/home/shashimal/Desktop/gooreal/finalrealdatasetImgsV2/'
pickle_path = '/home/shashimal/Desktop/gooreal/oneshotrealhumansNew.pickle'
test_images_dir = '/home/shashimal/Desktop/gooreal/finalrealdatasetImgsV2/'
test_pickle_path = '/home/shashimal/Desktop/gooreal/testrealhumansNew.pickle'
train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers)

val_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = torch.utils.data.DataLoader(val_set, batch_size=testbatchsize, num_workers=workers, shuffle=False)

Total images in this set: 2450
Total images in this set: 2146


In [6]:
img, bbox, eyes_loc, shifted_grids, eyes, idx, eyes_bbox, gaze_final = next(iter(train_data_loader))

In [18]:
eyes

tensor([[0.2748, 0.2166],
        [0.4209, 0.6095],
        [0.3969, 0.2162],
        [0.4537, 0.6530]], dtype=torch.float64)

### 4. Load Model and Set Training Hyperparameters
- For Gazefollow, the model requires the alexnet_places365 pretrained model, provided here: https://urlzs.com/ytKK3
- When resuming training, set to True and set the resume_path for the saved model.
- Here, logging module is initialized (logger) to save training and testing errors.

In [5]:
# Loads model
net = GazeNet(placesmodel_path='./whole_alexnet_places365.pth')
net.cuda()

# Hyperparameters
start_epoch = 0
max_epoch = 5
learning_rate = 1e-4

# Initializes Optimizer
gaze_opt = GazeOptimizer(net, learning_rate)
optimizer = gaze_opt.getOptimizer(start_epoch)

# Is training resumed? If so, set the resume_path and set flag to True
resume_training = False
resume_path = './saved_models/completed/recasens_goo_epoch4.pth.tar'
if resume_training :
    net, optimizer, _ = resume_checkpoint(net, optimizer, resume_path)
    test(net, test_data_loader,logger, save_output=True)



In [6]:
test(net, test_data_loader,logger, save_output=True)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  0%|          | 1/1073 [00:05<1:31:28,  5.12s/it]

tensor([ 0.5609,  0.7500, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  0%|          | 2/1073 [00:05<39:11,  2.20s/it]  

tensor([ 0.4891,  0.1875, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  0%|          | 4/1073 [00:05<15:07,  1.18it/s]

tensor([ 0.4875,  0.6375, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  0%|          | 5/1073 [00:06<16:50,  1.06it/s]

tensor([ 0.5469,  0.1917, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|          | 6/1073 [00:07<14:58,  1.19it/s]

tensor([ 0.6766,  0.6146, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|          | 7/1073 [00:07<11:14,  1.58it/s]

tensor([ 0.5766,  0.4292, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|          | 9/1073 [00:09<13:09,  1.35it/s]

tensor([ 0.1172,  0.8104, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|          | 10/1073 [00:09<12:09,  1.46it/s]

tensor([ 0.7406,  0.6167, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|          | 12/1073 [00:10<07:42,  2.30it/s]

[tensor(0.3646, dtype=torch.float64)]
0.39375 0.3645833333333333
tensor([ 0.5719,  0.3896, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1

  1%|          | 13/1073 [00:11<12:51,  1.37it/s]

tensor([ 0.5094,  0.1625, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|▏         | 14/1073 [00:12<12:32,  1.41it/s]

tensor([ 0.5687,  0.1917, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|▏         | 16/1073 [00:12<08:53,  1.98it/s]

tensor([ 0.5312,  0.3500, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000,
        -1.0000, -1.0000, -1.0000, -1.00

  1%|▏         | 16/1073 [00:13<15:14,  1.16it/s]


KeyboardInterrupt: 

### 5. Training the Model
- Determine in which epochs do you want to save the model, as you might not want to save every epoch
- Training and test errors can be accessed in the logs directory set up earlier

In [16]:
best_l2 = np.inf

for epoch in range(1,5):
    
    # Update optimizer
    optimizer = gaze_opt.getOptimizer(epoch)

    # Train model
    train(net, train_data_loader, optimizer, epoch, logger)
    
    # Evaluate model
    if epoch+1 in [1,5]:
        scores = test(net, test_data_loader, logger, save_output=True)
    
    # Save model+optimizer with best L2 Score
    #if scores[1] < best_l2:
    #    best_l2 = scores[1]
    #    save_path = './saved_models/recasens_gooreal_notrain/'
    #    save_checkpoint(net, optimizer, 0, save_path)

 16%|█▌        | 99/613 [02:40<14:21,  1.68s/it]3.1793133854866027
3.1793133854866027
 32%|███▏      | 199/613 [05:15<10:55,  1.58s/it]3.068253939151764
3.068253939151764
 49%|████▉     | 299/613 [07:47<07:20,  1.40s/it]2.965940194129944
2.965940194129944
 65%|██████▌   | 399/613 [10:20<04:45,  1.34s/it]2.985254397392273
2.985254397392273
 81%|████████▏ | 499/613 [12:50<02:35,  1.36s/it]2.9243754649162295
2.9243754649162295
 98%|█████████▊| 599/613 [15:17<00:14,  1.04s/it]2.946562349796295
2.946562349796295
100%|██████████| 613/613 [15:35<00:00,  1.53s/it]
 16%|█▌        | 99/613 [02:25<11:11,  1.31s/it]2.9239328265190125
2.9239328265190125
 32%|███▏      | 199/613 [04:51<08:29,  1.23s/it]2.9015071249008177
2.9015071249008177
 49%|████▉     | 299/613 [07:13<08:28,  1.62s/it]2.8320250606536863
2.8320250606536863
 65%|██████▌   | 399/613 [09:38<06:44,  1.89s/it]2.8565286231040954
2.8565286231040954
 81%|████████▏ | 499/613 [12:04<03:26,  1.81s/it]2.825448603630066
2.825448603630066
 98%|

In [17]:
test(net, test_data_loader, logger, save_output=True)

  6%|▋         | 68/1073 [00:52<12:58,  1.29it/s]


KeyboardInterrupt: 