<a href="https://colab.research.google.com/github/MohamedAliRashad/Crystal/blob/master/Crystal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Welcome to Crystal**

[GitHub](https://github.com/MohamedAliRashad/Crystal) - [YouTube](https://www.youtube.com/watch?v=96E6hB2DE5w)

We wish here to achieve the following


1.   Provide a working (fully tested) environment for trying out **Crystal**
2.   Explain the different stages of our pipeline for anyone interested in adding to it or trying new things.



## Try it out yourself

For those who are just interested in using **Crystal**

### Set Up Everything Needed

**Note:** The working directory will be Crystal.

In [0]:
!git clone https://github.com/MohamedAliRashad/Crystal.git
%cd Crystal
!pip3 install -r requirements.txt

In [0]:
#@title Downlad the video from YouTube
video_url = "https://www.youtube.com/watch?v=W2X_p4DDK3s" #@param {type:"string"}

!youtube-dl -f 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/mp4' {video_url} -o video

### Just, Run ^ ^

**Note:** Output will be in ```content``` with the name of ```video.mp4```

In [0]:
from main import main

# main only require 2 arguments:
#       Video Path
#       Output Path
# And that's it

main("video.mp4", "/content/")

## Learn how it ticks

For those who want to learn or improve on the existing.

### First Step: Extract frames

**Note:** Audio will be extracted too.

In [0]:
import argparse
import os
import os.path as osp
import imageio
from tqdm import tqdm

video_file = "video.mp4"
rate = None

out_dir = osp.splitext(osp.basename(video_file))[0]
os.mkdir(out_dir)

reader = imageio.get_reader(video_file)
meta_data = reader.get_meta_data()
fps = meta_data['fps']
n_frames = meta_data['nframes']

j = 0
for i, img in tqdm(enumerate(reader), total=n_frames):
    if rate is None or i % int(round(fps / rate)) == 0:
        imageio.imsave(osp.join(out_dir, '%08d.jpg' % j), img)
        j = j + 1

# !ffmpeg -i video.mp4 -vn -acodec copy output-audio.aac

# !rm {video_file}

### Second Step: Run DAIN

**DAIN** or ```Depth Aware Frame Interpolation``` is the first module in our pipeline where we smooth out the video by increasing its frame rate.

It basically consists of a Depth Estimator generated with Deep Learning (HourGlass and PWCNets) aimed in providing a better optical flow of the motion happening in the video so we can interpolate intermediate frames (in a smarter way).

[Project](https://sites.google.com/view/wenbobao/dain) **|** [Paper](http://arxiv.org/abs/1904.00830)


(**default:** doubling the FPS)

In [0]:
import cv2
from pathlib import Path
from imageio import imread, imsave
import torch
import numpy as np

# Get the DAIN Module ready
def load_DAIN():
    # Let the magic happen
    from DAIN.DAIN import DAIN
    module = DAIN()
    
    # load the weights online
    from torch.hub import load_state_dict_from_url
    state_dict = load_state_dict_from_url("http://vllab1.ucmerced.edu/~wenbobao/DAIN/best.pth")
    module.load_state_dict(state_dict)

    return module

# Forward the video frames through the model
# Note: inframes and outframes need to be a Path object
def infer_DAIN(model, meta_data, inframes, outframes):

    model.cuda() # use the GPU

    frames = sorted(inframes.glob("*.jpg"))

    # Scale the frames down while mainting the aspect ratio so no stackoverflow happens in GPU
    # scale_precent is how much do you want from the original size where 100 is the size not changed
    scale_precent = 100
    width = int(meta_data["size"][0] * scale_precent / 100)
    height = int(meta_data["size"][1] * scale_precent / 100)
    dim = (width, height)
    model.eval()

    j = 0
    for i in tqdm(range(len(frames) - 1)):

        image1 = cv2.resize(imread(frames[i]), dim, interpolation=cv2.INTER_AREA)
        image2 = cv2.resize(imread(frames[i + 1]), dim, interpolation=cv2.INTER_AREA)
        
        image1 = imread(frames[i])
        image2 = imread(frames[i + 1])
        
        X0 = torch.from_numpy(np.transpose(image1, (2, 0, 1)).astype("float32") / 255.0).type(torch.cuda.FloatTensor)
        X1 = torch.from_numpy(np.transpose(image2, (2, 0, 1)).astype("float32") / 255.0).type(torch.cuda.FloatTensor)
        y_ = torch.FloatTensor()

        intWidth = X0.size(2)
        intHeight = X0.size(1)
        channel = X0.size(0)

        if intWidth != ((intWidth >> 7) << 7):
            intWidth_pad = ((intWidth >> 7) + 1) << 7  # more than necessary
            intPaddingLeft = int((intWidth_pad - intWidth) / 2)
            intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
        else:
            intWidth_pad = intWidth
            intPaddingLeft = 32
            intPaddingRight = 32

        if intHeight != ((intHeight >> 7) << 7):
            intHeight_pad = ((intHeight >> 7) + 1) << 7  # more than necessary
            intPaddingTop = int((intHeight_pad - intHeight) / 2)
            intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
        else:
            intHeight_pad = intHeight
            intPaddingTop = 32
            intPaddingBottom = 32

        pader = torch.nn.ReplicationPad2d(
            [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]
        )

        torch.set_grad_enabled(False)
        X0 = torch.unsqueeze(X0, 0)
        X1 = torch.unsqueeze(X1, 0)

        X0 = pader(X0).cuda()
        X1 = pader(X1).cuda()

        y_s, offset, filter = model(torch.stack((X0, X1), dim=0))

        y_ = y_s[1]

        X0 = X0.data.cpu().numpy()
        y_ = y_.data.cpu().numpy()
        offset = [offset_i.data.cpu().numpy() for offset_i in offset]
        filter = [filter_i.data.cpu().numpy() for filter_i in filter] if filter[0] is not None else None
        X1 = X1.data.cpu().numpy()

        X0 = np.transpose(
            255.0
            * X0.clip(0, 1.0)[
                0,
                :,
                intPaddingTop : intPaddingTop + intHeight,
                intPaddingLeft : intPaddingLeft + intWidth,
            ],
            (1, 2, 0),
        )
        y_ = np.transpose(
            255.0
            * y_.clip(0, 1.0)[
                0,
                :,
                intPaddingTop : intPaddingTop + intHeight,
                intPaddingLeft : intPaddingLeft + intWidth,
            ],
            (1, 2, 0),
        )
        offset = [
            np.transpose(
                offset_i[
                    0,
                    :,
                    intPaddingTop : intPaddingTop + intHeight,
                    intPaddingLeft : intPaddingLeft + intWidth,
                ],
                (1, 2, 0),
            )
            for offset_i in offset
        ]
        filter = (
            [
                np.transpose(
                    filter_i[
                        0,
                        :,
                        intPaddingTop : intPaddingTop + intHeight,
                        intPaddingLeft : intPaddingLeft + intWidth,
                    ],
                    (1, 2, 0),
                )
                for filter_i in filter
            ]
            if filter is not None
            else None
        )
        X1 = np.transpose(
            255.0
            * X1.clip(0, 1.0)[
                0,
                :,
                intPaddingTop : intPaddingTop + intHeight,
                intPaddingLeft : intPaddingLeft + intWidth,
            ],
            (1, 2, 0),
        )

        imsave(os.path.join(str(outframes), str(j).zfill(6) + ".jpg"), cv2.resize(image1, meta_data["size"], interpolation=cv2.INTER_AREA))
        imsave(os.path.join(str(outframes), str(j+1).zfill(6) + ".jpg"), cv2.resize(np.round(y_).astype(np.uint8), meta_data["size"], interpolation=cv2.INTER_AREA))
        j = j + 2
        
    imsave(os.path.join(str(outframes), str(j).zfill(6) + ".jpg"), cv2.resize(image2, meta_data["size"], interpolation=cv2.INTER_AREA))
    meta_data["fps"] = meta_data["fps"]*2

    return meta_data


model = load_DAIN()
inframes = Path("./video/")
outframes = Path("./tmp/")
outframes.mkdir(parents=True, exist_ok=True)
meta_data = infer_DAIN(model, meta_data, inframes, outframes)

### Third Step: Run EDVR


**EDVR** or ```Enhanced Deformables for Video Restoration``` is the second module in our pipeline where we enhance the video by increasing its resolution with sharp and deblurring filters.

**EDVR** is the winner in all four tracks of NTIRE19 video restoration and enhancement challenges with his **Pyramid, Cascading and Deformable (PCD) alignment module** and **Temporal and Spatial Attention (TSA) fusion module**.

[Project](https://xinntao.github.io/projects/EDVR) **|** [Paper](https://arxiv.org/abs/1905.02716)

(**default:** Two Stage Enhancement)

In [0]:
import glob

import cv2
import torch
from torch.hub import load_state_dict_from_url
from tqdm import tqdm
import os.path as osp
from pathlib import Path

from SR_EDVR.EDVR_arch import EDVR
from SR_EDVR.utils.data_utils import index_generation, read_img_seq
from SR_EDVR.utils.util import flipx4_forward, mkdirs, single_forward, tensor2img, preProcess

Weights = {
    "EDVR_REDS_SR_L": "https://drive.google.com/uc?export=download&id=1PYULZmtpsmY4Wx8M9f4owdLIwcwQFEmi",
    "EDVR_REDS_deblur_L": "https://drive.google.com/uc?export=download&id=1ZCl0aU8isEnUCsUYv9rIZZQrGo7vBFUH",
    "EDVR_REDS_deblurcomp_L": "https://drive.google.com/uc?export=download&id=1SGVehpZt4WL_X8Jh6blyqmHpc8DdImgv",
    "EDVR_REDS_SRblur_L": "https://drive.google.com/uc?export=download&id=18ev7Zx_10-C8-0tAVAe_BpYeLHpr_ChE",
    "EDVR_Vimeo90K_SR_L": "https://drive.google.com/uc?export=download&id=1I7x87ee3E1DoFVgMxX09nfIb2tdUdE3x",
    "EDVR_REDS_SR_Stage2": "https://drive.google.com/uc?export=download&id=1kfArevFT8hzbUT2QWXFmUl983LTebQGP",
    "EDVR_REDS_deblur_Stage2": "https://drive.google.com/uc?export=download&id=1Y1y6v40dL74Kgf5fxbGd0QC010LFCBYz",
    "EDVR_REDS_deblurcomp_Stage2": "https://drive.google.com/uc?export=download&id=1G466gQ1rRl8MUKSEbtaR0U5xgIWdsG66",
    "EDVR_REDS_SRblur_Stage2": "https://drive.google.com/uc?export=download&id=13c-VxMdf8h7MGX-_y4xamxo1hhOMYzsH",
}


def SuperResolution(inframes, outframes, stage, data_mode, use_gpu=True):
    """
	Perform a Super Resolution step on a frames folder

	Args
	----
		inframes(str, Path): folder with the frames to enhance
		outframes(str, Path): the output directory
		stage(int): determine the stage used (1 or 2)
		data_mode(str): the process wanted
			Vid4: SR
			REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
					blur (deblur-clean), blur_comp (deblur-compression).
	"""

	flip_test = False
	inframes = Path(inframes)
	outframes = str(outframes)
	device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"

	################## Model ##################
	model_name = None
	if data_mode == "Vid4":
			if stage == 1:
					model_name = "EDVR_Vimeo90K_SR_L"
			else:
					raise ValueError("Vid4 does not support stage 2.")
	elif data_mode == "sharp_bicubic":
			if stage == 1:
					model_name = "EDVR_REDS_SR_L"
			else:
					model_name = "EDVR_REDS_SR_Stage2"
	elif data_mode == "blur_bicubic":
			if stage == 1:
					model_name = "EDVR_REDS_SRblur_L"
			else:
					model_name = "EDVR_REDS_SRblur_Stage2"
	elif data_mode == "blur":
			if stage == 1:
					model_name = "EDVR_REDS_deblur_L"
			else:
					model_name = "EDVR_REDS_deblur_Stage2"
	elif data_mode == "blur_comp":
			if stage == 1:
					model_name = "EDVR_REDS_deblurcomp_L"
			else:
					model_name = "EDVR_REDS_deblurcomp_Stage2"
	else:
			raise NotImplementedError

	print("Model Used: ", model_name)

	if data_mode == "Vid4":
			N_in = 7  # use N_in images to restore one HR image
	else:
			N_in = 5

	predeblur, HR_in = False, False
	back_RBs = 40
	if data_mode == "blur_bicubic":
			predeblur = True
	if data_mode == "blur" or data_mode == "blur_comp":
			predeblur, HR_in = True, True
	if stage == 2:
			HR_in = True
			back_RBs = 20

	# Initialize the model
	model = EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

	#### evaluation
	crop_border = 0
	border_frame = N_in // 2  # border frames when evaluate
	# temporal padding mode
	if data_mode in ("Vid4", "sharp_bicubic"):
			padding = "new_info"
	else:
			padding = "replicate"
	save_imgs = True

	#### set up the models
	state_dict = load_state_dict_from_url(Weights[model_name], model_dir=model_name)
	model.load_state_dict(state_dict, strict=True)
	model.eval()
	model = model.to(device)

	img_path_l = sorted(inframes.glob("*"))

	# preprocess images (needed for blurred models)
	if predeblur:
			preProcess(img_path_l, 16)
	else:
			preProcess(img_path_l, 4)

	imgs_LQ = read_img_seq(inframes)
	max_idx = len(img_path_l)

	# process each image
	for img_idx, img_path in enumerate(tqdm(img_path_l)):
			img_name = osp.splitext(osp.basename(img_path))[0]
			select_idx = index_generation(img_idx, max_idx, N_in, padding=padding)
			imgs_in = (
					imgs_LQ.index_select(0, torch.LongTensor(select_idx))
					.unsqueeze(0)
					.to(device)
			)

			if flip_test:
					output = flipx4_forward(model, imgs_in)
			else:
					output = single_forward(model, imgs_in)
			output = tensor2img(output.squeeze(0))

			if save_imgs:
					cv2.imwrite(osp.join(outframes, "{}.jpg".format(img_name)), output)


if __name__ == "__main__":
	temp = Path("./tmp/")
	temp2 = Path("./tmp2/")
	temp2.mkdir(parents=True, exist_ok=True)
	SuperResolution(tmp, temp2, 1, "sharp_bicubic")


### Fourth Step: Rebuild the Video

**Note** The video will be Audioless (we are still working on it)

In [0]:
def build_video(frames_dir, save_dir, meta_data, audio_path=None):
    """
    Construct video from frames
    
    Args
    ----
        frames_dir(str, Path): folder path to frames
        save_dir(str, Path):

    """
    save_dir = Path(save_dir)
    out_path = str(save_dir / meta_data["name"])

    # Use ffmpeg to reconstruct the video
    ffmpeg.input(
        str(frames_dir), format="image2", vcodec="mjpeg", framerate=meta_data["fps"]
    ).output(out_path, crf=17, vcodec="libx264").run(capture_stdout=True)

build_video(tmp2, "/content/", meta_data)
