### 1. Dependencies

In [1]:
import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import argparse
import os
from datetime import datetime
import shutil
import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import cv2

from utils_logging import setup_logger

### 2. Choose between Recasens or GazeNet

- Idea is you can just swap 
models.recasens, dataloader.recasens, training.train_recasens, etc...
- with the following
models.gazenet, dataloader.gazenet, training.train_gazenet

In [2]:
from models.chong import ModelSpatial
from models.__init__ import save_checkpoint, resume_checkpoint
from dataloader.chong import GazeDataset, GooDataset
from dataloader import chong_imutils
from training.train_chong import train, test, GazeOptimizer

  from .collection import imread_collection_wrapper


In [3]:
# Logger will save the training and test errors to a .log file 
logger = setup_logger(name='first_logger',
                      log_dir ='./logs/',
                      log_file='train_chong_gooreal.log',
                      log_format = '%(asctime)s %(levelname)s %(message)s',
                      verbose=True)

### 3. Dataloaders
- Choose between GazeDataset (Gazefollow dataset) or GooDataset (GooSynth/GooReal)
- Set paths to image directories and pickle paths. For Gazefollow, images_dir and test_images_dir should be the same and both lead to the path containing the train and test folders.

In [4]:
# Dataloaders for GOO-Synth
batch_size=32
workers=12

images_dir = '/hdd/HENRI/goosynth/1person/GazeDatasets/'
pickle_path = '/hdd/HENRI/goosynth/picklefiles/trainpickle2to19human.pickle'
test_images_dir = '/hdd/HENRI/goosynth/test/'
test_pickle_path = '/hdd/HENRI/goosynth/picklefiles/testpickle120.pickle'

train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=16)

test_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=batch_size//2,
                            shuffle=False, num_workers=8)

Number of Images: 172800
Number of Images: 19200


In [4]:
# Dataloaders for GOO-Real
batch_size=4
workers=12

images_dir = '/home/shashimal/Desktop/gooreal/finalrealdatasetImgsV2/'
pickle_path = '/home/shashimal/Desktop/gooreal/oneshotrealhumansNew.pickle'
test_images_dir = '/home/shashimal/Desktop/gooreal/finalrealdatasetImgsV2/'
test_pickle_path = '/home/shashimal/Desktop/gooreal/testrealhumansNew.pickle'
train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=16)

test_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=batch_size//2,
                            shuffle=False, num_workers=8)

Number of Images: 2450
Number of Images: 2146




In [4]:
# Dataloaders for GAZE

batch_size=32
workers=12
testbatchsize=16

images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/train_annotations.mat'
test_images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
test_pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/test_annotations.mat'

train_set = GazeDataset(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=16)

test_set = GazeDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=batch_size//2,
                            shuffle=False, num_workers=8)

In [6]:
torch.cuda.empty_cache()

### 4. Load Model and Set Training Hyperparameters
- For Gazefollow, the model requires the alexnet_places365 pretrained model, provided here: https://urlzs.com/ytKK3
- When resuming training, set to True and set the resume_path for the saved model.
- Here, logging module is initialized (logger) to save training and testing errors.

In [5]:
#!wget https://www.dropbox.com/s/s9y65ajzjz4thve/initial_weights_for_spatial_training.pt
init_weights = 'initial_weights_for_spatial_training.pt'

# Loads model
print("==> Constructing model")
net = ModelSpatial()
net.cuda()

# Hyperparameters
start_epoch = 0
max_epoch = 5
learning_rate = 3e-4

# Initial weights chong
print("==> Loading initial weights")
model_dict = net.state_dict()
pretrained_dict = torch.load(init_weights)
pretrained_dict = pretrained_dict['model']
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

# Initializes Optimizer
gaze_opt = GazeOptimizer(net, learning_rate)
optimizer = gaze_opt.getOptimizer(start_epoch)

# Resuming Training
resume_training = False
resume_path = './saved_models/chong_goosynth/model_epoch25.pth.tar'
if resume_training:
    net, optimizer, _ = resume_checkpoint(net, optimizer,resume_path)
    test(net, test_data_loader,logger, save_output=True)

==> Constructing model
==> Loading initial weights


In [6]:
test(net, test_data_loader,logger, save_output=True)

  0%|          | 0/1073 [00:19<?, ?it/s]


