# SRGAN

In [1]:
import os
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
import numpy as np
from data import DIV2K
from model.srgan import generator, discriminator
from model.srgan1 import SRGAN, Discriminator
from train import SrganTrainer, SrganGeneratorTrainer

%matplotlib inline

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)
seed_value = 1234
tf.random.set_seed(seed_value)

In [2]:
# Location of model weights (needed for demo)
weights_dir = 'weights/srgan'
weights_file = lambda filename: os.path.join(weights_dir, filename)

os.makedirs(weights_dir, exist_ok=True)

## Datasets

You don't need to download the DIV2K dataset as the required parts are automatically downloaded by the `DIV2K` class. By default, DIV2K images are stored in folder `.div2k` in the project's root directory.

In [3]:
div2k_train = DIV2K(scale=4, subset='train', downgrade='bicubic')
div2k_valid = DIV2K(scale=4, subset='valid', downgrade='bicubic')

In [4]:
train_ds = div2k_train.dataset(batch_size=16, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=16, random_transform=True, repeat_count=1)

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


## Training

### Pre-trained models

If you want to skip training and directly run the demo below, download [weights-srgan.tar.gz](https://martin-krasser.de/sisr/weights-srgan.tar.gz) and extract the archive in the project's root directory. This will create a folder `weights/srgan` containing the weights of the pre-trained models.

### Generator pre-training

In [5]:
generator = SRGAN()

In [6]:
pre_trainer = SrganGeneratorTrainer(model=generator, checkpoint_dir=f'.ckpt/pre_generator')
pre_trainer.train(train_ds,
                  valid_ds.take(1000),
                  steps=200000, 
                  evaluate_every=1000, 
                  save_best_only=False)

pre_trainer.model.save_weights(weights_file('pre_generator.h5'))

1000/200000: loss = 716.943, PSNR = 23.302647 (90.83s)


ResourceExhaustedError:  OOM when allocating tensor with shape[16,256,48,48] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[node srgan/upsample_1/conv2d_38/Conv2D (defined at /home/aryan/Source/AI/vision/super-resolution/model/srgan1.py:36) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_train_step_2146]

Errors may have originated from an input operation.
Input Source operations connected to node srgan/upsample_1/conv2d_38/Conv2D:
 srgan/upsample/p_re_lu_18/add (defined at /home/aryan/Source/AI/vision/super-resolution/model/srgan1.py:38)

Function call stack:
train_step


In [8]:
pre_trainer.restore()

Model restored from checkpoint at step 200000.


### Generator fine-tuning (GAN)

In [18]:
gan_generator = generator()
gan_generator.load_weights(weights_file('pre_generator.h5'))

gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=50000)

50/50000, perceptual loss = 0.0865, discriminator loss = 1.8275
100/50000, perceptual loss = 0.0807, discriminator loss = 1.1918
150/50000, perceptual loss = 0.0820, discriminator loss = 1.2629
200/50000, perceptual loss = 0.0803, discriminator loss = 0.4169
250/50000, perceptual loss = 0.0905, discriminator loss = 0.6568
300/50000, perceptual loss = 0.0769, discriminator loss = 1.0456
350/50000, perceptual loss = 0.0809, discriminator loss = 0.9208
400/50000, perceptual loss = 0.0872, discriminator loss = 0.3032
450/50000, perceptual loss = 0.0843, discriminator loss = 0.8855
500/50000, perceptual loss = 0.0848, discriminator loss = 0.5697
550/50000, perceptual loss = 0.0830, discriminator loss = 0.7960
600/50000, perceptual loss = 0.0833, discriminator loss = 0.5764
650/50000, perceptual loss = 0.0860, discriminator loss = 0.8572
700/50000, perceptual loss = 0.0929, discriminator loss = 0.5304
750/50000, perceptual loss = 0.0917, discriminator loss = 0.3951
800/50000, perceptual loss

6300/50000, perceptual loss = 0.0884, discriminator loss = 0.6178
6350/50000, perceptual loss = 0.0848, discriminator loss = 0.4522
6400/50000, perceptual loss = 0.0859, discriminator loss = 0.5688
6450/50000, perceptual loss = 0.0850, discriminator loss = 0.3792
6500/50000, perceptual loss = 0.0841, discriminator loss = 0.3680
6550/50000, perceptual loss = 0.0787, discriminator loss = 0.2644
6600/50000, perceptual loss = 0.0845, discriminator loss = 0.7636
6650/50000, perceptual loss = 0.0837, discriminator loss = 0.5472
6700/50000, perceptual loss = 0.0868, discriminator loss = 0.8105
6750/50000, perceptual loss = 0.0836, discriminator loss = 0.4608
6800/50000, perceptual loss = 0.0818, discriminator loss = 0.3820
6850/50000, perceptual loss = 0.0857, discriminator loss = 0.8046
6900/50000, perceptual loss = 0.0863, discriminator loss = 0.2444
6950/50000, perceptual loss = 0.0848, discriminator loss = 0.6393
7000/50000, perceptual loss = 0.0798, discriminator loss = 0.6420
7050/50000

