## SRGAN (Super-Resolution Generative Adversarial Network)

A tensorflow implementation of "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" paper (https://arxiv.org/abs/1609.04802). This implementation is quite different from original paper. The differences are as followings:

1. MNIST data set is used for convenience.
2. Replaced MSE loss with GAN using tuple input for discriminator.
3. Used sub-pixel CNN instead of deconvolution. (see : http://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Shi_Real-Time_Single_Image_CVPR_2016_paper.pdf)

The existing CNN based super-resolution skill mainly use MSE loss and this makes super-resolved images look blurry. If we replace MSE loss with gradients from GAN, we may prevent the blurry artifacts of the super-resolved images and this is the key idea of this paper.

#### Credits

Credits for this code goes to https://github.com/buriburisuri/.

In [1]:
import sugartensor as tf
import numpy as np

In [2]:
# set log level to debug
tf.sg_verbosity(10)

In [3]:
# hyper parameters
batch_size = 32

In [4]:
# MNIST input tensor (with QueueRunner)
data = tf.sg_data.Mnist(batch_size=batch_size)

# input images
x = data.train.image

# low resolution image
x_small = tf.image.resize_bicubic(x, (14, 14))
x_nearest = tf.image.resize_images(x_small, (28, 28), tf.image.ResizeMethod.NEAREST_NEIGHBOR)

# generator labels (all ones)
y = tf.ones(batch_size, dtype=tf.sg_floatx)

# discriminator labels (half 1s, half 0s)
y_disc = tf.concat([y, y * 0], 0, name='concat')

Extracting ./asset/data/mnist/train-images-idx3-ubyte.gz
Extracting ./asset/data/mnist/train-labels-idx1-ubyte.gz
Extracting ./asset/data/mnist/t10k-images-idx3-ubyte.gz
Extracting ./asset/data/mnist/t10k-labels-idx1-ubyte.gz


In [5]:
# generator network
with tf.sg_context(name='generator', act='relu', bn=True):
    gen = (x_small
           .sg_conv(dim=32)
           .sg_conv()
           .sg_conv(dim=4, act='sigmoid', bn=False)
           .sg_periodic_shuffle(factor=2))
    
# add image summary
tf.sg_summary_image(gen)

In [6]:
# input image pairs
x_real_pair = tf.concat([x_nearest, x], 3, name='concat')
x_fake_pair = tf.concat([x_nearest, gen], 3, name='concat')

In [7]:
# create discriminator & recognizer

# create real + fake image input
xx = tf.concat([x_real_pair, x_fake_pair], 0, name='concat')

with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu'):
    # discriminator part
    disc = (xx.sg_conv(dim=64)
              .sg_conv(dim=128)
              .sg_flatten()
              .sg_dense(dim=1024)
              .sg_dense(dim=1, act='linear')
              .sg_squeeze())

In [8]:
# loss and train ops
loss_disc = tf.reduce_mean(disc.sg_bce(target=y_disc))  # discriminator loss
loss_gen = tf.reduce_mean(disc.sg_reuse(input=x_fake_pair).sg_bce(target=y))  # generator loss

train_disc = tf.sg_optim(loss_disc, lr=0.0001, category='discriminator')  # discriminator train ops
train_gen = tf.sg_optim(loss_gen, lr=0.001, category='generator')  # generator train ops

In [9]:
# training

# def alternate training func
@tf.sg_train_func
def alt_train(sess, opt):
    l_disc = sess.run([loss_disc, train_disc])[0]  # training discriminator
    l_gen = sess.run([loss_gen, train_gen])[0]  # training generator
    return np.mean(l_disc) + np.mean(l_gen)

# do training
alt_train(log_interval=10, max_ep=20, ep_size=data.train.num_batch, early_stop=False)

INFO:tensorflow:Restoring parameters from asset/train/model.ckpt-17385


I 0325:18:43:55.548:sg_train.py:327] Training started from epoch[010]-step[17385].


INFO:tensorflow:global_step/sec: 0


train:  14%|███▍                    | 247/1718 [00:09<00:57, 25.76b/s]

INFO:tensorflow:global_step/sec: 8.65219


train:  18%|████▏                   | 303/1718 [00:19<01:31, 15.52b/s]

INFO:tensorflow:global_step/sec: 11.2995


train:  21%|████▉                   | 357/1718 [00:29<01:53, 12.04b/s]

INFO:tensorflow:global_step/sec: 10.6968


train:  24%|█████▋                  | 409/1718 [00:39<02:06, 10.35b/s]

INFO:tensorflow:global_step/sec: 10.5035


train:  27%|██████▍                 | 461/1718 [00:49<02:15,  9.31b/s]

INFO:tensorflow:global_step/sec: 10.3964


train:  30%|███████▏                | 511/1718 [00:59<02:20,  8.59b/s]

INFO:tensorflow:global_step/sec: 10.0028


train:  33%|███████▊                | 563/1718 [01:09<02:22,  8.10b/s]

INFO:tensorflow:global_step/sec: 10.3985


train:  36%|████████▌               | 615/1718 [01:19<02:22,  7.72b/s]

INFO:tensorflow:global_step/sec: 10.302


train:  39%|█████████▎              | 667/1718 [01:29<02:21,  7.44b/s]

INFO:tensorflow:global_step/sec: 10.3997


I 0325:18:48:46.786:sg_train.py:301] 	Epoch[010:gs=20411] - loss = 1.599364
I 0325:18:49:26.361:sg_train.py:301] 	Epoch[011:gs=20821] - loss = 1.594539
I 0325:18:54:25.047:sg_train.py:301] 	Epoch[012:gs=23847] - loss = 1.687448
I 0325:18:55:07.348:sg_train.py:301] 	Epoch[013:gs=24257] - loss = 1.639389
I 0325:19:00:36.803:sg_train.py:301] 	Epoch[014:gs=27283] - loss = 1.640004
I 0325:19:01:22.037:sg_train.py:301] 	Epoch[015:gs=27693] - loss = 1.704872
I 0325:19:06:52.349:sg_train.py:301] 	Epoch[016:gs=30719] - loss = 1.690638
I 0325:19:07:37.229:sg_train.py:301] 	Epoch[017:gs=31129] - loss = 1.807905
I 0325:19:13:07.412:sg_train.py:301] 	Epoch[018:gs=34155] - loss = 1.869398
I 0325:19:13:52.517:sg_train.py:301] 	Epoch[019:gs=34565] - loss = 1.836846
I 0325:19:19:28.311:sg_train.py:301] 	Epoch[020:gs=37591] - loss = 1.920649
I 0325:19:19:30.469:sg_train.py:368] Training finished at epoch[20]-step[37591].
