In [None]:
import torch
import numpy as np
import pandas as pd

from EvEye.utils.scripts.load_config import load_config
from EvEye.dataset.dataset_factory import make_dataset
from EvEye.model.model_factory import make_model
from EvEye.utils.scripts.load_config import load_config
from EvEye.dataset.DavisEyeCenter.losses import process_detector_prediction

In [None]:
config_path = '/mnt/data2T/junyuan/eye-tracking/configs/TestTextDavisEyeDataset_TennSt.yaml'

In [None]:
config = load_config(config_path)
testDataset = make_dataset(config['dataset'])
model = make_model(config['model'])
model.load_state_dict(
    torch.load(config["test"]["ckpt_path"])["state_dict"]
    )
device = config["test"]["map_location"]
model.to(device)

In [None]:
event_frames = testDataset[0].unsqueeze(0).to(device)

In [None]:
pred = model.streaming_inference(model, event_frames)
pred = process_detector_prediction(pred)
pred = pred.squeeze(0)
pred[0] *= 346
pred[1] *= 260
predictions_numpy = pred.detach().cpu().numpy().T.astype(np.int32)

In [None]:
arange = np.arange(predictions_numpy.shape[0])
predictions_numpy = np.concatenate([arange[:, None], predictions_numpy], axis=1)

In [None]:
df = pd.DataFrame(predictions_numpy, columns=["row_id", "x", "y"])
df.to_csv("submission.csv", index=False)

In [None]:
import cv2
import natsort
import os
import torch
import numpy as np

from tonic import transforms
from pathlib import Path
from tqdm import tqdm
from EvEye.utils.tonic.functional.CutMaxCount import cut_max_count
from EvEye.utils.scripts.load_config import load_config
from EvEye.model.model_factory import make_model
from EvEye.dataset.DavisEyeCenter.losses import process_detector_prediction
from EvEye.utils.tonic.slicers.SliceEventsAtIndices import slice_events_at_timepoints
from EvEye.utils.tonic.functional.ToFrameStack import to_frame_stack_numpy
from EvEye.utils.processor.TxtProcessor import TxtProcessor
from EvEye.utils.visualization.visualization import save_image
from EvEye.utils.visualization.visualization import visualize

In [None]:
config_path = '/mnt/data2T/junyuan/eye-tracking/configs/TestTextDavisEyeDataset_TennSt.yaml'
txt_path = Path("/mnt/data2T/junyuan/eye-tracking/datasets/DavisEyeCenterDataset/test/data/user43_left_session_1_0_1_events.txt")
label_path = Path("/mnt/data2T/junyuan/eye-tracking/datasets/DavisEyeCenterDataset/test/label/user43_left_session_1_0_1_centers.txt")
rgb_path = '/mnt/data2T/junyuan/eye-tracking/datasets/DavisEyeCenterDatasetFrames/user43_left_session_1_0_1'
output_path = '/mnt/data2T/junyuan/eye-tracking/outputs/InferenceResultsTest_4'
os.makedirs(output_path, exist_ok=True)

In [None]:
events = TxtProcessor(txt_path).load_events_from_txt()
labels = TxtProcessor(label_path).load_labels_from_txt()
config = load_config(config_path)
model = make_model(config['model'])
model.load_state_dict(
    torch.load(config["test"]["ckpt_path"])["state_dict"]
    )
device = config["test"]["map_location"]
model.to(device)

In [None]:
sensor_size = (346, 260, 2)
time_window = 40000
events_interpolation = "causal_linear"
max_count = 5
spatial_factor = 0.5

In [None]:
events_downsampled = transforms.Downsample(spatial_factor=spatial_factor)(events)
sensor_size_downsampled = (
    int(sensor_size[0] * spatial_factor),
    int(sensor_size[1] * spatial_factor),
    int(sensor_size[2]),
)

In [None]:
delete_first_flag = False
start_time_first = max(events['t'][0], labels['t'][0] - time_window)
end_time_first = labels['t'][0]
if start_time_first >= end_time_first:
    labels = labels[1:]
    start_time_first = max(events['t'][0], labels['t'][0] - time_window)
    end_time_first = labels['t'][0]
    delete_first_flag = True
assert start_time_first < end_time_first

In [None]:
event_segment_first_downsampled = slice_events_at_timepoints(events_downsampled, start_time_first, end_time_first)
frame_first_downsampled = to_frame_stack_numpy(event_segment_first_downsampled, sensor_size_downsampled, 1, events_interpolation)
event_segment_others_downsampled = slice_events_at_timepoints(events_downsampled, end_time_first, labels['t'][-1])
frame_others_downsampled = to_frame_stack_numpy(event_segment_others_downsampled, sensor_size_downsampled, labels.shape[0]-1, events_interpolation)
event_frames_downsampled = np.concatenate([frame_first_downsampled, frame_others_downsampled], axis=0)
cut_max_count(event_frames_downsampled, max_count, True)
event_frames_pred = torch.from_numpy(event_frames_downsampled).moveaxis(0, 1).to(torch.float32).unsqueeze(0).to(device)

In [None]:
event_segment_first = slice_events_at_timepoints(events, start_time_first, end_time_first)
frame_first = to_frame_stack_numpy(event_segment_first, sensor_size, 1, events_interpolation)
event_segment_others = slice_events_at_timepoints(events, end_time_first, labels['t'][-1])
frame_others = to_frame_stack_numpy(event_segment_others, sensor_size, labels.shape[0]-1, events_interpolation)
event_frames = np.concatenate([frame_first, frame_others], axis=0)
cut_max_count(event_frames, max_count, True)

In [None]:
event_frames.shape, event_frames_downsampled.shape,  event_frames_pred.shape

In [None]:
pred = model.streaming_inference(model, event_frames_pred)
pred = process_detector_prediction(pred)
pred = pred.squeeze(0)
pred[0] *= 346
pred[1] *= 260
predictions_numpy = pred.detach().cpu().numpy().T.astype(np.int32)

In [None]:
images_path = [image for image in os.listdir(rgb_path) if image.endswith('.png')]
images = natsort.natsorted([os.path.join(rgb_path, image) for image in images_path])
if delete_first_flag:
    images = images[1:]

In [None]:
assert predictions_numpy.shape[0] == labels.shape[0] == len(images)
for index in tqdm(range(len(images)), desc="Saving images ..."):
    image = cv2.imread(images[index])
    event_frame = visualize(event_frames[index])
    center_x, center_y = labels['x'][index], labels['y'][index]
    pred_x, pred_y = predictions_numpy[index]
    event_frame = cv2.circle(event_frame, (int(center_x), int(center_y)), 3, (0, 255, 0), -1)
    event_frame = cv2.circle(event_frame, (int(pred_x), int(pred_y)), 3, (255, 255, 255), -1)
    image = cv2.circle(image, (int(center_x), int(center_y)), 3, (0, 255, 0), -1)
    image = cv2.circle(image, (int(pred_x), int(pred_y)), 3, (255, 255, 255), -1)
    combined_image = np.concatenate([image, event_frame], axis=1)
    save_image(combined_image, f"{output_path}/{index:04}.png")