12500/50000, perceptual loss = 0.0867, discriminator loss = 0.4057
12550/50000, perceptual loss = 0.0866, discriminator loss = 0.4955
12600/50000, perceptual loss = 0.0841, discriminator loss = 0.3872
12650/50000, perceptual loss = 0.0842, discriminator loss = 0.4327
12700/50000, perceptual loss = 0.0820, discriminator loss = 0.2488
12750/50000, perceptual loss = 0.0822, discriminator loss = 0.3350
12800/50000, perceptual loss = 0.0902, discriminator loss = 0.3519
12850/50000, perceptual loss = 0.0877, discriminator loss = 0.2982
12900/50000, perceptual loss = 0.0817, discriminator loss = 0.3425
12950/50000, perceptual loss = 0.0792, discriminator loss = 0.5284
13000/50000, perceptual loss = 0.0789, discriminator loss = 0.2688
13050/50000, perceptual loss = 0.0866, discriminator loss = 0.5037
13100/50000, perceptual loss = 0.0845, discriminator loss = 0.5694
13150/50000, perceptual loss = 0.0850, discriminator loss = 0.2497
13200/50000, perceptual loss = 0.0798, discriminator loss = 0.

18650/50000, perceptual loss = 0.0848, discriminator loss = 0.4496
18700/50000, perceptual loss = 0.0816, discriminator loss = 0.3503
18750/50000, perceptual loss = 0.0795, discriminator loss = 0.2721
18800/50000, perceptual loss = 0.0787, discriminator loss = 0.4634
18850/50000, perceptual loss = 0.0847, discriminator loss = 0.3734
18900/50000, perceptual loss = 0.0840, discriminator loss = 0.3859
18950/50000, perceptual loss = 0.0818, discriminator loss = 0.3211
19000/50000, perceptual loss = 0.0828, discriminator loss = 0.5334
19050/50000, perceptual loss = 0.0788, discriminator loss = 0.2273
19100/50000, perceptual loss = 0.0872, discriminator loss = 0.6527
19150/50000, perceptual loss = 0.0769, discriminator loss = 0.2446
19200/50000, perceptual loss = 0.0839, discriminator loss = 0.4675
19250/50000, perceptual loss = 0.0800, discriminator loss = 0.3118
19300/50000, perceptual loss = 0.0843, discriminator loss = 0.4749
19350/50000, perceptual loss = 0.0831, discriminator loss = 0.

24800/50000, perceptual loss = 0.0805, discriminator loss = 0.4170
24850/50000, perceptual loss = 0.0843, discriminator loss = 0.2260
24900/50000, perceptual loss = 0.0814, discriminator loss = 0.4599
24950/50000, perceptual loss = 0.0748, discriminator loss = 0.3773
25000/50000, perceptual loss = 0.0819, discriminator loss = 0.2417
25050/50000, perceptual loss = 0.0799, discriminator loss = 0.2420
25100/50000, perceptual loss = 0.0790, discriminator loss = 0.2235
25150/50000, perceptual loss = 0.0867, discriminator loss = 0.5850
25200/50000, perceptual loss = 0.0818, discriminator loss = 0.4213
25250/50000, perceptual loss = 0.0823, discriminator loss = 0.3130
25300/50000, perceptual loss = 0.0845, discriminator loss = 0.2623
25350/50000, perceptual loss = 0.0790, discriminator loss = 0.4147
25400/50000, perceptual loss = 0.0842, discriminator loss = 0.4970
25450/50000, perceptual loss = 0.0889, discriminator loss = 0.2325
25500/50000, perceptual loss = 0.0797, discriminator loss = 0.

30950/50000, perceptual loss = 0.0849, discriminator loss = 0.4530
31000/50000, perceptual loss = 0.0819, discriminator loss = 0.2992
31050/50000, perceptual loss = 0.0822, discriminator loss = 0.2305
31100/50000, perceptual loss = 0.0762, discriminator loss = 0.2601
31150/50000, perceptual loss = 0.0810, discriminator loss = 0.2419
31200/50000, perceptual loss = 0.0775, discriminator loss = 0.2453
31250/50000, perceptual loss = 0.0824, discriminator loss = 0.2623
31300/50000, perceptual loss = 0.0803, discriminator loss = 0.3935
31350/50000, perceptual loss = 0.0793, discriminator loss = 0.3473
31400/50000, perceptual loss = 0.0840, discriminator loss = 0.3245
31450/50000, perceptual loss = 0.0811, discriminator loss = 0.2839
31500/50000, perceptual loss = 0.0792, discriminator loss = 0.3244
31550/50000, perceptual loss = 0.0835, discriminator loss = 0.2165
31600/50000, perceptual loss = 0.0762, discriminator loss = 0.2634
31650/50000, perceptual loss = 0.0811, discriminator loss = 0.

