In [None]:
import torch
import numpy as np
import albumentations as A

from pathlib import Path
from EvEye.model.DavisEyeEllipse.EPNet.EPNet import EPNet
from EvEye.utils.tonic.functional.ToFrameStack import to_frame_stack_numpy
from EvEye.utils.cache.MemmapCacheStructedEvents import *
from EvEye.utils.visualization.visualization import *
from EvEye.utils.tonic.functional.CutMaxCount import cut_max_count
from EvEye.dataset.DavisEyeEllipse.utils import *

In [None]:
index = 100

data_path = Path("/mnt/data2T/junyuan/Datasets/FixedTime10000Dataset/train/cached_data")
ellipse_path = Path(
    "/mnt/data2T/junyuan/Datasets/FixedTime10000Dataset/train/cached_ellipse"
)
model_path = Path(
    "/mnt/data2T/junyuan/eye-tracking/logs/EPNet_FixedTime10000/version_0/checkpoints/epochepoch=51-val_lossval_loss=12.4443.ckpt"
)

device = "cuda:0"

ellipse = convert_to_ellipse(load_ellipse(index, ellipse_path))
event_segment = load_event_segment(index, data_path, 5000)
model = EPNet(input_channels=2)
model.load_state_dict(torch.load(model_path)["state_dict"])
model.eval()
model.to(device)

In [None]:
def event_to_frame(
    event_segment,
    sensor_size=(346, 260, 2),
    events_interpolation='causal_linear',
    weight=10,
):
    """
    Convert a segment of events to a frame stack.
    HWC, shape: (260, 346, 2).

    Args:
        event_segment (np.array): structed array of events with fields ['x', 'y', 'p', 't']
        sensor_size (tuple): size of the sensor (width, height, n_channels)
        events_interpolation (str): interpolation mode for the events
        weight (int): weight for the events

    Returns:
        event_frame (np.array): frame stack of events. Shape (height, width, n_channels)

    """
    event_frame = to_frame_stack_numpy(
        events=event_segment,
        sensor_size=sensor_size,
        n_time_bins=1,
        mode=events_interpolation,
        start_time=event_segment['t'][0],
        end_time=event_segment['t'][-1],
        weight=weight,
    ).squeeze(0)
    cut_max_count(event_frame, 255)
    event_frame = np.moveaxis(event_frame, 0, -1)

    return event_frame

In [None]:
event_frame = event_to_frame(event_segment)

In [None]:
def pre_process(event_frame):
    # event_frame
    # HWC, shape: (260, 346, 2) -> (256, 256, 2)
    transform = A.Compose([A.Resize(256, 256)])
    augment = transform(image=event_frame)
    event_frame = augment["image"]
    event_frame = (event_frame / 255.0).astype(np.float32)
    # shape: (256, 256, 2) -> (2, 256, 256)
    event_frame = np.moveaxis(event_frame, -1, 0)
    # shape: (2, 256, 256) -> (1, 2, 256, 256)
    event_frame = np.expand_dims(event_frame, axis=0)
    # type: np.array -> torch.tensor
    event_frame = torch.from_numpy(event_frame)

    return event_frame

In [None]:
input = pre_process(event_frame)
input = input.to(device)
input.shape

In [None]:
with torch.no_grad():
    pred = model(input)

In [None]:
def dict2cpu(cuda_dict):
    cpu_dict = {}
    for key, value in cuda_dict.items():
        cpu_dict[key] = value.cpu()
    return cpu_dict

In [None]:
# pred = dict2cpu(pred)
# pred

In [None]:
class PredDecoder:
    def __init__(self, pred):
        self.pred = pred
        self.hm = pred["hm"].sigmoid_()
        self.ang = pred["ang"]
        self.ab = pred["ab"]
        self.reg = pred["reg"]

    def transpose_feat(self, feat):
        # feat.shape: (b, c, h, w)
        b, c, h, w = feat.size()
        # feat.shape: (b, c, h, w) -> (b, c, h*w)
        feat = feat.view(b, c, h * w)
        # feat.shape: (b, c, h*w) -> (b, h*w, c)
        feat = feat.permute(0, 2, 1).contiguous()

        return feat

    def gather_feat(self, feat, ind, mask=None):
        # feat.shape: (b, h*w, c)
        feat_b, hw, c = feat.size()
        # ind.shape: (b, 100)
        ind_b, n = ind.size()
        assert feat_b == ind_b
        b = ind_b
        # ind.shape: (b, 100) -> (b, 100, 1) -> (b, 100, c)
        ind = ind.unsqueeze(2).expand(b, n, c)
        # feat.shape: (b, h*w, c) -> (b, 100, c)
        feat = feat.gather(1, ind)

        if mask is not None:
            mask = mask.unsqueeze(2).expand_as(feat)
            feat = feat[mask]
            feat = feat.view(-1, c)

        return feat

    def nms(self, heatmap, kernel=3):
        pad = (kernel - 1) // 2
        hmax = torch.nn.functional.max_pool2d(
            heatmap, (kernel, kernel), stride=1, padding=pad
        )
        keep = (hmax == heatmap).float()
        heatmap = heatmap * keep

        return heatmap

    def topk(self, heatmap, K=100):
        # heatmap.shape: (b, c, h, w)
        b, c, h, w = heatmap.size()
        # heatmap.shape: (b, c, h, w) -> (b, c, h*w)
        heatmap = heatmap.view(b, c, -1)
        topk_scores, topk_inds = torch.topk(heatmap, K)
        topk_inds = topk_inds % (h * w)
        topk_ys = (topk_inds / w).int().float()
        topk_xs = (topk_inds % w).int().float()
        # (b, c, h*w) -> (b, c*h*w)
        heatmap_flat = heatmap.view(b, -1)
        topk_score, topk_ind = torch.topk(heatmap_flat, K)
        topk_clses = (topk_ind / K).int()
        topk_inds = self.gather_feat(topk_inds.view(b, -1, 1), topk_ind).view(b, K)
        topk_ys = self.gather_feat(topk_ys.view(b, -1, 1), topk_ind).view(b, K)
        topk_xs = self.gather_feat(topk_xs.view(b, -1, 1), topk_ind).view(b, K)

        return topk_score, topk_inds, topk_clses, topk_ys, topk_xs

    def decode(self, K=100):
        b, c, h, w = self.hm.size()
        hm = self.nms(self.hm)
        scores, inds, clses, ys, xs = self.topk(hm)

        reg = self.transpose_feat(self.reg)
        reg = self.gather_feat(reg, inds)
        reg = reg.view(b, K, 2)

        xs = xs.view(b, K, 1) + reg[:, :, 0:1]
        ys = ys.view(b, K, 1) + reg[:, :, 1:2]

        ab = self.transpose_feat(self.ab)
        ab = self.gather_feat(ab, inds)
        ab = ab.view(b, K, 2)

        ang = self.transpose_feat(self.ang)
        ang = self.gather_feat(ang, inds)
        ang = ang.view(b, K, 1)

        clses = clses.view(b, K, 1).float()
        scores = scores.view(b, K, 1)
        bboxes = torch.cat([xs, ys, ab[..., 0:1], ab[..., 1:2], ang], dim=2)

        detections = torch.cat([bboxes, scores, clses], dim=2)
        print(bboxes)
        # print(detections)

        return detections

In [None]:
det = PredDecoder(pred).decode()

In [None]:
type(det), det.device, det.dtype

In [None]:
print(det)

In [None]:
import os
import torch

# 设置环境变量
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [None]:
# det.detach().cpu().numpy()