In [1]:
import numpy as np
import tensorflow as tf
import os
import scipy.misc
from easydict import EasyDict as edict
from DPED import *
from utils import *
from ops import *

%reload_ext autoreload
%autoreload 2

config = edict()
# training parameters
config.batch_size = 50
config.patch_size = 100
config.mode = "RGB" #YCbCr
config.channels = 3
config.content_layer = 'relu5_4'
config.learning_rate = 1e-4
config.augmentation = True #data augmentation (flip, rotation)

# weights for loss
config.w_color = 1.2 # gaussian blur + mse (originally 0.1)
config.w_texture = 1 # gan (originally 0.4)
config.w_content = 2 # vgg19 (originally 1)
config.w_tv = 1/400 # total variation (originally 400)

# directories
config.dataset_name = "iphone"
config.train_path_phone = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data",str(config.dataset_name),"*.jpg")
config.train_path_dslr = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"training_data/canon/*.jpg")
config.test_path_phone_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches",str(config.dataset_name),"*.jpg")
config.test_path_dslr_patch = os.path.join("/home/johnyi/Downloads/dped",str(config.dataset_name),"test_data/patches/canon/*.jpg")
config.test_path_phone_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images",str(config.dataset_name),"*.jpg")
config.test_path_dslr_image = os.path.join("/home/johnyi/deeplearning/research/SISR_Datasets/test/DPED/sample_images/original_images/canon/*.jpg")
config.sample_dir = "samples"
config.checkpoint_dir = "checkpoint"
config.vgg_dir = "../vgg_pretrained/imagenet-vgg-verydeep-19.mat"
config.log_dir = "logs"

if not os.path.exists(config.checkpoint_dir):
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.sample_dir):
    os.makedirs(config.sample_dir)
if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)

  from ._conv import register_converters as _register_converters


In [2]:
# load dataset
dataset_phone, dataset_dslr = load_dataset(config)

Dataset: iphone, 160471 image pairs
160471 image pairs loaded! setting took: 89.3427s


In [2]:
# build DPED model
tf.reset_default_graph()
# uncomment this when only trying to test the model
dataset_phone = []
dataset_dslr = []
sess = tf.Session()
model = DPED(sess, config, dataset_phone, dataset_dslr)

Completed building generator. Number of variables: 26
Completed building discriminator. Number of variables: 22


In [10]:
# pretrain discriminator with (phone, dslr) pairs
model.pretrain_discriminator(load = False)

 Discriminator training starts from beginning
Iteration 0, runtime: 0.350 s, discriminator loss: 1.376139
Discriminator test accuracy: phone: 126/200, dslr: 121/200
Iteration 2000, runtime: 73.813 s, discriminator loss: 0.955353
Discriminator test accuracy: phone: 119/200, dslr: 183/200
Iteration 4000, runtime: 147.454 s, discriminator loss: 1.039908
Discriminator test accuracy: phone: 175/200, dslr: 143/200
Iteration 6000, runtime: 221.263 s, discriminator loss: 0.528191
Discriminator test accuracy: phone: 174/200, dslr: 175/200
Iteration 8000, runtime: 294.902 s, discriminator loss: 0.675135
Discriminator test accuracy: phone: 177/200, dslr: 165/200
pretraining complete


In [3]:
# test discriminator performance for (phone, dslr) pair
model.test_discriminator(200, load = True)

Loading checkpoints from  checkpoint/iphone
INFO:tensorflow:Restoring parameters from checkpoint/iphone/DPED
 [*] Load SUCCESS
Discriminator test accuracy: phone: 160/200, dslr: 130/200


In [12]:
# train generator & discriminator together
model.train(load = True)

Loading checkpoints from  checkpoint/iphone
INFO:tensorflow:Restoring parameters from checkpoint/iphone/DPED
 [*] Load SUCCESS
Iteration 0, runtime: 2.042 s, generator loss: 55.007935
Loss per component: color 14.718228, texture 12.986241, content 11.960189, tv 175.774673
Dricriminator test accuracy: phone: 93/200, dslr: 198/200, enhanced: 149/200
(runtime: 2.144 s) Average test PSNR for 200 test image patches: phone-enhanced 11.953, dslr-enhanced 11.768
Iteration 1000, runtime: 442.805 s, generator loss: 19.725391
Loss per component: color 1.934178, texture 1.590602, content 6.660661, tv 996.982178
Dricriminator test accuracy: phone: 143/200, dslr: 159/200, enhanced: 144/200
(runtime: 2.125 s) Average test PSNR for 200 test image patches: phone-enhanced 22.126, dslr-enhanced 20.383
Iteration 2000, runtime: 884.879 s, generator loss: 23.790648
Loss per component: color 2.261600, texture 2.301839, content 8.285510, tv 881.546875
Dricriminator test accuracy: phone: 155/200, dslr: 142/200

KeyboardInterrupt: 

In [3]:
# test trained model
model.test_generator(200, 14, load = True)

Loading checkpoints from  checkpoint/iphone
INFO:tensorflow:Restoring parameters from checkpoint/iphone/DPED
 [*] Load SUCCESS
Dricriminator test accuracy: phone: 158/200, dslr: 134/200, enhanced: 148/200
(runtime: 2.716 s) Average test PSNR for 200 random test image patches: phone-enhanced 21.092, dslr-enhanced 20.136
(runtime: 100.793 s) Average test PSNR for 14 random full test images: phone-enhanced 21.092


In [13]:
# save trained model
model.save()