# **Imports & Iperparameters**

In [None]:
import os
import time
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

import main.utilities.utils as us
import main.models.VI_SRGAN.architectures.architectures_srgan as arcs

from tf_keras.layers import *
from tf_keras.optimizers import Adam
from tf_keras.optimizers.schedules import PiecewiseConstantDecay
from tf_keras.losses import MeanSquaredError, BinaryCrossentropy
from tf_keras.applications.vgg19 import VGG19
from tf_keras.models import Model
from tf_keras.metrics import Mean
from tqdm import tqdm
from PIL import Image

**Iperparametri** \
Questi sono gli Iperparametri che possiamo decidere noi

In [None]:
# SuperRes
BATCH_SIZE   = 16
SCALE        = 4

# DataSet
LABEL_SIZE  = 96

# Model
TRAINING_STEPS = 5000
STEPSxEPOCHS = 1000
EPOCHS = TRAINING_STEPS // STEPSxEPOCHS
UPSAMPLESxSCALE = {
    2: 1,
    4: 2,
    8: 3
}

# Random Seed
AUTOTUNE = tf.data.AUTOTUNE
SEED        = 42
tf.random.set_seed(SEED)

# **Load Data**

In [None]:
# Download DIV2K from TF Datasets
# Using bicubic 4x degradation type
div2k_data = tfds.image.Div2k(config="bicubic_x4")
div2k_data.download_and_prepare()

# Taking train data from div2k_data object
train = div2k_data.as_dataset(split="train", as_supervised=True)
# Validation data
val = div2k_data.as_dataset(split="validation", as_supervised=True)

train_cache = train.cache()
val_cache = val.cache()

**Visualize Training Data**

In [None]:
samples = train.take(1)
ls = []
for sample in samples:
  lr = sample[0]
  hr = sample[1]
  ls.append(lr)
  ls.append(hr)
us.plot_images(ls)

# **Create Dataset & Preprocess Data**

In [None]:
def dataset_object(dataset_cache, training=True):
  ds = dataset_cache

  # Random Crops
  ds = ds.map(
      lambda lowres, highres: us.random_crop(lowres, highres, hr_crop_size=LABEL_SIZE, scale=SCALE),
      num_parallel_calls=AUTOTUNE,
  )

  # Augment Data
  if training:
    ds = ds.map(us.random_rotate, num_parallel_calls=AUTOTUNE)
    ds = ds.map(us.random_flip, num_parallel_calls=AUTOTUNE)
    ds = ds.map(us.random_lr_jpeg_noise, num_parallel_calls=AUTOTUNE)

  # Batching Data
  ds = ds.batch(BATCH_SIZE)

  # Repeating Data, so that cardinality if dataset becomes infinte
  if training:
      ds = ds.repeat()

  # prefetching allows later images to be prepared while the current image is being processed
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds

train_ds = dataset_object(train_cache, training=True)
val_ds = dataset_object(val_cache, training=False)

**Visualize the Data**

In [None]:
lowres, highres = next(iter(train_ds))

# High Resolution Images
plt.figure(figsize=(10, 10))
for i in range(3):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(highres[i].numpy().astype("uint8"))
    plt.title(highres[i].shape)
    plt.axis("off")

# Low Resolution Images
plt.figure(figsize=(10, 10))
for i in range(3):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(lowres[i].numpy().astype("uint8"))
    plt.title(lowres[i].shape)
    plt.axis("off")

# **Model**

**Setup**

In [None]:
# - Compile Properties -
# Optimizers
learning_rate=PiecewiseConstantDecay(boundaries=[100000], values=[1e-4, 1e-5])
generator_optimizer = Adam(learning_rate=learning_rate)
discriminator_optimizer = Adam(learning_rate=learning_rate)

# Losses
binary_cross_entropy = BinaryCrossentropy()
mean_squared_error = MeanSquaredError()

In [None]:
# Load Generator
#generator = Generator()                            | Default Generator
#generator.load_weights("/content/generator.h5")    |

original_model = arcs.Generator()
generator = arcs.SRResNet_DenseBlock(original_model)
generator.load_weights("/content/SRResNet_DenseBlock_fit25_lr0.0005.h5")


# Create Discriminator
discriminator = arcs.Discriminator(hr_crop_size=LABEL_SIZE)

# Load VGG
layer_5_4 = 20
vgg = VGG19(input_shape=(None, None, 3), include_top=False)
perceptual_model = Model(vgg.input, vgg.layers[layer_5_4].output)

In [None]:
# Create Instance of Sbuclassed Model
gan = arcs.TrainingClassGAN(generator, discriminator)

In [None]:
# Compile the Model
gan.compile(generator_optimizer, discriminator_optimizer, mean_squared_error, binary_cross_entropy)

**Custom Loop Training**

In [None]:
perceptual_loss_metric = Mean()
discriminator_loss_metric = Mean()

tensor_psnr = tf.Variable(0.0)
now = time.perf_counter()

monitor_folder = "srgan_bicubic_x4"
os.makedirs(monitor_folder, exist_ok=True)

for epoch in range(EPOCHS):
    print(f"\n --- Start of Training Epoch {epoch + 1} ---")
    step = 0
    for (x_batch, y_batch) in tqdm(train_ds, total=STEPSxEPOCHS):
        step += 1
        perceptual_loss, discriminator_loss = gan.train_step(x_batch, y_batch)
        perceptual_loss_metric(perceptual_loss)
        discriminator_loss_metric(discriminator_loss)
        #print(f"step = {step}")
        if(step >= STEPSxEPOCHS):
            break;

    psnr_values = []
    for lr, hr in val_ds:
        sr = generator.predict(lr)[0]
        sr = tf.clip_by_value(sr, 0, 255)
        sr = tf.round(sr)
        sr = tf.cast(sr, tf.uint8)

        psnr_value = us.PSNR_metric(hr, sr)[0]
        psnr_values.append(psnr_value)
        psnr = tf.reduce_mean(psnr_values)

    image = Image.fromarray(sr.numpy())
    image.save(f"{monitor_folder}/{epoch + 1}.png" )

    duration = time.perf_counter() - now

    now = time.perf_counter()

    print(f'{epoch + 1}/{EPOCHS}, psnr = {psnr}, perceptual loss = {perceptual_loss_metric.result():.4f}, discriminator loss = {discriminator_loss_metric.result():.4f} ({duration:.2f}s)')

    perceptual_loss_metric.reset_states()
    discriminator_loss_metric.reset_states()

    tensor_psnr.assign(psnr)

In [None]:
# Save the Model
generator.save(f"/content/drive/MyDrive/GAN/GAN_LRs/DenseBlock/GAN_DenseBlock_fit5_lr0.0001.h5")

# **References**

[1] image-super-resolution - "https://github.com/jlaihong/image-super-resolution"
[2] Implementing SRResnet/SRGAN Super-Resolution with Tensorflow - "https://medium.com/analytics-vidhya/implementing-srresnet-srgan-super-resolution-with-tensorflow-89900d2ec9b2"
[3] Image Super Resolution: SRResNet and SRGAN TensorFlow 2 implementation and model intuition - "https://www.youtube.com/watch?v=FwvTsx_dxn8&list=WL&index=18"