In [1]:
import os
import torch
import numpy as np
import cv2

from tqdm import tqdm
from pathlib import Path
from EvEye.utils.tonic.functional.ToFrameStack import to_frame_stack_numpy
from EvEye.utils.tonic.slicers.SliceEventsAtIndices import slice_events_at_timepoints
from EvEye.utils.processor.TxtProcessor import TxtProcessor
from EvEye.utils.visualization.visualization import (
    visualize,
    save_image,
    load_image,
    draw_contour,
)
from tqdm import tqdm
from natsort import natsorted
from EvEye.model.model_factory import make_model
from EvEye.utils.scripts.process_model_output import process_model_output
from EvEye.utils.scripts.load_config import load_config
import albumentations as A
from albumentations.pytorch import ToTensorV2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
time_window = 10000
sensor_size = (346, 260, 2)

In [3]:
output_base_path = Path("/mnt/data2T/junyuan/eye-tracking/outputs/EventMasks")
root_txt_path = Path(
    "/mnt/data2T/junyuan/eye-tracking/datasets/DavisEyeCenterDataset/test/data"
)
root_rgb_path = Path(
    "/mnt/data2T/junyuan/eye-tracking/datasets/DavisEyeCenterDatasetFrames"
)
user_paths = natsorted([f for f in root_rgb_path.iterdir() if f.is_dir()])

In [4]:
config_path = Path("/mnt/data2T/junyuan/eye-tracking/configs/OutputGroundTruth.yaml")
config = load_config(config_path)
model = make_model(config["model"])
model.load_state_dict(
    torch.load(config["test"]["ckpt_path"], map_location="cuda:0")["state_dict"]
)
model = model.cuda()
model.eval()

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment

