In [1]:
from pathlib import Path
import numpy as np
import torch
import os
from natsort import natsorted
from tqdm import tqdm
import tonic.functional as toc

from EvEye.utils.scripts.CacheFrameStack import load_memmap
from EvEye.utils.dvs_common_utils.representation.TorchFrameStack import (
    TorchFrameStack,
)

In [2]:
class CacheDavisEyeCenterDataset:
    def __init__(
        self,
        root_path: Path | str,
        split="train", 
        time_window=40000, 
        frames_per_segment=50,
        spatial_downsample=(2, 2),
        events_interpolation="bilinear",  # 'bilinear', 'nearest', 'causal_linear'
    ):
        assert time_window == 40000
        self.root_path = Path(root_path)
        self.split = split
        self.time_window = time_window
        self.frames_per_segment = frames_per_segment
        self.time_window_per_segment = time_window * frames_per_segment
        self.spatial_downsample = spatial_downsample
        self.events_interpolation = events_interpolation
        
        self.events, self.labels = [], []
        self.num_frames_list, self.num_segments_list = [], []

        self.data_base_path: Path = self.root_path / self.split / "cached_data"
        self.label_base_path: Path = self.root_path / self.split / "cached_label"
        self.data_paths: list = natsorted(
            self.data_base_path.glob("events_batch_*.memmap")
        )
        self.label_paths: list = natsorted(
            self.label_base_path.glob("labels_batch_*.memmap")
        )
        self.data_info_paths: list = natsorted(
            self.data_base_path.glob("events_info_batch_*.txt")
        )
        self.label_info_paths: list = natsorted(
            self.label_base_path.glob("labels_info_batch_*.txt")
        )
        self.data_indices_paths: list = natsorted(
            self.data_base_path.glob("events_indices_batch_*.memmap")
        )
        self.label_indices_paths: list = natsorted(
            self.label_base_path.glob("labels_indices_batch_*.memmap")
        )
        self.data_indices_info_paths = natsorted(
            self.data_base_path.glob("events_indices_info_batch_*.txt")
        )
        self.label_indices_info_paths = natsorted(
            self.label_base_path.glob("labels_indices_info_batch_*.txt")
        )
        for (
            data_path,
            label_path,
            data_info_path,
            label_info_path,
            data_indices_path,
            label_indices_path,
            data_indices_info_path,
            label_indices_info_path,
        ) in tqdm(
            zip(
                self.data_paths,
                self.label_paths,
                self.data_info_paths,
                self.label_info_paths,
                self.data_indices_paths,
                self.label_indices_paths,
                self.data_indices_info_paths,
                self.label_indices_info_paths,
            ),
            total=len(self.data_paths),
            desc="Loading data...",
        ):
            events = load_memmap(data_path, data_info_path)
            events_indices = load_memmap(data_indices_path, data_indices_info_path)
            labels = load_memmap(label_path, label_info_path)
            labels_indices = load_memmap(
                label_indices_path, label_indices_info_path
            )
            for indice in events_indices:
                event = events[:, indice[0] : indice[1]]
                self.events.append(event)

            for indice in labels_indices:
                num_frames = indice[1] - indice[0]
                self.num_frames_list.append(num_frames)
                self.num_segments_list.append(num_frames // frames_per_segment)
                label = labels[:, indice[0] : indice[1]]
                self.labels.append(label)
        self.total_segments = sum(self.num_segments_list)

    def get_index(self, file_lens, index):
        file_lens_cumsum = np.cumsum(np.array(file_lens))
        file_id = np.searchsorted(file_lens_cumsum, index, side="right")
        sample_id = index - file_lens_cumsum[file_id - 1] if file_id > 0 else index

        return file_id, sample_id
    
    def __len__(self):
        return self.total_segments

    def __getitem__(self, index):
        file_id, segment_id = self.get_index(self.num_segments_list, index)
        event, label = self.events[file_id], self.labels[file_id]
        start_time = (
                label[0][0] + segment_id * self.time_window * self.frames_per_segment
            )
        end_time = start_time + self.time_window * self.frames_per_segment

        start_event_id = np.searchsorted(event[3], start_time, side="left")
        end_event_id = np.searchsorted(event[3], end_time, side="left")
        event_segment = event[:, start_event_id:end_event_id]
        event_segment = np.array(event_segment)
        event_segment[-1] -= start_time
        num_frames = self.frames_per_segment
        event_segment = torch.from_numpy(event_segment)
        # print(event_segment.shape)
        event_frame = TorchFrameStack(
                events=event_segment,
                size=(
                    260 // self.spatial_downsample[0],
                    346 // self.spatial_downsample[1],
                ),
                num_frames=num_frames,
                spatial_downsample=self.spatial_downsample,
                temporal_downsample=self.time_window,
                mode=self.events_interpolation,
            )
        event_frame = event_frame.moveaxis(0, 1)
        event_frame = event_frame.numpy()

        start_label_id = segment_id * self.frames_per_segment
        end_label_id = start_label_id + self.frames_per_segment
        label_segment = label[:, start_label_id:end_label_id]
        label_x = (label_segment[1] / 2).round()
        label_y = (label_segment[2] / 2).round()
        label_coord = np.vstack([label_x, label_y])

        closeness = 1- np.array(label_segment[3])

        return event_frame, label_coord, closeness

        

In [3]:
dataset = CacheDavisEyeCenterDataset(
    root_path="/mnt/data2T/junyuan/eye-tracking/testDataset",
    split="train", 
    time_window=40000, 
    frames_per_segment=50,
    spatial_downsample=(2, 2),
    events_interpolation="bilinear",  # 'bilinear', 'nearest', 'causal_linear'
)

Loading data...:   0%|          | 0/1 [00:00<?, ?it/s]

Loading data...: 100%|██████████| 1/1 [00:00<00:00, 40.71it/s]


In [5]:
dataset[0][0].shape

(2, 50, 130, 173)

In [22]:
# from torch.utils.data import DataLoader
# dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
# for i, (x, y, z) in enumerate(dataloader):
#     print(f"Batch {i}:")
#     print(f"Data shape: {x.shape}")
#     print(f"Data dtype: {x.dtype}")
#     print(f"Label shape: {y.shape}")
#     print(f"Label dtype: {y.dtype}")
#     print(f"Close shape: {z.shape}")
#     print(f"Close dtype: {z.dtype}")
#     # print(f'Input data: {x}')
#     # print(f'Output data: {y}')
#     print()

In [23]:
len(dataset)

154

In [24]:
dataset[0][1][0][35], dataset[0][1][1][35]

(82.0, 68.0)

In [25]:
output_data_path = "/mnt/data2T/junyuan/eye-tracking/np_data"
output_label_path = "/mnt/data2T/junyuan/eye-tracking/np_label"
output_close_path = "/mnt/data2T/junyuan/eye-tracking/np_close"
os.makedirs(output_data_path, exist_ok=True)
os.makedirs(output_label_path, exist_ok=True)
os.makedirs(output_close_path, exist_ok=True)
for i in tqdm(range(len(dataset)), desc="Saving data"):
    # 获取数据
    data, label, close = dataset[i]

    # 分别保存 data, label, close
    np.save(f"{output_data_path}/{i}.npy", data)
    np.save(f"{output_label_path}/{i}.npy", label)
    np.save(f"{output_close_path}/{i}.npy", close)

Saving data: 100%|██████████| 154/154 [00:26<00:00,  5.84it/s]


In [26]:
data = np.load("/mnt/data2T/junyuan/eye-tracking/np_data/0.npy")
data.shape

(2, 50, 130, 173)

In [27]:
dataset[0][0].shape, dataset[0][1].shape, dataset[0][2].shape

((2, 50, 130, 173), (2, 50), (50,))

In [28]:
np.array_equal(data, dataset[0][0])

True

In [29]:
def merge_arrays(arrays):
    total_length = sum(arr.shape[1] for arr in arrays)
    merged_array = np.zeros((arrays[0].shape[0], total_length))
    indices = []
    current_index = 0
    for arr in tqdm(arrays, desc="Merging arrays"):
        end_index = current_index + arr.shape[1]
        merged_array[:, current_index:end_index] = arr
        indices.append((current_index, end_index))
        current_index = end_index
    return merged_array, indices

def create_memmap(data, data_file, info_file):
    mmap = np.memmap(data_file, dtype=data.dtype, mode='w+', shape=data.shape)
    mmap[:] = data
    mmap.flush() 
    with open(info_file, 'w') as f:
        f.write(f"Data shape: {data.shape}\n")
        f.write(f"Data dtype: {data.dtype}\n")
    return mmap

def load_memmap(data_file, info_file):
    with open(info_file, 'r') as f:
        lines = f.readlines()
        shape_line = lines[0].strip()
        dtype_line = lines[1].strip()
        shape_str = shape_line.split(': ')[1]
        shape = tuple(map(int, shape_str.strip('()').split(',')))
        dtype_str = dtype_line.split(': ')[1]
        dtype = np.dtype(dtype_str)
    mmap = np.memmap(data_file, dtype=dtype, mode='r', shape=shape)
    return mmap

In [30]:
merged_events, events_indices = merge_arrays(dataset.events)
merged_labels, labels_indices = merge_arrays(dataset.labels)

Merging arrays: 100%|██████████| 2/2 [00:00<00:00,  3.82it/s]
Merging arrays: 100%|██████████| 2/2 [00:00<00:00, 9078.58it/s]


In [31]:
events_memmap = create_memmap(merged_events, 'events.memmap', 'events_info.txt')
labels_memmap = create_memmap(merged_labels, 'labels.memmap', 'labels_info.txt')
events_indices_memmap = create_memmap(np.array(events_indices), 'events_indices.memmap', 'events_indices_info.txt')
labels_indices_memmap = create_memmap(np.array(labels_indices), 'labels_indices.memmap', 'labels_indices_info.txt')

In [32]:
events = load_memmap('events.memmap', 'events_info.txt')
events_indices = load_memmap('events_indices.memmap', 'events_indices_info.txt')
labels = load_memmap('labels.memmap', 'labels_info.txt')
labels_indices = load_memmap('labels_indices.memmap', 'labels_indices_info.txt')

In [33]:
events.shape, labels.shape

((4, 18522866), (4, 7764))

In [34]:
labels_indices

memmap([[   0, 5070],
        [5070, 7764]])

In [35]:
label = labels[:, labels_indices[1][0]:labels_indices[1][1]]

In [36]:
label.shape

(4, 2694)