In [None]:
import os
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
import imageio
from utils.event_utils import brightness_increment_image
import torchvision


def inner_double_integral(bii):
    assert bii.shape[0] % 2 == 0
    N = bii.shape[0] // 2

    images = []
    # Left part of the interval from f-T/2 to f
    for i in range(N):
        images.append(- bii[i:N].sum(axis=0))
    # Frame at f
    images.append(np.zeros_like(images[0]))
    # Right part of the interval from f to f+T/2
    for i in range(N):
        images.append(+ bii[N:N + 1 + i].sum(axis=0))

    images = np.stack(images, axis=0)
    return images


def deblur_double_integral(blurry, bii, idx=0):
    N = bii.shape[0] // 2
    
    images = inner_double_integral(bii)
    
    if idx == 4:
        sharp = ((2*N+1) * blurry / np.exp(images).sum(axis=0))
    elif idx < 4:
        sharp = ((2*N+1) * blurry / np.exp(images).sum(axis=0)) / np.exp(bii[idx:4].sum(axis=0))
    else:
        sharp = ((2*N+1) * blurry / np.exp(images).sum(axis=0)) * np.exp(bii[4:idx+1].sum(axis=0))
    return sharp

with open("data/ev-deblurnerf_blender/blurfactory/events/image_start.pickle", "rb") as f:
    start = pickle.load(f)
with open("data/ev-deblurnerf_blender/blurfactory/events/image_end.pickle", "rb") as f:
    end = pickle.load(f)
with open("data/ev-deblurnerf_blender/blurfactory/events/events.pickle", "rb") as f:
    ev = pickle.load(f)
with open("data/ev-deblurnerf_blender/blurfactory/events/id_to_coords.pickle", "rb") as f:
    id_to_coords = pickle.load(f)

all_tms = []
for tms_start, tms_end in zip(start, end):
    all_tms.append(np.linspace(tms_start, tms_end, 9))

all_tms = torch.tensor(np.concatenate(all_tms))
ev_tms = ev[:, 1]

N = 29

idx_events_left = torch.searchsorted(ev_tms, all_tms).reshape(N, 9)
idx_events_right = torch.searchsorted(ev_tms, all_tms, side="right").reshape(N, 9)

for j in range(N):
    blurry_image = Image.open("data/ev-deblurnerf_blender/blurwine/images/" + '{0:02d}'.format(j) + ".png")
    blurry_image = np.array(blurry_image) / 255.

    bii_list = []
    for i in range(8):
        idx_left = idx_events_left[j, i]
        idx_right = idx_events_right[j, i+1]
        ev_ = ev[idx_left:idx_right]
        x, y = id_to_coords[ev_[:, 0].long()].T.cpu().numpy()
        p = ev_[:, 2].cpu().numpy()

        bii = brightness_increment_image(x, y, p, 600, 400, 0.2, 0.2, interpolate=True, threshold=True)  # [H, W] -> 346, 260
        bii = bii[:, :, None].repeat(3, axis=-1)

        bii_list.append(bii)
    bii = np.stack(bii_list, axis=0)

    for i in range(9):
        edi = deblur_double_integral(blurry_image, bii, idx=i)
        edi = torch.from_numpy(edi)
        torchvision.utils.save_image(edi.permute(2, 0, 1), "/data/sjlee/DiET-GS/upload_data/ev-deblurnerf_blender/blurwine/images_edi/" + str(j) + "_" + str(i) + ".png")


  idx_events_left = torch.searchsorted(ev_tms, all_tms).reshape(N, 9)
