In [1]:
import cv2
import natsort
import os
import torch
import numpy as np

from tonic import transforms
from pathlib import Path
from tqdm import tqdm
from torch.utils.data import DataLoader
from EvEye.utils.scripts.load_config import load_config
from EvEye.model.model_factory import make_model
from EvEye.dataset.dataset_factory import make_dataset
from EvEye.dataset.DavisEyeCenter.losses import process_detector_prediction
from EvEye.utils.tonic.slicers.SliceEventsAtIndices import slice_events_at_timepoints
from EvEye.utils.tonic.functional.ToFrameStack import to_frame_stack_numpy
from EvEye.utils.processor.TxtProcessor import TxtProcessor
from EvEye.utils.visualization.visualization import save_image
from EvEye.utils.visualization.visualization import visualize

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_path = (
    '/mnt/data2T/junyuan/eye-tracking/configs/TestTextDavisEyeDataset_TennSt.yaml'
)
config = load_config(config_path)
labels_path = config['dataset']['label_path']
dataset = make_dataset(config['dataset'])
rgb_path = config['test']['rgb_path']
output_path = config['test']['output_path']
os.makedirs(output_path, exist_ok=True)

In [3]:
model = make_model(config['model'])
model.load_state_dict(torch.load(config["test"]["ckpt_path"])["state_dict"])
device = config["test"]["map_location"]
model.to(device)

TennSt(
  (backbone): Sequential(
    (0): TemporalBlock(
      (block): Sequential(
        (0): Conv3d(2, 8, kernel_size=(5, 1, 1), stride=(1, 1, 1), bias=False)
        (1): GroupNormBlock(
          (gn_block): Sequential(
            (0): CausalGroupNormBlock(4, 8, eps=1e-05, affine=True)
            (1): ActivateLayer(
              (act_layer): ReLU()
            )
          )
        )
      )
    )
    (1): SpatialBlock(
      (block): Sequential(
        (0): Conv3d(8, 16, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False)
        (1): BatchNormBlock(
          (bn_block): Sequential(
            (0): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1): ActivateLayer(
              (act_layer): ReLU()
            )
          )
        )
      )
    )
    (2): TemporalBlock(
      (block): Sequential(
        (0): Conv3d(16, 32, kernel_size=(5, 1, 1), stride=(1, 1, 1), bias=False)
        (1): GroupNormBlock(
   

In [4]:
event_frames = dataset[0][0].to(device)
event_frames_vis = visualize(dataset[0][1])
labels = TxtProcessor(labels_path).load_labels_from_txt()

In [5]:
event_frames_vis.shape

(5009, 260, 346, 3)

In [6]:
pred, inference_times = model.streaming_inference(model, event_frames)
pred = process_detector_prediction(pred)
pred = pred.squeeze(0)
pred[0] *= 346
pred[1] *= 260
predictions_numpy = pred.detach().cpu().numpy().T.astype(np.int32)



In [7]:
"""Save inference times"""

import pandas as pd

inference_times_path = "/mnt/data2T/junyuan/eye-tracking/outputs/InferenceTimes"
os.makedirs(inference_times_path, exist_ok=True)
file_name = "FixedCount5000-Down-Aug-NoFlip.csv"
file_path = f"{inference_times_path}/{file_name}"
df = pd.DataFrame(inference_times, columns=['Inference Time'])
df.to_csv(file_path, index=False)

In [None]:
# """Draw inference time"""

# import seaborn as sns
# import matplotlib.pyplot as plt

# inference_times = inference_times[1:]
# figsize = (10, 20)

# plt.figure(figsize=figsize)
# sns.lineplot(x=range(1, len(inference_times) + 1), y=inference_times)
# plt.xlabel('Inference Count')
# plt.ylabel('Inference Time')
# plt.title('Inference Time vs. Inference Count Line Plot')
# plt.show()

# plt.figure(figsize=figsize)
# sns.violinplot(y=inference_times)
# plt.xlabel('Inference Time')
# plt.ylabel('Density')
# plt.title('Violin Plot of Inference Time')
# plt.show()

# plt.figure(figsize=figsize)
# sns.boxplot(y=inference_times)
# plt.xlabel('Inference Time')
# plt.ylabel('Box Plot')
# plt.title('Box Plot of Inference Time')
# plt.show()

In [None]:
images_path = [image for image in os.listdir(rgb_path) if image.endswith('.png')]
images = natsort.natsorted([os.path.join(rgb_path, image) for image in images_path])
# if delete_first_flag:
#     images = images[1:]

In [None]:
assert predictions_numpy.shape[0] == labels.shape[0]
delta = len(images) - labels.shape[0]
images = images[delta:]
predictions_numpy.shape[0], labels.shape[0], len(images)

In [None]:
assert predictions_numpy.shape[0] == labels.shape[0] == len(images)
for index in tqdm(range(len(images)), desc="Saving images ..."):
    image = cv2.imread(images[index])
    center_x, center_y = labels['x'][index], labels['y'][index]
    pred_x, pred_y = predictions_numpy[index]
    event_frames_vis[index] = cv2.circle(
        event_frames_vis[index], (int(center_x), int(center_y)), 3, (0, 255, 0), -1
    )
    event_frames_vis[index] = cv2.circle(
        event_frames_vis[index], (int(pred_x), int(pred_y)), 3, (255, 255, 255), -1
    )
    image = cv2.circle(image, (int(center_x), int(center_y)), 3, (0, 255, 0), -1)
    image = cv2.circle(image, (int(pred_x), int(pred_y)), 3, (255, 255, 255), -1)
    combined_image = np.concatenate([image, event_frames_vis[index]], axis=1)
    combined_image = cv2.cvtColor(combined_image, cv2.COLOR_BGR2RGB)
    save_image(combined_image, f"{output_path}/{index:04}.png")