In [None]:
import dataclasses
import pathlib

import cv2
import numpy as np
import PIL.Image
import torch
import tqdm
from rpg_e2vid.utils.inference_utils import events_to_voxel_grid
from rpg_e2vid.utils.loading_utils import load_model

import const
import utils

In [None]:
def to_displayable(img, converted: int | None = cv2.COLOR_BGR2RGB) -> PIL.Image.Image:
    if converted is not None:
        img = cv2.cvtColor(img, converted)
    return PIL.Image.fromarray(np.clip(img.astype(np.uint8), 0, 255))


def showarray(img, converted: int | None = cv2.COLOR_BGR2RGB) -> None:
    display(to_displayable(img, converted))

In [None]:
DATA_PATH = pathlib.Path("../data/raw/carla/") / "fullsynced"
events_orig = np.load(DATA_PATH / "all_events.npy")
bgr_frames = np.load(DATA_PATH / "bgr_frames.npy")
bgr_timestamps = np.load(DATA_PATH / "bgr_timestamps.npy")

In [None]:
events = np.stack(
    [events_orig["t"], events_orig["x"], events_orig["y"], events_orig["p"]]
).T.astype(np.int64)
event_timestamps = (events[:, 0] / 1e6).astype(np.int64)

In [None]:
_, ts_counts = np.unique(event_timestamps, return_counts=True)

In [None]:
model = load_model("../pretrained/E2VID_lightweight.pth.tar").to(const.DEVICE)

  raw_model = torch.load(path_to_model, map_location=device)


Using TransposedConvLayer (fast, with checkerboard artefacts)


In [None]:
eit = utils.EventWindowIterator(events, ts_counts, 30, 30, 0)
prev = None
rec_frames = []
rec_ts = []
for (
    i,
    window,
) in tqdm.tqdm(enumerate(eit), total=len(eit)):
    rec_ts.append(window[-1, 0])
    window = window.copy()
    vg = events_to_voxel_grid(window, 5, 640, 480)
    vg = torch.from_numpy(vg).unsqueeze(0).float().to(const.DEVICE)
    with torch.no_grad():
        pred, prev = model(vg, prev)
        pred = (pred.squeeze().cpu().numpy() * 255).astype(np.uint8)
    rec_frames.append(pred)

rec_ts = np.array(rec_ts)
rec_frames = np.stack(rec_frames)

100%|██████████| 832/832 [00:55<00:00, 15.00it/s]


In [None]:
out = cv2.VideoWriter(
    "out.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 30, (640, 480), isColor=False
)
for rec_frame in tqdm.tqdm(rec_frames):
    out.write(rec_frame)
out.release()

100%|██████████| 832/832 [00:00<00:00, 940.82it/s]


In [None]:
model2 = utils.load_model_2("../pretrained/better_e2vid_weights_v5.pth").to(
    const.DEVICE
)

Using skip: <function skip_sum at 0x000001F2BD89F400>
Using UpsampleConvLayer (slow, but no checkerboard artefacts)
Kernel size 5
Skip type sum
norm none


  model2.load_state_dict(torch.load(path))


In [None]:
eit = utils.EventWindowIterator(events, ts_counts, 30, 30, 0)
rec_frames = []
rec_ts = []
for (
    i,
    window,
) in tqdm.tqdm(enumerate(eit), total=len(eit)):
    rec_ts.append(window[-1, 0])
    window = window.copy()
    vg = events_to_voxel_grid(window, 5, 640, 480)
    vg = torch.from_numpy(vg).unsqueeze(0).float().to(const.DEVICE)
    with torch.no_grad():
        pred = model2(vg)["image"]
        pred = (pred.squeeze().cpu().numpy() * 255).astype(np.uint8)
    rec_frames.append(pred)

rec_ts = np.array(rec_ts)
rec_frames = np.stack(rec_frames)

100%|██████████| 832/832 [01:01<00:00, 13.45it/s]


In [None]:
out = cv2.VideoWriter(
    "out_pp.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 30, (640, 480), isColor=False
)
for rec_frame in tqdm.tqdm(rec_frames):
    out.write(rec_frame)
out.release()

100%|██████████| 832/832 [00:01<00:00, 772.76it/s]
