In [None]:
import torch
from torchvision import transforms
import torch.nn as nn
torch.multiprocessing.set_sharing_strategy('file_system')

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import torchvision.models as models
from torchvision.ops import FeaturePyramidNetwork
import torchvision.transforms.functional as TF


from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader

import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
import warnings
import pickle
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import sys
import csv
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from pathlib import PureWindowsPath, PurePosixPath
from datetime import datetime
from numpy.linalg import norm
warnings.simplefilter(action='ignore', category=FutureWarning)

package_paths = [
    '/kaggle/input/mm-gaze-target-prediction-weights',
    '/kaggle/input/mm-gaze-target-prediction',
]

for pth in package_paths:
    sys.path.append(pth)
    
!cp -r /kaggle/input/mm-gaze-target-prediction/ /kaggle/working/mmgazetargetprediction/

In [None]:
!pip install efficientnet_pytorch
from efficientnet_pytorch import EfficientNet

In [None]:
from mmgazetargetprediction.utils import evaluation, misc

In [None]:
# SPDX-FileCopyrightText: 2022 Idiap Research Institute <contact@idiap.ch>
# SPDX-FileContributor: Anshul Gupta <anshul.gupta@idiap.ch>
# SPDX-License-Identifier: GPL-3.0
# =============================================================================
# model config
# =============================================================================
input_resolution = 224
output_resolution = 64

cone_mode = 'early'    # {'late', 'early'} fusion of person information
modality_dropout = True    # only used for attention model
pred_inout = True    # {set True for VideoAttentionTarget}
privacy = False     # {set True to train/test privacy-sensitive model}

# pytorch amp to speed up training and reduce memory usage
use_amp = False

modality = 'attention'
#modality = 'pose'
extended_model_present = True
use_sparse_dataset = False
flag_run_on_goo_dataset = True
human_centric_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/human-centric.pt'

Bbox_topk_no_of_bboxes = 3 
Bbox_confidence_score_threshold = 0.2

if use_sparse_dataset:
    image_dir = '/kaggle/input/goorealdataset/finalrealdatasetImgsV3Sparsed/finalrealdatasetImgsV3Sparsed'
else:
    image_dir = '/kaggle/input/goorealdataset/finalrealdatasetImgsV3/finalrealdatasetImgsV3'

In [None]:
# returns gaze cone; resnet/efficientnet + prediction head
class HumanCentric(nn.Module):
    def __init__(self, backbone = 'resnet'):
        super(HumanCentric, self).__init__()
        
        self.backbone = backbone
        self.feature_dim = 512  # the dimension of the CNN feature to represent each frame
        # Build Network Base
        if backbone == 'resnet':
            self.base_head = models.resnet18(pretrained=True)
            self.base_head = nn.Sequential(*list(self.base_head.children())[:-1])
        elif backbone == 'efficientnet':
            self.base_head = models.efficientnet_b0(pretrained=True)
            self.base_head = nn.Sequential(*list(self.base_head.children())[:-1])
        else:
            assert False, 'Incorrect backbone, please choose from [resnet, efficientnet]'
        
        # Build Network Head
        num_outputs = 2
        self.num_outputs = num_outputs
        dummy_head = torch.empty((1, 3, 224, 224))
        dummy_head = self.base_head(dummy_head)            
        '''
        self.head_new = nn.Sequential(
                        nn.Linear(dummy_head.size(1), self.feature_dim), 
                        nn.ReLU(inplace=True),
                        nn.Linear(self.feature_dim, num_outputs),
                        nn.Tanh())
        ''' 
        self.head_new = nn.Sequential(
                        nn.Linear(dummy_head.size(1), self.feature_dim), 
                        nn.GELU(),
                        nn.Linear(self.feature_dim, num_outputs),
                        nn.Tanh())
        
    def forward(self, head, gaze_field):
        # Model output
        h = self.base_head(head).squeeze(dim=-1).squeeze(dim=-1) # Nx512   
        head_embedding = h.clone()
        
        direction = self.head_new(h) 
        # convert to unit vector
        normalized_direction = direction / direction.norm(dim=1).unsqueeze(1)
        
        # generate gaze field map
        batch_size, channel, height, width = gaze_field.size()
        gaze_field = gaze_field.permute([0, 2, 3, 1]).contiguous()
        gaze_field = gaze_field.view([batch_size, -1, self.num_outputs])
        gaze_field = torch.matmul(gaze_field, normalized_direction.view([batch_size, self.num_outputs, 1]))
        gaze_cone = gaze_field.view([batch_size, height, width, 1])
        gaze_cone = gaze_cone.permute([0, 3, 1, 2]).contiguous()

        #gaze_cone = nn.ReLU()(gaze_cone)
        gaze_cone = nn.GELU()(gaze_cone)
        return gaze_cone, normalized_direction, head_embedding


# efficientnet followed by an FPN
class FeatureExtractor(nn.Module):
    
    def __init__(self, backbone_name):
        
        '''
        args:
        backbone_name: name of the backbone to be used; ex. 'efficientnet-b0'
        '''
        
        super(FeatureExtractor, self).__init__()
    
        self.backbone = EfficientNet.from_pretrained(backbone_name)
        if backbone_name=='efficientnet-b3':
            self.fpn = FeaturePyramidNetwork([32, 48, 136, 384], 64)
        elif backbone_name=='efficientnet-b2':
            self.fpn = FeaturePyramidNetwork([24, 48, 120, 352], 64)
        elif backbone_name=='efficientnet-b0' or backbone_name=='efficientnet-b1':
            self.fpn = FeaturePyramidNetwork([24, 40, 112, 320], 64)        
        
    def forward(self, x):
        
        features = self.backbone.extract_endpoints(x)
        
        # select features to use
        fpn_features = OrderedDict()
        fpn_features['reduction_2'] = features['reduction_2']
        fpn_features['reduction_3'] = features['reduction_3']
        fpn_features['reduction_4'] = features['reduction_4']
        fpn_features['reduction_5'] = features['reduction_5']
        
        # upsample features from efficientnet using an FPN to generate features at (H/4, W/4) resolution
        features = self.fpn(fpn_features)['reduction_2']
        
        return features


