In [1]:
import os
import sys
sys.path.append('../')
import torch
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.utils as utils
import argparse
import time
import numpy as np
import cv2
from PIL import Image, ImageOps
from photo_wct import PhotoWCT, TORCH_MODE, get_dev_vec
from photo_smooth import Propagator
from collections import namedtuple
# prepare paths for model
vgg_paths = {}
for i in range(1,6):
    vgg_paths['vgg{}'.format(i)] = '../models/vgg_normalised_conv{}_1_mask.t7'.format(i) 
    vgg_paths['decoder{}'.format(i)] = '../models/feature_invertor_conv{}_1_mask.t7'.format(i)

vgg_key_list = list(vgg_paths.keys())
vgg_arg_class = namedtuple('VggArgs', vgg_key_list)
vgg_args = vgg_arg_class(*[vgg_paths[k] for k in vgg_key_list])
# Load model
p_wct = PhotoWCT(vgg_args)
p_pro = Propagator()

def process_image(cont_img, styl_img, 
                  content_image_path, 
                  output_image_path = 'cur_image.png',
                 resize_dim = (256, 256)):
    cont_img = cont_img.resize(resize_dim, Image.BICUBIC)
    styl_img = styl_img.resize(resize_dim, Image.BICUBIC)
    cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)
    styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)
    cont_img = Variable(get_dev_vec(cont_img), volatile=True)
    styl_img = Variable(get_dev_vec(styl_img), volatile=True)
    
    cont_seg = []
    styl_seg = []
    cont_seg = np.asarray(cont_seg)
    styl_seg = np.asarray(styl_seg)
    
    start_style_time = time.time()
    stylized_img = p_wct.transform(cont_img, styl_img, cont_seg, styl_seg)
    end_style_time = time.time()
    print('Elapsed time in stylization: %f' % (end_style_time - start_style_time))
    new_style_image = stylized_img.data.cpu().float()
    utils.save_image(new_style_image, output_image_path, nrow=1)
    start_propagation_time = time.time()
    out_img = p_pro.process(output_image_path, content_image_path)
    end_propagation_time = time.time()
    print('Elapsed time in propagation: %f' % (end_propagation_time - start_propagation_time))
    out_img = Image.fromarray(np.array(out_img)[:,:,::-1])
    return new_style_image.numpy().squeeze().swapaxes(0,2).swapaxes(0,1), out_img

In [2]:
import ipywidgets as ipw
from glob import glob
all_images = glob(os.path.join('..', 'images', 'style', '*.png'))+glob(os.path.join('..', 'images', 'content', '*.png'))+glob(os.path.join('..', 'images', '*.png'))

In [3]:
import matplotlib.pyplot as plt
figA, (ax1, ax2) = plt.subplots(1, 2, figsize = (12, 6))
figB, (ax3, ax4) = plt.subplots(1, 2, figsize = (12, 6))
ax1.set_title('Not Loaded')
figA.savefig('junk_figure.png', dpi = 50)
def_img = open('junk_figure.png', 'rb').read()
load_image = ipw.Image(layout = ipw.Layout(width = "800px"), 
                      value = def_img)
out_image = ipw.Image(layout = ipw.Layout(width = "800px"), 
                      value = def_img)
cont_path = ipw.Dropdown(options = all_images)
style_path = ipw.Dropdown(options = all_images)
res_level = ipw.Dropdown(options = [64, 128, 256, 384], value = 64)

In [4]:
def update_image(*args):
    cont_img = Image.open(cont_path.value).convert('RGB')
    styl_img = Image.open(style_path.value).convert('RGB')
    ax1.imshow(cont_img)
    ax1.set_title('Content')
    ax2.imshow(styl_img)
    ax2.set_title('Style')
    
    figA.savefig('load_figure.png')
    load_image.value = open('load_figure.png', 'rb').read()
    out_image.value = def_img
    
    ax3.set_title('Transfered')
    resize_dim = (res_level.value, res_level.value) 
    c_img, d_img = process_image(cont_img = cont_img,
                 styl_img = styl_img,
                 content_image_path=cont_path.value, 
                                 resize_dim = resize_dim
                 )
    ax3.imshow(np.clip(c_img*255, 0, 255).astype(np.uint8))
    ax4.imshow(d_img)
    ax4.set_title('Transfered and Propogated')
    figB.savefig('t_figure.png')
    out_image.value = open('t_figure.png', 'rb').read()

cont_path.observe(update_image, names='value')
style_path.observe(update_image, names='value')
res_level.observe(update_image, names='value')

In [5]:
ipw.VBox([ipw.HBox([ipw.VBox([ipw.Label(value = 'Content File:'), cont_path]), 
                    ipw.VBox([ipw.Label(value = 'Style File:'), style_path]),
                    ipw.VBox([ipw.Label(value = 'Image Resolution:'), res_level])
                   ]), 
          ipw.VBox([load_image, out_image])])

Elapsed time in stylization: 0.754482
Elapsed time in propagation: 0.148189
Elapsed time in stylization: 2.280497
Elapsed time in propagation: 0.762288
Elapsed time in stylization: 2.165730
Elapsed time in propagation: 0.710015
