In [53]:
import numpy as np
import tensorflow as tf
import os
import scipy.misc
from easydict import EasyDict as edict
from WESPE import *
from utils import *
from dataloader import *
from ops import *

%reload_ext autoreload
%autoreload 2

config = edict()
# training parameters
config.batch_size = 32
config.patch_size = 100
config.mode = "RGB"
config.channels = 3
config.content_layer = 'relu2_2' # originally relu5_4 in DPED
config.learning_rate = 1e-4
config.augmentation = True #data augmentation (flip, rotation)

# weights for loss
config.w_content = 0.1 # reconstruction (originally 1)
config.w_color = 20 # gan color (originally 5e-3)
config.w_texture = 3 # gan texture (originally 5e-3)
config.w_tv = 1/400 # total variation (originally 400)

config.model_name = "WESPE"

# directories
config.dataset_name = "iphone"
config.train_path_phone = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data",str(config.dataset_name),"*.jpg")
config.train_path_dslr = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data/canon/*.jpg")
config.test_path_phone_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches",str(config.dataset_name),"*.jpg")
config.test_path_dslr_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches/canon/*.jpg")
config.test_path_phone_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images",str(config.dataset_name),"*.jpg")
config.test_path_dslr_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images/canon/*.jpg")

config.vgg_dir = "../vgg_pretrained/imagenet-vgg-verydeep-19.mat"

config.result_dir = os.path.join("./result", config.model_name)
config.result_img_dir = os.path.join(config.result_dir, "samples")
config.checkpoint_dir = os.path.join(config.result_dir, "model")

if not os.path.exists(config.checkpoint_dir):
    print("creating dir...", config.checkpoint_dir)
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.result_dir):
    print("creating dir...", config.result_dir)
    os.makedirs(config.result_dir)
if not os.path.exists(config.result_img_dir):
    print("creating dir...", config.result_img_dir)
    os.makedirs(config.result_img_dir)
    
config.sample_dir = "samples"
if not os.path.exists(config.sample_dir):
    print("creating dir...", config.sample_dir)
    os.makedirs(config.sample_dir)

In [2]:
# load dataset
dataset_phone, dataset_dslr = load_dataset(config)

Dataset: iphone, 160471 image pairs
160471 image pairs loaded! setting took: 200.6160s


In [54]:
# build WESPE model
tf.reset_default_graph()
# uncomment this when only trying to test the model
#dataset_phone = []
#dataset_dslr = []
sess = tf.Session()
model = WESPE(sess, config, dataset_phone, dataset_dslr)

Completed building generator. Number of variables: 52
Discriminator-color (none)
Discriminator-texture
Discriminator-color (none)
Discriminator-texture
Completed building color discriminator. Number of variables: 22
Completed building texture discriminator. Number of variables: 22


In [58]:
# train generator & discriminator together
model.train(load = False)

 Overall training starts from beginning
Iteration 0, runtime: 0.645 s, generator loss: 38.146297
Loss per component: content 185.846390, color 0.740438, texture 0.542597, tv 1250.047607
(runtime: 1.421 s) Average test PSNR for 200 random test image patches: phone-enhanced 27.653, phone-reconstructed 40.409, dslr-enhanced 17.888


KeyboardInterrupt: 

In [57]:
# test trained model
model.test_generator(200, 14, load = False)

(runtime: 1.361 s) Average test PSNR for 200 random test image patches: phone-enhanced 28.234, phone-reconstructed 41.767, dslr-enhanced 18.012
(runtime: 104.316 s) Average test PSNR for 14 random full test images: original-enhanced 26.213, original-reconstructed 37.944


In [13]:
# save trained model
model.save()