In [1]:
## Created by Wentinn Liao

# CS180 Final Project: Eulerian Video Magnification

In [2]:
#@title Configure Jupyter Notebook
%load_ext autoreload
%autoreload 2

In [3]:
#@title Library Setup
import base64
import io
import math
import numpy as np
import os
import scipy as sc
import skimage as sk
import skimage.io as skio
import cv2
import torch
from typing import *
from matplotlib import pyplot as plt
from IPython.display import HTML
from IPython import display as ipythondisplay

In [4]:
#@title Utilities
def read_image(imname: str) -> torch.Tensor:
    return torch.IntTensor(skio.imread(imname))
    # return sk.img_as_float(np.array(img.convert('RGBA')))

def im_rescale(im):
    lo = torch.min(im)
    hi = torch.max(im)
    return (im - lo) / (hi - lo)

def im_saturate(im):
    return torch.stack([im_rescale(im[:, :, c]) for c in range(im.shape[2])], dim=2)

def multiply_outer(v: np.ndarray, arr: np.ndarray, axis=None):
    if axis is None:
        axis = v.ndim
    arr_ = arr.transpose(*range(axis, arr.ndim), *range(axis))
    return (arr_ * v).transpose(*range(arr.ndim - axis, arr.ndim), *range(arr.ndim - axis))

def plot_cycle(ax, points: np.ndarray, **kwargs):
    cycled_points = np.vstack([points, points[:1]])
    ax.plot(*cycled_points.T, **kwargs)