37100/50000, perceptual loss = 0.0836, discriminator loss = 0.1655
37150/50000, perceptual loss = 0.0851, discriminator loss = 0.3524
37200/50000, perceptual loss = 0.0827, discriminator loss = 0.3330
37250/50000, perceptual loss = 0.0872, discriminator loss = 0.0932
37300/50000, perceptual loss = 0.0823, discriminator loss = 0.0888
37350/50000, perceptual loss = 0.0782, discriminator loss = 0.5133
37400/50000, perceptual loss = 0.0799, discriminator loss = 0.1517
37450/50000, perceptual loss = 0.0725, discriminator loss = 0.2423
37500/50000, perceptual loss = 0.0815, discriminator loss = 0.3906
37550/50000, perceptual loss = 0.0774, discriminator loss = 0.2371
37600/50000, perceptual loss = 0.0810, discriminator loss = 0.3706
37650/50000, perceptual loss = 0.0787, discriminator loss = 0.3114
37700/50000, perceptual loss = 0.0850, discriminator loss = 0.3646
37750/50000, perceptual loss = 0.0815, discriminator loss = 0.1679
37800/50000, perceptual loss = 0.0801, discriminator loss = 0.

43250/50000, perceptual loss = 0.0804, discriminator loss = 0.5974
43300/50000, perceptual loss = 0.0748, discriminator loss = 0.1358
43350/50000, perceptual loss = 0.0810, discriminator loss = 0.2360
43400/50000, perceptual loss = 0.0858, discriminator loss = 0.2212
43450/50000, perceptual loss = 0.0823, discriminator loss = 0.4945
43500/50000, perceptual loss = 0.0773, discriminator loss = 0.1337
43550/50000, perceptual loss = 0.0812, discriminator loss = 0.3839
43600/50000, perceptual loss = 0.0770, discriminator loss = 0.1204
43650/50000, perceptual loss = 0.0826, discriminator loss = 0.2110
43700/50000, perceptual loss = 0.0798, discriminator loss = 0.2285
43750/50000, perceptual loss = 0.0805, discriminator loss = 0.1879
43800/50000, perceptual loss = 0.0793, discriminator loss = 0.3308
43850/50000, perceptual loss = 0.0792, discriminator loss = 0.0845
43900/50000, perceptual loss = 0.0868, discriminator loss = 0.1022
43950/50000, perceptual loss = 0.0820, discriminator loss = 0.

49400/50000, perceptual loss = 0.0791, discriminator loss = 0.4128
49450/50000, perceptual loss = 0.0755, discriminator loss = 0.2852
49500/50000, perceptual loss = 0.0797, discriminator loss = 0.0586
49550/50000, perceptual loss = 0.0796, discriminator loss = 0.1406
49600/50000, perceptual loss = 0.0797, discriminator loss = 0.0517
49650/50000, perceptual loss = 0.0782, discriminator loss = 0.1698
49700/50000, perceptual loss = 0.0823, discriminator loss = 0.1579
49750/50000, perceptual loss = 0.0842, discriminator loss = 0.3854
49800/50000, perceptual loss = 0.0811, discriminator loss = 0.2287
49850/50000, perceptual loss = 0.0791, discriminator loss = 0.1196
49900/50000, perceptual loss = 0.0897, discriminator loss = 0.1590
49950/50000, perceptual loss = 0.0782, discriminator loss = 0.3169
50000/50000, perceptual loss = 0.0830, discriminator loss = 0.1306


In [19]:
gan_trainer.generator.save_weights(weights_file('gan_generator.h5'))
gan_trainer.discriminator.save_weights(weights_file('gan_discriminator.h5'))

## test

In [None]:
from model.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('weights/srgan/pre_generator.h5')

## Demo

In [9]:
pre_generator = generator()
gan_generator = generator()

pre_generator.load_weights(weights_file('pre_generator.h5'))
gan_generator.load_weights(weights_file('gan_generator.h5'))

ValueError: The first argument to `Layer.call` must always be passed.

In [10]:
from model import resolve_single
from utils import load_image

lr = load_image('demo/0869x4-crop.png')
lr_batch = tf.expand_dims(lr, axis=0)
lr_batch = tf.cast(lr_batch, tf.float32)
sr_batch = generator(lr_batch)
sr_batch = tf.clip_by_value(sr_batch, 0, 255)
sr_batch = tf.round(sr_batch)
sr_batch = tf.cast(sr_batch, tf.uint8)

In [11]:
img = np.squeeze(np.array(sr_batch))
img = img[...,::-1]

In [14]:
cv2.imshow('hello', cv2.resize(img, (992, 944)))
if cv2.waitKey(0) == 27:
    cv2.destroyAllWindows()

In [28]:
cv2.imwrite('cat3.jpg', img)

True

In [None]:
from model import resolve_single
from utils import load_image

def resolve_and_plot(lr_image_path):
    lr = load_image(lr_image_path)
    
    pre_sr = resolve_single(pre_generator, lr)
    gan_sr = resolve_single(gan_generator, lr)
    
    plt.figure(figsize=(20, 20))
    
    images = [lr, pre_sr, gan_sr]
    titles = ['LR', 'SR (PRE)', 'SR (GAN)']
    positions = [1, 3, 4]
    
    for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
        plt.subplot(2, 2, pos)
        plt.imshow(img)
        plt.title(title)
        plt.xticks([])
        plt.yticks([])

In [None]:
resolve_and_plot('demo/0869x4-crop.png')