In [1]:
import argparse, json, os, mlflow
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model.BaselineEyeTrackingModel import CNN_GRU
from utils.training_utils import train_epoch, validate_epoch, top_k_checkpoints
from utils.metrics import weighted_MSELoss
from dataset import ThreeETplus_Eyetracking, ScaleLabel, NormalizeLabel, \
    TemporalSubsample, NormalizeLabel, SliceLongEventsToShort, \
    EventSlicesToVoxelGrid, SliceByTimeEventsTargets, EventSlicesToRVT, ToBoundingBox
import tonic.transforms as transforms
from tonic import SlicedDataset, DiskCachedDataset

import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_file = 'sliced_baseline.json'
with open(os.path.join('./configs', config_file), 'r') as f:
    config = json.load(f)
args = argparse.Namespace(**config)

In [3]:
factor = args.spatial_factor # spatial downsample factor
temp_subsample_factor = args.temporal_subsample_factor # downsampling original 100Hz label to 20Hz

# The original labels are spatially downsampled with 'factor', downsampled to 20Hz, and normalized w.r.t width and height to [0,1]
label_transform = transforms.Compose([
    ScaleLabel(factor),
    TemporalSubsample(temp_subsample_factor),
    NormalizeLabel(pseudo_width=640*factor, pseudo_height=480*factor),
    ToBoundingBox(box_width=20/640, box_height=20/640)

])

In [4]:
train_data_orig = ThreeETplus_Eyetracking(save_to=args.data_dir, split="train", \
                transform=transforms.Downsample(spatial_factor=factor), 
                target_transform=label_transform)
val_data_orig = ThreeETplus_Eyetracking(save_to=args.data_dir, split="val", \
                transform=transforms.Downsample(spatial_factor=factor),
                target_transform=label_transform)

In [15]:
30 3 h w
7429 42 442

0 - x
30 frames = 

0 - 1000, 1000 - train_data_or

t p x y

0 - 1000, 5000 eventsd, 1000 - 2000, events

SyntaxError: invalid syntax (2925462393.py, line 1)

In [5]:
slicing_time_window = args.train_length*int(10000/temp_subsample_factor) #microseconds
train_stride_time = int(10000/temp_subsample_factor*args.train_stride) #microseconds

train_slicer=SliceByTimeEventsTargets(slicing_time_window, overlap=slicing_time_window-train_stride_time, \
                seq_length=args.train_length, seq_stride=args.train_stride, include_incomplete=False)
# the validation set is sliced to non-overlapping sequences
val_slicer=SliceByTimeEventsTargets(slicing_time_window, overlap=0, \
                seq_length=args.val_length, seq_stride=args.val_stride, include_incomplete=False)

In [6]:
class ToNumpy:
    def __init__(self):
        """
        Initialize the transformation.

        Args:
        - time_window (int): The length of each sub-sequence.
        """
    def __call__(self, events):
        return np.array(events) 

In [7]:
post_slicer_transform = transforms.Compose([
    SliceLongEventsToShort(time_window=int(10000/temp_subsample_factor), overlap=0, include_incomplete=True),
    EventSlicesToVoxelGrid(sensor_size=(int(640*factor), int(480*factor), 2), \
                            n_time_bins=args.n_time_bins, per_channel_normalize=args.voxel_grid_ch_normaization)
])

In [8]:
train_data = SlicedDataset(train_data_orig, train_slicer, transform=post_slicer_transform, metadata_path=f"./metadata/3et_train_tl_{args.train_length}_ts{args.train_stride}_ch{args.n_time_bins}")
val_data = SlicedDataset(val_data_orig, val_slicer, transform=post_slicer_transform, metadata_path=f"./metadata/3et_val_vl_{args.val_length}_vs{args.val_stride}_ch{args.n_time_bins}")

Metadata read from ./metadata/3et_train_tl_30_ts15_ch3/slice_metadata.h5.
Metadata read from ./metadata/3et_val_vl_30_vs30_ch3/slice_metadata.h5.


In [9]:
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, \
                            num_workers=4, pin_memory=True)
val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, \
                        num_workers=4)

In [10]:
for a in train_loader:
    print(a[0].shape, a[1].shape)
    

(173, 3)
(30, 3, 60, 80)
(138, 3)
(30, 3, 60, 80)
(199, 3)
(30, 3, 60, 80)
(131, 3)
(30, 3, 60, 80)
(196, 3)
(30, 3, 60, 80)
(179, 3)
(30, 3, 60, 80)
(133, 3)
(148, 3)
(30, 3, 60, 80)
(30, 3, 60, 80)
(151, 3)
(30, 3, 60, 80)
(206, 3)
(30, 3, 60, 80)
(153, 3)
(30, 3, 60, 80)
(196, 3)
(30, 3, 60, 80)
(131, 3)
(30, 3, 60, 80)
(179, 3)
(111, 3)
(30, 3, 60, 80)
(30, 3, 60, 80)
(215, 3)
(30, 3, 60, 80)
(309, 3)
(30, 3, 60, 80)
(196, 3)
(30, 3, 60, 80)
(196, 3)
(30, 3, 60, 80)
(173, 3)
(30, 3, 60, 80)
(218, 3)
(148, 3)
(30, 3, 60, 80)(30, 3, 60, 80)

(196, 3)
(30, 3, 60, 80)
(196, 3)
(30, 3, 60, 80)
(211, 3)
(30, 3, 60, 80)
(133, 3)
(153, 3)
(30, 3, 60, 80)
(30, 3, 60, 80)
(218, 3)
(175, 3)
(30, 3, 60, 80)
(30, 3, 60, 80)
(196, 3)
(30, 3, 60, 80)
(175, 3)
(30, 3, 60, 80)
(196, 3)
(30, 3, 60, 80)
(175, 3)
(30, 3, 60, 80)
(111, 3)
(30, 3, 60, 80)
(309, 3)
(30, 3, 60, 80)
(196, 3)(218, 3)

(30, 3, 60, 80)
(30, 3, 60, 80)
(104, 3)
(30, 3, 60, 80)
(179, 3)
(30, 3, 60, 80)
(173, 3)
(30, 3, 60, 80)


In [12]:
d

TypeError: 'DataLoader' object is not subscriptable

In [None]:
a.shape