# Training Pipeline

# Imports
## Pip Packages

In [1]:
import os


if not os.path.exists('data'):
    new_directory_path = "..\\..\\"
    os.chdir(new_directory_path)

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import math
import h5py
from tqdm import tqdm

## My Modules

In [2]:
from src.parsers import CocoFreeView
from src.preprocess.simulation import gen_gaze, downsample
from src.preprocess.noise import add_random_center_correlated_radial_noise

# Code
## Utils

In [26]:
def search(mask, fixation, side = 'right'):
    start = 0
    stop = mask.shape[0]
    step = 1
    if side == 'left':
        start = mask.shape[0] - 1 
        stop = -1
        step = -1

    for i in range(start,stop, step):
        if mask[i] == fixation:
            return i
    return -1

def test_segment_is_inside(index, x, si,ei,gaze, fixation_mask):
    sidx = search(fixation_mask, si + 1, side = 'right')
    eidx = search(fixation_mask, ei + 1, side = 'left')
    if sidx == -1:
        print(f'{index}❌ Start Fixation not found: si:{si + 1} \n {fixation_mask}')
        return
    if eidx == -1:
        print(f'{index}❌ End Fixation not found: si:{ei + 1} \n {fixation_mask}')
        return
    if x[2,0] <= gaze[2,sidx] and (x[2,-1] + 200) >= gaze[2,eidx]:
        # print(f'{index}✅Pass: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')
        return 
    else:
        print(f'{index}❌Outside: DS [{x[2,0]},{x[2,-1]}] Ori [{gaze[2,sidx]},{gaze[2,eidx]}]')


def location_test(index, si, ei, gaze, fixation_mask, fixations):
    sidx = search(fixation_mask, si + 1, side = 'right')
    eidx = search(fixation_mask, ei + 1, side = 'left')
    if sidx == -1:
        print(f'{index}❌ Start Fixation not found: si:{si + 1} \n {fixation_mask}')
        return
    if eidx == -1:
        print(f'{index}❌ End Fixation not found: si:{ei + 1} \n {fixation_mask}')
        return
    max_dist = 0
    for i in range(sidx, eidx + 1):
        if fixation_mask[i] == 0:
            continue
        f_index = fixation_mask[i] - 1
        fx = fixations[0,f_index]
        fy = fixations[1,f_index]
        gx = gaze[0,i]
        gy = gaze[1,i]
        dist = math.sqrt((fx - gx)**2 + (fy - gy)**2)
        max_dist = max(max_dist, dist)
        if dist > 100:
            print(f'{index}❌ Fixation point too far from gaze point at index {i}: Fixation({fx},{fy}) Gaze({gx},{gy}) Dist:{dist}')
            return
    # print(f'✅ All fixation points are within acceptable distance from gaze points between indices {sidx} and {eidx}.')
    print(f'✅ Max distance between fixation and gaze points: {max_dist}')


## Dataset

In [25]:
def extract_random_period(size, noisy_samples, fixations, fixation_mask, sampling_rate, downsample):
    down_idx = np.random.randint(0, noisy_samples.shape[1] - size + 1, 1, dtype = int)[0]
    # get the values in the original sampling rate
    conversion_factor = downsample/(1000/sampling_rate)
    ori_idx = math.floor(down_idx*conversion_factor)
    ori_size = math.ceil((size - 1)*conversion_factor)
    last_idx = ori_idx + ori_size
    # TEST
    # get the fisrt fixation and if it is not completely included get the next one
    if fixation_mask[ori_idx] > 0:
        if (ori_idx - 1) >= 0 and fixation_mask[ori_idx - 1] == fixation_mask[ori_idx]:
            start_fixation = fixation_mask[ori_idx] + 1
        else:
            start_fixation = fixation_mask[ori_idx]
    else:
        # if the first value is a saccade look for the first fixation
        current_idx = ori_idx + 1
        while current_idx < (ori_idx + ori_size) and fixation_mask[current_idx] == 0:
            current_idx += 1
        if current_idx == (ori_idx + ori_size):
            # if there is not a fixation return an empty array
            return noisy_samples[:, down_idx:down_idx + size],np.array([]), -1,-1
        else:
            # TEST
            start_fixation = fixation_mask[current_idx]
    # search the last fixation
    if fixation_mask[last_idx] > 0:
        if (last_idx + 1) < fixation_mask.shape[0] and fixation_mask[last_idx + 1] == fixation_mask[last_idx]:
            end_fixation = fixation_mask[last_idx] - 1
        else:
            end_fixation = fixation_mask[last_idx]
    else:
        current_idx = last_idx - 1
        while current_idx > ori_idx and fixation_mask[current_idx] == 0:
            current_idx -= 1
        end_fixation = fixation_mask[current_idx]
    # the mask are saved shifted in order to assign 0 to the saccade samples
    start_fixation -= 1
    end_fixation -= 1
    x = noisy_samples[:, down_idx:down_idx + size]
    y = fixations[:, start_fixation: end_fixation + 1]
    return x, y, start_fixation, end_fixation

In [None]:
# TODO Compute image embeddings just once
# TODO Refactor the dataset class (TOO LARGE)
# TODO Review the outputs in the validation and with some tests


