### 1. Dependencies

In [1]:
import torch
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import argparse
import os
from datetime import datetime
import shutil
import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import cv2

from utils_logging import setup_logger

### 2. Choose between Recasens or GazeNet

- Idea is you can just swap 
models.recasens, dataloader.recasens, training.train_recasens, etc...
- with the following
models.gazenet, dataloader.gazenet, training.train_gazenet

In [2]:
from models.chong import ModelSpatial
from models.__init__ import save_checkpoint, resume_checkpoint
from dataloader.chong import GazeDataset, GooDataset
from dataloader import chong_imutils
from training.train_chong import train, test, GazeOptimizer

  from .collection import imread_collection_wrapper


In [3]:
# Logger will save the training and test errors to a .log file 
logger = setup_logger(name='first_logger',
                      log_dir ='./logs/',
                      log_file='train_chong_gooreal.log',
                      log_format = '%(asctime)s %(levelname)s %(message)s',
                      verbose=True)

### 3. Dataloaders
- Choose between GazeDataset (Gazefollow dataset) or GooDataset (GooSynth/GooReal)
- Set paths to image directories and pickle paths. For Gazefollow, images_dir and test_images_dir should be the same and both lead to the path containing the train and test folders.

In [4]:
# Dataloaders for GOO-Synth
batch_size=32
workers=12

images_dir = '/hdd/HENRI/goosynth/1person/GazeDatasets/'
pickle_path = '/hdd/HENRI/goosynth/picklefiles/trainpickle2to19human.pickle'
test_images_dir = '/hdd/HENRI/goosynth/test/'
test_pickle_path = '/hdd/HENRI/goosynth/picklefiles/testpickle120.pickle'

train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=16)

test_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=batch_size//2,
                            shuffle=False, num_workers=8)

Number of Images: 172800
Number of Images: 19200


In [4]:
# Dataloaders for GOO-Real
batch_size=4
workers=12

images_dir = '/home/shashimal/Desktop/FYPGaze/gooreal/finalrealdatasetImgsV2/'
pickle_path = '/home/shashimal/Desktop/FYPGaze/gooreal/oneshotrealhumansNew.pickle'
test_images_dir = '/home/shashimal/Desktop/FYPGaze/gooreal/finalrealdatasetImgsV2/'
test_pickle_path = '/home/shashimal/Desktop/FYPGaze/gooreal/testrealhumansNew.pickle'

train_set = GooDataset(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=16)

test_set = GooDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=batch_size//2,
                            shuffle=False, num_workers=8)

Number of Images: 2450




Number of Images: 2146




In [15]:
torch.cuda.empty_cache()


In [4]:
# Dataloaders for GAZE

batch_size=32
workers=12
testbatchsize=16

images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/train_annotations.mat'
test_images_dir = '/home/eee198/Documents/datasets/GazeFollowData/'
test_pickle_path = '/home/eee198/Documents/datasets/GazeFollowData/test_annotations.mat'

train_set = GazeDataset(images_dir, pickle_path, 'train')
train_data_loader = DataLoader(dataset=train_set,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=16)

test_set = GazeDataset(test_images_dir, test_pickle_path, 'test')
test_data_loader = DataLoader(test_set, batch_size=batch_size//2,
                            shuffle=False, num_workers=8)

### 4. Load Model and Set Training Hyperparameters
- For Gazefollow, the model requires the alexnet_places365 pretrained model, provided here: https://urlzs.com/ytKK3
- When resuming training, set to True and set the resume_path for the saved model.
- Here, logging module is initialized (logger) to save training and testing errors.

In [5]:
#!wget https://www.dropbox.com/s/s9y65ajzjz4thve/initial_weights_for_spatial_training.pt
init_weights = 'initial_weights_for_spatial_training.pt'

# Loads model
print("==> Constructing model")
net = ModelSpatial()
net.cuda()

# Hyperparameters
start_epoch = 0
max_epoch = 5
learning_rate = 3e-4

# Initial weights chong
print("==> Loading initial weights")
model_dict = net.state_dict()
pretrained_dict = torch.load(init_weights)
pretrained_dict = pretrained_dict['model']
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

# Initializes Optimizer
gaze_opt = GazeOptimizer(net, learning_rate)
optimizer = gaze_opt.getOptimizer(start_epoch)

