In [1]:
import numpy as np
import tensorflow as tf
import os
import scipy.misc
from easydict import EasyDict as edict
from WESPE_DIV2K import *
from dataloader.dataloader_DIV2K import *
from ops import *
from utils import *

%reload_ext autoreload
%autoreload 2

config = edict()
# training parameters
config.batch_size = 32#32
config.patch_size = 100
config.mode = "RGB"
config.channels = 3
config.content_layer = 'relu2_2' # originally relu5_4 in DPED
config.learning_rate = 1e-4
config.augmentation = True #data augmentation (flip, rotation)
config.test_every = 500
config.train_iter = 20000

# weights for loss
config.w_content = 0.2 # reconstruction (originally 1)
config.w_color = 40 # gan color (originally 5e-3)
config.w_texture = 3 # gan texture (originally 5e-3)
config.w_tv = 1/400 # total variation (originally 400)

config.model_name = "WESPE_DIV2K"

# directories
config.dataset_name = "iphone"
config.train_path_phone = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data",str(config.dataset_name),"*.jpg")
config.train_path_DIV2K = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/train/DIV2K/*.png")

config.test_path_phone_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches",str(config.dataset_name),"*.jpg")
config.test_path_phone_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images",str(config.dataset_name),"*.jpg")

config.vgg_dir = "../vgg_pretrained/imagenet-vgg-verydeep-19.mat"

config.result_dir = os.path.join("./result", config.model_name)
config.result_img_dir = os.path.join(config.result_dir, "samples")
config.checkpoint_dir = os.path.join(config.result_dir, "model")

if not os.path.exists(config.checkpoint_dir):
    print("creating dir...", config.checkpoint_dir)
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.result_dir):
    print("creating dir...", config.result_dir)
    os.makedirs(config.result_dir)
if not os.path.exists(config.result_img_dir):
    print("creating dir...", config.result_img_dir)
    os.makedirs(config.result_img_dir)
    
config.sample_dir = "samples_DIV2K"
if not os.path.exists(config.sample_dir):
    print("creating dir...", config.sample_dir)
    os.makedirs(config.sample_dir)

  from ._conv import register_converters as _register_converters


In [3]:
# load dataset
dataset_phone, dataset_DIV2K = load_dataset(config)

Dataset: iphone, 160471 images
DIV2K: 900 images
160471 images loaded! setting took: 174.4145s


In [10]:
phone_batch, DIV2K_batch = get_batch(dataset_phone, dataset_DIV2K, config, start = 0)
print('done!')

done!


In [2]:
# build WESPE model
tf.reset_default_graph()
# uncomment this when only trying to test the model
dataset_phone = []
dataset_DIV2K = []
sess = tf.Session()
model = WESPE(sess, config, dataset_phone, dataset_DIV2K)

Completed building generator. Number of variables: 52
Discriminator-color (none)
Discriminator-texture
Discriminator-color (none)
Discriminator-texture
Completed building color discriminator. Number of variables: 22
Completed building texture discriminator. Number of variables: 22


In [17]:
# train generator & discriminator together
model.train(load = True)

Loading checkpoints from  ./result/WESPE_DIV2K/model/iphone
INFO:tensorflow:Restoring parameters from ./result/WESPE_DIV2K/model/iphone/WESPE_DIV2K
 [*] Load SUCCESS
Iteration 0, runtime: 0.644 s, generator loss: 217.397369
Loss per component: content 668.107178, color 1.781722, texture 3.429668, tv 887.214111
(runtime: 1.640 s) Average test PSNR for 200 random test image patches: phone-enhanced 20.819, phone-reconstructed 36.158
Iteration 500, runtime: 251.566 s, generator loss: 166.596878
Loss per component: content 600.031555, color 0.856101, texture 3.239120, tv 1051.661499
(runtime: 2.006 s) Average test PSNR for 200 random test image patches: phone-enhanced 22.866, phone-reconstructed 37.617
Iteration 1000, runtime: 503.152 s, generator loss: 248.586487
Loss per component: content 675.082520, color 2.475869, texture 3.992794, tv 1022.739258
(runtime: 1.486 s) Average test PSNR for 200 random test image patches: phone-enhanced 22.565, phone-reconstructed 36.537
Iteration 1500, run

KeyboardInterrupt: 

In [3]:
# test trained model
model.test_generator(200, 4, load = False)

(runtime: 2.868 s) Average test PSNR for 200 random test image patches: phone-enhanced 12.057, phone-reconstructed 12.148
(runtime: 63.354 s) Average test PSNR for 4 random full test images: original-enhanced 9.778, original-reconstructed 9.894


In [13]:
# save trained model
model.save()