Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ihnatova committed Aug 23, 2017
0 parents commit 60cbc28
Show file tree
Hide file tree
Showing 18 changed files with 887 additions and 0 deletions.
Empty file added dped/.gitkeep
Empty file.
68 changes: 68 additions & 0 deletions load_dataset.py
@@ -0,0 +1,68 @@
from __future__ import print_function
from scipy import misc
import os
import numpy as np
import sys

def load_test_data(phone, dped_dir, IMAGE_SIZE):

test_directory_phone = dped_dir + str(phone) + '/test_data/patches/' + str(phone) + '/'
test_directory_dslr = dped_dir + str(phone) + '/test_data/patches/canon/'

NUM_TEST_IMAGES = len([name for name in os.listdir(test_directory_phone)
if os.path.isfile(os.path.join(test_directory_phone, name))])

test_data = np.zeros((NUM_TEST_IMAGES, IMAGE_SIZE))
test_answ = np.zeros((NUM_TEST_IMAGES, IMAGE_SIZE))

for i in range(0, NUM_TEST_IMAGES):

I = np.asarray(misc.imread(test_directory_phone + str(i) + '.jpg'))
I = np.float16(np.reshape(I, [1, IMAGE_SIZE]))/255
test_data[i, :] = I

I = np.asarray(misc.imread(test_directory_dslr + str(i) + '.jpg'))
I = np.float16(np.reshape(I, [1, IMAGE_SIZE]))/255
test_answ[i, :] = I

if i % 100 == 0:
print(str(round(i * 100 / NUM_TEST_IMAGES)) + "% done", end="\r")

return test_data, test_answ


def load_batch(phone, dped_dir, TRAIN_SIZE, IMAGE_SIZE):

train_directory_phone = dped_dir + str(phone) + '/training_data/' + str(phone) + '/'
train_directory_dslr = dped_dir + str(phone) + '/training_data/canon/'

NUM_TRAINING_IMAGES = len([name for name in os.listdir(train_directory_phone)
if os.path.isfile(os.path.join(train_directory_phone, name))])

# if TRAIN_SIZE == -1 then load all images

if TRAIN_SIZE == -1:
TRAIN_SIZE = NUM_TRAINING_IMAGES
TRAIN_IMAGES = np.arange(0, TRAIN_SIZE)
else:
TRAIN_IMAGES = np.random.choice(np.arange(0, NUM_TRAINING_IMAGES), TRAIN_SIZE, replace=False)

train_data = np.zeros((TRAIN_SIZE, IMAGE_SIZE))
train_answ = np.zeros((TRAIN_SIZE, IMAGE_SIZE))

i = 0
for img in TRAIN_IMAGES:

I = np.asarray(misc.imread(train_directory_phone + str(img) + '.jpg'))
I = np.float16(np.reshape(I, [1, IMAGE_SIZE])) / 255
train_data[i, :] = I

I = np.asarray(misc.imread(train_directory_dslr + str(img) + '.jpg'))
I = np.float16(np.reshape(I, [1, IMAGE_SIZE])) / 255
train_answ[i, :] = I

i += 1
if i % 100 == 0:
print(str(round(i * 100 / TRAIN_SIZE)) + "% done", end="\r")

return train_data, train_answ
136 changes: 136 additions & 0 deletions models.py
@@ -0,0 +1,136 @@
import tensorflow as tf

def resnet(input_image):

with tf.variable_scope("generator"):

W1 = weight_variable([9, 9, 3, 64], name="W1"); b1 = bias_variable([64], name="b1");
c1 = tf.nn.relu(conv2d(input_image, W1) + b1)

# residual 1

W2 = weight_variable([3, 3, 64, 64], name="W2"); b2 = bias_variable([64], name="b2");
c2 = tf.nn.relu(_instance_norm(conv2d(c1, W2) + b2))

W3 = weight_variable([3, 3, 64, 64], name="W3"); b3 = bias_variable([64], name="b3");
c3 = tf.nn.relu(_instance_norm(conv2d(c2, W3) + b3)) + c1

# residual 2

W4 = weight_variable([3, 3, 64, 64], name="W4"); b4 = bias_variable([64], name="b4");
c4 = tf.nn.relu(_instance_norm(conv2d(c3, W4) + b4))

W5 = weight_variable([3, 3, 64, 64], name="W5"); b5 = bias_variable([64], name="b5");
c5 = tf.nn.relu(_instance_norm(conv2d(c4, W5) + b5)) + c3

# residual 3

W6 = weight_variable([3, 3, 64, 64], name="W6"); b6 = bias_variable([64], name="b6");
c6 = tf.nn.relu(_instance_norm(conv2d(c5, W6) + b6))

W7 = weight_variable([3, 3, 64, 64], name="W7"); b7 = bias_variable([64], name="b7");
c7 = tf.nn.relu(_instance_norm(conv2d(c6, W7) + b7)) + c5