# simple prediction head that takes the features and gaze cone to regress the gaze target heatmap
class PredictionHead(nn.Module):
    
    def __init__(self, inchannels):
        super(PredictionHead, self).__init__()
        
        #self.act = nn.ReLU()
        self.act = nn.GELU()
        self.conv1 = nn.Conv2d(inchannels, inchannels, 3, padding=3, dilation=3)
        self.bn1 = nn.BatchNorm2d(inchannels)
        self.conv2 = nn.Conv2d(inchannels, inchannels, 3, padding=3, dilation=3)
        self.bn2 = nn.BatchNorm2d(inchannels)
        self.conv3 = nn.Conv2d(inchannels, inchannels, 3, padding=3, dilation=3)
        self.bn3 = nn.BatchNorm2d(inchannels)
        self.conv4 = nn.Conv2d(inchannels, inchannels, 3, padding=3, dilation=3)
        self.bn4 = nn.BatchNorm2d(inchannels)
        self.conv5 = nn.Conv2d(inchannels, inchannels//2, 3, padding=3, dilation=3)
        self.bn5 = nn.BatchNorm2d(inchannels//2)
        self.conv6 = nn.Conv2d(inchannels//2, inchannels//4, 3, padding=3, dilation=3)
        self.bn6 = nn.BatchNorm2d(inchannels//4)
        self.conv7 = nn.Conv2d(inchannels//4, 1, 1)

    def forward(self, x):
                
        # upsample the features to 64, 64
        x = nn.Upsample(size=(64,64), mode='bilinear', align_corners=False)(x)
        x = self.act(self.bn1(self.conv1(x)))
        
        # regress the heatmap
        x = self.act(self.bn2(self.conv2(x)))
        x = self.act(self.bn3(self.conv3(x)))
        x = self.act(self.bn4(self.conv4(x)))
        x = self.act(self.bn5(self.conv5(x)))
        x = self.act(self.bn6(self.conv6(x)))
        x = self.conv7(x)
        
        return x
        
# compress modality spatially
class CompressModality(nn.Module):
    
    def __init__(self, in_channels):
        super(CompressModality, self).__init__()
        
        self.act = nn.GELU()
        
        self.conv1 = nn.Conv2d(in_channels, 128, kernel_size=3, stride=2)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 256, kernel_size=3, stride=2)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, kernel_size=3, stride=2)
        self.bn3 = nn.BatchNorm2d(512)
    
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act(x)
        x = nn.MaxPool2d(x.shape[2])(x)

        return x.squeeze(dim=-1).squeeze(dim=-1)
    

# predicts in vs out gaze; CompressModality + Linear
class InvsOut(nn.Module):
    
    def __init__(self, in_channels):
        
        '''
        args:
        in_channels: number of input channels
        '''
        
        super(InvsOut, self).__init__()
        self.compress_inout = CompressModality(in_channels)
        """
        self.inout = nn.Sequential(nn.Linear(1024, 256),
                                   nn.ReLU(),
                                   nn.Linear(256, 1),
                                   nn.Sigmoid())
        """
        self.inout = nn.Sequential(nn.Linear(1024, 256),
                                   nn.GELU(),
                                   nn.Linear(256, 1),
                                   nn.Sigmoid())
    
    def forward(self, x, head_embedding):
        
        x = self.compress_inout(x)
        x = torch.cat([x, head_embedding], axis=1)
        x = self.inout(x)
        
        return x
    

# baseline model that takes a single modality and the gaze cone as input to predict a gaze target heatmap
class BaselineModel(nn.Module):
    
    def __init__(self, backbone_name, modality, cone_mode='early', pred_inout=False):
        
        '''
        args:
        backbone_name: name of the backbone to be used; ex. 'efficientnet-b0'
        cone_mode: early or late fusion of person information {'early', 'late'}
        pred_inout: predict an in vs out of frame gaze label
        '''
        
        super(BaselineModel, self).__init__()
        self.feature_extractor = FeatureExtractor(backbone_name)    
        self.prediction_head = PredictionHead(64)
        self.human_centric = HumanCentric()
        # load weights
        state_dict = torch.load(human_centric_weights)['model_state_dict']
        self.human_centric.load_state_dict(state_dict, strict=False)

        # add additional channels
        self.cone_mode = cone_mode
        if cone_mode=='early':
            input_layer = self.feature_extractor.backbone._conv_stem.weight
            self.feature_extractor.backbone._conv_stem.weight = torch.nn.Parameter(torch.cat([input_layer.clone(), input_layer.clone()[:,0:2,:,:]], axis=1))
        elif cone_mode=='late':
            self.cat_conv = nn.Conv2d(66, 64, 3, padding=1)
            
        # drop additional channels
        if modality == 'depth':
            self.feature_extractor.backbone._conv_stem.weight = torch.nn.Parameter(self.feature_extractor.backbone._conv_stem.weight[:,0:-2,:,:])
        
        self.pred_inout = pred_inout
        if pred_inout:
            self.in_vs_out_head = InvsOut(64)
    
    def forward(self, img, face, gaze_field, head_mask):
        
        # dummy predictions
        batch_size = img.shape[0]
        in_vs_out = torch.zeros(batch_size).cuda()
        direction = torch.zeros(batch_size, 2).cuda()
        
        # get gaze cone
        gaze_cone, direction, head_embedding = self.human_centric(face, gaze_field)
                
        if self.cone_mode=='early':
            x = torch.cat([img, gaze_cone, head_mask], dim=1)
        else:
            x = img
        
        # extract the features
        x = self.feature_extractor(x)
        
        if self.cone_mode=='late':
            x = torch.cat([x, gaze_cone, head_mask], dim=1)
            x = self.cat_conv(x)
            
        # apply the prediction head to get the heatmap
        hm = self.prediction_head(x)
        
        # apply the in vs out head
        if self.pred_inout:
            in_vs_out = self.in_vs_out_head(x, head_embedding)
        
        return hm, direction, in_vs_out


# attention based model. multiple modalities processed separately. output feature maps are weighted and added using predicted attention weights to predict a gaze target heatmap
class AttentionModelCombined(nn.Module):
    
    def __init__(self, cone_mode='early', pred_inout=False):
        
        '''
        args:
        cone_mode: early or late fusion of person information {'early', 'late'}
        pred_inout: predict an in vs out of frame gaze label
        '''
        
        super(AttentionModelCombined, self).__init__()
        self.feature_extractor_image = FeatureExtractor('efficientnet-b1')
        self.feature_extractor_depth = FeatureExtractor('efficientnet-b0')
        self.feature_extractor_pose = FeatureExtractor('efficientnet-b0')
        #self.feature_extractor_image = FeatureExtractor('efficientnet-b2')
        #self.feature_extractor_depth = FeatureExtractor('efficientnet-b1')
        #self.feature_extractor_pose = FeatureExtractor('efficientnet-b1')
        num_modalities = 3
        
        self.bn_image = nn.BatchNorm2d(64)
        self.bn_depth = nn.BatchNorm2d(64)
        self.bn_pose = nn.BatchNorm2d(64)
        
        additional_channels = 0
        if cone_mode=='late':
            additional_channels = 2
        self.Wv_image = nn.Conv2d(64+additional_channels, 64, kernel_size=3, padding=1)
        self.Wv_depth = nn.Conv2d(64+additional_channels, 64, kernel_size=3, padding=1)
        self.Wv_pose = nn.Conv2d(64+additional_channels, 64, kernel_size=3, padding=1)
        
        self.compress_image = CompressModality(64)
        self.compress_depth = CompressModality(64)
        self.compress_pose = CompressModality(64)
        self.attention_layer = nn.Sequential(nn.Linear(512*num_modalities, num_modalities),
                                             nn.Softmax()
                                             )
        
        self.human_centric = HumanCentric()
        # load weights
        state_dict = torch.load(human_centric_weights)['model_state_dict']
        self.human_centric.load_state_dict(state_dict, strict=False)
            
        # add additional channels
        self.cone_mode = cone_mode
        if cone_mode=='early':
            input_layer = self.feature_extractor_image.backbone._conv_stem.weight
            self.feature_extractor_image.backbone._conv_stem.weight = torch.nn.Parameter(torch.cat([input_layer.clone(), input_layer.clone()[:,0:2,:,:]], axis=1))
            input_layer = self.feature_extractor_depth.backbone._conv_stem.weight
            self.feature_extractor_depth.backbone._conv_stem.weight = torch.nn.Parameter(torch.cat([input_layer.clone(), input_layer.clone()[:,0:2,:,:]], axis=1))
            input_layer = self.feature_extractor_pose.backbone._conv_stem.weight
            self.feature_extractor_pose.backbone._conv_stem.weight = torch.nn.Parameter(torch.cat([input_layer.clone(), input_layer.clone()[:,0:2,:,:]], axis=1))
        
        # drop additional channels
        self.feature_extractor_depth.backbone._conv_stem.weight = torch.nn.Parameter(self.feature_extractor_depth.backbone._conv_stem.weight[:,0:-2,:,:])
        
        self.prediction_head = PredictionHead(64)
        self.output_act = nn.ReLU()
        
        self.pred_inout = pred_inout
        if pred_inout:
            self.in_vs_out_head = InvsOut(64)
    
    def forward(self, x, face, gaze_field, head_mask):
        
        # dummy predictions
        batch_size = x[0].shape[0]
        in_vs_out = torch.zeros(batch_size).cuda()
        direction = torch.zeros(batch_size, 2).cuda()
                
        # get gaze cone
        gaze_cone, direction, head_embedding = self.human_centric(face, gaze_field)
        
        # extract the features
        if self.cone_mode=='early':
            x_image = torch.cat([x[0], gaze_cone, head_mask], dim=1)
            x_depth = torch.cat([x[1], gaze_cone, head_mask], dim=1)
            x_pose = torch.cat([x[2], gaze_cone, head_mask], dim=1)
        else:
            x_image = x[0]
            x_depth = x[1]
            x_pose = x[2]

        x_image = self.feature_extractor_image(x_image)
        x_image = self.bn_image(x_image)
        x_depth = self.feature_extractor_depth(x_depth)
        x_depth = self.bn_depth(x_depth)
        x_pose = self.feature_extractor_pose(x_pose)
        x_pose = self.bn_pose(x_pose)
        
        if self.cone_mode=='late':
            x_image = torch.cat([x_image, gaze_cone, head_mask], dim=1)
            x_depth = torch.cat([x_depth, gaze_cone, head_mask], dim=1)
            x_pose = torch.cat([x_pose, gaze_cone, head_mask], dim=1)
        
        # get the values
        v_image = self.Wv_image(x_image)
        v_depth = self.Wv_depth(x_depth)
        v_pose = self.Wv_pose(x_pose)
                
        # get attention weights
        att_image = self.compress_image(v_image)
        att_depth = self.compress_depth(v_depth)
        att_pose = self.compress_pose(v_pose)
        att = torch.cat([att_image, att_depth, att_pose], dim=1)
        att = self.attention_layer(att).unsqueeze(2).unsqueeze(3).unsqueeze(4)    # add extra dimensions for weighting in the next step

        # weight values
        v_image = v_image * att[:, 0]
        v_depth = v_depth * att[:, 1]
        v_pose = v_pose * att[:, 2]
        x = v_image + v_depth + v_pose
        
        # apply the prediction head
        #hm = self.prediction_head(x)
        hm = self.output_act(self.prediction_head(x))
        
        # apply the in vs out head
        if self.pred_inout:
            in_vs_out = self.in_vs_out_head(x, head_embedding)
        
        #return hm, direction, in_vs_out, att, x
        return hm, direction, in_vs_out, att, x_image

In [None]:
class RegressionHeadBbox(nn.Module):
    
    def __init__(self, inchannels):
        super(RegressionHeadBbox, self).__init__()
        
        #self.act = nn.ReLU()
        self.act = nn.GELU()
        self.inchannels = inchannels
        self.conv1 = nn.Conv2d(inchannels, inchannels, 3, padding=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(inchannels)
        self.conv2 = nn.Conv2d(inchannels, inchannels, 3, padding=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(inchannels)
        self.conv3 = nn.Conv2d(inchannels, inchannels, 3, padding=1, dilation=1)
        self.bn3 = nn.BatchNorm2d(inchannels)
        self.conv4 = nn.Conv2d(inchannels, inchannels, 3, padding=1, dilation=1)
        self.bn4 = nn.BatchNorm2d(inchannels)
        self.out_conv = nn.Conv2d(inchannels, 4, 1)


    def forward(self, x):
                
        # upsample the features to 64, 64
        #x = nn.Upsample(size=(64,64), mode='bilinear', align_corners=False)(x)
        x = nn.Upsample(size=(self.inchannels,self.inchannels), mode='bilinear', align_corners=False)(x)
        x = self.act(self.bn1(self.conv1(x)))
        
        # regress the heatmap
        x = self.act(self.bn2(self.conv2(x)))
        x = self.act(self.bn3(self.conv3(x)))
        x = self.act(self.bn4(self.conv4(x)))
        x = self.act(self.out_conv(x))
        
        return x
    
class attentionModelBboxHead(nn.Module):
    def __init__(self, gaze_model, output_resolution=output_resolution):
        super(attentionModelBboxHead, self).__init__()
        self.gaze_model = gaze_model
        self.reg_head_bbox = RegressionHeadBbox(output_resolution)
        # From the paper ttfnet paper Size Regression section on page 4
        # this is nothing but the scalar s is a fixed scalar used 
        # to enlarge the predicted results for easier optimization. 
        # s = 16 is set in our experiments
        self.wh_offset_base = 16
        
    def forward(self,x, face, gaze_field, head_mask):
        """
        x, face, gaze_field, head_mask:  same shape gaze prediction model
        
        Outputs:
        outr: shape (1,2,64,64)
        """
        
        hm, direction, in_vs_out, att,x_out = self.gaze_model(x, face, gaze_field, head_mask)    
        outr = self.reg_head_bbox(x_out) * self.wh_offset_base
        return hm, direction, in_vs_out, att, outr
        

In [None]:
if flag_run_on_goo_dataset:
    
    """
    obj_test = pd.read_pickle('/kaggle/input/goorealdataset/testrealhumansNew.pickle', compression='infer')
    if use_sparse_dataset:
        obj_test = pd.read_pickle('/kaggle/input/goorealdataset/testrealhumansSparsedNew.pickle', compression='infer')
    df_rg_test = pd.DataFrame.from_records(obj_test)
    #df_test = df_rg_test
    ## print(len(df_test))
    # merge with pose csv file
    ## to get eye point in retailgaze dataset
    ## to filter frames with no pose detected both for train and inf
    column_names = ['filename', 'left_eye_x', 'left_eye_y', 'right_eye_x', 'right_eye_y', 'pose_min_x', 'pose_min_y', 'pose_max_x', 'pose_max_y']
    csv_path_test = '/kaggle/input/goorealposeanddepth/goo-real-test-pose/kaggle/working/goo-real-pose/master-pose.csv'
    if use_sparse_dataset:
        csv_path_test = '/kaggle/input/goorealposeanddepth/goo-real-sparse-test-pose/kaggle/working/goo-real-sparse-test-pose/master-pose.csv'
    df_pose_test = pd.read_csv(csv_path_test, sep=',', names=column_names, index_col=False, encoding="utf-8-sig")
    df_test = pd.merge(df_rg_test, df_pose_test, on="filename")
    print(len(df_test))
    """
    
    """
    obj_test = pd.read_pickle('/kaggle/input/goorealdataset/oneshotrealhumansNew.pickle', compression='infer')
    df_rg_test = pd.DataFrame.from_records(obj_test)
    print(len(df_rg_test))
    
    # merge with pose csv file
    ## to get eye point in retailgaze dataset
    ## to filter frames with no pose detected both for train and inf
    column_names = ['filename', 'left_eye_x', 'left_eye_y', 'right_eye_x', 'right_eye_y', 'pose_min_x', 'pose_min_y', 'pose_max_x', 'pose_max_y']
    csv_path_test = '/kaggle/input/goorealposeanddepth/goo-real-train-pose/kaggle/working/goo-real-pose/master-pose.csv'
    #csv_path_train = '/kaggle/input/goosynthtestposeanddepth/goo-synth-test-pose/kaggle/working/goo-synth-test-pose/master-pose.csv'
    df_pose_test = pd.read_csv(csv_path_test, sep=',', names=column_names, index_col=False, encoding="utf-8-sig")
    print(len(df_pose_test))
    df_test = pd.merge(df_rg_test, df_pose_test, on="filename")
    
    obj_test = pd.read_pickle('/kaggle/input/goosynthtestdataset/goosynth_test_v2_no_segm.pkl', compression='infer')
    df_rg_test = pd.DataFrame.from_records(obj_test)
    df_rg_test = df_rg_test[['filename','width','height','ann','gaze_item','gazeIdx','gaze_cx','gaze_cy','hx','hy']]
    print(len(df_rg_test))

    # merge with pose csv file
    ## to get eye point in retailgaze dataset
    ## to filter frames with no pose detected both for train and inf
    column_names = ['filename', 'left_eye_x', 'left_eye_y', 'right_eye_x', 'right_eye_y', 'pose_min_x', 'pose_min_y', 'pose_max_x', 'pose_max_y']
    csv_path_test = '/kaggle/input/goosynthtestposeanddepth/goo-synth-test-pose/kaggle/working/goo-synth-test-pose/master-pose.csv'
    #csv_path_train = '/kaggle/input/goosynthtestposeanddepth/goo-synth-test-pose/kaggle/working/goo-synth-test-pose/master-pose.csv'
    df_pose_test = pd.read_csv(csv_path_test, sep=',', names=column_names, index_col=False, encoding="utf-8-sig")
    print(len(df_pose_test))

    df_test = pd.merge(df_rg_test, df_pose_test, on="filename")
    image_dir = '/kaggle/input/goosynthtestdataset/goo-synth-test-images/images'
    """
    obj_test = pd.read_pickle('/kaggle/input/goorealdataset/valrealhumansNew.pickle', compression='infer')
    df_rg_test = pd.DataFrame.from_records(obj_test)
    df_rg_test = df_rg_test[['filename','width','height','ann','gaze_item','gazeIdx','gaze_cx','gaze_cy','hx','hy']]
    csv_path_test = '/kaggle/input/goorealposeanddepth/goo-real-val-pose/kaggle/working/goo-real-pose/master-pose.csv'
    column_names = ['filename', 'left_eye_x', 'left_eye_y', 'right_eye_x', 'right_eye_y', 'pose_min_x', 'pose_min_y', 'pose_max_x', 'pose_max_y']
    df_pose_test = pd.read_csv(csv_path_test, sep=',', names=column_names, index_col=False, encoding="utf-8-sig")

    df_test = pd.merge(df_rg_test, df_pose_test, on="filename")
    df_test['dataset_id'] = np.zeros(len(df_test),dtype=np.int8) 
    print(len(df_test))
else:
    obj_test = pd.read_pickle('/kaggle/input/retailgaze/RetailGaze_V2_seg/RetailGaze_V3_2_test.pickle', compression='infer')
    df_rg_test = pd.DataFrame.from_records(obj_test)
    # merge with pose csv file
    column_names = ['filename', 'left_eye_x', 'left_eye_y', 'right_eye_x', 'right_eye_y', 'pose_min_x', 'pose_min_y', 'pose_max_x', 'pose_max_y']
    csv_path_test = '/kaggle/input/retailgazedepthandpose/retailgaze-test-pose/kaggle/working/retailgaze-pose/master-pose.csv'
    df_pose_test = pd.read_csv(csv_path_test, sep=',', names=column_names, index_col=False, encoding="utf-8-sig")
    df_test = pd.merge(df_rg_test, df_pose_test, on="filename")
    print(len(df_test))


In [None]:
df_test.head()

In [None]:
WIDTH = 640
HEIGHT = 480
WIDTH_gazeutils, HEIGHT_gazeutils = 960, 720
def generate_data_field(eye_point, width=WIDTH_gazeutils, height=HEIGHT_gazeutils):
    """eye_point is (x, y) and between 0 and 1"""
    x_grid = np.array(range(width)).reshape([1, width]).repeat(height, axis=0)
    y_grid = np.array(range(height)).reshape([height, 1]).repeat(width, axis=1)
    grid = np.stack((x_grid, y_grid)).astype(np.float32)

    x, y = eye_point
    x, y = x * width, y * height

    grid -= np.array([x, y]).reshape([2, 1, 1]).astype(np.float32)
    grid[0] = grid[0] / width
    grid[1] = grid[1] / height
#     norm = np.sqrt(np.sum(grid ** 2, axis=0)).reshape([1, height, width])
#     # avoid zero norm
#     norm = np.maximum(norm, 0.1)
#     grid /= norm
    return grid


def generate_gaze_cone(gaze_field, normalized_direction, width=WIDTH_gazeutils, height=HEIGHT_gazeutils):
        
    gaze_field = np.ascontiguousarray(gaze_field.transpose([1, 2, 0]))
    gaze_field = gaze_field.reshape([-1, 2])
    gaze_field = np.matmul(gaze_field, normalized_direction.reshape([2, 1]))
    gaze_field_map = gaze_field.reshape([height, width, 1])
    gaze_field_map = np.ascontiguousarray(gaze_field_map.transpose([2, 0, 1]))
    
    gaze_field_map = gaze_field_map * (gaze_field_map > 0).astype(int)
    
    return gaze_field_map.squeeze()

def _get_transform():
    transform_list = []
    transform_list.append(transforms.Resize((input_resolution, input_resolution)))
    transform_list.append(transforms.ToTensor())
    transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    return transforms.Compose(transform_list)

def _get_transform_modality():
    transform_list = []
    transform_list.append(transforms.Resize((input_resolution, input_resolution)))
    transform_list.append(transforms.ToTensor())
    return transforms.Compose(transform_list)

def _get_object_transform():
    transform_list = []
    transform_list.append(transforms.Resize((256,256)))
    transform_list.append(transforms.ToTensor())
    transform_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
    return transforms.Compose(transform_list)

def get_head_box_channel(x_min, y_min, x_max, y_max, width, height, resolution, coordconv=False):
    head_box = np.array([x_min/width, y_min/height, x_max/width, y_max/height])*resolution
    head_box = head_box.astype(int)
    head_box = np.clip(head_box, 0, resolution-1)
    if coordconv:
        unit = np.array(range(0,resolution), dtype=np.float32)
        head_channel = []
        for i in unit:
            head_channel.append([unit+i])
        head_channel = np.squeeze(np.array(head_channel)) / float(np.max(head_channel))
        head_channel[head_box[1]:head_box[3],head_box[0]:head_box[2]] = 0
    else:
        head_channel = np.zeros((resolution,resolution), dtype=np.float32)
        head_channel[head_box[1]:head_box[3],head_box[0]:head_box[2]] = 1
    head_channel = torch.from_numpy(head_channel)
    return head_channel

def to_numpy(tensor):
    if torch.is_tensor(tensor):
        return tensor.cpu().numpy()
    elif type(tensor).__module__ != 'numpy':
        raise ValueError("Cannot convert {} to numpy array"
                         .format(type(tensor)))
    return tensor


def to_torch(ndarray):
    if type(ndarray).__module__ == 'numpy':
        return torch.from_numpy(ndarray)
    elif not torch.is_tensor(ndarray):
        raise ValueError("Cannot convert {} to torch tensor"
                         .format(type(ndarray)))
    return ndarray

def draw_labelmap(img, pt, sigma, type='Gaussian'):
    # Draw a 2D gaussian
    # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
    img = to_numpy(img)

    # Check that any part of the gaussian is in-bounds
    ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
    br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
    if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
            br[0] < 0 or br[1] < 0):
        # If not, just return the image as is
        return to_torch(img)

    # Generate gaussian
    size = 6 * sigma + 1
    x = np.arange(0, size, 1, float)
    y = x[:, np.newaxis]
    x0 = y0 = size // 2
    # The gaussian is not normalized, we want the center value to equal 1
    if type == 'Gaussian':
        g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
    elif type == 'Cauchy':
        g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)

    # Usable gaussian range
    g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
    g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
    # Image range
    img_x = max(0, ul[0]), min(br[0], img.shape[1])
    img_y = max(0, ul[1]), min(br[1], img.shape[0])

    img[img_y[0]:img_y[1], img_x[0]:img_x[1]] += g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
    img = img/np.max(img) # normalize heatmap so it has max value of 1
    return to_torch(img)

def multi_hot_targets(gaze_pts, out_res):
    w, h = out_res
    target_map = np.zeros((h, w))
    if gaze_pts[0] >= 0:
        x, y = map(int,[gaze_pts[0]*w.float(), gaze_pts[1]*h.float()])
        x = min(x, w-1)
        y = min(y, h-1)
        target_map[y, x] = 1
    return target_map

In [None]:
# ttfnet regression 
def bbox_areas(bboxes, keep_axis=False):
    x_min, y_min, x_max, y_max = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
    areas = (y_max - y_min + 1) * (x_max - x_min + 1)
    if keep_axis:
        return areas[:, None]
    return areas

def calc_region(bbox, ratio, featmap_size=None):
    """Calculate a proportional bbox region.
    The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
    Args:
        bbox (Tensor): Bboxes to calculate regions, shape (n, 4)
        ratio (float): Ratio of the output region.
        featmap_size (tuple): Feature map size used for clipping the boundary.
    Returns:
        tuple: x1, y1, x2, y2
    """
    x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
    y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
    x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
    y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
    if featmap_size is not None:
        x1 = x1.clamp(min=0, max=featmap_size[1] - 1)
        y1 = y1.clamp(min=0, max=featmap_size[0] - 1)
        x2 = x2.clamp(min=0, max=featmap_size[1] - 1)
        y2 = y2.clamp(min=0, max=featmap_size[0] - 1)
    return (x1, y1, x2, y2)

def gaussian_2d(shape, sigma_x=1, sigma_y=1):
    m, n = [(ss - 1.) / 2. for ss in shape]
    y, x = np.ogrid[-m:m + 1, -n:n + 1]

    h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y * sigma_y)))
    h[h < np.finfo(h.dtype).eps * h.max()] = 0
    return h