def show_video(video_name: str):
    if os.path.exists(video_name):
        video = io.open(video_name, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''<video alt="test" autoplay
            loop controls style="height: 400px;">
            <source src="data:video/mp4;base64,{0}" type="video/mp4" />
        </video>'''.format(encoded.decode('ascii'))))
    else:
        print("Could not find video")

def color(z: float, scale: float=120.) -> np.ndarray:
    k = 2 * np.pi * z / scale
    return (1 + np.asarray([np.sin(k), np.sin(k + 2 * np.pi / 3), np.sin(k + 4 * np.pi / 3)], dtype=float)) / 2

In [5]:
import time

class Timer(object):
    indent = 0
    p = False

    def __init__(self):
        self.start_t = 0

    def start(self, name):
        if Timer.p:
            print('\t' * Timer.indent + name + ' {')
            Timer.indent += 1
        self.start_t = time.time_ns()

    def stop(self):
        if Timer.p:
            Timer.indent -= 1
            print('\t' * Timer.indent + '} ' + f'{(time.time_ns() - self.start_t) * 1e-6}ms,')
        # return time.time_ns() - self.start_t

# Part 1. Laplacian Pyramid

In [8]:
"""
baby2.mp4: 900 frames
face.mp4: 301 frames
"""
def read_video_cv2(video_name: str, n_frames: int) -> np.ndarray:
    cap = cv2.VideoCapture(video_name)
    result, i = [], 0
    while cap.isOpened() and i < n_frames:
        result.append(np.array(cap.read()[1]))
        i += 1
    return np.array(result) / 255.

baby2 = read_video_cv2('data/baby2.mp4', 900)
face = read_video_cv2('data/face.mp4', 301)

In [11]:
sk.transform.rescale(baby2, 0.5, channel_axis=(0, -1)).shape

ValueError: only a single channel axis is currently supported

In [9]:
def gaussian_and_laplacian_stack(vid: np.ndarray, depth=8, sigma=1):
    gaussian_stack, laplacian_stack = [vid], []
    for _ in range(depth + 1):
        k = int(math.ceil(3 * sigma))
        G = ((G := cv2.getGaussianKernel(2 * k + 1, sigma)) @ G.T)[None, :, :, None]

        gaussian_stack.append(sc.signal.convolve(gaussian_stack[-1], G, mode='same'))
        laplacian_stack.append(gaussian_stack[-2] - gaussian_stack[-1])

        sigma *= 2
    gaussian_stack.pop()
    return gaussian_stack, laplacian_stack

In [10]:
g, l = gaussian_and_laplacian_stack(baby2)

(900, 352, 640, 3) (1, 7, 7, 1)
(900, 352, 640, 3) (1, 13, 13, 1)


KeyboardInterrupt: 

# Part 2. Temporal Filtering

In [None]:
def global_tone_map(im: torch.Tensor):
    return im / (1 + im)

In [None]:
global_im = global_tone_map(E)

plt.imshow(global_im)
plt.show()

In [None]:
def local_tone_map(im: torch.Tensor, dR: float=5.):
    intensity = torch.mean(im, dim=-1, keepdim=True)
    chrominance = im / intensity
    L = torch.log(intensity)
    B = torch.tensor(cv2.bilateralFilter(L.numpy(), 60, 240, 120))[:, :, None]
    B_ = -dR * im_rescale(-B)
    O = torch.exp((L - B) + B_)
    return im_saturate((chrominance * O) ** 0.5)

def adaptive_histogram(im: torch.Tensor, clipLimit: float=2., tileGridSize: Sequence[int]=(5, 5)):
    im = sk.img_as_ubyte(im)
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    return torch.tensor(sk.img_as_float(torch.stack([torch.tensor(clahe.apply(im[:, :, c])) for c in range(3)], dim=-1)))

In [None]:
def hdr(ims: torch.Tensor, dts: torch.Tensor) -> Dict[str, torch.Tensor]:
    g, E = radiance_map(ims, dts)
    global_im = global_tone_map(E)
    return {
        'radiance_map': global_im,
        'g': g,
        'im': adaptive_histogram(global_im)
    }

In [None]:
skio.imshow(sk.img_as_ubyte(hdr(ims, dts)['im']))
skio.show()

In [None]:
def save_result(ims: torch.Tensor, dts: torch.Tensor, name: str):
    hdr_result = hdr(ims, dts)

    stacked_ims = torch.cat([
        torch.cat([ims[0], ims[1]], dim=1),
        torch.cat([ims[2], torch.IntTensor(sk.img_as_ubyte(hdr_result['im']))], dim=1)
    ], dim=0)

    skio.imsave(f'../images/{name}_radiance_map.png', sk.img_as_ubyte(hdr_result['radiance_map']))

    plt.plot(hdr_result['g'][0], color='red')
    plt.plot(hdr_result['g'][1], color='blue')
    plt.plot(hdr_result['g'][2], color='green')
    plt.title('Exposure Map')
    plt.axis('on')
    plt.savefig(f'../images/{name}_exposure_map.png', bbox_inches='tight', pad_inches=0)
    plt.show()

    skio.imsave(f'../images/{name}_hdr.png', sk.img_as_ubyte(stacked_ims))
    plt.imshow(stacked_ims.numpy())
    plt.axis('off')
    plt.show()

In [None]:
save_result(ims, dts, 'memorial')

In [None]:
plt.rcParams['figure.figsize'] = (20.0, 20.0)

names = {
    'data/Train0000_0210': [3, 51, 112, 193],
    'data/Train0211_0420': [291, 343],
    'data/Train0421_0625': [466, 575]
}
for dir_name, nums in names.items():
    for i in nums:
        s = str(i).zfill(4)
        ims = torch.stack([
            read_image(f'{dir_name}/{s}_short.png'),
            read_image(f'{dir_name}/{s}_medium.png'),
            read_image(f'{dir_name}/{s}_long.png')
        ])
        print(s)
        dts = torch.pow(2, torch.FloatTensor(np.load(f'{dir_name}/{s}_exposures.npy')))
        save_result(ims, dts, s)

In [None]:
for dir_name, nums in names.items():
    str_nums = {str(i).zfill(4) for i in nums}
    for fname in os.listdir(dir_name):
        if fname[:4] not in str_nums:
            os.remove(f'{dir_name}/{fname}')