In [1]:
import numpy as np
import tensorflow as tf
import os
import scipy.misc
from easydict import EasyDict as edict
from WESPE import *
from utils import *
from dataloader.dataloader import *
from ops import *
import sys

os.environ["CUDA_VISIBLE_DEVICES"]="0"

%reload_ext autoreload
%autoreload 2

config = edict()
# training parameters
config.batch_size = 32
config.patch_size = 100
config.mode = "RGB"
config.channels = 3
config.content_layer = 'relu2_2' # originally relu5_4 in DPED
config.learning_rate = 1e-4
config.augmentation = True #data augmentation (flip, rotation)
config.test_every = 500
config.train_iter = 100000

# weights for loss
config.w_content = 0.1 # reconstruction (originally 1)
config.w_color = 20 # gan color (originally 5e-3)
config.w_texture = 3 # gan texture (originally 5e-3)
config.w_tv = 1/200 # total variation (originally 400)

config.model_name = "WESPE_DPED"
pre_path = "/mnt/sde/palparmar/dataset/dped/dped"
# directories
config.dataset_name = "iphone"
config.train_path_phone = os.path.join(pre_path,str(config.dataset_name),"training_data",str(config.dataset_name),"*.jpg")
config.train_path_dslr = os.path.join(pre_path,str(config.dataset_name),"training_data/canon/*.jpg")
config.test_path_phone_patch = os.path.join(pre_path,str(config.dataset_name),"test_data/patches",str(config.dataset_name),"*.jpg")
config.test_path_dslr_patch = os.path.join(pre_path,str(config.dataset_name),"test_data/patches/canon/*.jpg")
config.test_path_phone_image = os.path.join("/mnt/sde/palparmar/dataset/dped/sample_images/original_images", str(config.dataset_name),"*.jpg")
config.test_path_dslr_image = os.path.join("/mnt/sde/palparmar/dataset/dped/sample_images/original_images/canon/*.jpg")

config.vgg_dir = "../vgg_pretrained/imagenet-vgg-verydeep-19.mat"

config.result_dir = os.path.join("./result", config.model_name)
config.result_img_dir = os.path.join(config.result_dir, "samples")
config.checkpoint_dir = os.path.join(config.result_dir, "model")

if not os.path.exists(config.checkpoint_dir):
    print("creating dir...", config.checkpoint_dir)
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.result_dir):
    print("creating dir...", config.result_dir)
    os.makedirs(config.result_dir)
if not os.path.exists(config.result_img_dir):
    print("creating dir...", config.result_img_dir)
    os.makedirs(config.result_img_dir)
    
config.sample_dir = "samples"
if not os.path.exists(config.sample_dir):
    print("creating dir...", config.sample_dir)
    os.makedirs(config.sample_dir)

In [2]:
# load dataset
dataset_phone, dataset_dslr = load_dataset(config)

Dataset: iphone, 160471 image pairs
160471 image pairs loaded! setting took: 231.7841s


In [3]:
# build WESPE model
tf.reset_default_graph()
# uncomment this when only trying to test the model
#dataset_phone = []
#dataset_dslr = []
sess = tf.Session()
model = WESPE(sess, config, dataset_phone, dataset_dslr)

Completed building generator. Number of variables: 28
Discriminator-color (none)
Discriminator-texture
Discriminator-color (none)
Discriminator-texture
Completed building color discriminator. Number of variables: 22
Completed building texture discriminator. Number of variables: 22


In [4]:
# train generator & discriminator together
model.train(load = False)

 Overall training starts from beginning
