In [1]:
from Utilities.io import DataLoader
from Utilities.lossMetric import *
from Utilities.trainVal import MinMaxGame
from Models.RRDBNet import RRDBNet
from Models.GAN import Discriminator

### Load in the training dataset
we used the Artificial Mercosur License Plates and Chinese City Parking Dataset for this project.

In [3]:
import numpy as np
import glob
PATH = 'Data/pre/192_96' # only use images with shape 192 by 96 for training
files = glob.glob(PATH + '/*.jpg') * 3  # data augmentation, same image with different brightness and contrast
np.random.shuffle(files)
train, val = files[:int(len(files)*0.8)], files[int(len(files)*0.8):]
loader = DataLoader()
trainData = DataLoader().load(train, batchSize=16)
valData = DataLoader().load(val, batchSize=64)

In [12]:
valData

<PrefetchDataset element_spec=(TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name=None))>

### Training

In [4]:
discriminator = Discriminator()
extractor = buildExtractor()
generator = RRDBNet(blockNum=10)

In [5]:
# loss function combines MAE loss with VGG loss
def contentLoss(y_true, y_pred):
    featurePred = extractor(y_pred)
    feature = extractor(y_true)
    mae = tf.reduce_mean(tfk.losses.mae(y_true, y_pred))
    return 0.1*tf.reduce_mean(tfk.losses.mse(featurePred, feature)) + mae

optimizer = tfk.optimizers.Adam(learning_rate=1e-3)
generator.compile(loss=contentLoss, optimizer=optimizer, metrics=[psnr, ssim])
# epoch set to 20 is good
# When the model reaches PSNR=20/ssim=0.65, we can start the min-max game
history = generator.fit(x=trainData, validation_data=valData, epochs=1, steps_per_epoch=300, validation_steps=100)

  6/300 [..............................] - ETA: 4:29:01 - loss: 3.0402 - psnr: 4.9814 - ssim: 0.0120

### Generative adverserial network training

In [6]:
PARAMS = dict(lrGenerator = 1e-4, 
              lrDiscriminator = 1e-4,
              epochs = 1, 
              stepsPerEpoch = 500, 
              valSteps = 100)
game = MinMaxGame(generator, discriminator, extractor)
log, valLog = game.train(trainData, valData, PARAMS)

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0), HTML(value='')))




### Save the model
Because I defined the model as inherited class of tf keras model, they cannot be safely serialized.  
Therefore, please save the weights only and follow the instructions in tutorial 1 to reload the model  
You can found my pretrained model in the *Pretrained* folder

In [7]:
#generator.save_weights(YOUR_PATH), save_format='tf')