# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os


if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

from torch.utils.data import Dataset
import numpy as np
import math
import h5py


## My Modules

In [2]:
from src.datasets import CocoFreeView
from src.simulation import gen_gaze, downsample
from src.noise import add_random_center_correlated_radial_noise

# Code
## Data

In [3]:
def search(mask, fixation, side = 'right'):
    start = 0
    stop = mask.shape[0]
    step = 1
    if side == 'left':
        start = mask.shape[0] - 1 
        stop = -1
        step = -1

    for i in range(start,stop, step):
        if mask[i] == fixation:
            return i
    return -1

def test_segment_is_inside(x, si,ei,gaze, fixation_mask):
    sidx = search(fixation_mask, si + 1, side = 'right')
    eidx = search(fixation_mask, ei + 1, side = 'left')
    if sidx == -1:
        print(f'❌ Start Fixation not found: si:{si + 1} \n {fixation_mask}')
    if eidx == -1:
        print(f'❌ End Fixation not found: si:{si + 1} \n {fixation_mask}')
    # print(fixation_mask.shape)
    # print(si)
    # print(ei)
    # print(sidx)
    # print(eidx)
    if x[2,0] <= gaze[2,sidx] and (x[2,-1] + 200) >= gaze[2,eidx]:
        print(f'✅Pass: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')
    else:
        print(f'❌Outside: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')

In [35]:
# TODO Compute image embeddings just once

class PathCocoFreeViewDataset(Dataset):
    '''
    The noisy and downsampled simulated eye-tracking and the section of the scanpath that fits entirely in that part
    '''

    def __init__(self,data_path = 'data\\Coco FreeView', sample_size=-1, sampling_rate=60, downsample=200, min_scanpath_duration = 3000, max_fixation_duration = 1200, log = False, debug = False):
        super().__init__()
        self.sampling_rate = sampling_rate
        self.sample_size = sample_size # 90% larger than 20 at downsample 200
        self.downsample = downsample
        self.min_scanpath_duration = min_scanpath_duration
        self.max_fixation_duration = max_fixation_duration
        self.log = log
        self.debug = debug
        self.h5_path = os.path.join(data_path, 'dataset.hdf5')
        if not os.path.exists(self.h5_path):
            if self.log:
                print('gen data file not found')
                print('generating data')
            self.__preprocess()
        self.data = h5py.File(self.h5_path,'r')
        


    def __len__(self):
        # TODO change len for new data structure
        return self.data['down_gaze'].shape[0]
    

    def __preprocess(self):
        # TODO gen all the clean samples
        data = CocoFreeView()
        gen_data = []
        original_data_count = len(data)
        for index in range(original_data_count):
            gaze, fixations, fixation_mask = gen_gaze(data,
                                                        index, self.sampling_rate,
                                                        get_scanpath=True,
                                                        get_fixation_mask=True)
            down_gaze = downsample(gaze, down_time_step=self.downsample)
            if (gaze[2,-1] < max(self.min_scanpath_duration, (self.sample_size - 1)*self.downsample) or
                fixations[2].max() > self.max_fixation_duration) :
                continue
            gen_data.append({'down_gaze': down_gaze,
                             'fixations': fixations,
                             'fixation_mask': fixation_mask,
                             'gaze': gaze})
        if self.log:
            removed = original_data_count - len(gen_data)
            print(f'Removed: {removed} - {(removed/original_data_count)*100}% ')
        # save
        with h5py.File(self.h5_path, 'w') as f:
            # Create datasets with shape (43000,) and the vlen dtype
            item = gen_data[0]
            k_dset = dict()
            for k in item.keys():
                k_dset[k] = f.create_dataset(k, len(gen_data), dtype= h5py.special_dtype(vlen= item[k].dtype))


            # Loop and store each item
            for i, item in enumerate(gen_data):
                for k in item.keys():
                    if item[k].ndim > 1:
                        k_dset[k][i] = item[k].flatten()
                    else:
                        k_dset[k][i] = item[k]
        
        if self.log:
            print('generated items saved')


    def __extract_random_period(self, size, noisy_samples, fixations, fixation_mask):
        down_idx = np.random.randint(0, noisy_samples.shape[1] - size + 1, 1, dtype = int)[0]
        # get the values in the original sampling rate
        conversion_factor = self.downsample/(1000/self.sampling_rate)
        ori_idx = math.floor(down_idx*conversion_factor)
        ori_size = math.ceil((size - 1)*conversion_factor)
        last_idx = ori_idx + ori_size
        # TEST
        # get the fisrt fixation and if it is not completely included get the next one
        if fixation_mask[ori_idx] > 0:
            if (ori_idx - 1) >= 0 and fixation_mask[ori_idx - 1] == fixation_mask[ori_idx]:
                start_fixation = fixation_mask[ori_idx] + 1
            else:
                start_fixation = fixation_mask[ori_idx]
        else:
            # if the first value is a saccade look for the first fixation
            current_idx = ori_idx + 1
            while current_idx < (ori_idx + ori_size) and fixation_mask[current_idx] == 0:
                current_idx += 1
            if current_idx == (ori_idx + ori_size):
                # if there is not a fixation return an empty array
                return noisy_samples[:, down_idx:down_idx + size],np.array([]), -1,-1
            else:
                # TEST
                start_fixation = fixation_mask[current_idx]
        # search the last fixation
        if fixation_mask[last_idx] > 0:
            if (last_idx + 1) < fixation_mask.shape[0] and fixation_mask[last_idx + 1] == fixation_mask[last_idx]:
                end_fixation = fixation_mask[last_idx] - 1
            else:
                end_fixation = fixation_mask[last_idx]
        else:
            current_idx = last_idx - 1
            while current_idx > ori_idx and fixation_mask[current_idx] == 0:
                current_idx -= 1
            end_fixation = fixation_mask[current_idx]
        # the mask are saved shifted in order to assign 0 to the saccade samples
        start_fixation -= 1
        end_fixation -= 1
        x = noisy_samples[:, down_idx:down_idx + size]
        y = fixations[:, start_fixation: end_fixation + 1]
        return x, y, start_fixation, end_fixation

    def __getitem__(self, index):
        down_gaze = self.data['down_gaze'][index].reshape((3,-1))
        fixations = self.data['fixations'][index].reshape((3,-1))
        x = down_gaze        
        y = fixations
        if self.sample_size != -1:
            fixation_mask = self.data['fixation_mask'][index]
            x, y, start_fixation, end_fxation = self.__extract_random_period(self.sample_size,
                                                x,
                                                fixations,
                                                fixation_mask)
            if start_fixation != -1:
                gaze = self.data['gaze'][index].reshape((3,-1))
                test_segment_is_inside(x,start_fixation, end_fxation,gaze, fixation_mask)

        x, _ = add_random_center_correlated_radial_noise(x, [320//2, 512//2], 1/16,
                                                                  radial_corr=.2,
                                                                  radial_avg_norm=4.13,
                                                                  radial_std=3.5,
                                                                  center_noise_std=100,
                                                                  center_corr=.3,
                                                                  center_delta_norm=300,
                                                                  center_delta_r=.3)
        return x, y
    


In [36]:
dataset = PathCocoFreeViewDataset(sample_size=-1,log = True)


In [37]:
from tqdm import tqdm

for i in tqdm(range(len(dataset))):
    dataset[i]

100%|██████████| 39869/39869 [00:15<00:00, 2630.05it/s]