class FreeViewInMemory(Dataset):
    def __init__(self, data_path = 'data\\Coco FreeView', sample_size=-1, log = False, sampling_rate=60, downsample_int=200):
        self.sampling_rate = sampling_rate
        self.downsample = downsample_int
        self.sample_size = sample_size
        file_path = os.path.join(data_path, 'dataset.hdf5')
        self.data_path = data_path
        self.log = log
        self.data_store = {}
        with h5py.File(file_path, 'r') as f:
            for key in f.keys():
                self.data_store[key] = f[key][:] # [:] reads all data
        if self.log:
            print('Data loaded in memory')
        self.length = self.data_store['down_gaze'].shape[0]

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        """
        Fetches a single sample from RAM, applies sampling and noise.
        """
        # Get the pre-loaded data for this index
        down_gaze = self.data_store['down_gaze'][index].reshape((3, -1))
        fixations = self.data_store['fixations'][index].reshape((3, -1))
        gaze = self.data_store['gaze'][index].reshape((3, -1))

        x = down_gaze
        y = fixations

        if self.sample_size != -1:
            fixation_mask = self.data_store['fixation_mask'][index]
            x, y, start_fixation, end_fixation = extract_random_period(
                self.sample_size,
                x,
                fixations,
                fixation_mask,
                self.sampling_rate,
                self.downsample
            )
            
        # Apply noise augmentation
        # (Assuming your noise function can process a single [3, N] array)
        x, _ = add_random_center_correlated_radial_noise(x, [320//2, 512//2], 1/16,
                                                                radial_corr=.2,
                                                                radial_avg_norm=4.13,
                                                                radial_std=3.5,
                                                                center_noise_std=100,
                                                                center_corr=.3,
                                                                center_delta_norm=300,
                                                                center_delta_r=.3)
        # test_segment_is_inside(index, x, start_fixation, end_fixation, gaze, fixation_mask)
        # location_test(index, start_fixation, end_fixation, gaze, fixation_mask, fixations)
        
        
        return x, y
    



def seq2seq_collate_fn(batch):
    """
    Collate function for a seq-to-seq task with variable sequence lengths.

    Pads the input (encoder) sequences and the target (decoder) sequences 
    to the maximum length found in the current batch.

    Args:
        batch: A list of tuples, where each tuple is (input_seq, target_seq).
               - input_seq: torch.Tensor of shape (L_in,)
               - target_seq: torch.Tensor of shape (L_target,)

    Returns:
        A dictionary containing the batched and padded tensors.
    """
    # 1. Separate inputs and targets
    input_sequences = [torch.from_numpy(item[0].T).float() for item in batch]
    target_sequences = [torch.from_numpy(item[1].T).float() for item in batch]

    PAD_TOKEN_ID = 0.0

    # 4. Pad sequences
    # torch.nn.utils.rnn.pad_sequence is the most straightforward way
    # batch_first=True makes the output shape (Batch_Size, Max_Length)
    
    padded_inputs = torch.nn.utils.rnn.pad_sequence(
        input_sequences, 
        batch_first=True, 
        padding_value=PAD_TOKEN_ID
    )
    
    padded_targets = torch.nn.utils.rnn.pad_sequence(
        target_sequences, 
        batch_first=True, 
        padding_value=PAD_TOKEN_ID
    )
    # print('finishing collate')
    # 5. Return the collated batch
    return padded_inputs, padded_targets



# Test


## Load Dataset

In [27]:

datasetv2 = FreeViewInMemory(sample_size= 8,log = True)


Data loaded in memory


## Speed Test

In [28]:
for i in range(len(datasetv2)):
    datasetv2[i]

✅ Max distance between fixation and gaze points: 9.222525532104251
✅ Max distance between fixation and gaze points: 8.165574525883114
✅ Max distance between fixation and gaze points: 11.398549595420299
✅ Max distance between fixation and gaze points: 8.876283314731923
✅ Max distance between fixation and gaze points: 10.326665108090921
✅ Max distance between fixation and gaze points: 10.13516233192473
✅ Max distance between fixation and gaze points: 12.64706431970652
✅ Max distance between fixation and gaze points: 8.080511985243696
✅ Max distance between fixation and gaze points: 7.619763748500588
✅ Max distance between fixation and gaze points: 10.5494122261584
✅ Max distance between fixation and gaze points: 10.009144427992833
✅ Max distance between fixation and gaze points: 10.07458594388273
✅ Max distance between fixation and gaze points: 8.7005912823575
✅ Max distance between fixation and gaze points: 9.406439442941975
✅ Max distance between fixation and gaze points: 10.1763879747

In [48]:
x,y = datasetv2[13473]

13473✅Pass: DS [2016.6666259765625,3416.666748046875] Ori [2233.333251953125,3183.333251953125]


In [58]:
fixations = datasetv2.data_store['fixations'][557].reshape((3, -1))[2,14]
fixations

np.float64(6.0)

In [25]:


dataloader = DataLoader(datasetv2, batch_size=128, shuffle=True, num_workers=0, collate_fn= seq2seq_collate_fn)
for batch in tqdm(dataloader):
    x,y = batch

100%|██████████| 312/312 [00:01<00:00, 247.26it/s]


In [None]:
# dataset = PathCocoFreeViewDatasetBatch(sample_size= 8,log = True)
# dataset.close_and_remove_data()

# dataset.close_and_remove_data()
# dataset.shuffle_dataset()


# for i in tqdm(range(len(dataset))):
#     dataset[i]

reading original data
shuffled data saved
