# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os


if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

from torch.utils.data import Dataset
import numpy as np
import math
import h5py
from tqdm import tqdm

## My Modules

In [2]:
from src.datasets import CocoFreeView
from src.simulation import gen_gaze, downsample
from src.noise import add_random_center_correlated_radial_noise

# Code
## Utils

In [3]:
def search(mask, fixation, side = 'right'):
    start = 0
    stop = mask.shape[0]
    step = 1
    if side == 'left':
        start = mask.shape[0] - 1 
        stop = -1
        step = -1

    for i in range(start,stop, step):
        if mask[i] == fixation:
            return i
    return -1

def test_segment_is_inside(x, si,ei,gaze, fixation_mask):
    sidx = search(fixation_mask, si + 1, side = 'right')
    eidx = search(fixation_mask, ei + 1, side = 'left')
    if sidx == -1:
        print(f'❌ Start Fixation not found: si:{si + 1} \n {fixation_mask}')
    if eidx == -1:
        print(f'❌ End Fixation not found: si:{si + 1} \n {fixation_mask}')
    # print(fixation_mask.shape)
    # print(si)
    # print(ei)
    # print(sidx)
    # print(eidx)
    if x[2,0] <= gaze[2,sidx] and (x[2,-1] + 200) >= gaze[2,eidx]:
        print(f'✅Pass: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')
    else:
        print(f'❌Outside: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')

## Dataset