# residual 4

W8 = weight_variable([3, 3, 64, 64], name="W8"); b8 = bias_variable([64], name="b8");
c8 = tf.nn.relu(_instance_norm(conv2d(c7, W8) + b8))

W9 = weight_variable([3, 3, 64, 64], name="W9"); b9 = bias_variable([64], name="b9");
c9 = tf.nn.relu(_instance_norm(conv2d(c8, W9) + b9)) + c7

# Convolutional

W10 = weight_variable([3, 3, 64, 64], name="W10"); b10 = bias_variable([64], name="b10");
c10 = tf.nn.relu(conv2d(c9, W10) + b10)

W11 = weight_variable([3, 3, 64, 64], name="W11"); b11 = bias_variable([64], name="b11");
c11 = tf.nn.relu(conv2d(c10, W11) + b11)

# Final

W12 = weight_variable([9, 9, 64, 3], name="W12"); b12 = bias_variable([3], name="b12");
enhanced = tf.nn.tanh(conv2d(c11, W12) + b12) * 0.58 + 0.5

return enhanced

def adversarial(image_):

with tf.variable_scope("discriminator"):

conv1 = _conv_layer(image_, 48, 11, 4, batch_nn = False)
conv2 = _conv_layer(conv1, 128, 5, 2)
conv3 = _conv_layer(conv2, 192, 3, 1)
conv4 = _conv_layer(conv3, 192, 3, 1)
conv5 = _conv_layer(conv4, 128, 3, 2)

flat_size = 128 * 7 * 7
conv5_flat = tf.reshape(conv5, [-1, flat_size])

W_fc = tf.Variable(tf.truncated_normal([flat_size, 1024], stddev=0.01))
bias_fc = tf.Variable(tf.constant(0.01, shape=[1024]))

fc = leaky_relu(tf.matmul(conv5_flat, W_fc) + bias_fc)

W_out = tf.Variable(tf.truncated_normal([1024, 2], stddev=0.01))
bias_out = tf.Variable(tf.constant(0.01, shape=[2]))

adv_out = tf.nn.softmax(tf.matmul(fc, W_out) + bias_out)

return adv_out

def weight_variable(shape, name):

initial = tf.truncated_normal(shape, stddev=0.01)
return tf.Variable(initial, name=name)

def bias_variable(shape, name):

initial = tf.constant(0.01, shape=shape)
return tf.Variable(initial, name=name)

def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def leaky_relu(x, alpha = 0.2):
return tf.maximum(alpha * x, x)

def _conv_layer(net, num_filters, filter_size, strides, batch_nn=True):

weights_init = _conv_init_vars(net, num_filters, filter_size)
strides_shape = [1, strides, strides, 1]
bias = tf.Variable(tf.constant(0.01, shape=[num_filters]))

net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME') + bias
net = leaky_relu(net)

if batch_nn:
net = _instance_norm(net)

return net

def _instance_norm(net):

batch, rows, cols, channels = [i.value for i in net.get_shape()]
var_shape = [channels]

mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
shift = tf.Variable(tf.zeros(var_shape))
scale = tf.Variable(tf.ones(var_shape))

epsilon = 1e-3
normalized = (net-mu)/(sigma_sq + epsilon)**(.5)

return scale * normalized + shift

def _conv_init_vars(net, out_channels, filter_size, transpose=False):

_, rows, cols, in_channels = [i.value for i in net.get_shape()]

if not transpose:
weights_shape = [filter_size, filter_size, in_channels, out_channels]
else:
weights_shape = [filter_size, filter_size, out_channels, in_channels]

weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=0.01, seed=1), dtype=tf.float32)
return weights_init
Empty file added models/.gitkeep
Empty file.
Binary file added models_orig/blackberry_orig.data-00000-of-00001
Binary file not shown.
Binary file added models_orig/blackberry_orig.index
Binary file not shown.
Binary file added models_orig/iphone_orig.data-00000-of-00001
Binary file not shown.
Binary file added models_orig/iphone_orig.index
Binary file not shown.
Binary file added models_orig/sony_orig.data-00000-of-00001
Binary file not shown.
Binary file added models_orig/sony_orig.index
Binary file not shown.
Empty file added results/.gitkeep
Empty file.
86 changes: 86 additions & 0 deletions ssim.py
@@ -0,0 +1,86 @@
import numpy as np
from scipy import signal
from scipy.ndimage.filters import convolve
import tensorflow as tf


def _FSpecialGauss(size, sigma):

radius = size // 2
offset = 0.0
start, stop = -radius, radius + 1

if size % 2 == 0:
offset = 0.5
stop -= 1

x, y = np.mgrid[offset + start:stop, offset + start:stop]
g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2)))

