In [24]:
import numpy as np
import tensorflow as tf
import os
import scipy.misc
from easydict import EasyDict as edict
from WESPE import *
from utils import *
from dataloader import *
from ops import *

%reload_ext autoreload
%autoreload 2

config = edict()
# training parameters
config.batch_size = 50
config.patch_size = 100
config.mode = "RGB" #YCbCr
config.channels = 3
config.content_layer = 'relu5_4'
config.learning_rate = 1e-4
config.augmentation = True #data augmentation (flip, rotation)

# weights for loss
config.w_color = 1.2 # gaussian blur + mse (originally 0.1)
config.w_texture = 1 # gan (originally 0.4)
config.w_content = 2 # vgg19 (originally 1)
config.w_tv = 1/400 # total variation (originally 400)

config.model_name = "WESPE"

# directories
config.dataset_name = "iphone"
config.train_path_phone = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data",str(config.dataset_name),"*.jpg")
config.train_path_dslr = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data/canon/*.jpg")
config.test_path_phone_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches",str(config.dataset_name),"*.jpg")
config.test_path_dslr_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches/canon/*.jpg")
config.test_path_phone_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images",str(config.dataset_name),"*.jpg")
config.test_path_dslr_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images/canon/*.jpg")

config.vgg_dir = "../vgg_pretrained/imagenet-vgg-verydeep-19.mat"

config.result_dir = os.path.join("./result")
config.result_img_dir = os.path.join(config.result_dir, "samples")
config.checkpoint_dir = os.path.join(config.result_dir, "model")

if not os.path.exists(config.checkpoint_dir):
    print("creating dir...", config.checkpoint_dir)
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.result_dir):
    print("creating dir...", config.result_dir)
    os.makedirs(config.result_dir)
if not os.path.exists(config.result_img_dir):
    print("creating dir...", config.result_img_dir)
    os.makedirs(config.result_img_dir)
    
config.sample_dir = "samples"
if not os.path.exists(config.sample_dir):
    print("creating dir...", config.sample_dir)
    os.makedirs(config.sample_dir)

In [3]:
# load dataset
dataset_phone, dataset_dslr = load_dataset(config)

Dataset: iphone, 160471 image pairs
160471 image pairs loaded! setting took: 201.7262s


In [25]:
# build WESPE model
tf.reset_default_graph()
# uncomment this when only trying to test the model
#dataset_phone = []
#dataset_dslr = []
sess = tf.Session()
model = WESPE(sess, config, dataset_phone, dataset_dslr)

Completed building generator. Number of variables: 26
Completed building discriminator. Number of variables: 22


In [26]:
# pretrain discriminator with (phone, dslr) pairs
model.pretrain_discriminator(load = False)

 Discriminator training starts from beginning
Iteration 0, runtime: 0.373 s, discriminator loss: 1.373628
Discriminator test accuracy: phone: 126/200, dslr: 123/200
Iteration 2000, runtime: 75.340 s, discriminator loss: 0.877606
Discriminator test accuracy: phone: 176/200, dslr: 157/200
Iteration 4000, runtime: 149.913 s, discriminator loss: 0.795663
Discriminator test accuracy: phone: 178/200, dslr: 175/200
Iteration 6000, runtime: 224.705 s, discriminator loss: 0.796582
Discriminator test accuracy: phone: 170/200, dslr: 186/200
Iteration 8000, runtime: 299.593 s, discriminator loss: 0.793576
Discriminator test accuracy: phone: 137/200, dslr: 191/200
pretraining complete


In [27]:
# test discriminator performance for (phone, dslr) pair
model.test_discriminator(200, load = True)

Loading checkpoints from  ./result/model/iphone
INFO:tensorflow:Restoring parameters from ./result/model/iphone/WESPE
 [*] Load SUCCESS
Discriminator test accuracy: phone: 177/200, dslr: 166/200


In [28]:
# train generator & discriminator together
model.train(load = True)

Loading checkpoints from  ./result/model/iphone
INFO:tensorflow:Restoring parameters from ./result/model/iphone/WESPE
 [*] Load SUCCESS
Iteration 0, runtime: 1.560 s, generator loss: 53.031921
Loss per component: color 14.725231, texture 5.036114, content 14.990595, tv 137.734848
Dricriminator test accuracy: phone: 174/200, dslr: 176/200, enhanced: 93/200
(runtime: 5.921 s) Average test PSNR for 200 random test image patches: phone-enhanced 12.118, dslr-enhanced 12.029
Iteration 1000, runtime: 436.964 s, generator loss: 21.188353
Loss per component: color 2.035095, texture 1.803439, content 7.411040, tv 848.287354
Dricriminator test accuracy: phone: 97/200, dslr: 169/200, enhanced: 114/200
(runtime: 5.673 s) Average test PSNR for 200 random test image patches: phone-enhanced 20.849, dslr-enhanced 20.692


KeyboardInterrupt: 

In [29]:
# test trained model
model.test_generator(200, 14, load = True)

Loading checkpoints from  ./result/model/iphone
INFO:tensorflow:Restoring parameters from ./result/model/iphone/WESPE
 [*] Load SUCCESS
Dricriminator test accuracy: phone: 99/200, dslr: 178/200, enhanced: 108/200
(runtime: 6.146 s) Average test PSNR for 200 random test image patches: phone-enhanced 21.075, dslr-enhanced 20.306


KeyboardInterrupt: 

In [13]:
# save trained model
model.save()