In [4]:
# TODO Compute image embeddings just once
# TODO Refactor the dataset class (TOO LARGE)
# TODO Review the outputs in the validation and with some tests
class PathCocoFreeViewDatasetBatch(Dataset):
    '''
    The noisy and downsampled simulated eye-tracking and the section of the scanpath that fits entirely in that part
    '''

    def __init__(self,
                 data_path = 'data\\Coco FreeView',
                 sample_size=-1,
                 sampling_rate=60,
                 downsample=200,
                 batch_size=128,
                 min_scanpath_duration = 3000,
                 max_fixation_duration = 1200,
                 log = False,
                 debug = False):
        super().__init__()
        self.sampling_rate = sampling_rate
        self.sample_size = sample_size # 90% larger than 20 at downsample 200
        self.downsample = downsample
        self.min_scanpath_duration = min_scanpath_duration
        self.max_fixation_duration = max_fixation_duration
        self.log = log
        self.batch_size = batch_size
        self.debug = debug
        self.ori_path = os.path.join(data_path, 'dataset.hdf5')
        self.ori_data = None        
        if not os.path.exists(self.ori_path):
            if self.log:
                print('gen data file not found')
            self.__preprocess()
        self.shuffled_path = self.ori_path.replace('.hdf5', '_shuffled.hdf5')
        self.shuffled_data = None
        self.json_dataset = None
        


    def __len__(self):
        with h5py.File(self.ori_path,'r') as ori_data:
            return ori_data['down_gaze'].shape[0]//self.batch_size
    
    def sample_count(self):
        with h5py.File(self.ori_path,'r') as ori_data:
            return ori_data['down_gaze'].shape[0]


    def __preprocess(self):
        # TODO gen all the clean samples
        data = self.json_dataset
        gen_data = []
        original_data_count = len(data)
        for index in range(original_data_count):
            gaze, fixations, fixation_mask = gen_gaze(data,
                                                        index, self.sampling_rate,
                                                        get_scanpath=True,
                                                        get_fixation_mask=True)
            down_gaze = downsample(gaze, down_time_step=self.downsample)
            if (gaze[2,-1] < max(self.min_scanpath_duration, (self.sample_size - 1)*self.downsample) or
                fixations[2].max() > self.max_fixation_duration) :
                continue
            gen_data.append({'down_gaze': down_gaze,
                             'fixations': fixations,
                             'fixation_mask': fixation_mask,
                             'gaze': gaze})
        if self.log:
            removed = original_data_count - len(gen_data)
            print(f'Removed: {removed} - {(removed/original_data_count)*100}% ')
        # save
        with h5py.File(self.ori_path, 'w') as f:
            # Create datasets with shape (43000,) and the vlen dtype
            item = gen_data[0]
            k_dset = dict()
            for k in item.keys():
                k_dset[k] = f.create_dataset(k, len(gen_data), dtype= h5py.special_dtype(vlen= item[k].dtype))


            # Loop and store each item
            for i, item in enumerate(gen_data):
                for k in item.keys():
                    if item[k].ndim > 1:
                        k_dset[k][i] = item[k].flatten()
                    else:
                        k_dset[k][i] = item[k]
        
        if self.log:
            print('generated items saved')


    def __extract_random_period(self, size, noisy_samples, fixations, fixation_mask):
        down_idx = np.random.randint(0, noisy_samples.shape[1] - size + 1, 1, dtype = int)[0]
        # get the values in the original sampling rate
        conversion_factor = self.downsample/(1000/self.sampling_rate)
        ori_idx = math.floor(down_idx*conversion_factor)
        ori_size = math.ceil((size - 1)*conversion_factor)
        last_idx = ori_idx + ori_size
        # TEST
        # get the fisrt fixation and if it is not completely included get the next one
        if fixation_mask[ori_idx] > 0:
            if (ori_idx - 1) >= 0 and fixation_mask[ori_idx - 1] == fixation_mask[ori_idx]:
                start_fixation = fixation_mask[ori_idx] + 1
            else:
                start_fixation = fixation_mask[ori_idx]
        else:
            # if the first value is a saccade look for the first fixation
            current_idx = ori_idx + 1
            while current_idx < (ori_idx + ori_size) and fixation_mask[current_idx] == 0:
                current_idx += 1
            if current_idx == (ori_idx + ori_size):
                # if there is not a fixation return an empty array
                return noisy_samples[:, down_idx:down_idx + size],np.array([]), -1,-1
            else:
                # TEST
                start_fixation = fixation_mask[current_idx]
        # search the last fixation
        if fixation_mask[last_idx] > 0:
            if (last_idx + 1) < fixation_mask.shape[0] and fixation_mask[last_idx + 1] == fixation_mask[last_idx]:
                end_fixation = fixation_mask[last_idx] - 1
            else:
                end_fixation = fixation_mask[last_idx]
        else:
            current_idx = last_idx - 1
            while current_idx > ori_idx and fixation_mask[current_idx] == 0:
                current_idx -= 1
            end_fixation = fixation_mask[current_idx]
        # the mask are saved shifted in order to assign 0 to the saccade samples
        start_fixation -= 1
        end_fixation -= 1
        x = noisy_samples[:, down_idx:down_idx + size]
        y = fixations[:, start_fixation: end_fixation + 1]
        return x, y, start_fixation, end_fixation

    def gen_single_item(self, index):
        if self.json_dataset is None:
            self.json_dataset = CocoFreeView()
        data = self.json_dataset
        gaze, fixations, fixation_mask = gen_gaze(data,
                                                        index, self.sampling_rate,
                                                        get_scanpath=True,
                                                        get_fixation_mask=True)
        x = downsample(gaze, down_time_step=self.downsample)
        y = fixations
        if self.sample_size != -1:
            x, y, start_fixation, end_fixation = self.__extract_random_period(self.sample_size,
                                                x,
                                                y,
                                                fixation_mask)
            
        x, _ = add_random_center_correlated_radial_noise(x, [320//2, 512//2], 1/16,
                                                                  radial_corr=.2,
                                                                  radial_avg_norm=4.13,
                                                                  radial_std=3.5,
                                                                  center_noise_std=100,
                                                                  center_corr=.3,
                                                                  center_delta_norm=300,
                                                                  center_delta_r=.3)
        return x, y


    def get_single_item(self, index):
        # reading is inefficient because is reading from memory one by one
        if self.ori_data is None:
            self.ori_data = h5py.File(self.ori_path,'r')
        down_gaze = self.ori_data['down_gaze'][index].reshape((3,-1))
        fixations = self.ori_data['fixations'][index].reshape((3,-1))
        x = down_gaze
        y = fixations
        if self.sample_size != -1:
            fixation_mask = self.ori_data['fixation_mask'][index]
            x, y, start_fixation, end_fixation = self.__extract_random_period(self.sample_size,
                                                x,
                                                fixations,
                                                fixation_mask)
            # if start_fixation != -1:
            #     gaze = self.ori_data['gaze'][index].reshape((3,-1))
            #     test_segment_is_inside(x,start_fixation, end_fixation,gaze, fixation_mask)

        x, _ = add_random_center_correlated_radial_noise(x, [320//2, 512//2], 1/16,
                                                                  radial_corr=.2,
                                                                  radial_avg_norm=4.13,
                                                                  radial_std=3.5,
                                                                  center_noise_std=100,
                                                                  center_corr=.3,
                                                                  center_delta_norm=300,
                                                                  center_delta_r=.3)
        return x, y
    
    def __getitem__(self, index):
        # reading is 3x faster with batch size 128 
        # but can´t use the workers of the torch.dataloader (epoch in 4.9)
        if self.shuffled_data is None:
            self.shuffled_data = h5py.File(self.shuffled_path,'r')
        batch_size = self.batch_size
        down_gaze = self.shuffled_data['down_gaze'][index*batch_size:(index + 1)*batch_size]
        fixations = self.shuffled_data['fixations'][index*batch_size:(index + 1)*batch_size]
        vals = None
        if self.sample_size != -1:
            fixation_mask = self.shuffled_data['fixation_mask'][index*batch_size:(index + 1)*batch_size]
            # gaze = self.data['gaze'][index]
            # vals = (down_gaze,fixations,fixation_mask, gaze)
            vals = (down_gaze,fixations,fixation_mask)
        else:
            vals = (down_gaze,fixations)
        x_batch = []
        y_batch = []
        for value in zip(*vals):
            x = value[0].reshape((3,-1))        
            y = value[1].reshape((3,-1))
            if self.sample_size != -1:
                fixation_mask = value[2]
                x, y, start_fixation, end_fixation = self.__extract_random_period(self.sample_size,
                                                    x,
                                                    y,
                                                    fixation_mask)
                # if start_fixation != -1:
                #     gaze = value[3].reshape((3,-1))
                #     test_segment_is_inside(x,start_fixation, end_fixation,gaze, fixation_mask)

            x_batch.append(x)
            y_batch.append(y)
        x_batch, _ = add_random_center_correlated_radial_noise(x_batch, [320//2, 512//2], 1/16,
                                                                radial_corr=.2,
                                                                radial_avg_norm=4.13,
                                                                radial_std=3.5,
                                                                center_noise_std=100,
                                                                center_corr=.3,
                                                                center_delta_norm=300,
                                                                center_delta_r=.3)
        # self.close_and_remove_data()
        self.shuffled_data.close()
        self.shuffled_data = None
        
        return x_batch, y_batch
    
    def shuffle_dataset(self):
        
        with h5py.File(self.ori_path,'r') as ori_data:
            dataset_names = ['down_gaze', 'fixations', 'fixation_mask', 'gaze']
            if self.log:
                print('reading original data')
            original_data = {name: ori_data[name][:] for name in dataset_names}
        idx = np.arange(original_data['down_gaze'].shape[0])
        np.random.shuffle(idx)
        for name in dataset_names:
            original_data[name] = original_data[name][idx]

        with h5py.File(self.shuffled_path, 'w') as f_out:
            for name, data in original_data.items():
                f_out.create_dataset(
                    name,
                    data=data,
                )
        if self.log:
            print('shuffled data saved')

    def close_and_remove_data(self):
        if self.ori_data is not None:
            # if self.log:
            #     print('closing original data file')
            self.ori_data.close()
            self.ori_data = None
        if self.shuffled_data is not None:
            # if self.log:
            #     print('closing shuffled data file')
            self.shuffled_data.close()
            self.shuffled_data = None
    


# Test


## Load Dataset

In [5]:
dataset = PathCocoFreeViewDatasetBatch(sample_size= 8,log = True)


## Speed Test

In [9]:
for i in tqdm(range(dataset.sample_count())):
    dataset.get_single_item(i)

100%|██████████| 39869/39869 [00:19<00:00, 2000.98it/s]


In [None]:
for i in tqdm(range(dataset.sample_count())):
    try:
        dataset.gen_single_item(i)
    except Exception as e:
        print(f"Error processing item {i}: {e}")

  end_fixation -= 1
  y = fixations[:, start_fixation: end_fixation + 1]
  3%|▎         | 1344/39869 [00:00<00:13, 2758.24it/s]

Error processing item 893: high <= 0


  5%|▌         | 2183/39869 [00:00<00:13, 2772.92it/s]

Error processing item 1682: high <= 0


  7%|▋         | 2736/39869 [00:01<00:13, 2738.32it/s]

Error processing item 2493: high <= 0
Error processing item 2877: high <= 0


  9%|▉         | 3550/39869 [00:01<00:13, 2661.92it/s]

Error processing item 3219: high <= 0
Error processing item 3516: high <= 0


 12%|█▏        | 4644/39869 [00:01<00:13, 2636.43it/s]

Error processing item 4288: high <= 0


 14%|█▎        | 5454/39869 [00:02<00:12, 2655.22it/s]

Error processing item 4966: high <= 0


 17%|█▋        | 6824/39869 [00:02<00:12, 2722.65it/s]

Error processing item 6425: high <= 0
Error processing item 6488: high <= 0
Error processing item 6975: high <= 0


 20%|█▉        | 7931/39869 [00:02<00:11, 2741.62it/s]

Error processing item 7404: high <= 0
Error processing item 7511: high <= 0


 21%|██▏       | 8478/39869 [00:03<00:11, 2702.62it/s]

Error processing item 8009: high <= 0
Error processing item 8283: high <= 0


 23%|██▎       | 9298/39869 [00:03<00:11, 2706.75it/s]

Error processing item 8995: high <= 0
Error processing item 9118: high <= 0


 26%|██▌       | 10387/39869 [00:03<00:11, 2617.98it/s]

Error processing item 10065: high <= 0
Error processing item 10206: high <= 0
Error processing item 10373: high <= 0


 28%|██▊       | 11186/39869 [00:04<00:11, 2600.04it/s]

Error processing item 10762: high <= 0


 33%|███▎      | 13118/39869 [00:04<00:09, 2755.50it/s]

Error processing item 12769: high <= 0
Error processing item 13011: high <= 0
Error processing item 13111: high <= 0
Error processing item 13126: high <= 0


 36%|███▌      | 14251/39869 [00:05<00:09, 2792.31it/s]

Error processing item 13721: high <= 0


 37%|███▋      | 14806/39869 [00:05<00:09, 2707.18it/s]

Error processing item 14562: high <= 0
Error processing item 14819: high <= 0


 39%|███▊      | 15372/39869 [00:05<00:09, 2539.08it/s]

Error processing item 15138: high <= 0
Error processing item 15412: high <= 0


 41%|████▏     | 16463/39869 [00:06<00:08, 2671.99it/s]

Error processing item 15979: high <= 0
Error processing item 16031: high <= 0
Error processing item 16405: high <= 0


 43%|████▎     | 17000/39869 [00:06<00:08, 2663.91it/s]

Error processing item 16587: high <= 0


 45%|████▍     | 17803/39869 [00:06<00:08, 2642.08it/s]

Error processing item 17393: high <= 0


 47%|████▋     | 18622/39869 [00:07<00:07, 2699.55it/s]

Error processing item 18219: high <= 0
Error processing item 18457: high <= 0


 49%|████▊     | 19423/39869 [00:07<00:07, 2566.49it/s]

Error processing item 18995: high <= 0


 51%|█████     | 20215/39869 [00:07<00:07, 2591.94it/s]

Error processing item 19784: high <= 0


 52%|█████▏    | 20740/39869 [00:07<00:07, 2416.80it/s]

Error processing item 20415: high <= 0
Error processing item 20448: high <= 0
Error processing item 20508: high <= 0


 55%|█████▍    | 21790/39869 [00:08<00:07, 2578.08it/s]

Error processing item 21324: high <= 0


 56%|█████▌    | 22309/39869 [00:08<00:06, 2552.05it/s]

Error processing item 21924: high <= 0
Error processing item 21997: high <= 0
Error processing item 22030: high <= 0


 59%|█████▊    | 23397/39869 [00:08<00:06, 2677.98it/s]

Error processing item 23068: high <= 0
Error processing item 23414: high <= 0


 61%|██████▏   | 24499/39869 [00:09<00:05, 2675.75it/s]

Error processing item 24149: high <= 0


 63%|██████▎   | 25033/39869 [00:09<00:05, 2608.86it/s]

Error processing item 24722: high <= 0
Error processing item 24944: high <= 0
Error processing item 25253: high <= 0


 65%|██████▍   | 25849/39869 [00:09<00:05, 2666.58it/s]

Error processing item 25289: high <= 0
Error processing item 25304: high <= 0
Error processing item 25389: high <= 0
Error processing item 25426: high <= 0
Error processing item 25592: high <= 0


 66%|██████▌   | 26116/39869 [00:09<00:05, 2643.45it/s]

Error processing item 25941: high <= 0


 73%|███████▎  | 29152/39869 [00:11<00:03, 2712.18it/s]

Error processing item 28740: high <= 0
Error processing item 28833: high <= 0


 78%|███████▊  | 31103/39869 [00:11<00:03, 2772.68it/s]

Error processing item 30526: high <= 0


 84%|████████▎ | 33354/39869 [00:12<00:02, 2802.07it/s]

Error processing item 32918: high <= 0


 87%|████████▋ | 34489/39869 [00:12<00:01, 2804.63it/s]

Error processing item 33965: high <= 0
Error processing item 34001: high <= 0


 88%|████████▊ | 35045/39869 [00:13<00:01, 2736.13it/s]

Error processing item 34543: high <= 0
Error processing item 35014: high <= 0


 89%|████████▉ | 35595/39869 [00:13<00:01, 2529.94it/s]

Error processing item 35319: high <= 0
Error processing item 35356: high <= 0


 92%|█████████▏| 36664/39869 [00:13<00:01, 2528.80it/s]

Error processing item 36227: high <= 0
Error processing item 36473: high <= 0


 93%|█████████▎| 37200/39869 [00:14<00:01, 2593.84it/s]

Error processing item 36891: high <= 0


 97%|█████████▋| 38581/39869 [00:14<00:00, 2733.68it/s]

Error processing item 38182: high <= 0


 99%|█████████▉| 39420/39869 [00:14<00:00, 2766.59it/s]

Error processing item 38934: high <= 0
Error processing item 39121: high <= 0


100%|██████████| 39869/39869 [00:15<00:00, 2652.27it/s]


## Batch Test

In [6]:
# TODO The multiprocessing dataloader (sharing the hdf5 files through multitable)

dataset.close_and_remove_data()
dataset.shuffle_dataset()
for i in tqdm(range(len(dataset))):
    dataset[i]

reading original data
shuffled data saved


100%|██████████| 311/311 [00:03<00:00, 103.32it/s]


In [7]:
from torch.utils.data import DataLoader

dataset.close_and_remove_data()
dataset.shuffled_data
dataset.close_and_remove_data()
dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=3)
for batch in tqdm(dataloader):
    pass


closing data files
closing data files


  0%|          | 0/311 [00:05<?, ?it/s]


RuntimeError: DataLoader worker (pid(s) 2144, 11824, 18660) exited unexpectedly