In [3]:
# esrGAN upscaling
import tensorflow as tf
from tensorflow import keras
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
from keras.applications import VGG19
from tqdm.notebook import tqdm
import keras.backend as K
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import joblib, os, cv2, time

In [2]:
# get training data
os.chdir('train')
rawX = joblib.load('lrImgs.sav') # rawX/X2 are just np arrays containing low-res/high-res images
rawX2 = joblib.load('hrImgs.sav')
os.chdir('..')

# batch training data
rawX = np.array(rawX)
rawX2 = np.array(rawX2)

m = rawX.shape[0]
batchSize = 4
X = tf.data.Dataset.from_tensor_slices(rawX).batch(batchSize)
X2 = tf.data.Dataset.from_tensor_slices(rawX2).batch(batchSize)
listX = list(X.as_numpy_iterator())
listY = list(X2.as_numpy_iterator())

FileNotFoundError: [Errno 2] No such file or directory: 'lrImgs.sav'

In [None]:
# custom blocks used in esrGAN

class RB(tf.keras.layers.Layer):
  def __init__(self):
    super(RB, self).__init__()
    self.cv1 = Conv2D(64, 3, padding='same')
    self.cv2 = Conv2D(64, 3, padding='same')
    
  def call(self, inputs):
    cv1 = self.cv1(inputs)
    r1 = ReLU()(cv1)
    cv2 = self.cv2(r1)
    add1 = Add()([inputs, cv2])
    return add1

class DenseBlock(tf.keras.layers.Layer):
  def __init__(self):
    super(DenseBlock, self).__init__()
    self.cv1 = Conv2D(64, 3, padding='same')
    self.cv2 = Conv2D(64, 3, padding='same')
    self.cv3 = Conv2D(64, 3, padding='same')
    self.cv4 = Conv2D(64, 3, padding='same')
    self.cv5 = Conv2D(64, 3, padding='same')
  
  def call(self, inputs):
    cv1 = self.cv1(inputs)
    r1 = LeakyReLU()(cv1)
    add1 = Concatenate()([inputs, r1])
    
    cv2 = self.cv2(add1)
    r2 = LeakyReLU()(cv2)
    add2 = Concatenate()([inputs, add1, r2])
  
    cv3 = self.cv3(add2)
    r3 = LeakyReLU()(cv3)
    add3 = Concatenate()([inputs, add1, add2, r3])
  
    cv4 = self.cv4(add3)
    r4 = LeakyReLU()(cv4)
    add4 = Concatenate()([inputs, add1, add2, add3, r4])
  
    cv5 = self.cv5(add4)
    return cv5

class RRDB(tf.keras.layers.Layer):
  def __init__(self):
    super(RRDB, self).__init__()
    self.db1 = DenseBlock()
    self.db2 = DenseBlock()
    self.db3 = DenseBlock()
    # i divided the variables by three so the activations wouldn't get too large
    self.beta1 = tf.Variable(tf.random.normal(shape=()), trainable=True) / 3 
    self.beta2 = tf.Variable(tf.random.normal(shape=()), trainable=True) / 3
    self.beta3 = tf.Variable(tf.random.normal(shape=()), trainable=True) / 3
    self.beta4 = tf.Variable(tf.random.normal(shape=()), trainable=True) / 3
    self.beta5 = tf.Variable(tf.random.normal(shape=()), trainable=True) / 3
    
  def call(self, inputs):
    db1 = self.db1(inputs)
    db1 = tf.scalar_mul(self.beta1, db1)
    add1 = Add()([inputs, db1])
  
    db2 = self.db2(add1)
    db2 = tf.scalar_mul(self.beta2, db2)
    add2 = Add()([add1, db2])
  
    db3 = self.db3(add2)
    db3 = tf.scalar_mul(self.beta3, db3)
    add3 = Add()([add2, db3])

    add3 = tf.scalar_mul(self.beta4, add3)
    inps = tf.scalar_mul(self.beta5, inputs)
    add4 = Add()([inps, add3])
    return add4

def buildDBlock(inp):
  cv1 = Conv2D(32, 3, strides=2, padding='same')(inp)
  bn1 = BatchNormalization()(cv1)
  lr1 = LeakyReLU()(bn1)
  d1 = Dropout(0.2)(lr1)
  return d1

In [None]:
def genGen():
  inp = Input((32, 32, 3))
  layer = Conv2D(64, 9, padding='same')(inp)

  for i in range(3): # customize whichever block for your esrGAN
    layer = RRDB()(layer)
    #layer = DenseBlock()(layer)
    #layer = RB()(layer)

  layer = Conv2D(64, 3, padding='same')(layer)
  layer = Concatenate()([cv1, layer])

  layer = Conv2D(256, 3, padding='same')(layer)
  layer = UpSampling2D()(layer)
  layer = Conv2D(256, 3, padding='same')(layer)
  layer = UpSampling2D()(layer)

  layer = Conv2D(64, 3, padding='same')(layer)
  output = Conv2D(3, 9, padding='same', activation='sigmoid')(layer) # did sigmoid since it seemed to work the best for me

  generator = Model(inp, output, name='generator')
  return generator

def genDisc():
  inp = Input((128, 128, 3))
  layer = Conv2D(32, 3, padding='same')(inp)
  layer = LeakyReLU()(layer)

  for i in range(5):
    layer = buildDBlock(layer)

  flat = Flatten()(layer)
  output = Dense(1, activation='linear')(flat)
  discriminator = Model(inp, output, name='discriminator')

  return discriminator