return g / g.sum()


def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):

img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
_, height, width, _ = img1.shape

size = min(filter_size, height, width)
sigma = size * filter_sigma / filter_size if filter_size else 0

if filter_size:

window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1))
mu1 = signal.fftconvolve(img1, window, mode='valid')
mu2 = signal.fftconvolve(img2, window, mode='valid')
sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')

else:

mu1, mu2 = img1, img2
sigma11 = img1 * img1
sigma22 = img2 * img2
sigma12 = img1 * img2

mu11 = mu1 * mu1
mu22 = mu2 * mu2
mu12 = mu1 * mu2
sigma11 -= mu11
sigma22 -= mu22
sigma12 -= mu12

c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2
v1 = 2.0 * sigma12 + c2
v2 = sigma11 + sigma22 + c2

ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)))
cs = np.mean(v1 / v2)

return ssim, cs


def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, weights=None):

weights = np.array(weights if weights else [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
levels = weights.size

downsample_filter = np.ones((1, 2, 2, 1)) / 4.0
im1, im2 = [x.astype(np.float64) for x in [img1, img2]]

mssim = np.array([])
mcs = np.array([])

for _ in range(levels):

ssim, cs = _SSIMForMultiScale(im1, im2, max_val=max_val, filter_size=filter_size, filter_sigma=filter_sigma, k1=k1, k2=k2)
mssim = np.append(mssim, ssim)
mcs = np.append(mcs, cs)

filtered = [convolve(im, downsample_filter, mode='reflect') for im in [im1, im2]]
im1, im2 = [x[:, ::2, ::2, :] for x in filtered]

return np.prod(mcs[0:levels-1] ** weights[0:levels-1]) * (mssim[levels-1] ** weights[levels-1])
104 changes: 104 additions & 0 deletions test_model.py
@@ -0,0 +1,104 @@
# python test_model.py model=iphone_orig dped_dir=dped/ test_subset=full iteration=all resolution=orig use_gpu=true

from scipy import misc
import numpy as np
import tensorflow as tf
from models import resnet
import utils
import os
import sys

# process command arguments
phone, dped_dir, test_subset, iteration, resolution, use_gpu = utils.process_test_model_args(sys.argv)

# get all available image resolutions
res_sizes = utils.get_resolutions()

# get the specified image resolution
IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_SIZE = utils.get_specified_res(res_sizes, phone, resolution)

# disable gpu if specified
config = tf.ConfigProto(device_count={'GPU': 0}) if use_gpu == "false" else None

# create placeholders for input images
x_ = tf.placeholder(tf.float32, [None, IMAGE_SIZE])
x_image = tf.reshape(x_, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])

# generate enhanced image
enhanced = resnet(x_image)

with tf.Session(config=config) as sess:

test_dir = dped_dir + phone.replace("_orig", "") + "/test_data/full_size_test_images/"
test_photos = [f for f in os.listdir(test_dir) if os.path.isfile(test_dir + f)]

if test_subset == "small":
# use five first images only
test_photos = test_photos[0:5]

if phone.endswith("_orig"):

# load pre-trained model
saver = tf.train.Saver()
saver.restore(sess, "models_orig/" + phone)

for photo in test_photos:

# load training image and crop it if necessary

print("Testing original " + phone.replace("_orig", "") + " model, processing image " + photo)
image = np.float16(misc.imresize(misc.imread(test_dir + photo), res_sizes[phone])) / 255

image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])

# get enhanced image

enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])

before_after = np.hstack((image_crop, enhanced_image))
photo_name = photo.rsplit(".", 1)[0]

# save the results as .png images

misc.imsave("visual_results/" + phone + "_" + photo_name + "_enhanced.png", enhanced_image)
misc.imsave("visual_results/" + phone + "_" + photo_name + "_before_after.png", before_after)

else:

num_saved_models = int(len([f for f in os.listdir("models/") if f.startswith(phone + "_iteration")]) / 2)

if iteration == "all":
iteration = np.arange(1, num_saved_models) * 1000
else:
iteration = [int(iteration)]

for i in iteration:

# load pre-trained model
saver = tf.train.Saver()
saver.restore(sess, "models/" + phone + "_iteration_" + str(i) + ".ckpt")

for photo in test_photos:

# load training image and crop it if necessary

print("iteration " + str(i) + ", processing image " + photo)
image = np.float16(misc.imresize(misc.imread(test_dir + photo), res_sizes[phone])) / 255

image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])

# get enhanced image

enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])

before_after = np.hstack((image_crop, enhanced_image))
photo_name = photo.rsplit(".", 1)[0]

# save the results as .png images

misc.imsave("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_enhanced.png", enhanced_image)
misc.imsave("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_before_after.png", before_after)

0 comments on commit 60cbc28

Please sign in to comment.