def draw_truncate_gaussian(heatmap, center, h_radius, w_radius, k=1):
    h, w = 2 * h_radius + 1, 2 * w_radius + 1
    sigma_x = w / 6
    sigma_y = h / 6
    gaussian = gaussian_2d((h, w), sigma_x=sigma_x, sigma_y=sigma_y)
    gaussian = heatmap.new_tensor(gaussian)
    
    x, y = int(center[0]), int(center[1])

    height, width = heatmap.shape[0:2]

    left, right = min(x, w_radius), min(width - x, w_radius + 1)
    top, bottom = min(y, h_radius), min(height - y, h_radius + 1)

    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = gaussian[h_radius - top:h_radius + bottom,
                      w_radius - left:w_radius + right]
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
        torch.max(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
    return heatmap

def target_single_image(gt_boxes, feat_shape=(64,64)):
    """
    Args:
        The scale of the gt_boxes is between 0 and 1
        They are converted to feature space in the func.
        
        gt_boxes: tensor, tensor <=> img, (num_gt, 4).
        feat_shape: tuple.
    Returns:
        heatmap: tensor, tensor <=> img, (1, h, w).
        box_target: tensor, tensor <=> img, (4, h, w).
        reg_weight: tensor, same as box_target.
    """
    wh_area_process = 'log'
    output_h, output_w = feat_shape
    heatmap_channel = 1
    wh_gaussian = True
    wh_agnostic = True
    down_ratio = 4
    #alpha = 0.54
    #beta = 0.54
    alpha = 1.5
    beta = 1.5
    
    heatmap = gt_boxes.new_zeros((heatmap_channel, output_h, output_w))
    fake_heatmap = gt_boxes.new_zeros((output_h, output_w))
    box_target = gt_boxes.new_ones((4, output_h, output_w)) * -1
    reg_weight = gt_boxes.new_zeros((1, output_h, output_w),dtype=torch.float)
    
    if wh_area_process == 'log':
        boxes_areas_log = bbox_areas(gt_boxes).log()
    elif wh_area_process == 'sqrt':
        boxes_areas_log = bbox_areas(gt_boxes).sqrt()
    else:
        boxes_areas_log = bbox_areas(gt_boxes)
    boxes_area_topk_log  = boxes_areas_log

    if wh_area_process == 'norm':
        boxes_area_topk_log[:] = 1.

    
    # convert gt_boxes to 64 x 64 
    feat_gt_boxes = gt_boxes * output_h
    #feat_gt_boxes = torch.tensor([gt_boxes[:,0]*64/640,gt_boxes[:,1]*64/480,gt_boxes[:,2]*64/640,gt_boxes[:,3]*64/480]).unsqueeze(0)
    #print(feat_gt_boxes)
    feat_gt_boxes[:, [0, 2]] = torch.clamp(feat_gt_boxes[:, [0, 2]], min=0,
                                           max=output_w - 1)
    feat_gt_boxes[:, [1, 3]] = torch.clamp(feat_gt_boxes[:, [1, 3]], min=0,
                                           max=output_h - 1)
    feat_hs, feat_ws = (feat_gt_boxes[:, 3] - feat_gt_boxes[:, 1],
                        feat_gt_boxes[:, 2] - feat_gt_boxes[:, 0])

    # we calc the center and ignore area based on the gt-boxes of the origin scale
    # no peak will fall between pixels
    ct_ints = (torch.stack([(gt_boxes[:, 0] + gt_boxes[:, 2]) / 2,
                            (gt_boxes[:, 1] + gt_boxes[:, 3]) / 2],
                           dim=1)*output_h).to(torch.int)
    
    h_radiuses_alpha = (feat_hs / 2. * alpha).int()
    w_radiuses_alpha = (feat_ws / 2. * alpha).int()
    if wh_gaussian and alpha != beta:
        h_radiuses_beta = (feat_hs / 2. * beta).int()
        w_radiuses_beta = (feat_ws / 2. * beta).int()

    if not wh_gaussian:
        # calculate positive (center) regions
        r1 = (1 - beta) / 2
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = calc_region(gt_boxes.transpose(0, 1), r1)
        ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s = [torch.round(x.float() / down_ratio).int()
                                              for x in [ctr_x1s, ctr_y1s, ctr_x2s, ctr_y2s]]
        ctr_x1s, ctr_x2s = [torch.clamp(x, max=output_w - 1) for x in [ctr_x1s, ctr_x2s]]
        ctr_y1s, ctr_y2s = [torch.clamp(y, max=output_h - 1) for y in [ctr_y1s, ctr_y2s]]

    # larger boxes have lower priority than small boxes.
    #for k in range(boxes_ind.shape[0]):
    k = 0
    cls_id = 0
    #print(ct_ints[k])
    #print(h_radiuses_alpha[k].item())
    #print(w_radiuses_alpha[k].item())
    fake_heatmap = fake_heatmap.zero_()
    draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                h_radiuses_alpha[k].item(), w_radiuses_alpha[k].item())
    heatmap[cls_id] = torch.max(heatmap[cls_id], fake_heatmap)

    if wh_gaussian:
        if alpha != beta:
            fake_heatmap = fake_heatmap.zero_()
            draw_truncate_gaussian(fake_heatmap, ct_ints[k],
                                        h_radiuses_beta[k].item(),
                                        w_radiuses_beta[k].item())
        box_target_inds = fake_heatmap > 0
    else:
        ctr_x1, ctr_y1, ctr_x2, ctr_y2 = ctr_x1s[k], ctr_y1s[k], ctr_x2s[k], ctr_y2s[k]
        box_target_inds = torch.zeros_like(fake_heatmap, dtype=torch.uint8)
        box_target_inds[ctr_y1:ctr_y2 + 1, ctr_x1:ctr_x2 + 1] = 1

    if wh_agnostic:
        box_target[:, box_target_inds] = gt_boxes[k][:, None] * output_h
        #box_target[:, box_target_inds] = torch.tensor([abs(ct_ints[k][0] - feat_gt_boxes[k][0, None]),abs(ct_ints[k][1] - feat_gt_boxes[k][1, None]),abs(ct_ints[k][0] - feat_gt_boxes[k][2, None]),abs(ct_ints[k][1] - feat_gt_boxes[k][3, None])])[:, None]

    else:
        box_target[(cls_id * 4):((cls_id + 1) * 4), box_target_inds] = gt_boxes[k][:, None] * output_h

    if wh_gaussian:
        local_heatmap = fake_heatmap[box_target_inds].float()
        ct_div = local_heatmap.sum()
        local_heatmap *= boxes_area_topk_log[k]
        reg_weight[cls_id, box_target_inds] = local_heatmap / ct_div
    else:
        reg_weight[cls_id, box_target_inds] = \
            boxes_area_topk_log[k] / box_target_inds.sum().float()

    return heatmap, box_target, reg_weight



In [None]:
def non_max_suppression_fast(boxes, overlapThresh=0.6):
    boxes = boxes[0]
    # if there are no boxes, return an empty list
    if len(boxes) == 0:
        return []

    # initialize the list of picked indexes
    pick = []

    # grab the coordinates of the bounding boxes
    x1 = boxes[:,0]
    y1 = boxes[:,1]
    x2 = boxes[:,2]
    y2 = boxes[:,3]

    # compute the area of the bounding boxes and sort the bounding
    # boxes by the bottom-right y-coordinate of the bounding box
    area = (x2 - x1 + 1) * (y2 - y1 + 1)

    idxs = np.argsort(y2)

    # keep looping while some indexes still remain in the indexes
    # list
    while len(idxs) > 0:
        # grab the last index in the indexes list and add the
        # index value to the list of picked indexes
        last = len(idxs) - 1
        i = idxs[last]
        pick.append(i)

        # find the largest (x, y) coordinates for the start of
        # the bounding box and the smallest (x, y) coordinates
        # for the end of the bounding box
        xx1 = np.maximum(x1[i], x1[idxs[:last]])
        yy1 = np.maximum(y1[i], y1[idxs[:last]])
        xx2 = np.minimum(x2[i], x2[idxs[:last]])
        yy2 = np.minimum(y2[i], y2[idxs[:last]])

        # compute the width and height of the bounding box
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)

        # compute the ratio of overlap
        overlap = (w * h) / area[idxs[:last]]

        # delete all indexes from the index list that have
        idxs = np.delete(idxs, np.concatenate(([last],
            np.where(overlap > overlapThresh)[0])))

    # return only the bounding boxes that were picked
    return boxes[pick]

def simple_nms(heat, kernel=3, out_heat=None):
    pad = (kernel - 1) // 2
    hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
    keep = (hmax == heat).float()
    out_heat = heat if out_heat is None else out_heat
    return out_heat * keep

def _topk(scores, topk):
    batch, cat, height, width = scores.size()

    # both are (batch, 1, topk)
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), topk)

    topk_inds = topk_inds % (height * width)
    topk_ys = (topk_inds / width).int().float()
    topk_xs = (topk_inds % width).int().float()

    # both are (batch, topk). select topk from 80*topk
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), topk)
    
    topk_clses = (topk_ind / topk).int()
    topk_ind = topk_ind.unsqueeze(2)
    topk_inds = topk_inds.view(batch, -1, 1).gather(1, topk_ind).view(batch, topk)
    topk_ys = topk_ys.view(batch, -1, 1).gather(1, topk_ind).view(batch, topk)
    topk_xs = topk_xs.view(batch, -1, 1).gather(1, topk_ind).view(batch, topk)
    
    return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

def get_bboxes_ttfnet(pred_heatmap,pred_wh,rescale=False):
    if pred_heatmap.shape[1]!=1:
        pred_heatmap = pred_heatmap.unsqueeze(1)
    

    down_ratio = 1

    batch, cat, height, width = pred_heatmap.size()
    #pred_heatmap = pred_heatmap.detach().sigmoid_()
    pred_heatmap = pred_heatmap.detach()
    wh = pred_wh.detach()
    
    # perform nms on heatmaps
    heat = simple_nms(pred_heatmap)  # used maxpool to filter the max score
   

    topk = Bbox_topk_no_of_bboxes
    scores, inds, clses, ys, xs = _topk(heat, topk=topk)
   
    xs = xs.view(batch, topk, 1) * down_ratio
    ys = ys.view(batch, topk, 1) * down_ratio
    
    #print(scores, inds, clses, xs, ys) 
    
    wh = wh.permute(0, 2, 3, 1).contiguous()
    wh = wh.view(wh.size(0), -1, wh.size(3))
    
    inds = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), wh.size(2))
    
    #print(inds)
    wh = wh.gather(1, inds)
    wh = wh.view(batch, topk, 4)
    
    clses = clses.view(batch, topk, 1).float()
    scores = scores.view(batch, topk, 1)
    
    bboxes = torch.cat([xs - wh[..., [0]], ys - wh[..., [1]],
                        xs + wh[..., [2]], ys + wh[..., [3]]], dim=2)
    #print(bboxes)
    result_list = []
    score_thr = Bbox_confidence_score_threshold
    #print(bboxes.shape[0])
    for batch_i in range(bboxes.shape[0]):
        scores_per_img = scores[batch_i]
        scores_keep = (scores_per_img > score_thr).squeeze(-1)

        scores_per_img = scores_per_img[scores_keep]
        bboxes_per_img = bboxes[batch_i][scores_keep]
        labels_per_img = clses[batch_i][scores_keep]
        img_shape = [output_resolution, output_resolution]
        bboxes_per_img[:, 0::2] = bboxes_per_img[:, 0::2].clamp(min=0, max=img_shape[1] - 1)
        bboxes_per_img[:, 1::2] = bboxes_per_img[:, 1::2].clamp(min=0, max=img_shape[0] - 1)
        
        if rescale:
            scale_factor = img_metas[batch_i]['scale_factor']
            bboxes_per_img /= bboxes_per_img.new_tensor(scale_factor)

        bboxes_per_img = torch.cat([bboxes_per_img, scores_per_img], dim=1)
        labels_per_img = labels_per_img.squeeze(-1)
        result_list.append((bboxes_per_img, labels_per_img))

    return result_list

