In [13]:
import numpy as np
import tensorflow as tf
import os
import scipy.misc
from easydict import EasyDict as edict
from WESPE import *
from utils import *
from dataloader import *
from ops import *

%reload_ext autoreload
%autoreload 2

config = edict()
# training parameters
config.batch_size = 30
config.patch_size = 100
config.mode = "RGB"
config.channels = 3
config.content_layer = 'relu5_4'
config.learning_rate = 1e-4
config.augmentation = True #data augmentation (flip, rotation)

# weights for loss
config.w_content = 1 # reconstruction (originally 1)
config.w_color = 1e-2 # gan color (originally 5e-3)
config.w_texture = 1e-2 # gan texture (originally 5e-3)
config.w_tv = 1/400 # total variation (originally 400)

config.model_name = "WESPE"

# directories
config.dataset_name = "iphone"
config.train_path_phone = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data",str(config.dataset_name),"*.jpg")
config.train_path_dslr = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data/canon/*.jpg")
config.test_path_phone_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches",str(config.dataset_name),"*.jpg")
config.test_path_dslr_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches/canon/*.jpg")
config.test_path_phone_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images",str(config.dataset_name),"*.jpg")
config.test_path_dslr_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images/canon/*.jpg")

config.vgg_dir = "../vgg_pretrained/imagenet-vgg-verydeep-19.mat"

config.result_dir = os.path.join("./result", config.model_name)
config.result_img_dir = os.path.join(config.result_dir, "samples")
config.checkpoint_dir = os.path.join(config.result_dir, "model")

if not os.path.exists(config.checkpoint_dir):
    print("creating dir...", config.checkpoint_dir)
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.result_dir):
    print("creating dir...", config.result_dir)
    os.makedirs(config.result_dir)
if not os.path.exists(config.result_img_dir):
    print("creating dir...", config.result_img_dir)
    os.makedirs(config.result_img_dir)
    
config.sample_dir = "samples"
if not os.path.exists(config.sample_dir):
    print("creating dir...", config.sample_dir)
    os.makedirs(config.sample_dir)

In [2]:
# load dataset
dataset_phone, dataset_dslr = load_dataset(config)

Dataset: iphone, 160471 image pairs
160471 image pairs loaded! setting took: 145.3811s


In [14]:
# build WESPE model
tf.reset_default_graph()
# uncomment this when only trying to test the model
#dataset_phone = []
#dataset_dslr = []
sess = tf.Session()
model = WESPE(sess, config, dataset_phone, dataset_dslr)

Completed building generator. Number of variables: 52
Discriminator-color
Discriminator-texture
Discriminator-color
Discriminator-texture
Completed building color discriminator. Number of variables: 22
Completed building texture discriminator. Number of variables: 22


In [26]:
# pretrain discriminator with (phone, dslr) pairs
model.pretrain_discriminator(load = False)

 Discriminator training starts from beginning
Iteration 0, runtime: 0.373 s, discriminator loss: 1.373628
Discriminator test accuracy: phone: 126/200, dslr: 123/200
Iteration 2000, runtime: 75.340 s, discriminator loss: 0.877606
Discriminator test accuracy: phone: 176/200, dslr: 157/200
Iteration 4000, runtime: 149.913 s, discriminator loss: 0.795663
Discriminator test accuracy: phone: 178/200, dslr: 175/200
Iteration 6000, runtime: 224.705 s, discriminator loss: 0.796582
Discriminator test accuracy: phone: 170/200, dslr: 186/200
Iteration 8000, runtime: 299.593 s, discriminator loss: 0.793576
Discriminator test accuracy: phone: 137/200, dslr: 191/200
pretraining complete


In [27]:
# test discriminator performance for (phone, dslr) pair
model.test_discriminator(200, load = True)

Loading checkpoints from  ./result/model/iphone
INFO:tensorflow:Restoring parameters from ./result/model/iphone/WESPE
 [*] Load SUCCESS
Discriminator test accuracy: phone: 177/200, dslr: 166/200


In [15]:
# train generator & discriminator together
model.train(load = False)

 Overall training starts from beginning
Iteration 0, runtime: 2.099 s, generator loss: 8.374959
Loss per component: content 7.893203, color 0.722424, texture 0.721582, tv 186.926590
(runtime: 2.228 s) Average test PSNR for 200 random test image patches: phone-enhanced 11.616, dslr-enhanced 11.605
Iteration 1000, runtime: 524.675 s, generator loss: 4.602940
Loss per component: content 3.050100, color 32.443714, texture 31.426334, tv 365.655731
(runtime: 1.697 s) Average test PSNR for 200 random test image patches: phone-enhanced 19.423, dslr-enhanced 15.144
Iteration 2000, runtime: 1047.664 s, generator loss: 5.027886
Loss per component: content 3.573137, color 37.247185, texture 35.456203, tv 291.086121
(runtime: 2.610 s) Average test PSNR for 200 random test image patches: phone-enhanced 19.322, dslr-enhanced 14.500


KeyboardInterrupt: 

In [None]:
# test trained model
model.test_generator(200, 4, load = True)

Loading checkpoints from  ./result/WESPE/model/iphone
INFO:tensorflow:Restoring parameters from ./result/WESPE/model/iphone/WESPE
 [*] Load SUCCESS
(runtime: 2.236 s) Average test PSNR for 200 random test image patches: phone-enhanced 19.429, dslr-enhanced 14.398


In [13]:
# save trained model
model.save()