In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import cv2
import skvideo.io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm_notebook
from torch.autograd import Variable
from PIL import Image
from fast_neural_style.transformer_net import TransformerNet
from fast_neural_style.utils import recover_image, tensor_normalizer

%matplotlib inline

# Preprocess Pipeline

In [3]:
preprocess = transforms.Compose([
    transforms.Resize(1024),
    transforms.ToTensor(),
    tensor_normalizer()
])

# Setup the Model Architecture

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer = TransformerNet()
transformer.to(device)

TransformerNet(
  (conv1): ConvLayer(
    (reflection_pad): ReflectionPad2d((4, 4, 4, 4))
    (conv2d): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
  )
  (in1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv2): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (in2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv3): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (in3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (res1): ResidualBlock(
    (conv1): ConvLayer(
      (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
      (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
    (in1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (

# Load the Model and Stylize the Content Video

In [5]:
data_path = '/tmp2/vincentwu929/DIP_final/Sky/content/sky2/Video2_HR.avi'
mask_path = '/tmp2/vincentwu929/DIP_final/Sky/mask/mask_HR2.png'
video_name = 'Video2_HR'
style_name = 'ZaoWouKi' #'crayon', 'fountainpen', 'ZaoWouKi'

In [6]:
#skvideo.io.ffprobe(data_path)

In [7]:
mask = cv2.imread(mask_path).astype(np.bool)

In [8]:
save_model_path = "./models/" + style_name + "_10000_unstable_vgg19.pth"
transformer.load_state_dict(torch.load(save_model_path))
transformer.eval()

TransformerNet(
  (conv1): ConvLayer(
    (reflection_pad): ReflectionPad2d((4, 4, 4, 4))
    (conv2d): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
  )
  (in1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv2): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  )
  (in2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv3): ConvLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  )
  (in3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (res1): ResidualBlock(
    (conv1): ConvLayer(
      (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
      (conv2d): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
    (in1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (

In [9]:
BATCH_SIZE = 8

In [10]:
batch = []
videogen = skvideo.io.FFmpegReader(data_path)
writer = skvideo.io.FFmpegWriter("/tmp2/vincentwu929/DIP_final/" +\
                                 video_name + "_" + style_name + ".avi")
try:
    with torch.no_grad():
        for frame in tqdm_notebook(videogen.nextFrame()):
            batch.append(preprocess(Image.fromarray(frame[0:4000,1000:5000])).unsqueeze(0))
            if len(batch) == BATCH_SIZE:
                for frame_out in recover_image(transformer(
                    torch.cat(batch, 0).cuda()).cpu().numpy()):
                    frame_out = cv2.resize(frame_out, (4000, 4000), cv2.INTER_CUBIC)
                    out_img = frame.copy()
                    out_img[0:4000,1000:5000] = frame_out * mask + frame[0:4000,1000:5000] * (~mask)
                    writer.writeFrame(out_img)
                batch = []
except RuntimeError as e:
    print(e)
    pass
writer.close()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