def bbox_overlaps_ttfnet(bboxes1, bboxes2, mode='iou', is_aligned=False):
    """Calculate overlap between two set of bboxes.
    If ``is_aligned`` is ``False``, then calculate the ious between each bbox
    of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
    bboxes1 and bboxes2.
    Args:
        bboxes1 (Tensor): shape (m, 4)
        bboxes2 (Tensor): shape (n, 4), if is_aligned is ``True``, then m and n
            must be equal.
        mode (str): "iou" (intersection over union) or iof (intersection over
            foreground).
    Returns:
        ious(Tensor): shape (m, n) if is_aligned == False else shape (m, 1)
    """

    assert mode in ['iou', 'iof']

    rows = bboxes1.size(0)
    cols = bboxes2.size(0)
    if is_aligned:
        assert rows == cols

    if rows * cols == 0:
        return bboxes1.new(rows, 1) if is_aligned else bboxes1.new(rows, cols)

    if is_aligned:
        lt = torch.max(bboxes1[:, :2], bboxes2[:, :2])  # [rows, 2]
        rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:])  # [rows, 2]

        wh = (rb - lt + 1).clamp(min=0)  # [rows, 2]
        overlap = wh[:, 0] * wh[:, 1]
        area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
            bboxes1[:, 3] - bboxes1[:, 1] + 1)

        if mode == 'iou':
            area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
                bboxes2[:, 3] - bboxes2[:, 1] + 1)
            ious = overlap / (area1 + area2 - overlap)
        else:
            ious = overlap / area1
    else:
        lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2])  # [rows, cols, 2]
        rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:])  # [rows, cols, 2]

        wh = (rb - lt + 1).clamp(min=0)  # [rows, cols, 2]
        overlap = wh[:, :, 0] * wh[:, :, 1]
        area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (
            bboxes1[:, 3] - bboxes1[:, 1] + 1)
        
        if mode == 'iou':
            area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (
                bboxes2[:, 3] - bboxes2[:, 1] + 1)
            
            a_xmin = torch.min(bboxes1[:,0], bboxes2[:,0])
            a_ymin = torch.min(bboxes1[:,1], bboxes2[:,1])
            a_xmax = torch.max(bboxes1[:,2], bboxes2[:,2])
            a_ymax = torch.max(bboxes1[:,3], bboxes2[:,3])
            a_box = (a_xmax - a_xmin + 1) * (a_ymax - a_ymin + 1)

            w = torch.min(area1 / area2, area2 / area1)

            ious = overlap / (area1[:, None] + area2 - overlap)
            wuocs = (w[:,None] * ((area1[:, None] + area2 - overlap) / a_box[:, None]))
        else:
            ious = overlap / (area1[:, None])
    
    return ious,wuocs

In [None]:
class GooRealDataset(Dataset):
    
    def __init__(self,df, get_transform, get_transform_modality, input_size=input_resolution, output_size=output_resolution,
                 test=False, modality='image', imshow=False):
    
        self.df_data =  df
        self.image_dir = image_dir
        self.transform = get_transform
        self.transform_modality = get_transform_modality
        self.get_object_transform = _get_object_transform()
        self.input_size = input_size
        self.output_size = output_size
        self.imshow = imshow
        self.test = test
        self.modality = modality
    def __len__(self):
        return len(self.df_data.index)
    
    def __getitem__(self, index):
        gaze_inside = True
        
        row = self.df_data.iloc[index]
        
        path = PureWindowsPath(row['filename']).as_posix()
        filename_no_extension = os.path.splitext(os.path.basename(path))[0]
        subfolder_path_str = os.path.splitext(path)[0].split("/")
        key_filename = subfolder_path_str[0]+'/'+subfolder_path_str[1]
        
        #path = row['filename']
        #filename_no_extension = path.split('.')[0]
        #key_filename = ''   
        img = Image.open(os.path.join(self.image_dir, path))
        img = img.convert('RGB').resize((640,480))
        
        width, height = img.size
        
        #gaze_x = row.gaze_cx / width
        #gaze_y = row.gaze_cy / height
        
        eye_x = row.hx / width
        eye_y = row.hy / height
        
        x_min, y_min, x_max, y_max = np.array(row['ann']['bboxes'][-1])
        
        # For Single object heatmap
        gaze_obj_x_min, gaze_obj_y_min, gaze_obj_x_max, gaze_obj_y_max = np.array(row['ann']['bboxes'][row.gazeIdx])
        gazed_object_class = np.array(row.gaze_item)
        all_object_bboxes = np.array(row['ann']['bboxes'])[:-1]
        all_object_bboxes_class = np.array(row['ann']['labels'])[:-1]
        
                
        # For CenterNet move the gaze point to center of the gazed at object bounding box
        #if self.test:
        gaze_x = ((gaze_obj_x_min+gaze_obj_x_max)/2)/ width
        gaze_y = ((gaze_obj_y_min+gaze_obj_y_max)/2)/ height
        
        # expand face bbox a bit
        k = 0.05
        x_min -= k * abs(x_max - x_min)
        y_min -= k * abs(y_max - y_min)
        x_max += k * abs(x_max - x_min)
        y_max += k * abs(y_max - y_min)

        x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])

        if self.test: 
            
            #self.pose_dir = '/kaggle/input/goorealposeanddepth/goo-real-test-pose/kaggle/working/goo-real-pose'
            #self.depth_dir = '/kaggle/input/goorealposeanddepth/goo-real-test-depth/kaggle/working/goo-real-depth'
            #self.pose_dir = '/kaggle/input/goosynthtestposeanddepth/goo-synth-test-pose/kaggle/working/goo-synth-test-pose'
            #self.depth_dir = '/kaggle/input/goosynthtestposeanddepth/goo-synth-test-depth/kaggle/working/goo-synth-test-depth'
            self.pose_dir = '/kaggle/input/goorealposeanddepth/goo-real-val-pose/kaggle/working/goo-real-pose'
            self.depth_dir = '/kaggle/input/goorealposeanddepth/goo-real-val-depth/kaggle/working/goo-real-depth'
            if use_sparse_dataset:
                self.pose_dir = '/kaggle/input/goorealposeanddepth/goo-real-sparse-test-pose/kaggle/working/goo-real-sparse-test-pose'
                self.depth_dir = '/kaggle/input/goorealposeanddepth/goo-real-sparse-test-depth/kaggle/working/goo-real-sparse-depth'
        else:
            self.pose_dir = None
            self.depth_dir = None   
        # read pose
        if os.path.exists(os.path.join(self.pose_dir, key_filename,filename_no_extension+'-pose.png')):
            pose = Image.open(os.path.join(self.pose_dir, key_filename,filename_no_extension+'-pose.png'))
        else:
            pose = Image.open(os.path.join(self.pose_dir, key_filename,filename_no_extension+'-pose.jpg'))
        # read depth
        depth = Image.open(os.path.join(self.depth_dir,key_filename,filename_no_extension+'.png'))
        

        if self.imshow:
            img.save("origin_img.jpg")

        if self.test:
            imsize = torch.IntTensor([width, height])
            if privacy:
                img = Image.fromarray(np.uint8(np.zeros((height, width, 3))*255))
        else:
            ## data augmentation               
                        
            # Jitter (expansion-only) bounding box size
            if np.random.random_sample() <= 0.5:
                k = np.random.random_sample() * 0.2
                x_min -= k * abs(x_max - x_min)
                y_min -= k * abs(y_max - y_min)
                x_max += k * abs(x_max - x_min)
                y_max += k * abs(y_max - y_min)

            # Random Crop
            if np.random.random_sample() <= 0.5:
                # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
                crop_x_min = np.min([gaze_x * width, x_min, x_max])
                crop_y_min = np.min([gaze_y * height, y_min, y_max])
                crop_x_max = np.max([gaze_x * width, x_min, x_max])
                crop_y_max = np.max([gaze_y * height, y_min, y_max])

                # Randomly select a random top left corner
                if crop_x_min >= 0:
                    crop_x_min = np.random.uniform(0, crop_x_min)
                if crop_y_min >= 0:
                    crop_y_min = np.random.uniform(0, crop_y_min)

                # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
                crop_width_min = crop_x_max - crop_x_min
                crop_height_min = crop_y_max - crop_y_min
                crop_width_max = width - crop_x_min
                crop_height_max = height - crop_y_min
                # Randomly select a width and a height
                crop_width = np.random.uniform(crop_width_min, crop_width_max)
                crop_height = np.random.uniform(crop_height_min, crop_height_max)

                # Crop it
                img = TF.crop(img, crop_y_min, crop_x_min, crop_height, crop_width)
                pose = TF.crop(pose, crop_y_min, crop_x_min, crop_height, crop_width)
                depth = TF.crop(depth, crop_y_min, crop_x_min, crop_height, crop_width)
                
                # Record the crop's (x, y) offset
                offset_x, offset_y = crop_x_min, crop_y_min

                # convert coordinates into the cropped frame
                x_min, y_min, x_max, y_max = x_min - offset_x, y_min - offset_y, x_max - offset_x, y_max - offset_y
                # if gaze_inside:
                gaze_x, gaze_y = (gaze_x * width - offset_x) / float(crop_width), \
                                 (gaze_y * height - offset_y) / float(crop_height)
                eye_x, eye_y = (eye_x * width - offset_x) / float(crop_width), \
                                 (eye_y * height - offset_y) / float(crop_height)
                # else:
                #     gaze_x = -1; gaze_y = -1
                
                gaze_obj_x_min -= offset_x
                gaze_obj_y_min -= offset_y
                gaze_obj_x_max -= offset_x
                gaze_obj_y_max -= offset_y
                
                width, height = crop_width, crop_height

            # Random flip
            if np.random.random_sample() <= 0.5:
                img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
                pose = pose.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
                depth = depth.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
                                
                x_max_2 = width - x_min
                x_min_2 = width - x_max
                x_max = x_max_2
                x_min = x_min_2
                gaze_x = 1 - gaze_x
                eye_x = 1 - eye_x
                
                # flip the GT gazed object bbox
                gaze_obj_x_max_2 = width - gaze_obj_x_min
                gaze_obj_x_min_2 = width - gaze_obj_x_max
                gaze_obj_x_max = gaze_obj_x_max_2
                gaze_obj_x_min = gaze_obj_x_min_2

            # Random color change
            if np.random.random_sample() <= 0.5:
                img = TF.adjust_brightness(img, brightness_factor=np.random.uniform(0.5, 1.5))
                img = TF.adjust_contrast(img, contrast_factor=np.random.uniform(0.5, 1.5))
                img = TF.adjust_saturation(img, saturation_factor=np.random.uniform(0, 1.5))

        if cone_mode=='early':
            cone_resolution = input_resolution
        else:
            cone_resolution = input_resolution // 4
        
        head_channel = get_head_box_channel(x_min, y_min, x_max, y_max, width, height,
                                                    resolution=cone_resolution, coordconv=False).unsqueeze(0)

        # Crop the face
        if privacy:
            face = pose.crop((int(x_min), int(y_min), int(x_max), int(y_max)))
        else:
            face = img.crop((int(x_min), int(y_min), int(x_max), int(y_max)))

        # modality dropout
        height, width = int(height), int(width)
        num_modalities = 3
        dropped = np.zeros(num_modalities)
        if not self.test and self.modality=='attention':
            if modality_dropout:
                # keep one modality
                if privacy:
                    modality_idx = 1
                else:
                    modality_idx = 0
                m_keep = np.random.randint(modality_idx, num_modalities)

                if (np.random.rand() <= 0.2) and m_keep!=0:
                    img = Image.fromarray(np.uint8(np.random.rand(height, width, 3)*255))
                    dropped[0] = 1
                if (np.random.rand() <= 0.2) and m_keep!=1:
                    depth = Image.fromarray(np.uint8(np.random.rand(height, width)*255))
                    dropped[1] = 1
                if (np.random.rand() <= 0.2) and m_keep!=2:
                    pose = Image.fromarray(np.uint8(np.random.rand(height, width, 3)*255))
                    dropped[2] = 1

            if privacy:
                img = Image.fromarray(np.uint8(np.zeros((height, width, 3))))
                dropped[0] = 1
        
        gazed_object_bbox = np.array([gaze_obj_x_min, gaze_obj_y_min, gaze_obj_x_max, gaze_obj_y_max])
        
        # generate new gaze field (for human-centric branch)
        eye_point = np.array([eye_x, eye_y])
        gaze = np.array([gaze_x, gaze_y])
        gt_direction = np.array([-1.0, -1.0])
        if gaze_inside:
            gt_direction = gaze - eye_point            
            if gt_direction.mean()!=0:
                gt_direction = gt_direction / np.linalg.norm(gt_direction)
        
        gaze_field = generate_data_field(eye_point, width=cone_resolution, height=cone_resolution)
        # normalize
        norm = np.sqrt(np.sum(gaze_field ** 2, axis=0)).reshape([1, cone_resolution, cone_resolution])
        # avoid zero norm
        norm = np.maximum(norm, 0.1)
        gaze_field /= norm
          
        if self.transform is not None:
            img = self.transform(img)
            face = self.transform(face)
            
            pose = self.transform_modality(pose)
            depth = self.transform_modality(depth)
            #depth = depth / 65535    # depth maps are in 16 bit format
        
        ## for TTFNet
        gazed_object_box_for_cp = torch.tensor([gaze_obj_x_min/width,gaze_obj_y_min/height,gaze_obj_x_max/width,gaze_obj_y_max/height]).unsqueeze(0)
        heatmap_cp, box_target_cp, reg_weight_cp =  target_single_image(gazed_object_box_for_cp, feat_shape=(self.output_size, self.output_size)) 
        
        # generate the heat map used for deconv prediction
        gaze_heatmap = torch.zeros(self.output_size, self.output_size)  # set the size of the output
        #gaze_heatmap_org = torch.zeros(self.output_size, self.output_size)  # set the size of the output
        if self.test:  # aggregated heatmap
            if gaze_x != -1:
                #gaze_heatmap = draw_labelmap(gaze_heatmap, [gaze_x * self.output_size, gaze_y * self.output_size], 3, type='Gaussian')
                gaze_heatmap = heatmap_cp.squeeze(0).float() ## for TTFNet
        else:
            #gaze_heatmap = draw_labelmap(gaze_heatmap, [gaze_x * self.output_size, gaze_y * self.output_size], 3, type='Gaussian')
            gaze_heatmap = heatmap_cp.squeeze(0).float() ## for TTFNet
        
        if gaze_inside:
            cont_gaze = [gaze_x, gaze_y]
        else:
            cont_gaze = [-1, -1]
        cont_gaze = torch.FloatTensor(cont_gaze)
        
        if self.imshow:
            fig = plt.figure(111)
            img = 255 - unnorm(img.numpy()) * 255
            img = np.clip(img, 0, 255)
            plt.imshow(np.transpose(img, (1, 2, 0)))
            plt.imshow(cv2.resize(gaze_heatmap, (self.input_size, self.input_size)), cmap='jet', alpha=0.3)
            plt.imshow(cv2.resize(1 - head_channel.squeeze(0), (self.input_size, self.input_size)), alpha=0.2)
            plt.savefig('viz_aug.png')

        if self.test:
            return img, face, pose, depth, gaze_field, gt_direction, head_channel, gaze_heatmap, cont_gaze, imsize, path, eye_point, all_object_bboxes, all_object_bboxes_class, gazed_object_bbox, gazed_object_class
        else:
            return img, face, pose, depth, gaze_field, gt_direction, head_channel, gaze_heatmap, path, gaze_inside, dropped, box_target_cp, reg_weight_cp

    

