In [None]:
""" Imports """

from PIL import Image
import skvideo.io
import time
import os

import torch

from utils import *

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
MODEL = "still_life"

In [None]:
""" Select correct device """

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
""" Load trained model """

style_subnet = torch.load('models/style_subnet_' + MODEL + '.pt', map_location='cpu').eval().to(device)
enhance_subnet = torch.load('models/enhance_subnet_' + MODEL + '.pt', map_location='cpu').eval().to(device)
refine_subnet = torch.load('models/refine_subnet_' + MODEL + '.pt', map_location='cpu').eval().to(device)

In [None]:
""" Transform video """

preprocess = transforms.Compose([
    transforms.ToTensor(),
    tensor_normalizer()])

frames_256, frames_512, frames_1024, frames_orig = [], [], [], []
videogen = skvideo.io.vread("videos/test_vid.mp4")
start = time.time()
count = 0
for frame in videogen:
    count += 1
    frames_orig.append(Image.fromarray(frame))
    with torch.no_grad():
        y_frame_256, _ = style_subnet(preprocess(frame).unsqueeze(0).to(device))
        y_frame_512, _ = enhance_subnet(y_frame_256)
        y_frame_1024, _ = enhance_subnet(y_frame_512)
        
    frames_256.append(recover_frame(y_frame_256))
    frames_512.append(recover_frame(y_frame_512))
    frames_1024.append(recover_frame(y_frame_1024))
    print("Frame {} of {} processed. Total time: {:.2f}s, time per frame: {:.2f}s".format(
            count, len(videogen), time.time()-start, (time.time()-start)/count))

In [None]:
""" Save videos """

# generated video after refine subnet
writer = skvideo.io.FFmpegWriter("generated_videos/generated_vid_" + MODEL + ".mp4", outputdict={"-pix_fmt":"yuv420p"})
for frame in frames_1024:
    writer.writeFrame(frame)
writer.close()

# original video
writer = skvideo.io.FFmpegWriter("generated_videos/original_vid_" + MODEL + ".mp4", outputdict={"-pix_fmt":"yuv420p"})
for frame in frames_orig:
    writer.writeFrame(frame)
writer.close()