In [5]:
transform = A.Compose(
    [
        A.Resize(height=240, width=346),
        A.Normalize(
            mean=[0.0],
            std=[1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

In [6]:
for user_path in user_paths:
    print(f"Processing {user_path.stem}...")
    rgb_path = natsorted(list(user_path.glob("*.png")))
    txt_path = Path(f"{root_txt_path}/{user_path.stem}_events.txt")
    output_path = output_base_path / user_path.stem
    os.makedirs(output_path, exist_ok=True)
    if not txt_path.exists():
        continue
    events = TxtProcessor(txt_path).load_events_from_txt()
    events_end_time = events['t'][-1]
    events_start_time = events['t'][0]
    num_frames = (events_end_time - events_start_time) // time_window
    count = 0
    for rgb in tqdm(rgb_path):
        end_time = int(rgb.stem.split("_")[1])
        start_time = end_time - time_window
        event_segment = slice_events_at_timepoints(events, start_time, end_time)
        if len(event_segment) < 1000:
            continue
        event_frame = to_frame_stack_numpy(
            event_segment,
            sensor_size,
            1,
            "causal_linear",
            start_time,
            end_time,
        )
        event_frame_vis = visualize(event_frame)
        event_frame_vis_name = str(output_path / f"{count:05}_{end_time}.png")
        save_image(event_frame_vis, event_frame_vis_name)
        image_gray = load_image(str(rgb), "grayscale")[0]
        transformed = transform(image=image_gray)
        image = transformed["image"]
        image = image.unsqueeze(0)
        with torch.no_grad():
            output = model(image.cuda())
        mask = process_model_output(output, use_softmax=True)
        mask = mask.detach().cpu().numpy()  # shape: [260, 346]
        mask = cv2.resize(
            mask,
            (sensor_size[0], sensor_size[1]),
            interpolation=cv2.INTER_NEAREST,
        )
        mask_vis = draw_contour(mask)
        save_image(mask_vis, str(output_path / f"{count:05}_{end_time}_mask.png"))
        image_name = str(output_path / f"{count:05}_{end_time}_rgb.png")
        save_image(image_gray, image_name)
        count += 1

Processing user1_left_session_1_0_1...
Processing user1_left_session_1_0_2...
Processing user1_left_session_2_0_1...
Processing user1_left_session_2_0_2...
Processing user1_right_session_1_0_1...
Processing user1_right_session_1_0_2...
Processing user1_right_session_2_0_1...
Processing user1_right_session_2_0_2...
Processing user2_left_session_1_0_1...
Processing user2_left_session_1_0_2...
Processing user2_left_session_2_0_1...
Processing user2_left_session_2_0_2...
Processing user2_right_session_1_0_1...
Processing user2_right_session_1_0_2...
Processing user2_right_session_2_0_1...
Processing user2_right_session_2_0_2...
Processing user3_left_session_1_0_1...
Processing user3_left_session_1_0_2...
Processing user3_left_session_2_0_1...
Processing user3_left_session_2_0_2...
Processing user3_right_session_1_0_1...
Processing user3_right_session_1_0_2...
Processing user3_right_session_2_0_1...
Processing user3_right_session_2_0_2...
Processing user4_left_session_1_0_1...
Processing us

100%|██████████| 5010/5010 [00:08<00:00, 622.52it/s]


Processing user43_left_session_1_0_2...


100%|██████████| 2843/2843 [00:04<00:00, 581.85it/s]


Processing user43_left_session_2_0_1...


100%|██████████| 5002/5002 [00:07<00:00, 693.47it/s]


Processing user43_left_session_2_0_2...


100%|██████████| 2834/2834 [00:04<00:00, 634.39it/s]


Processing user43_right_session_1_0_1...


100%|██████████| 4975/4975 [00:06<00:00, 727.85it/s] 


Processing user43_right_session_1_0_2...


100%|██████████| 2819/2819 [00:03<00:00, 705.27it/s]


Processing user43_right_session_2_0_1...


100%|██████████| 4977/4977 [00:05<00:00, 835.34it/s] 


Processing user43_right_session_2_0_2...


100%|██████████| 2813/2813 [00:03<00:00, 809.53it/s]


Processing user44_left_session_1_0_1...


100%|██████████| 4990/4990 [00:09<00:00, 501.60it/s]


Processing user44_left_session_1_0_2...


100%|██████████| 2802/2802 [00:05<00:00, 480.29it/s]


Processing user44_left_session_2_0_1...


100%|██████████| 4969/4969 [00:09<00:00, 536.80it/s]


Processing user44_left_session_2_0_2...


100%|██████████| 2787/2787 [00:05<00:00, 489.77it/s]


Processing user44_right_session_1_0_1...


100%|██████████| 4963/4963 [00:07<00:00, 654.56it/s]


Processing user44_right_session_1_0_2...


100%|██████████| 2771/2771 [00:03<00:00, 730.98it/s]


Processing user44_right_session_2_0_1...


100%|██████████| 4937/4937 [00:07<00:00, 680.37it/s]


Processing user44_right_session_2_0_2...


100%|██████████| 2766/2766 [00:04<00:00, 650.95it/s]


Processing user45_left_session_1_0_1...


100%|██████████| 4997/4997 [00:09<00:00, 542.31it/s]


Processing user45_left_session_1_0_2...


100%|██████████| 2875/2875 [00:06<00:00, 460.01it/s]


Processing user45_left_session_2_0_1...


100%|██████████| 4985/4985 [00:09<00:00, 503.71it/s]


Processing user46_left_session_1_0_1...


100%|██████████| 5043/5043 [00:09<00:00, 550.43it/s]


Processing user46_left_session_1_0_2...


100%|██████████| 2823/2823 [00:04<00:00, 646.79it/s]


Processing user46_left_session_2_0_1...


100%|██████████| 5035/5035 [00:10<00:00, 468.40it/s]


Processing user46_left_session_2_0_2...


100%|██████████| 2806/2806 [00:05<00:00, 532.03it/s]


Processing user46_right_session_1_0_1...


100%|██████████| 5019/5019 [00:08<00:00, 621.14it/s]


Processing user46_right_session_1_0_2...


100%|██████████| 2801/2801 [00:03<00:00, 768.52it/s] 


Processing user46_right_session_2_0_1...


100%|██████████| 5011/5011 [00:09<00:00, 519.40it/s]


Processing user46_right_session_2_0_2...


100%|██████████| 2780/2780 [00:04<00:00, 610.64it/s]


Processing user47_left_session_1_0_1...


100%|██████████| 4968/4968 [00:08<00:00, 606.02it/s]


Processing user47_left_session_1_0_2...


100%|██████████| 2785/2785 [00:05<00:00, 467.90it/s]


Processing user47_left_session_2_0_1...


100%|██████████| 5020/5020 [00:09<00:00, 544.15it/s]


Processing user47_left_session_2_0_2...


100%|██████████| 2822/2822 [00:07<00:00, 379.05it/s]


Processing user47_right_session_1_0_1...


100%|██████████| 4941/4941 [00:06<00:00, 765.92it/s] 


Processing user47_right_session_1_0_2...


100%|██████████| 2760/2760 [00:04<00:00, 683.20it/s]


Processing user47_right_session_2_0_1...


100%|██████████| 4998/4998 [00:06<00:00, 732.41it/s]


Processing user47_right_session_2_0_2...


100%|██████████| 2794/2794 [00:04<00:00, 666.96it/s]


Processing user48_left_session_1_0_1...


100%|██████████| 5060/5060 [00:10<00:00, 502.87it/s]


Processing user48_left_session_1_0_2...


100%|██████████| 2821/2821 [00:06<00:00, 441.81it/s]