In [None]:
class retailGazeDataset(Dataset):
    def __init__(self,df, get_transform, get_transform_modality, input_size=input_resolution, output_size=output_resolution, test=False, imshow=False):
        self.df_data =  df
        self.image_dir = '/kaggle/input/retailgaze/RetailGaze_V2_seg/RetailGaze_V2/'
        self.get_transform = get_transform
        self.transform_modality = get_transform_modality
        self.input_size = input_size
        self.output_size = output_size
        self.imshow = imshow
        self.test = test
    def __len__(self):
        return len(self.df_data.index)

    def __getitem__(self,idx):
         
        gaze_inside = True
        row = self.df_data.iloc[idx]
        path = row.filename
        path_seg_mask = row.seg_mask
        filename_no_extension = os.path.splitext(os.path.basename(row.filename))[0]
        
        subfolder_path_str = os.path.splitext(path)[0].split("/")
        key_filename = subfolder_path_str[0]+'/'+subfolder_path_str[1]
    
        img = Image.open(os.path.join(self.image_dir, row.filename))
        img = img.convert('RGB')
        frame_org = img
        width, height = img.size
        
        # Can give -1 in dataset if not know for test
        gaze_x = row.gaze_cx / width
        gaze_y = row.gaze_cy / height
        
        """
        From the paper section 3.2, gaze cone generator
        The gaze cone generator produces Ico by drawing a cone from the 
        subject’s eyes location p eye (i.e. eye mid-point if available 
        from the pose modality; otherwise, using a prototypal location 
        in the head bounding box) along the direction of g2D.
        """
        eye_x = ((row.left_eye_x+row.right_eye_x) * 0.5)/width
        eye_y = ((row.left_eye_y+row.right_eye_y) * 0.5)/height
        
        # expand face bbox a bit
        head_box_org = list(row.ann.values())[0]
        x_min, y_min, x_max, y_max = head_box_org
        
        k = 0.05
        x_min -= k * abs(x_max - x_min)
        y_min -= k * abs(y_max - y_min)
        x_max += k * abs(x_max - x_min)
        y_max += k * abs(y_max - y_min)
        x_min, y_min, x_max, y_max = map(float, [x_min, y_min, x_max, y_max])
        
        # point to val dir
        if self.test: 
            
            self.pose_dir = '/kaggle/input/retailgazedepthandpose/retailgaze-test-pose/kaggle/working/retailgaze-pose'
            self.depth_dir = '/kaggle/input/retailgazedepthandpose/retailgaze-test-depth/kaggle/working/retailgaze-depth'
            
        # point to train dir
        else:
            self.pose_dir = '/kaggle/input/retailgazedepthandpose/retailgaze-train-pose/kaggle/working/retailgaze-pose'
            self.depth_dir = '/kaggle/input/retailgazedepthandpose/retailgaze-train-depth/kaggle/working/retailgaze-depth'
            
        # read pose
        pose = Image.open(os.path.join(self.pose_dir, key_filename,filename_no_extension+'-pose.jpg'))
        
        # read depth
        depth = Image.open(os.path.join(self.depth_dir,key_filename,filename_no_extension+'.png'))
        
        
        if self.test:
            imsize = torch.IntTensor([width, height])
        else:
            ## data augmentation                          
            # Jitter (expansion-only) bounding box size
            
            if np.random.random_sample() <= 0.5:
                k = np.random.random_sample() * 0.2
                x_min -= k * abs(x_max - x_min)
                y_min -= k * abs(y_max - y_min)
                x_max += k * abs(x_max - x_min)
                y_max += k * abs(y_max - y_min)

            # Random Crop
            if np.random.random_sample() <= 0.5:
                # Calculate the minimum valid range of the crop that doesn't exclude the face and the gaze target
                crop_x_min = np.min([gaze_x * width, x_min, x_max])
                crop_y_min = np.min([gaze_y * height, y_min, y_max])
                crop_x_max = np.max([gaze_x * width, x_min, x_max])
                crop_y_max = np.max([gaze_y * height, y_min, y_max])

                # Randomly select a random top left corner
                if crop_x_min >= 0:
                    crop_x_min = np.random.uniform(0, crop_x_min)
                if crop_y_min >= 0:
                    crop_y_min = np.random.uniform(0, crop_y_min)

                # Find the range of valid crop width and height starting from the (crop_x_min, crop_y_min)
                crop_width_min = crop_x_max - crop_x_min
                crop_height_min = crop_y_max - crop_y_min
                crop_width_max = width - crop_x_min
                crop_height_max = height - crop_y_min
                # Randomly select a width and a height
                crop_width = np.random.uniform(crop_width_min, crop_width_max)
                crop_height = np.random.uniform(crop_height_min, crop_height_max)

                # Crop it
                img = TF.crop(img, crop_y_min, crop_x_min, crop_height, crop_width)
                pose = TF.crop(pose, crop_y_min, crop_x_min, crop_height, crop_width)
                depth = TF.crop(depth, crop_y_min, crop_x_min, crop_height, crop_width)
                
                # Record the crop's (x, y) offset
                offset_x, offset_y = crop_x_min, crop_y_min

                # convert coordinates into the cropped frame
                x_min, y_min, x_max, y_max = x_min - offset_x, y_min - offset_y, x_max - offset_x, y_max - offset_y
                gaze_x, gaze_y = (gaze_x * width - offset_x) / float(crop_width), \
                                 (gaze_y * height - offset_y) / float(crop_height)
                eye_x, eye_y = (eye_x * width - offset_x) / float(crop_width), \
                                 (eye_y * height - offset_y) / float(crop_height)
                

                width, height = crop_width, crop_height

            
            # Random flip
            if np.random.random_sample() <= 0.5:
                img = img.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
                pose = pose.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
                depth = depth.transpose(Image.Transpose.FLIP_LEFT_RIGHT)
                                
                x_max_2 = width - x_min
                x_min_2 = width - x_max
                x_max = x_max_2
                x_min = x_min_2
                gaze_x = 1 - gaze_x
                eye_x = 1 - eye_x
                
            
            # Random color change
            if np.random.random_sample() <= 0.5:
                img = TF.adjust_brightness(img, brightness_factor=np.random.uniform(0.5, 1.5))
                img = TF.adjust_contrast(img, contrast_factor=np.random.uniform(0.5, 1.5))
                img = TF.adjust_saturation(img, saturation_factor=np.random.uniform(0, 1.5))
        
        if cone_mode=='early':
            cone_resolution = input_resolution
        else:
            cone_resolution = input_resolution // 4
            
        head_box = [x_min, y_min, x_max, y_max]
        head = img.crop((head_box)) # head crop
        head_channel = get_head_box_channel(head_box[0], head_box[1], head_box[2], head_box[3], width, height,
                                                    resolution=input_resolution).unsqueeze(0)
        
        
        # modality dropout
        height, width = np.int32(height), np.int32(width)
        num_modalities = 3
        dropped = np.zeros(num_modalities)
        
        if not self.test:
            if modality_dropout:
                # keep one modality
                modality_idx = 0
                m_keep = np.random.randint(modality_idx, num_modalities)

                if (np.random.rand() <= 0.2) and m_keep!=0:
                    img = Image.fromarray(np.uint8(np.random.rand(height, width, 3)*255))
                    dropped[0] = 1
                if (np.random.rand() <= 0.2) and m_keep!=1:
                    depth = Image.fromarray(np.uint8(np.random.rand(height, width)*255))
                    dropped[1] = 1
                if (np.random.rand() <= 0.2) and m_keep!=2:
                    pose = Image.fromarray(np.uint8(np.random.rand(height, width, 3)*255))
                    dropped[2] = 1

                    
        # generate new gaze field (for human-centric branch)
        eye_point = np.array([eye_x, eye_y])
        gaze = np.array([gaze_x, gaze_y])
        gt_direction = np.array([-1.0, -1.0])
        if gaze_inside:
            gt_direction = gaze - eye_point            
            if gt_direction.mean()!=0:
                gt_direction = gt_direction / np.linalg.norm(gt_direction)
                
        gaze_field = generate_data_field(eye_point, width=cone_resolution, height=cone_resolution)
        # normalize
        norm = np.sqrt(np.sum(gaze_field ** 2, axis=0)).reshape([1, cone_resolution, cone_resolution])
        norm = np.maximum(norm, 0.1)
        gaze_field /= norm
        
        
        head = self.get_transform(head) # transform inputs
        img = self.get_transform(img)
        pose = self.transform_modality(pose)
        depth = self.transform_modality(depth)
        # Note: I saved depth in 8 bits not 16 bits
        #########
        #depth = depth / 65535    # depth maps are in 16 bit format
        #########        
                
        # generate the heat map used for deconv prediction
        gaze_heatmap = torch.zeros(self.output_size, self.output_size)  # set the size of the output
        #gaze_heatmap = torch.zeros(self.output_size[0], self.output_size[1])  # set the size of the output
        
        if gaze_x != -1:
            gaze_heatmap = draw_labelmap(gaze_heatmap, [gaze_x * self.output_size, gaze_y * self.output_size],3,type='Gaussian')
        if gaze_inside:
            cont_gaze = [gaze_x, gaze_y]
        cont_gaze = torch.FloatTensor(cont_gaze)
        
        if self.imshow:
            fig = plt.figure(111)
            plt.imshow(frame_org)
            #plt.imshow(cv2.resize(gaze_heatmap, (self.input_size, self.input_size)), cmap='jet', alpha=0.3)
            #plt.imshow(cv2.resize(1 - head_channel.squeeze(0), (self.input_size, self.input_size)), alpha=0.2)
            plt.imshow(cv2.resize(gaze_heatmap.numpy(), (self.input_size[0], self.input_size[1])), cmap='jet', alpha=0.3)
            plt.imshow(cv2.resize(1 - head_channel.squeeze(0).numpy(), (self.input_size[0], self.input_size[1])), alpha=0.2)
            plt.savefig('viz_aug.png')
        
        # for some images seg mask is None to avoid error give a dummy value
        if path_seg_mask is None:
            path_seg_mask = 0
        if self.test:
            return img, head, pose, depth, gaze_field, gt_direction, head_channel, gaze_heatmap, cont_gaze, imsize, gaze_inside, path, path_seg_mask
        else:
            return img, head, pose, depth, gaze_field, gt_direction, head_channel, gaze_heatmap, path, gaze_inside, dropped

In [None]:

transform = _get_transform()
transform_modality = _get_transform_modality()