# Resuming Training
resume_training = False
resume_path = './saved_models/chong_goosynth/model_epoch25.pth.tar'
if resume_training:
    net, optimizer, _ = resume_checkpoint(net, optimizer,resume_path)
    test(net, test_data_loader,logger, save_output=True)

==> Constructing model
==> Loading initial weights


In [12]:
test(net, test_data_loader,logger, save_output=True)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  0%|          | 1/1073 [00:05<1:34:29,  5.29s/it]


KeyboardInterrupt: 

### 5. Training the Model
- Determine in which epochs do you want to save the model, as you might not want to save every epoch
- Training and test errors can be accessed in the logs directory set up earlier

In [6]:
best_l2 = np.inf

for epoch in range(1,5):

    # Update optimizer
    optimizer = gaze_opt.getOptimizer(epoch)

    # Train model
    print('training')
    train(net, train_data_loader, optimizer, epoch, logger)

    # Evaluate model
    #scores = test(net, test_data_loader, logger)
    
    # Save model+optimizer with best L2 Scorehttp://localhost:8888/notebooks/train_chong.ipynb#
    #if scores[1] < best_l2:
    #    best_l2 = scores[1]
    #    save_path = './saved_models/chong_gooreal_notrained/'
    #    save_checkpoint(net, optimizer, 420, save_path)

training


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
 16%|█▌        | 99/613 [05:20<27:19,  3.19s/it]68.27586071014404
 32%|███▏      | 199/613 [10:39<22:00,  3.19s/it]67.43212947845458
 49%|████▉     | 299/613 [15:57<16:28,  3.15s/it]66.64984302520752
 65%|██████▌   | 399/613 [19:25<07:24,  2.08s/it]66.35670455932618
 81%|████████▏ | 499/613 [22:53<03:57,  2.08s/it]65.80953327178955
 98%|█████████▊| 599/613 [26:20<00:28,  2.07s/it]65.7650089263916
100%|██████████| 613/613 [26:49<00:00,  2.63s/it]


training


 16%|█▌        | 99/613 [03:27<17:42,  2.07s/it]65.20166042327881
 32%|███▏      | 199/613 [06:53<14:14,  2.06s/it]65.2870754623413
 49%|████▉     | 299/613 [10:44<11:59,  2.29s/it]64.93747440338134
 65%|██████▌   | 399/613 [14:33<08:03,  2.26s/it]65.3558052444458
 81%|████████▏ | 499/613 [18:22<04:22,  2.30s/it]64.85508255004883
 98%|█████████▊| 599/613 [22:08<00:28,  2.01s/it]64.382360496521
100%|██████████| 613/613 [22:36<00:00,  2.21s/it]

training



 16%|█▌        | 99/613 [03:21<17:17,  2.02s/it]64.03812389373779
 32%|███▏      | 199/613 [06:43<13:53,  2.01s/it]64.13457027435302
 49%|████▉     | 299/613 [10:06<10:31,  2.01s/it]63.584960823059085
 65%|██████▌   | 399/613 [13:28<07:11,  2.02s/it]63.30810253143311
 81%|████████▏ | 499/613 [16:49<03:49,  2.01s/it]63.34888717651367
 98%|█████████▊| 599/613 [20:13<00:28,  2.01s/it]63.07262031555176
100%|██████████| 613/613 [20:42<00:00,  2.03s/it]

training



 16%|█▌        | 99/613 [03:23<17:12,  2.01s/it]62.31209442138672
 32%|███▏      | 199/613 [06:43<13:01,  1.89s/it]63.033398094177244
 49%|████▉     | 299/613 [09:51<09:50,  1.88s/it]63.33745651245117
 65%|██████▌   | 399/613 [13:01<06:39,  1.87s/it]63.7181665802002
 81%|████████▏ | 499/613 [16:07<03:32,  1.87s/it]62.06017261505127
 98%|█████████▊| 599/613 [19:14<00:26,  1.87s/it]62.6815690612793
100%|██████████| 613/613 [19:40<00:00,  1.93s/it]


In [7]:
test(net, test_data_loader,logger, save_output=True)

  0%|          | 0/1073 [00:02<?, ?it/s]


TypeError: forward() missing 1 required positional argument: 'objects'