In [1]:
import os
import sys
sys.path.append('../modules')

import numpy as np
import cv2

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
# config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.8
session = InteractiveSession(config=config)

from srgan_d import SRGAN
from generator import ImageDataGenerator

Using TensorFlow backend.


In [2]:
r = 2
input_dir = f'../data/train/{r}/low/'
tgt_dir = f'../data/train/{r}/high/'

input_dir_val = f'../data/val/{r}/low/'
tgt_dir_val = f'../data/val/{r}/high/'


test_file = 'UBMk30rjy0o_17675_42.jpg'
test_low_path =  input_dir_val + test_file
test_high_path = tgt_dir_val + test_file
test_out_dir = '../data/out/'

model_dir = '../models/'
vgg_name = 'vgg16_notop.hdf5'
vgg_path = model_dir + vgg_name

input_shape = (67, 120, 3)
tgt_shape = (134, 240, 3)

batch_size = 8

In [3]:
srgan = SRGAN(vgg_path, batch_size, input_shape, tgt_shape)
gen = ImageDataGenerator(input_dir, tgt_dir, batch_size)
gen_valid = ImageDataGenerator(input_dir_val, tgt_dir_val, batch_size)

  self.shared_generator = Network(input=layers[0], output=layers[-1], name='generator')
  self.shared_discriminator = Network(input=layers[0], output=layers[-1], name='discriminator')


In [4]:
srgan.generator.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 67, 120, 3)   0                                            
__________________________________________________________________________________________________
G_Head_conv0 (Conv2D)           (None, 67, 120, 128) 3584        input_1[0][0]                    
__________________________________________________________________________________________________
G_Head_norm0 (BatchNormalizatio (None, 67, 120, 128) 512         G_Head_conv0[0][0]               
__________________________________________________________________________________________________
G_Head_act0 (LeakyReLU)         (None, 67, 120, 128) 0           G_Head_norm0[0][0]               
____________________________________________________________________________________________

In [5]:
for e in range(10001):
    X_low, X_high = gen.flow_from_directory().__next__()
    srgan.train(X_low, X_high)
    if e % 100 == 0:
        print('epoch', e)
        srgan.valid(X_low, X_high)
        
        test_img_low = (cv2.imread(test_low_path) / 255.0)[np.newaxis, :, :, :]
        test_img_high = (cv2.imread(test_high_path) / 255.0)[np.newaxis, :, :, :]

        test_img_sr = srgan.generator.predict(test_img_low)
        test_img_sr_int = (test_img_sr * 255).astype('u1')[0]
        test_out_path = test_out_dir + test_file.split('.')[0] + f'_test_{e}.jpg'
        cv2.imwrite(test_out_path, test_img_sr_int)

epoch 0
d_loss_real:1.1153713464736938, d_loss_fake:0.3533744215965271, vgg_loss:2.694037914276123
epoch 100
d_loss_real:0.7907151579856873, d_loss_fake:0.6056246161460876, vgg_loss:0.3184126019477844
epoch 200
d_loss_real:0.7231326103210449, d_loss_fake:0.6633784770965576, vgg_loss:0.2529701888561249
epoch 300
d_loss_real:0.6891871690750122, d_loss_fake:0.7421635389328003, vgg_loss:0.20926277339458466
epoch 400
d_loss_real:0.7078758478164673, d_loss_fake:0.6904424428939819, vgg_loss:0.15136684477329254
epoch 500
d_loss_real:0.7165847420692444, d_loss_fake:0.679253101348877, vgg_loss:0.34675827622413635
epoch 600
d_loss_real:0.6935920715332031, d_loss_fake:0.7042490839958191, vgg_loss:0.1743207722902298
epoch 700
d_loss_real:0.7668333053588867, d_loss_fake:0.6550408601760864, vgg_loss:0.14352600276470184
epoch 800
d_loss_real:0.6856063604354858, d_loss_fake:0.677096962928772, vgg_loss:0.10492296516895294
epoch 900
d_loss_real:0.6821091175079346, d_loss_fake:0.6897879838943481, vgg_loss

epoch 8000
d_loss_real:1.2220957279205322, d_loss_fake:0.4755988121032715, vgg_loss:0.0844523087143898
epoch 8100
d_loss_real:1.898169994354248, d_loss_fake:1.2870672941207886, vgg_loss:0.13337342441082
epoch 8200
d_loss_real:1.3197476863861084, d_loss_fake:0.7352728843688965, vgg_loss:0.0652821734547615
epoch 8300
d_loss_real:1.3366966247558594, d_loss_fake:0.901583194732666, vgg_loss:0.15638545155525208
epoch 8400
d_loss_real:1.7198586463928223, d_loss_fake:0.4400041401386261, vgg_loss:0.16600608825683594
epoch 8500
d_loss_real:1.2895734310150146, d_loss_fake:0.3643782138824463, vgg_loss:0.06728056818246841
epoch 8600
d_loss_real:1.5158324241638184, d_loss_fake:0.8132302761077881, vgg_loss:0.09798571467399597
epoch 8700
d_loss_real:1.1983940601348877, d_loss_fake:0.49901652336120605, vgg_loss:0.08764973282814026
epoch 8800
d_loss_real:1.1435515880584717, d_loss_fake:1.4745426177978516, vgg_loss:0.19253893196582794
epoch 8900
d_loss_real:0.8378617763519287, d_loss_fake:0.4766204059123

In [6]:
out_model_name = '20200521_generator.h5'
srgan.generator.save(model_dir + out_model_name, include_optimizer=False)