# **Imports & Iperparameters**

In [None]:
import tf_keras
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

import main.utilities.utils as us
import main.models.V_SRResNet.architectures.architectures_srresnet as arcs

from tf_keras.layers import *
from tf_keras.optimizers import Adam
from tf_keras.losses import MeanSquaredError, MeanAbsoluteError

**Iperparametri** \
Questi sono gli Iperparametri che possiamo decidere noi

In [None]:
# SuperRes
BATCH_SIZE   = 16
SCALE        = 4

# DataSet
LABEL_SIZE  = 96

# Model
TRAINING_STEPS = 25000
STEPSxEPOCHS = 1000
EPOCHS = TRAINING_STEPS // STEPSxEPOCHS
UPSAMPLESxSCALE = {
    2: 1,
    4: 2,
    8: 3
}

# Random Seed
AUTOTUNE = tf.data.AUTOTUNE
SEED        = 42
tf.random.set_seed(SEED)

# **Load Data**

In [None]:
# Download DIV2K from TF Datasets
# Using bicubic 4x degradation type
div2k_data = tfds.image.Div2k(config="bicubic_x4")
div2k_data.download_and_prepare()

# Taking train data from div2k_data object
train = div2k_data.as_dataset(split="train", as_supervised=True)
# Validation data
val = div2k_data.as_dataset(split="validation", as_supervised=True)

train_cache = train.cache()
val_cache = val.cache()

**Visualize Training Data**

In [None]:
samples = train.take(1)
ls = []
for sample in samples:
  lr = sample[0]
  hr = sample[1]
  ls.append(lr)
  ls.append(hr)
us.plot_images(ls)

# **Create Dataset & Preprocess Data**

In [None]:
def dataset_object(dataset_cache, training=True):
  ds = dataset_cache

  # Random Crops
  ds = ds.map(
      lambda lowres, highres: us.random_crop(lowres, highres, hr_crop_size=LABEL_SIZE, scale=SCALE),
      num_parallel_calls=AUTOTUNE,
  )

  # Augment Data
  if training:
    ds = ds.map(us.random_rotate, num_parallel_calls=AUTOTUNE)
    ds = ds.map(us.random_flip, num_parallel_calls=AUTOTUNE)
    ds = ds.map(us.random_lr_jpeg_noise, num_parallel_calls=AUTOTUNE)

  # Batching Data
  ds = ds.batch(BATCH_SIZE)

  # Repeating Data, so that cardinality if dataset becomes infinte
  if training:
      ds = ds.repeat()

  # prefetching allows later images to be prepared while the current image is being processed
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds

train_ds = dataset_object(train_cache, training=True)
val_ds = dataset_object(val_cache, training=False)

**Visualize the Data**

In [None]:
lowres, highres = next(iter(train_ds))

# High Resolution Images
plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(highres[i].numpy().astype("uint8"))
    plt.title(highres[i].shape)
    plt.axis("off")

# Low Resolution Images
plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(lowres[i].numpy().astype("uint8"))
    plt.title(lowres[i].shape)
    plt.axis("off")


# **Model**

**1. Create the Model**

In [None]:
original_model = arcs.SRResNet_Original()
original_model.load_weights("/content/generator.h5")
original_model.trainable = False

model = arcs.SRResNet_Vanilla(original_model)
model._name = ("SRResNet_Vanilla_fit25_lr0.0001")

**2. Compile the Model**

In [None]:
# Scegliamo quale Loss e quale Optimizer assegnare al Modello

my_loss = MeanSquaredError();
my_opt = Adam(learning_rate=0.0001);

model.compile(
    loss=my_loss,
    optimizer=my_opt,
    metrics=[us.PSNR_metric, us.SSIM_metric]
)

In [None]:
# Eseguiamo il Plot del Modello
model.summary()
us.plot_model(model, show_shapes=True, rankdir="LR")

**3. Fit the Model**

In [None]:

# Define TensorBoard Callbacks
tb_callback = tf_keras.callbacks.TensorBoard(log_dir=("logs/" + str(lr)), histogram_freq=1)

In [None]:
# Fit
model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds, steps_per_epoch=STEPSxEPOCHS, callbacks=[tb_callback])

# **References**

[1] image-super-resolution - "https://github.com/jlaihong/image-super-resolution"
[2] Implementing SRResnet/SRGAN Super-Resolution with Tensorflow - "https://medium.com/analytics-vidhya/implementing-srresnet-srgan-super-resolution-with-tensorflow-89900d2ec9b2"
[3] Image Super Resolution: SRResNet and SRGAN TensorFlow 2 implementation and model intuition - "https://www.youtube.com/watch?v=FwvTsx_dxn8&list=WL&index=18"