ValueError: too many values to unpack (expected 8)

### 5. Training the Model
- Determine in which epochs do you want to save the model, as you might not want to save every epoch
- Training and test errors can be accessed in the logs directory set up earlier

In [8]:
best_l2 = np.inf

for epoch in range(1,5):

    # Update optimizer
    optimizer = gaze_opt.getOptimizer(epoch)

    # Train model
    print('training')
    train(net, train_data_loader, optimizer, epoch, logger)

    # Evaluate model
    #scores = test(net, test_data_loader, logger)
    
    # Save model+optimizer with best L2 Scorehttp://localhost:8888/notebooks/train_chong.ipynb#
    #if scores[1] < best_l2:
    #    best_l2 = scores[1]
    #    save_path = './saved_models/chong_gooreal_notrained/'
    #    save_checkpoint(net, optimizer, 420, save_path)

training


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
 16%|█▌        | 99/613 [03:56<19:47,  2.31s/it]68.91891025543212
 32%|███▏      | 199/613 [07:43<15:50,  2.30s/it]68.11870040893555
 49%|████▉     | 299/613 [11:26<10:08,  1.94s/it]67.11412971496583
 65%|██████▌   | 399/613 [14:48<07:56,  2.23s/it]66.19479419708252
 81%|████████▏ | 499/613 [18:35<04:16,  2.25s/it]66.49509391784667
 98%|█████████▊| 599/613 [21:58<00:26,  1.92s/it]65.8030415725708
100%|██████████| 613/613 [22:25<00:00,  2.19s/it]


training


 16%|█▌        | 99/613 [03:13<16:32,  1.93s/it]65.58567909240723
 32%|███▏      | 199/613 [06:26<13:19,  1.93s/it]65.08650608062744
 49%|████▉     | 299/613 [09:39<10:03,  1.92s/it]64.79565208435059
 65%|██████▌   | 399/613 [12:52<06:51,  1.92s/it]64.39537612915039
 81%|████████▏ | 499/613 [16:04<03:38,  1.92s/it]65.25406341552734
 98%|█████████▊| 599/613 [19:16<00:26,  1.92s/it]64.75374084472656
100%|██████████| 613/613 [19:43<00:00,  1.93s/it]

training



 16%|█▌        | 99/613 [03:13<16:26,  1.92s/it]64.12285961151123
 32%|███▏      | 199/613 [06:26<13:16,  1.92s/it]64.52913795471191
 49%|████▉     | 299/613 [09:38<10:00,  1.91s/it]63.733947868347165
 65%|██████▌   | 399/613 [12:52<07:20,  2.06s/it]63.739145545959474
 81%|████████▏ | 499/613 [16:00<03:33,  1.87s/it]64.12066928863526
 98%|█████████▊| 599/613 [19:08<00:26,  1.87s/it]63.93507568359375
100%|██████████| 613/613 [19:34<00:00,  1.92s/it]

training



 16%|█▌        | 99/613 [03:06<16:01,  1.87s/it]64.13455711364746
 32%|███▏      | 199/613 [06:14<12:54,  1.87s/it]63.66544696807861
 49%|████▉     | 299/613 [09:21<09:47,  1.87s/it]63.89187271118164
 65%|██████▌   | 399/613 [12:30<06:39,  1.87s/it]63.60672634124756
 81%|████████▏ | 499/613 [15:37<03:33,  1.87s/it]63.72061542510986
 98%|█████████▊| 599/613 [18:50<00:26,  1.87s/it]63.89179275512695
100%|██████████| 613/613 [19:16<00:00,  1.89s/it]


In [9]:
test(net, test_data_loader,logger, save_output=True)

100%|██████████| 135/135 [00:16<00:00,  8.39it/s]
average error: [0.8613018095794566, 0.1459625192366771, 26.391563159013245]


[0.8613018095794566, 0.1459625192366771, 26.391563159013245]

In [9]:
torch.save(net.state_dict(), 'chong.pth')

In [10]:
torch.save(net, 'chong_model.pth')

In [2]:
import cv2
image = cv2.imread('/home/shashimal/Downloads/cam00000_img00000.jpg', cv2.IMREAD_COLOR)
print(image.shape)

(1080, 1920, 3)
