In [1]:
import argparse
import json
import os
import mlflow
import torch
from torch.utils.data import DataLoader
from model.BaselineEyeTrackingModel import CNN_GRU
from model.RecurrentVisionTransformer import RVT
from utils.training_utils import train_epoch, validate_epoch, top_k_checkpoints
from utils.metrics import weighted_MSELoss
from dataset import ThreeETplus_Eyetracking, ScaleLabel, NormalizeLabel, TemporalSubsample, SliceLongEventsToShort, EventSlicesToVoxelGrid, SliceByTimeEventsTargets, RandomSpatialAugmentor
import tonic.transforms as transforms
from tonic import SlicedDataset, DiskCachedDataset


# Parse arguments force insert the config
args = {
    "config_file": "./configs/rvt_2_layered_test.json",
    "checkpoint": "mlruns/775203291142996437/81d8e4a993c94f1692b65eb1c0bcabf4/artifacts/model_best_ep265_val_loss_0.0101.pth",
    "train_length": 30,
    "val_length": 30,
    "train_stride": 15,
    "val_stride": 30,
    "data_augmentation": {
        "random": {
            "prob_hflip": 0.5,
            "rotate": {
                "prob": 0,
                "min_angle_deg": 2,
                "max_angle_deg": 6
            },
            "zoom": {
                "prob": 0.8,
                "zoom_in": {
                    "weight": 8,
                    "factor": {
                        "min": 1,
                        "max": 1.5
                    }
                },
                "zoom_out": {
                    "weight": 2,
                    "factor": {
                        "min": 1,
                        "max": 1.2
                    }
                }
            }
        },
        "stream": {
            "prob_hflip": 0.5,
            "rotate": {
                "prob": 0,
                "min_angle_deg": 2,
                "max_angle_deg": 6
            },
            "zoom": {
                "prob": 0.5,
                "zoom_out": {
                "factor": {
                    "min": 1,
                    "max": 1.2
                }
                }
            }
        }
    }
}

args = argparse.Namespace(**args)

# Load hyperparameters from JSON configuration file
if args.config_file:
    with open(args.config_file, 'r') as f:
        config = json.load(f)
    # Overwrite command line arguments with config file
    for key, value in config.items():
        setattr(args, key, value)

# Parameters from args (now includes config file parameters)
factor = args.spatial_factor
temp_subsample_factor = args.temporal_subsample_factor
test_length = args.test_length
data_dir = args.data_dir
test_stride = args.test_stride
n_time_bins = args.n_time_bins
voxel_grid_ch_normaization = args.voxel_grid_ch_normaization

label_transform = transforms.Compose([
    ScaleLabel(factor),
    TemporalSubsample(temp_subsample_factor),
    NormalizeLabel(pseudo_width=640*factor, pseudo_height=480*factor)
])

test_data_orig = ThreeETplus_Eyetracking(save_to=data_dir, split="test",
                transform=transforms.Downsample(spatial_factor=factor),
                target_transform=label_transform)

slicing_time_window = test_length*int(10000/temp_subsample_factor)  # microseconds

test_slicer = SliceByTimeEventsTargets(slicing_time_window, overlap=0,
                                       seq_length=test_length, seq_stride=test_stride, include_incomplete=True)

post_slicer_transform = transforms.Compose([
    SliceLongEventsToShort(time_window=int(10000/temp_subsample_factor), overlap=0, include_incomplete=True),
    EventSlicesToVoxelGrid(sensor_size=(int(640*factor), int(480*factor), 2),
                           n_time_bins=n_time_bins, per_channel_normalize=voxel_grid_ch_normaization)
])

test_data = SlicedDataset(test_data_orig, test_slicer, transform=post_slicer_transform)

test_loader = DataLoader(test_data, batch_size=1, shuffle=False,
                         num_workers=int(os.cpu_count() - 2))


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
val_data_orig = ThreeETplus_Eyetracking(save_to=args.data_dir, split="val", \
                        transform=transforms.Downsample(spatial_factor=factor),
                        target_transform=label_transform)

# Then we slice the event recordings into sub-sequences. 
# The time-window is determined by the sequence length (train_length, val_length) 
# and the temporal subsample factor.
slicing_time_window = args.train_length*int(10000/temp_subsample_factor) #microseconds

