In [None]:
from dlsr import *
from tensorflow import keras
import tensorflow as tf

helpers.config(True)

In [None]:
# general
image_size = 96
batch_size = 16
epochs = 500

# generator
gen_n = 1e-4
# discriminator
dis_n = 1e-4

In [None]:
# setup training data
training_data = helpers.get_training_data(
    image_size=image_size, batch_size=batch_size, repeat_count=20, scale=4
)

In [None]:
# create discriminator
discriminator = models.discriminator(image_size=image_size)

# load generator
json_file = open('./results/saved-models/srgan4x/architecture.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
generator = tf.keras.models.model_from_json(loaded_model_json)
generator.load_weights("./results/saved-models/srgan4x/generator.h5")

# create SRGAN
srgan = models.SRGAN(discriminator=discriminator, generator=generator)

# instantiate losses
generator_loss = losses.PerceptualLoss(discriminator=discriminator, image_size=image_size)
discriminator_loss = keras.losses.BinaryCrossentropy(from_logits=False)

# create a history object
history = helpers.History([
    "d_loss",
    "g_loss",
    "val_d_loss",
    "val_g_loss",
    "d_accuracy",
    "g_accuracy",
    "val_d_accuracy",
    "val_g_accuracy",
])

In [None]:
train_gan(
    srgan=srgan,
    training_data=training_data,
    epochs=epochs,
    batch_size=batch_size,
    generator_loss_fn=generator_loss,
    generator_n=gen_n,
    discriminator_loss_fn=discriminator_loss,
    discriminator_n=dis_n,
    history=history,
)

In [None]:
architecture = generator.to_json()
with open("./results/saved-models/srgan4x_2/architecture.json", "w") as f:
    f.write(architecture)
generator.save_weights("./results/saved-models/srgan4x_2/generator.h5")

In [None]:
history.plot({
    "g_loss": "Generator Loss",
    "val_g_loss": "Validation Generator Loss"
})
history.plot({
    "d_loss": "Discriminator Loss",
    "val_d_loss": "Validation Discriminator Loss"
})
history.plot({
    "g_accuracy": "Generator Accuracy",
    "val_g_accuracy": "Validation Generator Accuracy"
})
history.plot({
    "d_accuracy": "Discriminator Accuracy",
    "val_d_accuracy": "Validation Discriminator Accuracy"
})