# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os


if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import math
import h5py
from tqdm import tqdm

## My Modules

In [2]:
from src.datasets import CocoFreeView
from src.preprocess.simulation import gen_gaze, downsample
from src.preprocess.noise import add_random_center_correlated_radial_noise

# Code
## Utils

In [3]:
def search(mask, fixation, side = 'right'):
    start = 0
    stop = mask.shape[0]
    step = 1
    if side == 'left':
        start = mask.shape[0] - 1 
        stop = -1
        step = -1

    for i in range(start,stop, step):
        if mask[i] == fixation:
            return i
    return -1

def test_segment_is_inside(x, si,ei,gaze, fixation_mask):
    sidx = search(fixation_mask, si + 1, side = 'right')
    eidx = search(fixation_mask, ei + 1, side = 'left')
    if sidx == -1:
        print(f'❌ Start Fixation not found: si:{si + 1} \n {fixation_mask}')
    if eidx == -1:
        print(f'❌ End Fixation not found: si:{si + 1} \n {fixation_mask}')
    # print(fixation_mask.shape)
    # print(si)
    # print(ei)
    # print(sidx)
    # print(eidx)
    if x[2,0] <= gaze[2,sidx] and (x[2,-1] + 200) >= gaze[2,eidx]:
        print(f'✅Pass: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')
    else:
        print(f'❌Outside: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')

## Dataset