# the validation set is sliced to non-overlapping sequences
val_slicer=SliceByTimeEventsTargets(slicing_time_window, overlap=0, \
                seq_length=args.val_length, seq_stride=args.val_stride, include_incomplete=False)

# After slicing the raw event recordings into sub-sequences, 
# we make each subsequences into your favorite event representation, 
# in this case event voxel-grid
post_slicer_transform = transforms.Compose([
    SliceLongEventsToShort(time_window=int(10000/temp_subsample_factor), overlap=0, include_incomplete=True),
    EventSlicesToVoxelGrid(sensor_size=(int(640*factor), int(480*factor), 2), \
                                n_time_bins=args.n_time_bins, per_channel_normalize=args.voxel_grid_ch_normaization)
])

# We use the Tonic SlicedDataset class to handle the collation of the sub-sequences into batches.
val_data = SlicedDataset(val_data_orig, val_slicer, transform=post_slicer_transform, metadata_path=f"./metadata/3et_val_vl_{args.val_length}_vs{args.val_stride}_ch{args.n_time_bins}")

augmentation = RandomSpatialAugmentor(dataset_wh = (1, 1), augm_config=args.data_augmentation) 

# cache the dataset to disk to speed up training. The first epoch will be slow, but the following epochs will be fast.
val_data = DiskCachedDataset(val_data, cache_path=f'./cached_dataset/val_vl_{args.val_length}_vs{args.val_stride}_ch{args.n_time_bins}', transforms=augmentation)

# Finally we wrap the dataset with pytorch dataloader
val_loader = DataLoader(val_data, batch_size=args.batch_size, shuffle=False, \
                                num_workers=int(os.cpu_count()-2))


Metadata read from ./metadata/3et_val_vl_30_vs30_ch3/slice_metadata.h5.


In [3]:
model = eval(args.architecture)(args).to(args.device)

# load weights from a checkpoint
if args.checkpoint:
    model.load_state_dict(torch.load(args.checkpoint))
else:
    raise ValueError("Please provide a checkpoint file.")
    

In [11]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

def plot_voxel_grid_as_rgb_to_html(voxel_grid, title, predictions=[], targets=[]):
    voxel_grid = np.moveaxis(voxel_grid, 1, -1)  # N, C, H, W -> N, H, W, C
    fig, ax = plt.subplots()
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    im = ax.imshow(voxel_grid[0, :, :, :])

    def update(i):
        ax.clear()  # Clear to avoid overlaying dots
        ax.imshow(voxel_grid[i, :, :, :])
        ax.set_xticks([])
        ax.set_yticks([])
        if len(predictions) > 0:
            x, y = predictions[i]
            ax.plot(x*voxel_grid.shape[2], y*voxel_grid.shape[1], 'ro')
        if len(targets) > 0:
            x, y = targets[i]
            ax.plot(x*voxel_grid.shape[2], y*voxel_grid.shape[1], 'go')
        return ax,

    ani = FuncAnimation(fig, update, frames=range(voxel_grid.shape[0]), blit=False)
    html_str = ani.to_jshtml()
    plt.close(fig)
    return html_str

# Initialize HTML document
html_doc = """
<html>
<head>
<title>Animation Gallery</title>
</head>
<body>
<h1>Animation Gallery</h1>
"""

# Assuming val_loader is defined and properly loaded
for i, (voxel_grid, target) in enumerate(val_loader):
    voxel_grid = voxel_grid.to(args.device)
    pred = model(voxel_grid).detach().cpu().numpy()[0]
    voxel_grid_np = voxel_grid[0, :, :, :, :].cpu().numpy()
    voxel_grid_np = (voxel_grid_np - voxel_grid_np.min()) / (voxel_grid_np.max() - voxel_grid_np.min())  # Normalize
    html_str = plot_voxel_grid_as_rgb_to_html(voxel_grid_np, f"Voxel grid {i}", pred, target[0][:,:2])
    html_doc += html_str

# Close HTML document
html_doc += """
</body>
</html>
"""

# Save the HTML document
with open("web/animation_gallery.html", "w") as f:
    f.write(html_doc)
