In [None]:
#@title Mount google drive

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

%cd './drive/MyDrive/gaze_estimation'

In [None]:
#@title Import required modules

import os
import time
import copy
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm.auto import tqdm
from sklearn.linear_model import LinearRegression


from hglm.precision_module import large_precision_module
from loss.h_likelihood_precision import nhll_correlated_precision
from util import make_reproducibility, TensorDataset, convert_to_xyz, mae
from networks import *
from hglm.hglm_correlated_precision import correlated_precision_without_val

In [None]:
#@title Load preprocessed & subsampled data (LOOCV)

ids =                    np.load('../mpii_dataset/loocv_ids.npy')
images = torch.as_tensor(np.load('../mpii_dataset/loocv_images.npy'), dtype=torch.float)
hps =    torch.as_tensor(np.load('../mpii_dataset/loocv_2d_hps.npy'), dtype=torch.float)
gazes =  torch.as_tensor(np.load('../mpii_dataset/loocv_2d_gazes.npy'), dtype=torch.float)

In [None]:
device = torch.device('cuda:0')
seed = 10
experiment_name = 'mpii_loocv'

pretrain_iter = 1
m_pretrain_epoch = 50
v_pretrain_epoch = 20

max_iter = 150
mean_epoch = 10
v_step_iter = 100
patience = 10

mean_lr = 5e-3
variance_lr = 1e-3
batch_size = 1000
weight_decay = 0

hidden_features = 500
test_unseen=True
large_test=True

In [None]:
res_list = []

In [None]:
for looid in range(15) :

    train_ids = np.concatenate([ids[:looid], ids[(looid+1):]]).reshape(-1)
    train_images = torch.cat([images[:looid], images[(looid+1):]]).reshape(-1,36,60)
    train_hps = torch.cat([hps[:looid], hps[(looid+1):]]).reshape(-1,2)
    train_gazes = torch.cat([gazes[:looid], gazes[(looid+1):]]).reshape(-1,2)

    test_ids = ids[looid]
    test_images = images[looid]
    test_hps = hps[looid]
    test_gazes = gazes[looid]

    res_list.append(correlated_precision_without_val(
        train_ids, train_images, train_hps, train_gazes,
        test_ids, test_images, test_hps, test_gazes,
        ResNet_batchnorm.ResNet_batchnorm, hidden_features=hidden_features, K=2, 
        mean_lr=mean_lr, variance_lr=variance_lr, weight_decay=weight_decay, batch_size=batch_size,
        pretrain_iter=pretrain_iter, m_pretrain_epoch=m_pretrain_epoch, v_pretrain_epoch=v_pretrain_epoch, max_iter=max_iter, mean_epoch=mean_epoch, v_step_iter=v_step_iter, patience=patience,
        device=device, experiment_name=f'{experiment_name}_{looid}', SEED=seed + looid,
        normalize=True, deg=True, test_unseen=test_unseen, weighted=True, verbose=False, large_test=large_test))


