# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os


if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

from torch.utils.data import Dataset
import numpy as np

## My Modules

In [2]:
from src.datasets import CocoFreeView
from src.simulation import gen_gaze, downsample
from src.noise import add_random_center_correlated_radial_noise

# Code
## Data

In [None]:
# TODO Compute image embeddings just once
import math

class PathCocoFreeViewDataset(Dataset):
    '''
    The noisy and downsampled simulated eye-tracking and the section of the scanpath that fits entirely in that part
    '''

    def __init__(self, sample_size=-1, sampling_rate=60, downsample=200):
        super().__init__()
        self.data = CocoFreeView()
        self.sampling_rate = sampling_rate
        # more than 90% generate at least 20 samples after downsampling to 200 samples per second
        self.sample_size = sample_size
        self.downsample = downsample

    def __len__(self):
        return len(self.data)
    
    def __extract_random_period(self, size, noisy_samples, fixations, fixation_mask):
        down_idx = np.random.randint(0, noisy_samples.shape[1] - size + 1, 1, dtype = int)[0]
        # get the values in the original sampling rate
        conversion_factor = self.downsample/(1000/self.sampling_rate)
        ori_idx = math.floor(down_idx*conversion_factor)
        ori_size = math.ceil(size*conversion_factor)
        last_idx = ori_idx + ori_size - 1
        # get the fisrt fixation and if it is not completely included get the next one
        if fixation_mask[ori_idx] > 0:
            if (ori_idx - 1) >= 0 and fixation_mask[ori_idx - 1] == fixation_mask[ori_idx]:
                start_fixation = fixation_mask[ori_idx] + 1
            else:
                start_fixation = fixation_mask[ori_idx]
        else:
            # if the first value is a saccade look for the first fixation
            current_idx = ori_idx
            while current_idx < (ori_idx + ori_size) and fixation_mask[current_idx] == 0:
                current_idx += 1
            if current_idx == (ori_idx + ori_size):
                # if there is not a fixation return an empty array
                return noisy_samples[:, down_idx:down_idx + size],np.array([])
            else:
                start_fixation = fixation_mask[current_idx]
        # search the last fixation
        if fixation_mask[last_idx] > 0:
            if (last_idx + 1) < fixation_mask.shape[0] and fixation_mask[last_idx + 1] == fixation_mask[last_idx]:
                end_fixation = fixation_mask[last_idx] - 1
            else:
                end_fixation = fixation_mask[last_idx]
        else:
            current_idx = last_idx
            while current_idx > ori_idx and fixation_mask[current_idx] == 0:
                current_idx -= 1
            else: 
                end_fixation = fixation_mask[current_idx]
        # the mask are saved shifted in order to assign 0 to the saccade samples
        start_fixation -= 1
        end_fixation -= 1
        x = noisy_samples[:, down_idx:down_idx + size]
        y = fixations[:, start_fixation: end_fixation + 1]
        return x, y, ori

    def __getitem__(self, index):
        # TODO save the simulated scanpath without noise due that they can be reused
        gaze, fixations, fixation_mask = gen_gaze(self.data,
                                                        index, self.sampling_rate,
                                                        get_scanpath=True,
                                                        get_fixation_mask=True)
        down_gaze = downsample(gaze, down_time_step=self.downsample)
        noisy_gaze, _ = add_random_center_correlated_radial_noise(down_gaze, [320//2, 512//2], 1/16,
                                                                  radial_corr=.2,
                                                                  radial_avg_norm=4.13,
                                                                  radial_std=3.5,
                                                                  center_noise_std=100,
                                                                  center_corr=.3,
                                                                  center_delta_norm=300,
                                                                  center_delta_r=.3)
        x = down_gaze
        y = fixations
        if self.sample_size != -1:
            x, y, idx = self.__extract_random_period(self.sample_size,
                                                noisy_gaze,
                                                fixations,
                                                fixation_mask)
        return x, y

dataset = PathCocoFreeViewDataset(sample_size=4)
idx = np.random.randint(0,len(dataset),size = 1)[0]
dataset[idx]

number


(array([[ 250.33891,  247.37654,  252.8097 ,  461.75394],
        [ 188.30936,  184.91881,  176.77838,   59.3595 ],
        [1816.6666 , 2016.6666 , 2216.6667 , 2416.6667 ]], dtype=float32),
 array([[219.33714286, 230.24761905, 359.58857143],
        [209.00571429, 201.2647619 , 135.10095238],
        [ 80.        , 192.        , 208.        ]]))