if flag_run_on_goo_dataset:
    val_dataset = GooRealDataset(df_test, transform, transform_modality, 
                           input_size=input_resolution, output_size=output_resolution, 
                           test=True, modality='attention' , imshow=False)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=0)
    encoded_inputs = next(iter(val_loader))
    img, face, pose, depth, gaze_field, gt_direction, head_channel, gaze_heatmap, cont_gaze, imsize, path, eye_point, all_object_bboxes, all_object_bboxes_class, gazed_object_bbox, gazed_object_class = encoded_inputs

    plt.imshow(img.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(depth.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(pose.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(face.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(gaze_field.squeeze(0).numpy().transpose(1,2,0)[:,:,0])
    plt.show()
    plt.imshow(gaze_field.squeeze(0).numpy().transpose(1,2,0)[:,:,1])
    plt.show()
    print(gt_direction)
    plt.imshow(head_channel.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(gaze_heatmap.squeeze(0).numpy())
    plt.show()
    print(cont_gaze)
    print(imsize)
    print(path)
    print(eye_point)
    print(all_object_bboxes.shape)
    print(all_object_bboxes_class.shape)
    print(gazed_object_bbox)
    print(gazed_object_class)
    
else:
    ds = retailGazeDataset(df_test, transform, transform_modality, 
                       input_size=input_resolution, output_size=output_resolution, 
                       test=True,imshow=False)
    dl = torch.utils.data.DataLoader(ds,batch_size=1,shuffle=False,num_workers=0)
    encoded_inputs = next(iter(dl))
    img, head, pose, depth, gaze_field, gt_direction, head_channel, gaze_heatmap, cont_gaze, imsize, gaze_inside, path,path_seg_mask = encoded_inputs
    plt.imshow(img.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(head.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(pose.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(depth.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(gaze_field.squeeze(0).numpy().transpose(1,2,0)[:,:,0])
    plt.show()
    plt.imshow(gaze_field.squeeze(0).numpy().transpose(1,2,0)[:,:,1])
    plt.show()
    print(gt_direction)
    plt.imshow(head_channel.squeeze(0).numpy().transpose(1,2,0))
    plt.show()
    plt.imshow(gaze_heatmap.squeeze(0).numpy())
    plt.show()
    print(cont_gaze)
    print(imsize)
    print(path)

In [None]:
def test(model_weights, val_loader, batch_size=48, device=0, mode='dict', save_path=None):

    # Load model
    print("Constructing model")
    if mode=='pt':
        pretrained_dict = torch.load(model_weights)
    elif mode=='dict':
        pretrained_dict = model_weights
    
    if extended_model_present:
        if pretrained_dict['modality'] == 'attention':
            model_base = AttentionModelCombined(cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        else:
            model_base = BaselineModel(pretrained_dict['backbone_name'], pretrained_dict['modality'], cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        model_base.cuda().to(device)
        model = attentionModelBboxHead(model_base)
   
    else:
        if pretrained_dict['modality'] == 'attention':
            model = AttentionModelCombined(cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        else:
            model = BaselineModel(pretrained_dict['backbone_name'], pretrained_dict['modality'], cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        

    model.cuda().to(device)
    model_dict = model.state_dict()
    model_dict.update(pretrained_dict['model'])
    model.load_state_dict(model_dict)

    print('Evaluation in progress ...')
    model.train(False)
    gt_gaze = []; pred_hm = []; image_size = [] ; paths = []; pred_att = []; directions = []
    gt_eye_point = []
    all_object_bboxes_list = []
    all_object_bboxes_class_list = []
    gazed_object_bbox_list = []
    gazed_object_class_list = []
    all_result_list_ttfnet = [] 
    with torch.no_grad():
        for val_batch, (val_img, val_face, val_pose, val_depth, val_gaze_field, val_gt_direction, val_head_channel, val_gaze_heatmap, cont_gaze, imsize, path, eye_point, all_object_bboxes, all_object_bboxes_class, gazed_object_bbox, gazed_object_class) in tqdm(enumerate(val_loader), total=len(val_loader)):
            
            val_images = val_img.cuda().to(device)
            val_faces = val_face.cuda().to(device)
            val_head_channels = val_head_channel.cuda().to(device)
            val_gaze_fields = val_gaze_field.cuda().to(device)
            val_depth_maps = val_depth.cuda().to(device)
            val_pose_maps = val_pose.cuda().to(device)
            val_gt_direction = val_gt_direction.cuda().to(device)
            
            # choose input modality
            if pretrained_dict['modality'] == 'image':
                model_input = val_images
            elif pretrained_dict['modality'] == 'pose':
                model_input = val_pose_maps
            elif pretrained_dict['modality'] == 'depth':
                model_input = val_depth_maps
            elif pretrained_dict['modality'] == 'attention':
                model_input = [val_images, val_depth_maps, val_pose_maps]
            if pretrained_dict['modality'] == 'attention':
                if extended_model_present:
                    val_gaze_heatmap_pred, direction, val_inout_pred, val_att, val_pred_wh = model(model_input, val_faces, val_gaze_fields, val_head_channels)
                else:
                    val_gaze_heatmap_pred, direction, val_inout_pred, val_att = model(model_input, val_faces, val_gaze_fields, val_head_channels)
                pred_att.extend(val_att.cpu().numpy())
            else:
                val_gaze_heatmap_pred, direction, val_inout_pred = model(model_input, val_faces, val_gaze_fields, val_head_channels)
            val_gaze_heatmap_pred = val_gaze_heatmap_pred.squeeze(1)
            
            gt_gaze.extend(cont_gaze)
            gt_eye_point.extend(eye_point)
            pred_hm.extend(val_gaze_heatmap_pred.cpu().numpy())
            image_size.extend(imsize)
            paths.extend(path)
            directions.extend(direction.cpu().numpy())
            all_object_bboxes_list.extend(all_object_bboxes.cpu().numpy().astype(int)) 
            all_object_bboxes_class_list.extend(all_object_bboxes_class.cpu().numpy().astype(int)) 
            
            gazed_object_bbox_list.extend(gazed_object_bbox.cpu().numpy().astype(int))
            gazed_object_class_list.extend(gazed_object_class.cpu().numpy().astype(int))
            if extended_model_present:
                
                result_list_ttfnet = get_bboxes_ttfnet(val_gaze_heatmap_pred,val_pred_wh)
            
                if torch.numel(result_list_ttfnet[0][0]) > 0:
                    all_result_list_ttfnet.extend([result_list_ttfnet[0][0].cpu().numpy()])
                    
                else:
                    #all_result_list_ttfnet.extend(np.expand_dims(np.array([-1,-1,-1,-1,-1]), axis=0))
                    all_result_list_ttfnet.extend([np.expand_dims(np.array([-1,-1,-1,-1,-1]), axis=0)])
                
            if val_batch% 100 == 0:
                #th = 0.5
                path = str(path[0])
                print(path)
                imsize = imsize[0].int().numpy()
                gt_object_bbox = gazed_object_bbox.squeeze(0).cpu().numpy().astype(int)
                print(gt_object_bbox)
                print(all_result_list_ttfnet[val_batch])
                print(torch.numel(result_list_ttfnet[0][0])/5)
                
                pred_object_box_ttfnet = []
                        
                img1 = Image.open(os.path.join(image_dir, path))
                img1 = np.array(img1.convert('RGB').resize((640,480)))
                img2 = Image.open(os.path.join(image_dir, path))
                img2 = np.array(img2.convert('RGB').resize((640,480)))
                print(f'In vs Out: {val_inout_pred.cpu().numpy()}')
                if pred_att:
                    print(f'Attention weights image | depth | pose: {val_att[:, 0].cpu().numpy()} | {val_att[:, 1].cpu().numpy()} | {val_att[:, 2].cpu().numpy()}')
                fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(20, 15))
                ax1.axis('off')
                ax2.axis('off')
                ax1.set_title('GT',size=24,fontweight="bold")
                ax1.imshow(img1)
                ax1.imshow(cv2.rectangle(img1, (gt_object_bbox[0],gt_object_bbox[1]), (gt_object_bbox[2],gt_object_bbox[3]), (0,255,0), 2))
                ax1.imshow(cv2.resize(val_gaze_heatmap.squeeze(0).cpu().numpy(),(WIDTH,HEIGHT)), cmap='jet', alpha=0.5)
                
                ax2.set_title('Pred',size=24,fontweight="bold")
                ax2.imshow(img2)
                if torch.numel(result_list_ttfnet[0][0]) > 0:
                    for index in range(int(torch.numel(result_list_ttfnet[0][0])/5)):
                        cur_box = result_list_ttfnet[0][0].cpu().numpy()[index]
                        # get predicted BBox in 640x480
                        pred_object_box_ttfnet = torch.tensor([cur_box[0]*imsize[0]/output_resolution,
                                                            cur_box[1]*imsize[1]/output_resolution,
                                                            cur_box[2]*imsize[0]/output_resolution,
                                                            cur_box[3]*imsize[1]/output_resolution]).int().numpy()
                        ax2.imshow(cv2.rectangle(img2, (pred_object_box_ttfnet[0],pred_object_box_ttfnet[1]), (pred_object_box_ttfnet[2],pred_object_box_ttfnet[3]), (255,0,0), 2))
                ax2.imshow(cv2.resize(val_gaze_heatmap_pred.cpu().numpy().transpose(1,2,0), (WIDTH, HEIGHT)), cmap='jet', alpha=0.5)
                plt.show()
                
    
    if extended_model_present:
        AUC, min_dist, avg_dist, avg_ang, count_matching_object_class, count_matching_object_class_ap_50, avg_ang_hm, avg_pred_hm_energy_in_gt_bbox, avg_pred_hm_energy_in_gt_cat = compute_metrics_extend(pred_hm, gt_gaze, image_size, gt_eye_point, all_object_bboxes_list, all_object_bboxes_class_list, gazed_object_class_list, all_result_list_ttfnet, directions, gazed_object_bbox_list)
    else:
        AUC, min_dist, avg_dist, avg_ang, count_matching_object_class = compute_metrics(pred_hm, gt_gaze, image_size, gt_eye_point, all_object_bboxes_list, all_object_bboxes_class_list, gazed_object_class_list)
    if save_path is not None:
        output = {}
        if pretrained_dict['modality'] == 'attention':
            output['pred_att'] = pred_att
        output['pred_hm'] = pred_hm 
        output['gt_gaze'] = gt_gaze
        output['gazed_object_bbox'] = gazed_object_bbox_list
        output['paths'] = paths
        output['AUC'] = AUC 
        output['min_dist'] = min_dist 
        output['avg_dist'] = avg_dist 
        output['pred_direction'] = directions
        output['avg_ang'] = avg_ang
        output['avg_ang_hm'] = avg_ang_hm
        output['count_matching_object_class'] = count_matching_object_class
        output['avg_pred_hm_energy_in_gt_bbox'] = avg_pred_hm_energy_in_gt_bbox
        if extended_model_present:
            output['count_matching_object_class_ap_50'] = count_matching_object_class_ap_50
            output['avg_pred_hm_energy_in_gt_cat'] = avg_pred_hm_energy_in_gt_cat
            output['pred_bboxes_per_img'] = all_result_list_ttfnet
        with open(os.path.join(save_path, 'output_gazefollow.pkl'), 'wb') as fp:
            pickle.dump(output, fp)    
    
    
    final_AUC = np.mean(AUC)
    final_min_dist = np.mean(min_dist)
    final_avg_dist = np.mean(avg_dist)
    final_avg_ang = np.mean(avg_ang)
    final_avg_ang_hm = np.mean(avg_ang_hm)
    final_count_matching_object_class = np.mean(count_matching_object_class) * 100
    if extended_model_present:
        final_count_matching_object_class_ap_50 = np.mean(count_matching_object_class_ap_50) * 100
        final_avg_pred_hm_energy_in_gt_bbox = np.mean(avg_pred_hm_energy_in_gt_bbox)
        final_avg_pred_hm_energy_in_gt_cat = np.mean(avg_pred_hm_energy_in_gt_cat)
        
    if pred_att:
        avg_attention_weights = [sum(x) / len(x) for x in zip(*pred_att)]
        print(f'Avg. Attention weights image | depth | pose: {avg_attention_weights[0]} | {avg_attention_weights[1]} | {avg_attention_weights[2]}')
    
    if extended_model_present:
        print("\tAUC:{:.4f}\t min dist:{:.4f}\t avg dist:{:.4f}\t avg ang:{:.4f}\t Object prediction Acc. (%):{:.4f}\t BBoX head Object prediction Acc. (%):{:.4f} \t avg ang hm:{:.4f} \t avg_pred_hm_energy_in_gt_bbox (%): {:.4f} \t avg_pred_hm_energy_in_gt_cat (%): {:.4f}".format(
          final_AUC,
          final_min_dist,
          final_avg_dist,
          final_avg_ang,
          final_count_matching_object_class,
          final_count_matching_object_class_ap_50,
          final_avg_ang_hm,
          final_avg_pred_hm_energy_in_gt_bbox,
          final_avg_pred_hm_energy_in_gt_cat))
    
        return final_AUC, final_min_dist, final_avg_dist, final_avg_ang, final_count_matching_object_class, final_count_matching_object_class_ap_50
   
    print("\tAUC:{:.4f}\t min dist:{:.4f}\t avg dist:{:.4f}\t avg ang:{:.4f} \t Object prediction Acc.:{:.4f}".format(
          final_AUC,
          final_min_dist,
          final_avg_dist,
          final_avg_ang,
          final_count_matching_object_class))
    
    return final_AUC, final_min_dist, final_avg_dist, final_avg_ang, final_count_matching_object_class

def compute_metrics(pred_hm, gt_gaze, image_size, gt_eye_point, all_object_bboxes_list, all_object_bboxes_class_list, gazed_object_class_list, eps = 1e-8):
    
    AUC = []; min_dist = []; avg_dist = []
    avg_ang = []
    count_matching_object_class = []
    # go through each data point and record AUC, min dist, avg dist
    for b_i in tqdm(range(len(gt_gaze))):
        # remove padding and recover valid ground truth points
        valid_gaze = gt_gaze[b_i]        
        valid_gaze = valid_gaze[valid_gaze != -1].view(-1,2)
        
        valid_eye_point = gt_eye_point[b_i]
        valid_eye_point = valid_eye_point[valid_eye_point != -1].view(-1,2)
        
        # AUC: area under curve of ROC
        pm = pred_hm[b_i]
        multi_hot = multi_hot_targets(gt_gaze[b_i], image_size[b_i])
        scaled_heatmap = cv2.resize(pm, (image_size[b_i][0].item(), image_size[b_i][1].item()))
        auc_score = evaluation.auc(scaled_heatmap, multi_hot)
        AUC.append(auc_score)
        # min distance: minimum among all possible pairs of <ground truth point, predicted point>
        pred_x, pred_y = evaluation.argmax_pts(pm)
        norm_p = [pred_x/float(output_resolution), pred_y/float(output_resolution)]
        all_distances = []
        for gaze in valid_gaze:
            all_distances.append(evaluation.L2_dist(gaze, norm_p))
        min_dist.append(min(all_distances))
        # average distance: distance between the predicted point and human average point
        mean_gt_gaze = torch.mean(valid_gaze, 0)
        avg_distance = evaluation.L2_dist(mean_gt_gaze, norm_p)
        avg_dist.append(avg_distance)
        mean_gt_gaze_direction = mean_gt_gaze - valid_eye_point
        mean_pred_gaze_direction = torch.tensor(norm_p) - valid_eye_point
        
        #f_cos_sim = (np.dot(mean_gt_gaze_direction, mean_pred_gaze_direction)/((norm(mean_gt_gaze_direction)*norm(mean_pred_gaze_direction))+eps))
        #f_cos_sim = np.maximum(np.minimum(f_cos_sim, 1.0), -1.0)
        #avg_ang.append(np.rad2deg(np.arccos(f_cos_sim)))
        avg_ang.append(torch.rad2deg(torch.acos(cos_sim_func(mean_gt_gaze_direction,mean_pred_gaze_direction))).item())
        
        cur_all_object_bboxes_list = np.array(all_object_bboxes_list[b_i])
        cur_all_object_bboxes_class_list = np.array(all_object_bboxes_class_list[b_i])
        cur_gazed_object_class = np.array(gazed_object_class_list[b_i])
        count_matching_object_class.extend(
            match_object_cat_after_get_predicted_bbox_from_energy(
                scaled_heatmap,
                cur_all_object_bboxes_list, 
                cur_all_object_bboxes_class_list,
                cur_gazed_object_class))
        
    return np.array(AUC), np.array(min_dist), np.array(avg_dist), np.abs(np.array(avg_ang)),  np.array(count_matching_object_class)


def compute_metrics_extend(pred_hm, gt_gaze, image_size, gt_eye_point, all_object_bboxes_list, 
                    all_object_bboxes_class_list, gazed_object_class_list, 
                    all_result_list_ttfnet, directions, gazed_object_bbox_list, eps = 1e-8):
    
    AUC = []; min_dist = []; avg_dist = []
    avg_ang = []
    avg_ang_hm = []
    count_matching_object_class = []
    count_matching_object_class_ap_50 = []
    avg_pred_hm_energy_in_gt_bbox = []
    avg_pred_hm_energy_in_gt_cat = []
    # go through each data point and record AUC, min dist, avg dist
    for b_i in tqdm(range(len(gt_gaze))):
        
        # remove padding and recover valid ground truth points
        valid_gaze = gt_gaze[b_i]        
        valid_gaze = valid_gaze[valid_gaze != -1].view(-1,2)
        
        valid_eye_point = gt_eye_point[b_i]
        valid_eye_point = valid_eye_point[valid_eye_point != -1].view(-1,2)
        
        valid_pred_direction = torch.tensor(directions[b_i])
        # AUC: area under curve of ROC
        pm = pred_hm[b_i]
        multi_hot = multi_hot_targets(gt_gaze[b_i], image_size[b_i])
        scaled_heatmap = cv2.resize(pm, (image_size[b_i][0].item(), image_size[b_i][1].item()))
        auc_score = evaluation.auc(scaled_heatmap, multi_hot)
        AUC.append(auc_score)
        # min distance: minimum among all possible pairs of <ground truth point, predicted point>
        pred_x, pred_y = evaluation.argmax_pts(pm)
        norm_p = [pred_x/float(output_resolution), pred_y/float(output_resolution)]
        all_distances = []
        for gaze in valid_gaze:
            all_distances.append(evaluation.L2_dist(gaze, norm_p))
        min_dist.append(min(all_distances))
        # average distance: distance between the predicted point and human average point
        mean_gt_gaze = torch.mean(valid_gaze, 0)
        avg_distance = evaluation.L2_dist(mean_gt_gaze, norm_p)
        avg_dist.append(avg_distance)
        
        
        mean_gt_gaze_direction = mean_gt_gaze - valid_eye_point
        mean_pred_gaze_direction = torch.tensor(norm_p) - valid_eye_point
        avg_ang.append(torch.rad2deg(torch.acos(cos_sim_func(mean_gt_gaze_direction,valid_pred_direction))).item())
        
        mean_pred_gaze_direction = mean_pred_gaze_direction.squeeze().numpy()
        mean_gt_gaze_direction = mean_gt_gaze_direction.squeeze().numpy()
        norm_f = (mean_pred_gaze_direction[0] ** 2 + mean_pred_gaze_direction[1] ** 2) ** 0.5
        norm_gt = (mean_gt_gaze_direction[0] ** 2 + mean_gt_gaze_direction[1] ** 2) ** 0.5

        f_cos_sim = (mean_pred_gaze_direction[0] * mean_gt_gaze_direction[0] + 
                     mean_pred_gaze_direction[1] * mean_gt_gaze_direction[1]) / \
                    (norm_gt * norm_f + eps)
        f_cos_sim = np.maximum(np.minimum(f_cos_sim, 1.0), -1.0)
        f_angle = np.arccos(f_cos_sim) * 180 / np.pi
        avg_ang_hm.extend([f_angle])
                        
        cur_all_object_bboxes_list = np.array(all_object_bboxes_list[b_i])
        cur_all_object_bboxes_class_list = np.array(all_object_bboxes_class_list[b_i])
        cur_gazed_object_class = np.array(gazed_object_class_list[b_i])
        count_matching_object_class.extend(
            match_object_cat_after_get_predicted_bbox_from_energy(
                scaled_heatmap,
                cur_all_object_bboxes_list, 
                cur_all_object_bboxes_class_list,
                cur_gazed_object_class))
        
        cur_all_result_list_ttfnet = np.array(all_result_list_ttfnet[b_i])
        #print(cur_all_result_list_ttfnet)
        count_matching_object_class_ap_50.extend(
            match_object_cat_from_ttffnet_regression(
            cur_all_result_list_ttfnet,
            image_size[b_i],
            cur_all_object_bboxes_list,
            cur_all_object_bboxes_class_list,
            cur_gazed_object_class))
        
        avg_pred_hm_energy_in_gt_bbox.extend([
        get_avg_energy_by_gtbox_predheatmap(
            gazed_object_bbox_list[b_i],
            scaled_heatmap)])
        
        avg_pred_hm_energy_in_gt_cat.extend([
        get_avg_energy_in_gtcat_predheatmap(
            cur_all_object_bboxes_list,
            scaled_heatmap,
            cur_gazed_object_class,
            cur_all_object_bboxes_class_list)])
        
            
    return np.array(AUC), np.array(min_dist), np.array(avg_dist), np.abs(np.array(avg_ang)),  np.array(count_matching_object_class), np.array(count_matching_object_class_ap_50), np.abs(np.array(avg_ang_hm)), np.array(avg_pred_hm_energy_in_gt_bbox), np.array(avg_pred_hm_energy_in_gt_cat)

# Get energy aggregation loss from GaTector paper
def get_avg_energy_by_gtbox_predheatmap(box, heatmap, width=640, height=480):
    
    # Use ground truth box and predicted heatmap to compute the energy aggregation loss
    # GT bbox is passed in size (640, 480)
    # Check the scale factor orginal = 10 *
    power, total_power = 0., 0.
    eng = 0.
    cur_box = box
   
    cur_heatmap = heatmap
    #cur_heatmap = np.maximum(heatmap,0)
    # axis are flipped in the heatmap
    power = np.sum(cur_heatmap[cur_box[1]: cur_box[3] + 1, cur_box[0]: cur_box[2] + 1])
    total_power = cur_heatmap.sum()
    if total_power > 0:
        eng = (power / total_power) * 100
    
    return eng 

def get_avg_energy_in_gtcat_predheatmap(all_object_bboxes, scaled_pred_heatmap, gazed_object_class, all_object_bboxes_class, width=640, height=480):
    
    # Use ground truth boxes and predicted heatmap to compute the energy aggregation loss
    # Since GT Object BBoXes overlap hence it is possible that cat_power > total_power
    # And cat_eng becomes > 100%
    cat_power, total_power = 0., 0.
    cat_eng = 0.
    cat_boxes = all_object_bboxes[np.where(all_object_bboxes_class == gazed_object_class)[0]]
    cur_heatmap = scaled_pred_heatmap
    total_power = cur_heatmap.sum()
    
    #cur_heatmap = np.maximum(heatmap,0)
    # axis are flipped in the heatmap
    if total_power > 0:
        for index in range (len(cat_boxes)):
            cur_box = cat_boxes[index]
            cat_power += np.sum(cur_heatmap[cur_box[1]: cur_box[3] + 1, cur_box[0]: cur_box[2] + 1])
            
        cat_eng = np.minimum((cat_power / total_power) * 100 , 100)
    
    return cat_eng 

# Iterate through all GT boxes and calc mean energy of BBox. 
# Predict the BBox with max mean energy

def match_object_cat_after_get_predicted_bbox_from_energy(scaled_pred_heatmap,all_object_bboxes, all_object_bboxes_class, gazed_object_class):
    """
    Use pred heatmap to find the GT object bboxes with maximum energy as gaze target
    """
    max_energy = 0.
    pred_bbox = None
    cur_all_object_bboxes = all_object_bboxes
    cur_all_object_bboxes_class = all_object_bboxes_class
    # All Bboxes are in original image resolution
    cur_pred_heatmap = scaled_pred_heatmap
    #cur_pred_heatmap = np.maximum(scaled_pred_heatmap,0)
    cur_gazed_object_class = gazed_object_class
    
    for ind, cur_box in enumerate(cur_all_object_bboxes):
        xmin, ymin, xmax, ymax =  cur_box
        # axis are flipped in the heatmap
        no_of_pixels_in_box = (xmax+1-xmin) * (ymax+1-ymin)
        mean_energy = np.sum(cur_pred_heatmap[ymin: ymax + 1, xmin: xmax + 1])/(no_of_pixels_in_box)
        #mean_energy = torch.sum(cur_pred_heatmap[ymin: ymax + 1, xmin: xmax + 1])
        if mean_energy > max_energy:
            max_energy = mean_energy
            pred_bbox = cur_box
            pred_box_class = cur_all_object_bboxes_class[ind]

    if pred_bbox is not None:
        if int(pred_box_class) == int(cur_gazed_object_class):
            return [1.] 
        else:
            return [0.]
    else:
        return [0.]



'''
# Iterate through all GT Object category and calc total mean energy of this category. 
# Predict the Object category with max total energy
def match_object_cat_after_get_predicted_bbox_from_energy(scaled_pred_heatmap,all_object_bboxes, all_object_bboxes_class, gazed_object_class):
    """
    Use pred heatmap to find the GT object bboxes with maximum energy as gaze target
    """
    max_energy_in_cat = 0.
    pred_box_class = None
    cur_all_object_bboxes = all_object_bboxes
    cur_all_object_bboxes_class = all_object_bboxes_class
    # All Bboxes are in original image resolution
    cur_pred_heatmap = scaled_pred_heatmap
    cur_gazed_object_class = gazed_object_class
    
    #Object categories in the scene is from 1-24
    for cat_ind in range(1,25):
        cat_boxes = all_object_bboxes[np.where(all_object_bboxes_class == cat_ind)[0]]
        total_mean_energy_in_cat = 0.0
        for index in range (len(cat_boxes)):
            xmin, ymin, xmax, ymax = cat_boxes[index]
         
            # axis are flipped in the heatmap
            no_of_pixels_in_box = (xmax+1-xmin) * (ymax+1-ymin)
        
            total_mean_energy_in_cat += np.sum(cur_pred_heatmap[ymin: ymax + 1, xmin: xmax + 1])/(no_of_pixels_in_box)
            
            
        if total_mean_energy_in_cat > max_energy_in_cat:
            max_energy_in_cat = total_mean_energy_in_cat
            pred_box_class = cat_ind

    if pred_box_class is not None:
        if int(pred_box_class) == int(cur_gazed_object_class):
            return [1.] 
        else:
            return [0.]
    else:
        return [0.]
'''

    
def match_object_cat_from_ttffnet_regression(cur_all_result_list_ttfnet, cur_image_size, all_object_bboxes, all_object_bboxes_class, gazed_object_class):
    cat_wuoc = 0.0
    cat_ious = 0.0
    for ind, cur_box in enumerate(cur_all_result_list_ttfnet):
        
        # get predicted BBox in 64x64
        pred_gazed_object_box_from_reg = cur_box[:4]
        if pred_gazed_object_box_from_reg[0] == -1:
            break
        # get predicted BBox in 640x480
        pred_gazed_object_box_from_reg_imsize = torch.tensor([pred_gazed_object_box_from_reg[0]*cur_image_size[0]/output_resolution,
                                                              pred_gazed_object_box_from_reg[1]*cur_image_size[1]/output_resolution,
                                                              pred_gazed_object_box_from_reg[2]*cur_image_size[0]/output_resolution,
                                                              pred_gazed_object_box_from_reg[3]*cur_image_size[1]/output_resolution]).int()
        pred_gazed_object_box_from_reg_imsize = pred_gazed_object_box_from_reg_imsize.unsqueeze(dim=0)
        ious, wuocs = bbox_overlaps_ttfnet(torch.tensor(all_object_bboxes), pred_gazed_object_box_from_reg_imsize, mode='iou', is_aligned=False)
        cat_ious = (ious[np.where(all_object_bboxes_class == gazed_object_class)].sum())
        #print(f'cat_ious.sum() for BBoX {ind}: {cat_ious}')
        cat_wuoc = (wuocs[np.where(all_object_bboxes_class == gazed_object_class)].sum())
        #print(f'cat_wuoc.sum() for BBoX {ind}: {cat_wuoc}')
        if int(all_object_bboxes_class[torch.argmax(ious,axis=0)]) == int(gazed_object_class) and torch.argmax(ious,axis=0)>=0.5:
        #if int(all_object_bboxes_class[torch.argmax(wuocs,axis=0)]) == int(gazed_object_class):
            return [1.]
    
    return [0.]

if __name__ == "__main__" and flag_run_on_goo_dataset:
    
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model_epoch_14.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-ttfnet-gaussian_epoch_7.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-ttfnet-gaussian-only-pose-modality_epoch_20.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-ttfnet-complete_2_epoch_12.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-ttfnet-focal_loss_only_epoch_15.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-energy_aggr_loss_epoch_20.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-plus-synth-original-model-energy_aggr_loss_epoch_10.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-energy_wh_loss_epoch_4.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-energy_aggr_loss_plus_focal_loss_epoch_13.pt'
    model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-focal-ciou-regr_epoch_27.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-efficientnet_b2_epoch_32.pt'
    device = ('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("Loading Data")

    transform = _get_transform()
    transform_modality = _get_transform_modality()
    
    val_dataset = GooRealDataset(df_test, transform, transform_modality, 
                             input_size=input_resolution, output_size=output_resolution,
                             modality=modality, test=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=2)
    cos_sim_func = nn.CosineSimilarity(dim=1, eps=1e-8)
    if extended_model_present:
        final_AUC, final_min_dist, final_avg_dist, final_avg_ang, final_count_matching_object_class, final_count_matching_object_class_ap_50 = test(model_weights, val_loader, batch_size=1, device=0, mode='pt', save_path = '/kaggle/working/')

    else:
        final_AUC, final_min_dist, final_avg_dist, final_avg_ang,final_count_matching_object_class = test(model_weights, val_loader, batch_size=1, device=0, mode='pt', save_path = '/kaggle/working/')

In [None]:
"""
goo-real-original-model_epoch_14.pt
Avg. Attention weights image | depth | pose: [[[0.44971883]]] | [[[0.35991648]]] | [[[0.19036518]]]
AUC:0.9553	 min dist:0.1160	 avg dist:0.1160	 avg ang:19.0602 Object prediction Acc.:35.6011
"""

"""
goo-real-original-model-ttfnet-gaussian_epoch_7.pt
Avg. Attention weights image | depth | pose: [[[0.5265263]]] | [[[0.31016862]]] | [[[0.16330457]]]
	AUC:0.9210	 min dist:0.1117	 avg dist:0.1117	 avg ang:18.1694 	 Object prediction Acc.:38.4902

Sparse dataset
Avg. Attention weights image | depth | pose: [[[0.5288933]]] | [[[0.2894723]]] | [[[0.18163462]]]
	AUC:0.9206	 min dist:0.1272	 avg dist:0.1272	 avg ang:22.1647 	 Object prediction Acc.:37.4355
"""

"""
goo-real-original-model-ttfnet-gaussian-only-pose-modality_epoch_20.pt

Sparse dataset
AUC:0.8718	 min dist:0.1497	 avg dist:0.1497	 avg ang:23.6209 	 Object prediction Acc.:31.0671
"""

"""
goo-real-original-model-ttfnet-complete_2_epoch_12.pt
Avg. Attention weights image | depth | pose: [[[0.64857745]]] | [[[0.23286861]]] | [[[0.11855373]]]
AUC:0.9849	 min dist:0.1115	 avg dist:0.1115	 avg ang:18.9836	 Object prediction Acc. (%):38.0242	 BBoX head Object prediction Acc. (%):33.4110
Sparse dataset
Avg. Attention weights image | depth | pose: [[[0.68183804]]] | [[[0.19980803]]] | [[[0.11835343]]]
AUC:0.9870	 min dist:0.1256	 avg dist:0.1256	 avg ang:22.4442	 Object prediction Acc. (%):36.2306	 BBoX head Object prediction Acc. (%):30.7229
"""

"""
goo-real-original-model-ttfnet-focal_loss_only_epoch_15.pt
Avg. Attention weights image | depth | pose: [[[0.6077914]]] | [[[0.17469546]]] | [[[0.21751298]]]
AUC:0.9748	 min dist:0.1069	 avg dist:0.1069	 avg ang:20.1690	 Object prediction Acc. (%):41.6123	 BBoX head Object prediction Acc. (%):38.0708 	 avg ang hm:18.3665 	 avg_pred_hm_energy_in_gt_bbox (%): 12.9769

Sparse dataset
Avg. Attention weights image | depth | pose: [[[0.6220978]]] | [[[0.16036293]]] | [[[0.21753924]]]
AUC:0.9842	 min dist:0.1232	 avg dist:0.1232	 avg ang:24.2822	 Object prediction Acc. (%):38.7263	 BBoX head Object prediction Acc. (%):33.3046 	 avg ang hm:21.7696 	 avg_pred_hm_energy_in_gt_bbox (%): 14.6222
"""

"""
goo-real-original-model-energy_aggr_loss_epoch_20.pt
Avg. Attention weights image | depth | pose: [[[0.48699257]]] | [[[0.29781172]]] | [[[0.2151948]]]
AUC:0.6280	 min dist:0.1056	 avg dist:0.1056	 avg ang:19.5412	 Object prediction Acc. (%):43.8024	 BBoX head Object prediction Acc. (%):34.5760 	 avg ang hm:18.0903 	 avg_pred_hm_energy_in_gt_bbox (%): 20.9762
# After non-RELU trained weights and RELU inference 
Avg. Attention weights image | depth | pose: [[[0.4869935]]] | [[[0.29781204]]] | [[[0.21519478]]]
AUC:0.7840	 min dist:0.1056	 avg dist:0.1056	 avg ang:19.5412	 Object prediction Acc. (%):43.8956	 BBoX head Object prediction Acc. (%):34.5760 	 avg ang hm:18.0903 	 avg_pred_hm_energy_in_gt_bbox (%): 20.6925 	 avg_pred_hm_energy_in_gt_cat (%): 44.6343

Train 
Avg. Attention weights image | depth | pose: [[[0.5025198]]] | [[[0.29043576]]] | [[[0.2070445]]]
AUC:0.9966	 min dist:0.0211	 avg dist:0.0211	 avg ang:3.5655	 Object prediction Acc. (%):90.4082	 BBoX head Object prediction Acc. (%):88.3673 	 avg ang hm:2.7669 	 avg_pred_hm_energy_in_gt_bbox (%): 56.2714 	 avg_pred_hm_energy_in_gt_cat (%): 88.5585
    
Sparse dataset
Avg. Attention weights image | depth | pose: [[[0.51403946]]] | [[[0.2729661]]] | [[[0.21299438]]]
AUC:0.6862	 min dist:0.1205	 avg dist:0.1205	 avg ang:23.5692	 Object prediction Acc. (%):39.5009	 BBoX head Object prediction Acc. (%):25.8176 	 avg ang hm:21.2358 	 avg_pred_hm_energy_in_gt_bbox (%): 26.0571
"""

## Check whether used output relu for hm

"""
goo-real-plus-synth-original-model-energy_aggr_loss_epoch_10.pt
Avg. Attention weights image | depth | pose: [[[0.5025191]]] | [[[0.29169896]]] | [[[0.2057821]]]
AUC:0.7544	 min dist:0.1179	 avg dist:0.1179	 avg ang:22.6881	 Object prediction Acc. (%):39.3290	 BBoX head Object prediction Acc. (%):23.2992 	 avg ang hm:21.0754 	 avg_pred_hm_energy_in_gt_bbox (%): nan

GOO Synth tes dataset which was used for training is used
Avg. Attention weights image | depth | pose: [[[0.3271901]]] | [[[0.36095533]]] | [[[0.31185308]]]
AUC:0.7144	 min dist:0.1610	 avg dist:0.1610	 avg ang:18.5566	 Object prediction Acc. (%):30.1927	 BBoX head Object prediction Acc. (%):24.5104 	 avg ang hm:41.6436 	 avg_pred_hm_energy_in_gt_bbox (%): 19.0564 	 avg_pred_hm_energy_in_gt_cat (%): 27.9017
    
"""

"""
goo-real-original-model-energy_wh_loss_epoch_4.pt
Avg. Attention weights image | depth | pose: [[[0.5299074]]] | [[[0.29873222]]] | [[[0.17135933]]]
AUC:0.7459	 min dist:0.1104	 avg dist:0.1104	 avg ang:19.6288	 Object prediction Acc. (%):42.8705	 BBoX head Object prediction Acc. (%):42.1715 	 avg ang hm:18.8043 	 avg_pred_hm_energy_in_gt_bbox (%): 21.0518 	 avg_pred_hm_energy_in_gt_cat (%): 44.8470
"""

"""

Avg. Attention weights image | depth | pose: [[[0.50555205]]] | [[[0.2576175]]] | [[[0.2368302]]]
AUC:0.9558	 min dist:0.1102	 avg dist:0.1102	 avg ang:24.2152	 Object prediction Acc. (%):40.0746	 BBoX head Object prediction Acc. (%):6.6636 	 avg ang hm:20.9180 	 avg_pred_hm_energy_in_gt_bbox (%): 16.2313 	 avg_pred_hm_energy_in_gt_cat (%): 40.9557
"""

"""
ttfnet bbox prediction head inp -> ximg
goo-real-original-model-focal-ciou-regr_epoch_27.pt
Avg. Attention weights image | depth | pose: [[[0.52287644]]] | [[[0.2664772]]] | [[[0.21064633]]]
AUC:0.9801	 min dist:0.1093	 avg dist:0.1093	 avg ang:22.8425	 Object prediction Acc. (%):40.9133	 BBoX head Object prediction Acc. (%):41.0531 	 avg ang hm:20.0835 	 avg_pred_hm_energy_in_gt_bbox (%): 14.1286 	 avg_pred_hm_energy_in_gt_cat (%): 37.3826

# Object prediction Acc. -> calc with product cat with max total mean energy
# BBoX head Object prediction Acc.-> calc with any BBox in max BBoxes per image is 5 
# (select all pred BBox with conf > 0.1) matches with GT BBox
Avg. Attention weights image | depth | pose: [[[0.522876]]] | [[[0.2664773]]] | [[[0.21064635]]]
AUC:0.9801	 min dist:0.1093	 avg dist:0.1093	 avg ang:22.8425	 Object prediction Acc. (%):39.7950	 BBoX head Object prediction Acc. (%):61.4632 	 avg ang hm:20.0835 	 avg_pred_hm_energy_in_gt_bbox (%): 14.1286 	 avg_pred_hm_energy_in_gt_cat (%): 37.3826

# Test dataset -> GOO Inf
# Object prediction Acc. -> BBox with max total mean energy
# BBoX head Object prediction Acc.-> calc with any BBox in max BBoxes per image is 3 
# (select all pred BBox with conf > 0.2) matches with GT BBox
Avg. Attention weights image | depth | pose: [[[0.5275465]]] | [[[0.25919157]]] | [[[0.21326162]]]
AUC:0.9809	 min dist:0.1060	 avg dist:0.1060	 avg ang:20.0947	 Object prediction Acc. (%):43.9682	 BBoX head Object prediction Acc. (%):53.2762 	 avg ang hm:17.9558 	 avg_pred_hm_energy_in_gt_bbox (%): 14.7562 	 avg_pred_hm_energy_in_gt_cat (%): 39.720

Train Acc
Avg. Attention weights image | depth | pose: [[[0.5368456]]] | [[[0.2633405]]] | [[[0.19981423]]]
AUC:0.9993	 min dist:0.0216	 avg dist:0.0216	 avg ang:4.7542	 Object prediction Acc. (%):89.3878	 BBoX head Object prediction Acc. (%):89.8367 	 avg ang hm:3.0586 	 avg_pred_hm_energy_in_gt_bbox (%): 35.1186 	 avg_pred_hm_energy_in_gt_cat (%): 71.9576
"""

In [None]:
def test(model_weights, val_loader=None, device=None, batch_size=1, mode='dict', save_path=None):
    plot_figure = True
    count = 0
    # Load model
    print("Constructing model")
    if mode=='pt':
        pretrained_dict = torch.load(model_weights)
    elif mode=='dict':
        pretrained_dict = model_weights
    
    if extended_model_present:
        if pretrained_dict['modality'] == 'attention':
            model_base = AttentionModelCombined(cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        else:
            model_base = BaselineModel(pretrained_dict['backbone_name'], pretrained_dict['modality'], cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        model_base.cuda().to(device)
        model = attentionModelBboxHead(model_base)
   
    else:
        if pretrained_dict['modality'] == 'attention':
            model = AttentionModelCombined(cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        else:
            model = BaselineModel(pretrained_dict['backbone_name'], pretrained_dict['modality'], cone_mode=pretrained_dict['cone_mode'], pred_inout=pretrained_dict['pred_inout'])
        

    model.cuda().to(device)
    model_dict = model.state_dict()
    model_dict.update(pretrained_dict['model'])
    model.load_state_dict(model_dict)

    print('Evaluation in progress ...')
    model.train(False)
    gt_gaze = [] 
    pred_hm = [] 
    gt_hm = []
    image_size = [] 
    paths = []
    pred_att = [] 
    directions = []
    gt_directions = []
    in_vs_out_groundtruth = []
    in_vs_out_pred = []
    with torch.no_grad():
        for val_batch, (val_img, val_face, val_pose, val_depth, val_gaze_field, val_gt_direction, val_head_channel, val_gaze_heatmap, cont_gaze, imsize, gaze_inside, path, path_seg_mask) in tqdm(enumerate(val_loader), total=len(val_loader)):
            count += 1
            val_images = val_img.cuda().to(device)
            val_faces = val_face.cuda().to(device)
            val_head_channels = val_head_channel.cuda().to(device)
            val_gaze_fields = val_gaze_field.cuda().to(device)
            val_depth_maps = val_depth.cuda().to(device)
            val_pose_maps = val_pose.cuda().to(device)
            val_gt_direction = val_gt_direction.cuda().to(device)
            gt_hm.extend(val_gaze_heatmap)
            
            # choose input modality
            if pretrained_dict['modality'] == 'image':
                model_input = val_images
            elif pretrained_dict['modality'] == 'pose':
                model_input = val_pose_maps
            elif pretrained_dict['modality'] == 'depth':
                model_input = val_depth_maps
            elif pretrained_dict['modality'] == 'attention':
                model_input = [val_images, val_depth_maps, val_pose_maps]
            if pretrained_dict['modality'] == 'attention':
                if extended_model_present:
                    val_gaze_heatmap_pred, direction, val_inout_pred, val_att, val_pred_wh = model(model_input, val_faces, val_gaze_fields, val_head_channels)
                else:
                    val_gaze_heatmap_pred, direction, val_inout_pred, val_att = model(model_input, val_faces, val_gaze_fields, val_head_channels)
                pred_att.extend(val_att.cpu().numpy())
            else:
                val_gaze_heatmap_pred, direction, val_inout_pred = model(model_input, val_faces, val_gaze_fields, val_head_channels)
            val_gaze_heatmap_pred = val_gaze_heatmap_pred.squeeze(1)
            
            gt_gaze.extend(cont_gaze.cpu().numpy())
            pred_hm.extend(val_gaze_heatmap_pred.cpu().numpy())
            image_size.extend(imsize.cpu().numpy())
            paths.extend(path)
            directions.extend(direction.cpu().numpy())
            gt_directions.extend(val_gt_direction.cpu().numpy())
            # in vs out classification
            in_vs_out_groundtruth.extend(gaze_inside.float().numpy())
            in_vs_out_pred.extend(val_inout_pred.cpu().numpy())
            
            if plot_figure and count%50 == 0:
                path = str(path[0])
                print(path)
                image_dir = '/kaggle/input/retailgaze/RetailGaze_V2_seg/RetailGaze_V2/'
                img = Image.open(os.path.join(image_dir, path))
                img = np.array(img.convert('RGB'))
                
                print(f'In vs Out: {val_inout_pred.cpu().numpy()}')
                print(f'Attention weights image | depth | pose: {val_att[:, 0].cpu().numpy()} | {val_att[:, 1].cpu().numpy()} | {val_att[:, 2].cpu().numpy()}')
                #ax = plt.gca()
                #pred_x, pred_y = evaluation.argmax_pts(raw_hm)
                fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(20, 15))
                ax1.axis('off')
                ax2.axis('off')
                #ax1.imshow(cv2.resize(val_images.squeeze(0).cpu().numpy().transpose(1,2,0),(WIDTH,HEIGHT)))
                ax1.set_title('GT',size=24,fontweight="bold")
                ax1.imshow(img)
                if path_seg_mask:
                    path_seg_mask = str(path_seg_mask[0])
                    print(path_seg_mask)
                    img_mask = cv2.imread(os.path.join(image_dir, path_seg_mask))
                    ax1.imshow(img_mask,alpha=0.25)
                ax1.imshow(cv2.resize(val_gaze_heatmap.squeeze(0).cpu().numpy(),(WIDTH,HEIGHT)), cmap='jet', alpha=0.25)
                
                #circ_act = patches.Circle((np.rint(cont_gaze),np.rint(cont_gaze),height/50.0, facecolor=(1,0,0), edgecolor='none')
                #ax.add_patch(circ_act)
                #ax2.imshow(cv2.resize(val_images.squeeze(0).cpu().numpy().transpose(1,2,0),(WIDTH,HEIGHT)))
                ax2.set_title('Pred',size=24,fontweight="bold")
                ax2.imshow(img)
                if path_seg_mask:
                    ax2.imshow(img_mask,alpha=0.25)
                ax2.imshow(cv2.resize(val_gaze_heatmap_pred.cpu().numpy().transpose(1,2,0), (WIDTH, HEIGHT)), cmap='jet', alpha=0.25)
                plt.show()
                    
    
    
    AUC, distance, cos_sim = compute_metrics(pred_hm, gt_hm, gt_gaze,directions,gt_directions)
    if save_path is not None:
        output = {}
        output['pred_att'] = pred_att
        output['pred_hm'] = pred_hm; output['gt_gaze'] = gt_gaze; output['paths'] = paths; output['direction'] = directions, output['gt_inout'] = in_vs_out_groundtruth
        output['AUC'] = AUC; output['gt_directions'] = gt_directions; output['pred_inout'] = in_vs_out_pred;  output['cos_sim'] = cos_sim
        with open(os.path.join(save_path, 'output_retailgaze.pkl'), 'wb') as fp:
            pickle.dump(output, fp)    
            
    final_AUC = torch.mean(torch.tensor(AUC))
    final_distance = torch.mean(torch.tensor(distance))
    final_ap = evaluation.ap(in_vs_out_groundtruth, in_vs_out_pred)
    final_avg_ang_dist = np.mean(np.rad2deg(np.arccos(cos_sim)))
    avg_attention_weights = [sum(x) / len(x) for x in zip(*pred_att)]
    print(f'Avg. Attention weights image | depth | pose: {avg_attention_weights[0]} | {avg_attention_weights[1]} | {avg_attention_weights[2]}')
    print("\tAUC:{:.4f}\t Avg. L2 dist:{:.4f}\t in vs out AP:{:.4f}\t Avg. Angular Dist.:{:.4f}".format(
          final_AUC,
          final_distance,
          final_ap,
          final_avg_ang_dist))
    
    return final_AUC, final_distance, final_ap, final_avg_ang_dist


def compute_metrics(pred_hm, gt_hm, gt_gaze,directions,gt_directions):
    AUC = []; distance = []; cos_sim = []; dir_metric = []
    # go through each data point and record AUC, min dist, avg dist
    inout = [gt_gaze[i].mean()==-1 for i in range(len(gt_gaze))]
    print(np.array(inout).sum())
    for b_i in tqdm(range(len(gt_hm))):
        if gt_gaze[b_i].mean()!=-1:
            multi_hot = gt_hm[b_i]
            multi_hot = (multi_hot > 0).float() * 1 # make GT heatmap as binary labels
            multi_hot = misc.to_numpy(multi_hot)

            pm = pred_hm[b_i]
            scaled_heatmap = cv2.resize(pm, (output_resolution, output_resolution))
            auc_score = evaluation.auc(scaled_heatmap, multi_hot)
            AUC.append(auc_score)

            gaze_x, gaze_y = gt_gaze[b_i]
            # distance: L2 distance between ground truth and argmax point
            pred_x, pred_y = evaluation.argmax_pts(pm)
            norm_p = [pred_x/output_resolution, pred_y/output_resolution]
            dist_score = evaluation.L2_dist([gaze_x, gaze_y], norm_p).item()
            distance.append(dist_score)
            dir_metric.append(directions[b_i]* gt_directions[b_i])
            cos_sim.append(np.maximum(np.minimum(np.dot(directions[b_i], gt_directions[b_i])/(norm(directions[b_i])*norm(gt_directions[b_i])), 1.0), -1.0))
            
    print(np.rad2deg(np.arccos(np.mean(dir_metric)))) 
    return np.array(AUC), np.array(distance), np.array(cos_sim)

if __name__ == "__main__" and not(flag_run_on_goo_dataset):
    
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-weights/epoch_13.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-weights/attention-videoatttarget.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-ttfnet-focal_loss_only_epoch_15.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-ttfnet-gaussian_epoch_7.pt'
    #model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-energy_aggr_loss_epoch_20.pt'
    model_weights = '/kaggle/input/mm-gaze-target-prediction-new-weights/goo-real-original-model-focal-ciou-regr_epoch_27.pt'
    
    device = ('cuda' if torch.cuda.is_available() else 'cpu')
    print("Loading Data")

    transform = _get_transform()
    transform_modality = _get_transform_modality()
    
    val_dataset = retailGazeDataset(df_test, transform, transform_modality, 
                       input_size=input_resolution, output_size=output_resolution, 
                       test=True,imshow=False)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=0)
    
    final_AUC, final_distance, final_ap,final_avg_ang_dist = test(model_weights, val_loader, device, batch_size=1, mode='pt')

In [None]:
"""
goo-real-original-model-energy_aggr_loss_epoch_20.pt
64.28667002474117
Avg. Attention weights image | depth | pose: [[[0.43945652]]] | [[[0.3071239]]] | [[[0.2534196]]]
AUC:0.6414	 Avg. L2 dist:0.2216	 in vs out AP:1.0000	 Avg. Angular Dist.:20.8008
"""

"""
goo-real-original-model-focal-ciou-regr_epoch_27.pt
64.69747255415196
Avg. Attention weights image | depth | pose: [[[0.47156763]]] | [[[0.25737855]]] | [[[0.27105403]]]
AUC:0.8151	 Avg. L2 dist:0.2219	 in vs out AP:1.0000	 Avg. Angular Dist.:21.3683
"""