In [1]:
import tensorflow as tf
import _init_paths
import os
import numpy as np
from functools import partial
import tensorflow.keras.models as KM
import tensorflow.keras.backend as K
from utils import DisplayPlot
from mydata.loader_coco import COCO
from mymodel.wgangp import WGanGP
from mymodel.srgen import SrGen
from mymodel.utils import Vgg, Resnet, Discriminator
import cfg

In [2]:
import argparse
def GetTrainingOptions(weightAdv = 0.1, # weight for adversal loss
                       weightGP = 10, # weight for gradient Penalty
                       weightFeat = 10, # weight for feature reconstruction loss
                       weightPixel = 0.1, # weight for pixel-wise mean absolute error 
                       devices = '2', # gpu devices to use
                       batch = 12,
                       ratio = 2, # resolution ratio between low and high resolution images
                       lr_g = 5e-5, # learning rate for generator
                       lr_d = 1e-4, # learning rate for discriminator
                       epochs = 100, # epochs to train
                       size_train = 1000, # num of training data per epoch
                       size_val = 1, # num of validation data per epoch
                       dir_results = '.', # root directory for saving results
                       ):
    
    opt = argparse.Namespace()
    opt.weightAdv = weightAdv
    opt.weightGP = weightGP
    opt.weightFeat = weightFeat
    opt.weightPixel = weightPixel
    opt.devices = devices
    opt.batch = batch
    opt.ratio = ratio
    opt.lr_g = lr_g
    opt.lr_d = lr_d
    opt.epochs = epochs
    opt.size_train = size_train
    opt.size_val = size_val
    opt.dir_results = dir_results
    return opt


In [3]:
opt = GetTrainingOptions(epochs=200, weightFeat=20)

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = opt.devices

### dataset

In [5]:
ds = COCO(root=cfg.PATH_COCO, ratio=opt.ratio, batch_size=opt.batch)
ds_tr, ds_val = ds.GetDataset()

### Model

In [6]:
generator = SrGen(input_shape=ds.GetShapeLow())
generator.load_weights(os.path.join('save_model', 'weights.h5'))
discriminator = Discriminator(input_shape=ds.GetShapeHigh())
feat = Resnet(input_shape=ds.GetShapeHigh())

### Main

In [8]:
optimizer_g = tf.keras.optimizers.RMSprop(opt.lr_g)
optimizer_d = tf.keras.optimizers.RMSprop(opt.lr_d)
model = WGanGP(discriminator, generator, feat, opt)
model.compile(optimizer_g=optimizer_g, optimizer_d=optimizer_d)
display = DisplayPlot(root='Imgs', ds=ds)

In [None]:
model.fit(ds_tr,
          steps_per_epoch = opt.size_train,
          epochs = opt.epochs, 
          verbose = 1,
          validation_data = ds_val,
          validation_steps = opt.size_val,
          callbacks = [display])

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200


Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200


Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200