# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os


if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

from torch.utils.data import Dataset
import numpy as np

## My Modules

In [2]:
from src.datasets import CocoFreeView
from src.simulation import gen_gaze, downsample
from src.noise import add_random_center_correlated_radial_noise

# Code
## Data

In [None]:
# TODO Compute image embeddings just once
import math

class PathCocoFreeViewDataset(Dataset):
    '''
    The noisy and downsampled simulated eye-tracking and the section of the scanpath that fits entirely in that part
    '''

    def __init__(self, sample_size=-1, sampling_rate=60, downsample=200):
        super().__init__()
        self.data = CocoFreeView()
        self.sampling_rate = sampling_rate
        # more than 90% generate at least 20 samples after downsampling to 200 samples per second
        self.sample_size = sample_size
        self.downsample = downsample

    def __len__(self):
        return len(self.data)
    
    def __extract_random_period(self, size, noisy_samples, fixations, fixation_mask):
        down_idx = np.random.randint(0, noisy_samples.shape[1] - size + 1, 1, dtype = int)[0]
        # get the values in the original sampling rate
        conversion_factor = self.downsample/(1000/self.sampling_rate)
        ori_idx = math.floor(down_idx*conversion_factor)
        ori_size = math.ceil(size*conversion_factor)
        last_idx = ori_idx + ori_size - 1
        # TEST
        test_final_ori = ori_idx
        test_final_last_idx = last_idx
        # get the fisrt fixation and if it is not completely included get the next one
        if fixation_mask[ori_idx] > 0:
            if (ori_idx - 1) >= 0 and fixation_mask[ori_idx - 1] == fixation_mask[ori_idx]:
                start_fixation = fixation_mask[ori_idx] + 1
            else:
                start_fixation = fixation_mask[ori_idx]
        else:
            # if the first value is a saccade look for the first fixation
            current_idx = ori_idx
            while current_idx < (ori_idx + ori_size) and fixation_mask[current_idx] == 0:
                current_idx += 1
            if current_idx == (ori_idx + ori_size):
                # if there is not a fixation return an empty array
                return noisy_samples[:, down_idx:down_idx + size],np.array([])
            else:
                # TEST
                test_final_ori = current_idx
                start_fixation = fixation_mask[current_idx]
        # search the last fixation
        if fixation_mask[last_idx] > 0:
            if (last_idx + 1) < fixation_mask.shape[0] and fixation_mask[last_idx + 1] == fixation_mask[last_idx]:
                end_fixation = fixation_mask[last_idx] - 1
            else:
                end_fixation = fixation_mask[last_idx]
        else:
            current_idx = last_idx
            while current_idx > ori_idx and fixation_mask[current_idx] == 0:
                current_idx -= 1
            test_final_last_idx = current_idx
            end_fixation = fixation_mask[current_idx]
        # the mask are saved shifted in order to assign 0 to the saccade samples
        start_fixation -= 1
        end_fixation -= 1
        x = noisy_samples[:, down_idx:down_idx + size]
        y = fixations[:, start_fixation: end_fixation + 1]
        return x, y,test_final_ori ,test_final_last_idx, start_fixation

    def __getitem__(self, index):
        # TODO save the simulated scanpath without noise due that they can be reused
        gaze, fixations, fixation_mask = gen_gaze(self.data,
                                                        index, self.sampling_rate,
                                                        get_scanpath=True,
                                                        get_fixation_mask=True)
        down_gaze = downsample(gaze, down_time_step=self.downsample)
        noisy_gaze, _ = add_random_center_correlated_radial_noise(down_gaze, [320//2, 512//2], 1/16,
                                                                  radial_corr=.2,
                                                                  radial_avg_norm=4.13,
                                                                  radial_std=3.5,
                                                                  center_noise_std=100,
                                                                  center_corr=.3,
                                                                  center_delta_norm=300,
                                                                  center_delta_r=.3)
        x = down_gaze
        y = fixations
        if self.sample_size != -1:
            x, y,si, ei, start_fixation = self.__extract_random_period(self.sample_size,
                                                noisy_gaze,
                                                fixations,
                                                fixation_mask)
            test_y = y.copy()
            np.concat([np.full((3,1), fill_value=-1), test_y])
            print
            print(si,ei)
            print((gaze[:2,si:ei]))
            print(y[:2])
            
        return x, y

dataset = PathCocoFreeViewDataset(sample_size=4)
idx = np.random.randint(0,len(dataset),size = 1)[0]
dataset[idx]
pass

180 226
[[300.96045 300.99887 301.36194 301.14352 301.74377 319.07584 319.79782
  319.9582  319.87894 320.08444 319.7314  319.3268  319.751   320.031
  319.59302 319.95105 319.76184 319.6237  319.76886 319.64023 319.62302
  319.59106 320.0627  320.03958 319.9849  319.85526 319.9777  319.47397
  320.14484 319.45355 319.63318 319.99423 319.77042 321.05862 321.482
  321.52664 321.57693 321.65085 321.47855 321.31668 321.29333 321.53574
  321.4423  321.42206 321.28174 321.6886 ]
 [207.21892 207.04546 207.6615  206.54132 206.43504 186.51741 185.41223
  185.39552 185.55959 185.7533  185.56427 185.79056 185.50398 185.48561
  185.41002 185.70647 185.59193 185.90562 185.45753 185.29407 185.88875
  186.01297 185.26521 185.36308 185.4343  185.63461 185.85147 185.46892
  185.7002  185.64351 185.6277  185.91058 185.38362 179.83167 177.83615
  178.0057  177.98123 177.93282 178.61916 177.66818 177.76761 177.6017
  177.99258 178.34644 177.89427 178.0784 ]]
[[319.84761905 321.43238095]
 [185.63047619 17