Iteration 0, runtime: 7.493 s, generator loss: 2320.744873
Loss per component: content 22994.992188, color 0.719649, texture 0.708685, tv 945.335571
(runtime: 2.377 s) Average test PSNR for 200 random test image patches: phone-enhanced 9.827, phone-reconstructed 12.363, dslr-enhanced 10.052
Iteration 500, runtime: 128.954 s, generator loss: 153.843857
Loss per component: content 643.850464, color 3.110806, texture 8.236347, tv 506.730255
(runtime: 1.757 s) Average test PSNR for 200 random test image patches: phone-enhanced 21.995, phone-reconstructed 35.992, dslr-enhanced 16.495
Iteration 1000, runtime: 250.271 s, generator loss: -4.002808
Loss per component: content 281.453918, color -1.605125, texture -0.733044, tv 430.685883
(runtime: 1.984 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.995, phone-reconstructed 40.976, dslr-enhanced 15.499
Iteration 1500, runtime: 371.760 s, generator loss: 44.576275
Loss per compone

Iteration 14000, runtime: 3411.698 s, generator loss: 15.527098
Loss per component: content 51.148342, color 0.260687, texture 1.254297, tv 287.127289
(runtime: 1.848 s) Average test PSNR for 200 random test image patches: phone-enhanced 25.361, phone-reconstructed 44.849, dslr-enhanced 17.719
Iteration 14500, runtime: 3536.748 s, generator loss: 72.403900
Loss per component: content 79.327408, color 2.835053, texture 1.926430, tv 398.159912
(runtime: 1.805 s) Average test PSNR for 200 random test image patches: phone-enhanced 22.148, phone-reconstructed 46.075, dslr-enhanced 17.012
Iteration 15000, runtime: 3661.075 s, generator loss: 30.048107
Loss per component: content 35.184532, color 1.024876, texture 1.473321, tv 322.432343
(runtime: 1.804 s) Average test PSNR for 200 random test image patches: phone-enhanced 25.686, phone-reconstructed 47.303, dslr-enhanced 17.760
Iteration 15500, runtime: 3785.860 s, generator loss: 40.398903
Loss per component: content 45.325176, color 1.2244

Iteration 28000, runtime: 6817.550 s, generator loss: 37.295765
Loss per component: content 96.790710, color 1.013957, texture 1.861496, tv 350.613892
(runtime: 1.691 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.424, phone-reconstructed 45.298, dslr-enhanced 19.807
Iteration 28500, runtime: 6937.308 s, generator loss: 31.695557
Loss per component: content 96.743126, color 0.542471, texture 3.254153, tv 281.871735
(runtime: 1.704 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.872, phone-reconstructed 45.170, dslr-enhanced 20.052
Iteration 29000, runtime: 7057.368 s, generator loss: 41.613895
Loss per component: content 92.870102, color 1.301229, texture 1.538866, tv 337.142334
(runtime: 1.548 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.898, phone-reconstructed 45.416, dslr-enhanced 19.715
Iteration 29500, runtime: 7177.815 s, generator loss: 50.536709
Loss per component: content 76.554703, color 1.8051

Iteration 42000, runtime: 10229.622 s, generator loss: 100.362900
Loss per component: content 178.223755, color 3.448540, texture 4.004466, tv 311.267273
(runtime: 1.841 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.514, phone-reconstructed 40.994, dslr-enhanced 17.951
Iteration 42500, runtime: 10355.418 s, generator loss: 87.575302
Loss per component: content 110.704010, color 3.043996, texture 4.865963, tv 205.418503
(runtime: 1.835 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.867, phone-reconstructed 43.860, dslr-enhanced 18.842
Iteration 43000, runtime: 10481.294 s, generator loss: 82.017799
Loss per component: content 87.610672, color 3.091830, texture 3.261654, tv 327.035217
(runtime: 2.280 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.613, phone-reconstructed 44.765, dslr-enhanced 19.740
Iteration 43500, runtime: 10607.407 s, generator loss: 62.786930
Loss per component: content 94.725388, color

Iteration 56000, runtime: 13660.928 s, generator loss: 165.628479
Loss per component: content 168.570831, color 5.723551, texture 10.855145, tv 346.984589
(runtime: 1.551 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.569, phone-reconstructed 43.374, dslr-enhanced 19.966
Iteration 56500, runtime: 13779.149 s, generator loss: 89.026215
Loss per component: content 159.111725, color 2.368874, texture 8.053188, tv 315.600830
(runtime: 1.669 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.902, phone-reconstructed 42.852, dslr-enhanced 19.547
Iteration 57000, runtime: 13897.598 s, generator loss: 94.689247
Loss per component: content 175.194016, color 3.270267, texture 3.425444, tv 297.636810
(runtime: 1.798 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.485, phone-reconstructed 42.665, dslr-enhanced 19.978
Iteration 57500, runtime: 14015.568 s, generator loss: 119.941071
Loss per component: content 155.323105, c

(runtime: 1.479 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.141, phone-reconstructed 41.890, dslr-enhanced 19.536
Iteration 70000, runtime: 17057.735 s, generator loss: 172.001907
Loss per component: content 186.151047, color 5.897144, texture 11.393240, tv 252.840179
(runtime: 1.621 s) Average test PSNR for 200 random test image patches: phone-enhanced 23.144, phone-reconstructed 42.568, dslr-enhanced 19.493
Iteration 70500, runtime: 17176.023 s, generator loss: 147.670151
Loss per component: content 154.814148, color 5.292436, texture 8.300328, tv 287.809570
(runtime: 1.950 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.103, phone-reconstructed 42.698, dslr-enhanced 19.715
Iteration 71000, runtime: 17294.408 s, generator loss: 235.472137
Loss per component: content 250.476700, color 8.735840, texture 11.369350, tv 319.924927
(runtime: 1.611 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.108, phone-rec

Iteration 83500, runtime: 20346.569 s, generator loss: 273.028442
Loss per component: content 215.200623, color 10.210537, texture 15.273279, tv 295.563263
(runtime: 1.520 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.901, phone-reconstructed 42.282, dslr-enhanced 19.993
Iteration 84000, runtime: 20465.120 s, generator loss: 201.603439
Loss per component: content 209.822479, color 6.805541, texture 14.392331, tv 266.673767
(runtime: 1.449 s) Average test PSNR for 200 random test image patches: phone-enhanced 26.498, phone-reconstructed 42.415, dslr-enhanced 18.318
Iteration 84500, runtime: 20583.520 s, generator loss: 149.809845
Loss per component: content 209.441193, color 3.514869, texture 18.881466, tv 384.787323
(runtime: 1.775 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.854, phone-reconstructed 41.101, dslr-enhanced 19.981
Iteration 85000, runtime: 20702.087 s, generator loss: 123.682014
Loss per component: content 217.1189

(runtime: 2.141 s) Average test PSNR for 200 random test image patches: phone-enhanced 25.157, phone-reconstructed 42.233, dslr-enhanced 19.506
Iteration 97500, runtime: 23833.561 s, generator loss: 171.688782
Loss per component: content 288.912964, color 5.391371, texture 11.188371, tv 280.990662
(runtime: 1.982 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.773, phone-reconstructed 41.522, dslr-enhanced 18.945
Iteration 98000, runtime: 23960.951 s, generator loss: 239.539093
Loss per component: content 294.499268, color 7.731886, texture 18.007118, tv 286.017883
(runtime: 1.977 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.128, phone-reconstructed 40.069, dslr-enhanced 19.417
Iteration 98500, runtime: 24087.867 s, generator loss: 174.280823
Loss per component: content 207.518158, color 4.941877, texture 17.689381, tv 324.663177
(runtime: 1.592 s) Average test PSNR for 200 random test image patches: phone-enhanced 25.760, phone-re

In [5]:
# test trained model
model.test_generator(200, 14, load = True)

Loading checkpoints from  ./result/WESPE_DPED/model/iphone
INFO:tensorflow:Restoring parameters from ./result/WESPE_DPED/model/iphone/WESPE_DPED
 [*] Load SUCCESS
(runtime: 1.692 s) Average test PSNR for 200 random test image patches: phone-enhanced 24.632, phone-reconstructed 40.986, dslr-enhanced 19.694
Time taken for Image 0 = 9.64775037765503
Time taken for Image 1 = 5.844906806945801
Time taken for Image 2 = 5.629693508148193
Time taken for Image 3 = 5.516461133956909
Time taken for Image 4 = 7.850811958312988
Time taken for Image 5 = 7.195908784866333
Time taken for Image 6 = 5.6406567096710205
Time taken for Image 7 = 9.198256015777588
Time taken for Image 8 = 5.500957489013672
Time taken for Image 9 = 4.953590393066406
Time taken for Image 10 = 4.408099174499512
Time taken for Image 11 = 10.791743278503418
Time taken for Image 12 = 9.654475688934326
Time taken for Image 13 = 9.406728506088257
(runtime: 454.600 s) Average test PSNR for 14 random full test images: original-enhanc

In [6]:
# save trained model
model.save()