In [4]:
# TODO Compute image embeddings just once
# TODO Refactor the dataset class (TOO LARGE)
# TODO Review the outputs in the validation and with some tests
class PathCocoFreeViewDatasetBatch(Dataset):
    '''
    The noisy and downsampled simulated eye-tracking and the section of the scanpath that fits entirely in that part
    '''

    def __init__(self,
                 data_path = 'data\\Coco FreeView',
                 sample_size=-1,
                 sampling_rate=60,
                 downsample_int=200,
                 batch_size=128,
                 min_scanpath_duration = 3000,
                 max_fixation_duration = 1200,
                 log = False,
                 debug = False):
        super().__init__()
        self.sampling_rate = sampling_rate
        self.sample_size = sample_size # 90% larger than 20 at downsample 200
        self.downsample = downsample_int
        self.min_scanpath_duration = min_scanpath_duration
        self.max_fixation_duration = max_fixation_duration
        self.log = log
        self.batch_size = batch_size
        self.debug = debug
        self.ori_path = os.path.join(data_path, 'dataset.hdf5')
        self.ori_data = None        
        if not os.path.exists(self.ori_path):
            print('Execute preprocess')
        self.shuffled_path = self.ori_path.replace('.hdf5', '_shuffled.hdf5')
        self.shuffled_data = None
        self.json_dataset = None
        


    def __len__(self):
        with h5py.File(self.ori_path,'r') as ori_data:
            return math.ceil(ori_data['down_gaze'].shape[0]/self.batch_size)
    
    def sample_count(self):
        with h5py.File(self.ori_path,'r') as ori_data:
            return ori_data['down_gaze'].shape[0]


    def __extract_random_period(self, size, noisy_samples, fixations, fixation_mask):
        down_idx = np.random.randint(0, noisy_samples.shape[1] - size + 1, 1, dtype = int)[0]
        # get the values in the original sampling rate
        conversion_factor = self.downsample/(1000/self.sampling_rate)
        ori_idx = math.floor(down_idx*conversion_factor)
        ori_size = math.ceil((size - 1)*conversion_factor)
        last_idx = ori_idx + ori_size
        # TEST
        # get the fisrt fixation and if it is not completely included get the next one
        if fixation_mask[ori_idx] > 0:
            if (ori_idx - 1) >= 0 and fixation_mask[ori_idx - 1] == fixation_mask[ori_idx]:
                start_fixation = fixation_mask[ori_idx] + 1
            else:
                start_fixation = fixation_mask[ori_idx]
        else:
            # if the first value is a saccade look for the first fixation
            current_idx = ori_idx + 1
            while current_idx < (ori_idx + ori_size) and fixation_mask[current_idx] == 0:
                current_idx += 1
            if current_idx == (ori_idx + ori_size):
                # if there is not a fixation return an empty array
                return noisy_samples[:, down_idx:down_idx + size],np.array([]), -1,-1
            else:
                # TEST
                start_fixation = fixation_mask[current_idx]
        # search the last fixation
        if fixation_mask[last_idx] > 0:
            if (last_idx + 1) < fixation_mask.shape[0] and fixation_mask[last_idx + 1] == fixation_mask[last_idx]:
                end_fixation = fixation_mask[last_idx] - 1
            else:
                end_fixation = fixation_mask[last_idx]
        else:
            current_idx = last_idx - 1
            while current_idx > ori_idx and fixation_mask[current_idx] == 0:
                current_idx -= 1
            end_fixation = fixation_mask[current_idx]
        # the mask are saved shifted in order to assign 0 to the saccade samples
        start_fixation -= 1
        end_fixation -= 1
        x = noisy_samples[:, down_idx:down_idx + size]
        y = fixations[:, start_fixation: end_fixation + 1]
        return x, y, start_fixation, end_fixation


    def get_single_item(self, index):
        # reading is inefficient because is reading from memory one by one
        if self.ori_data is None:
            self.ori_data = h5py.File(self.ori_path,'r')
        down_gaze = self.ori_data['down_gaze'][index].reshape((3,-1))
        fixations = self.ori_data['fixations'][index].reshape((3,-1))
        x = down_gaze
        y = fixations
        if self.sample_size != -1:
            fixation_mask = self.ori_data['fixation_mask'][index]
            x, y, start_fixation, end_fixation = self.__extract_random_period(self.sample_size,
                                                x,
                                                fixations,
                                                fixation_mask)
            # if start_fixation != -1:
            #     gaze = self.ori_data['gaze'][index].reshape((3,-1))
            #     test_segment_is_inside(x,start_fixation, end_fixation,gaze, fixation_mask)

        x, _ = add_random_center_correlated_radial_noise(x, [320//2, 512//2], 1/16,
                                                                  radial_corr=.2,
                                                                  radial_avg_norm=4.13,
                                                                  radial_std=3.5,
                                                                  center_noise_std=100,
                                                                  center_corr=.3,
                                                                  center_delta_norm=300,
                                                                  center_delta_r=.3)
        return x, y
    
    def __getitem__(self, index):
        # reading is 3x faster with batch size 128 
        # but can´t use the workers of the torch.dataloader (epoch in 4.9)
        if self.shuffled_data is None:
            self.shuffled_data = h5py.File(self.shuffled_path,'r')
        batch_size = self.batch_size
        down_gaze = self.shuffled_data['down_gaze'][index*batch_size:(index + 1)*batch_size]
        fixations = self.shuffled_data['fixations'][index*batch_size:(index + 1)*batch_size]
        vals = None
        if self.sample_size != -1:
            fixation_mask = self.shuffled_data['fixation_mask'][index*batch_size:(index + 1)*batch_size]
            # gaze = self.data['gaze'][index]
            # vals = (down_gaze,fixations,fixation_mask, gaze)
            vals = (down_gaze,fixations,fixation_mask)
        else:
            vals = (down_gaze,fixations)
        x_batch = []
        y_batch = []
        for value in zip(*vals):
            x = value[0].reshape((3,-1))        
            y = value[1].reshape((3,-1))
            if self.sample_size != -1:
                fixation_mask = value[2]
                x, y, start_fixation, end_fixation = self.__extract_random_period(self.sample_size,
                                                    x,
                                                    y,
                                                    fixation_mask)
                # if start_fixation != -1:
                #     gaze = value[3].reshape((3,-1))
                #     test_segment_is_inside(x,start_fixation, end_fixation,gaze, fixation_mask)

            x_batch.append(x)
            y_batch.append(y)
        x_batch, _ = add_random_center_correlated_radial_noise(x_batch, [320//2, 512//2], 1/16,
                                                                radial_corr=.2,
                                                                radial_avg_norm=4.13,
                                                                radial_std=3.5,
                                                                center_noise_std=100,
                                                                center_corr=.3,
                                                                center_delta_norm=300,
                                                                center_delta_r=.3)
        # self.close_and_remove_data()
        self.shuffled_data.close()
        self.shuffled_data = None
        
        return x_batch, y_batch
    
    def shuffle_dataset(self):
        
        with h5py.File(self.ori_path,'r') as ori_data:
            dataset_names = ['down_gaze', 'fixations', 'fixation_mask', 'gaze']
            if self.log:
                print('reading original data')
            original_data = {name: ori_data[name][:] for name in dataset_names}
        idx = np.arange(original_data['down_gaze'].shape[0])
        np.random.shuffle(idx)
        for name in dataset_names:
            original_data[name] = original_data[name][idx]

        with h5py.File(self.shuffled_path, 'w') as f_out:
            for name, data in original_data.items():
                f_out.create_dataset(
                    name,
                    data=data,
                )
        if self.log:
            print('shuffled data saved')

    def close_and_remove_data(self):
        if self.ori_data is not None:
            # if self.log:
            #     print('closing original data file')
            self.ori_data.close()
            self.ori_data = None
        if self.shuffled_data is not None:
            # if self.log:
            #     print('closing shuffled data file')
            self.shuffled_data.close()
            self.shuffled_data = None
    


In [5]:
class InMemoryDataset(Dataset):
    def __init__(self, data_path = 'data\\Coco FreeView', sample_size=-1, log = False, sampling_rate=60, downsample_int=200):
        self.sampling_rate = sampling_rate
        self.downsample = downsample_int
        self.sample_size = sample_size
        file_path = os.path.join(data_path, 'dataset.hdf5')
        self.data_path = data_path
        self.log = log
        self.data_store = {}
        with h5py.File(file_path, 'r') as f:
            for key in f.keys():
                self.data_store[key] = f[key][:] # [:] reads all data
        if self.log:
            print('Data loaded in memory')

    def __len__(self):
        return self.data_store['down_gaze'].shape[0]
        

    def __extract_random_period(self, size, noisy_samples, fixations, fixation_mask):
        down_idx = np.random.randint(0, noisy_samples.shape[1] - size + 1, 1, dtype = int)[0]
        # get the values in the original sampling rate
        conversion_factor = self.downsample/(1000/self.sampling_rate)
        ori_idx = math.floor(down_idx*conversion_factor)
        ori_size = math.ceil((size - 1)*conversion_factor)
        last_idx = ori_idx + ori_size
        # TEST
        # get the fisrt fixation and if it is not completely included get the next one
        if fixation_mask[ori_idx] > 0:
            if (ori_idx - 1) >= 0 and fixation_mask[ori_idx - 1] == fixation_mask[ori_idx]:
                start_fixation = fixation_mask[ori_idx] + 1
            else:
                start_fixation = fixation_mask[ori_idx]
        else:
            # if the first value is a saccade look for the first fixation
            current_idx = ori_idx + 1
            while current_idx < (ori_idx + ori_size) and fixation_mask[current_idx] == 0:
                current_idx += 1
            if current_idx == (ori_idx + ori_size):
                # if there is not a fixation return an empty array
                return noisy_samples[:, down_idx:down_idx + size],np.array([]), -1,-1
            else:
                # TEST
                start_fixation = fixation_mask[current_idx]
        # search the last fixation
        if fixation_mask[last_idx] > 0:
            if (last_idx + 1) < fixation_mask.shape[0] and fixation_mask[last_idx + 1] == fixation_mask[last_idx]:
                end_fixation = fixation_mask[last_idx] - 1
            else:
                end_fixation = fixation_mask[last_idx]
        else:
            current_idx = last_idx - 1
            while current_idx > ori_idx and fixation_mask[current_idx] == 0:
                current_idx -= 1
            end_fixation = fixation_mask[current_idx]
        # the mask are saved shifted in order to assign 0 to the saccade samples
        start_fixation -= 1
        end_fixation -= 1
        x = noisy_samples[:, down_idx:down_idx + size]
        y = fixations[:, start_fixation: end_fixation + 1]
        return x, y, start_fixation, end_fixation

    def __getitem__(self, index):
        """
        Fetches a single sample from RAM, applies sampling and noise.
        """
        # Get the pre-loaded data for this index
        down_gaze = self.data_store['down_gaze'][index].reshape((3, -1))
        fixations = self.data_store['fixations'][index].reshape((3, -1))

        x = down_gaze
        y = fixations

        if self.sample_size != -1:
            fixation_mask = self.data_store['fixation_mask'][index]
            x, y, start_fixation, end_fixation = self.__extract_random_period(
                self.sample_size,
                x,
                fixations,
                fixation_mask
            )
            
        # Apply noise augmentation
        # (Assuming your noise function can process a single [3, N] array)
        x, _ = add_random_center_correlated_radial_noise(x, [320//2, 512//2], 1/16,
                                                                radial_corr=.2,
                                                                radial_avg_norm=4.13,
                                                                radial_std=3.5,
                                                                center_noise_std=100,
                                                                center_corr=.3,
                                                                center_delta_norm=300,
                                                                center_delta_r=.3)
        
        return x, y

# Test


## Load Dataset

In [17]:
dataset = PathCocoFreeViewDatasetBatch(sample_size= 8,log = True)


In [24]:
dataset.close_and_remove_data()

In [7]:
datasetv2 = InMemoryDataset(sample_size= 8,log = True)

Data loaded in memory


## Speed Test

In [7]:
for i in tqdm(range(dataset.sample_count())):
    dataset.get_single_item(i)

NameError: name 'dataset' is not defined

In [9]:
for i in tqdm(range(len(datasetv2))):
    datasetv2[i]

100%|██████████| 39869/39869 [00:00<00:00, 49042.97it/s]


In [11]:


def seq2seq_collate_fn(batch):
    """
    Collate function for a seq-to-seq task with variable sequence lengths.

    Pads the input (encoder) sequences and the target (decoder) sequences 
    to the maximum length found in the current batch.

    Args:
        batch: A list of tuples, where each tuple is (input_seq, target_seq).
               - input_seq: torch.Tensor of shape (L_in,)
               - target_seq: torch.Tensor of shape (L_target,)

    Returns:
        A dictionary containing the batched and padded tensors.
    """
    # print('Inside collate')
    # print(type(batch[0][0]))
    # print(batch[0][0].shape)
    # print(batch[0][1].shape)
    # 1. Separate inputs and targets
    input_sequences = [torch.from_numpy(item[0].T).float() for item in batch]
    target_sequences = [torch.from_numpy(item[1].T).float() for item in batch]

    PAD_TOKEN_ID = 0.0

    # 4. Pad sequences
    # torch.nn.utils.rnn.pad_sequence is the most straightforward way
    # batch_first=True makes the output shape (Batch_Size, Max_Length)
    
    padded_inputs = torch.nn.utils.rnn.pad_sequence(
        input_sequences, 
        batch_first=True, 
        padding_value=PAD_TOKEN_ID
    )
    
    padded_targets = torch.nn.utils.rnn.pad_sequence(
        target_sequences, 
        batch_first=True, 
        padding_value=PAD_TOKEN_ID
    )
    # print('finishing collate')
    # 5. Return the collated batch
    return padded_inputs, padded_targets


In [12]:


dataloader = DataLoader(datasetv2, batch_size=128, shuffle=True, num_workers=0, collate_fn= seq2seq_collate_fn)
for batch in tqdm(dataloader):
    x,y = batch

100%|██████████| 312/312 [00:01<00:00, 244.08it/s]


## Batch Test

In [18]:
dataset.close_and_remove_data()
dataset.shuffle_dataset()

reading original data
shuffled data saved


In [19]:
# TODO The multiprocessing dataloader (sharing the hdf5 files through multitable)


for i in tqdm(range(len(dataset))):
    dataset[i]

100%|██████████| 312/312 [00:01<00:00, 308.94it/s]