def build_vgg():
  vgg = VGG19(input_shape=(128, 128, 3), include_top=False, weights="imagenet")
  vgg.outputs = [vgg.layers[6].output] # vgg-54 is [15]
  inputLayer = vgg.layers[0].output

  return Model(inputLayer, vgg.outputs)

In [None]:
# loss functions - I hvan't found a way to get VGG activations before activation functions

def dra(y1Pred, y2Pred):
  return K.mean(tf.math.sigmoid(y1Pred - y2Pred))

def discLoss(truePred, fakePred):
  realLoss = -K.mean(tf.math.log(dra(truePred, fakePred)))
  fakeLoss = -K.mean(tf.math.log(1 - dra(fakePred, truePred)))
  return realLoss + fakeLoss

def genLoss(truePred, fakePred, trueVGG, fakeVGG, y, fakeImgs, lb=5e-2, eta=1e-2):
  mse = MeanSquaredError()
  mae = MeanAbsoluteError()
  percepLoss = mse(trueVGG, fakeVGG)

  realLoss = -K.mean(tf.math.log(dra(truePred, fakePred)))
  fakeLoss = -K.mean(tf.math.log(dra(fakePred, truePred)))
  adLoss = lb * (realLoss + fakeLoss)

  normLoss = eta * mae(y, fakeImgs)

  return percepLoss + adLoss + normLoss

In [None]:
def step(batch, y):
  global genModel, discModel, vgg, genOpt, discOpt
  with tf.GradientTape() as gtape, tf.GradientTape() as dtape:
    fakes = genModel(batch, training=True)
    truePreds = discModel(y, training=True)
    fakePreds = discModel(fakes, training=True)
    trueVGG = vgg(y, training=False)
    fakeVGG = vgg(fakes, training=False)

    dloss = discLoss(truePreds, fakePreds)
    gloss = genLoss(truePreds, fakePreds, trueVGG, fakeVGG, y, fakes, lb=lb)

    gradGen = gtape.gradient(gloss, genModel.trainable_variables)
    genOpt.apply_gradients(zip(gradGen, genModel.trainable_variables))
    if dloss > 15: # discriminator seems to train faster so i cripple it to not get mode collapse
      gradDisc = dtape.gradient(dloss, discModel.trainable_variables)
      discOpt.apply_gradients(zip(gradDisc, discModel.trainable_variables))
    
  return dloss, gloss

def train(epochs, steps=1000):
  global listX, listY, m, batchSize
  for i in range(epochs):
    dcost = 0
    gcost = 0
    gloss = 0
    for batch in tqdm(range(steps)):
      batchInd = np.random.randint(low=0, high=m//batchSize)
      batchX = listX[batchInd]
      batchY = listY[batchInd]
      dloss, gloss = step(batchX, batchY, gloss)

      dcost += dloss
      gcost += gloss

    print('\n-----Epoch: {} | Discriminator Cost: {} | Generator Cost: {}-----\n'.format(i, dcost, gcost))

In [None]:
vgg = build_vgg()
vgg.trainable = False

genModel = genGen()
discModel = genDisc()

# load trained models here
'''
mode='DB' # DB = dense block, RRDB = RRDB
if tf.__version__ == '2.2.0':
  genModel = tf.keras.models.load_model('models/tf_220/esrGAN_{}/gen'.format(mode))
  discModel = tf.keras.models.load_model('models/tf_220/esrGAN_{}/disc'.format(mode))
else:
  genModel = tf.keras.models.load_model('models/tf_230/esrGAN_{}/gen'.format(mode))
  discModel = tf.keras.models.load_model('models/tf_230/esrGAN_{}/disc'.format(mode))
'''

In [None]:
genOpt = Adam(learning_rate=1e-4)
discOpt = Adam(learning_rate=1e-4)

# progressively increase lower lambda and lower learning rate to get sharper image quiality 
lb = 5e-3
#lb = 5e-2
#lb = 3e-1

In [None]:
# training loop - show images and train

while True:
  rows, cols = 3, 5
  fig = plt.figure(figsize=(30, 15))
  axes = fig.subplots(rows, cols)
  for i in range(cols):
    if i % 2 == 0:
      predInput = np.array([rawX[i]])
      pred = genModel.predict(predInput)[0]
    
      axes[0][i].imshow(rawX[i])
      axes[1][i].imshow(pred)
      axes[2][i].imshow(rawX2[i])
    else:
      randI = np.random.randint(low=0, high=m)
      predInput = np.array([rawX[randI]])
      pred = genModel.predict(predInput)[0]

      axes[0][i].imshow(rawX[randI])
      axes[1][i].imshow(pred)
      axes[2][i].imshow(rawX2[randI])

  plt.show()
  train(1)

In [None]:
now = time.time()
if tf.__version__ == '2.2.0':
  genModel.save('models/tf_220/esrGAN_{}_{}/gen'.format(mode, now))
  discModel.save('models/tf_220/esrGAN_{}_{}/disc'.format(mode, now))
else:
  genModel.save('models/tf_230/esrGAN_{}_{}/gen'.format(mode, now))
  discModel.save('models/tf_230/esrGAN_{}_{}/disc'.format(mode, now))

In [None]:
# clear out memory

import gc
gc.collect()
K.clear_session()