# **Imports & Iperparameters**

In [None]:
import tf_keras
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

import main.utilities.utils as us
import main.models.III_EDSR.architectures.architectures_edsr as arcs

from tf_keras.layers import *
from tf_keras.optimizers import *
from tf_keras.utils import plot_model
from tf_keras.preprocessing.image import *

**Iperparametri** \
Questi sono gli Iperparametri che possiamo decidere noi

In [None]:
# SuperRes
BATCH_SIZE  = 16
SCALE       = 4

# DataSet
LABEL_SIZE  = 96

# Model
EPOCHS = 50

# Random Seed
SEED        = 42
tf.random.set_seed(SEED)

# **Load Data**

In [None]:
# Download DIV2K from TF Datasets
# Using bicubic 4x degradation type
div2k_data = tfds.image.Div2k(config="bicubic_x4")
div2k_data.download_and_prepare()

# Taking train data from div2k_data object
train = div2k_data.as_dataset(split="train", as_supervised=True)
train_cache = train.cache()
# Validation data
val = div2k_data.as_dataset(split="validation", as_supervised=True)
val_cache = val.cache()

**Visualize Training Data**

In [None]:
len(train), len(val)

# **Create Dataset & Preprocess Data**

In [None]:
def dataset_object(dataset_cache, training=True):
  ds = dataset_cache

  # Random Crops
  ds = ds.map(
      lambda lowres, highres: us.random_crop(lowres, highres, hr_crop_size=LABEL_SIZE, scale=SCALE),
      num_parallel_calls=4,
  )

  # Augment Data
  if training:
    ds = ds.map(us.random_rotate, num_parallel_calls=4)
    ds = ds.map(us.flip_left_right, num_parallel_calls=4)

  # Batching Data
  ds = ds.batch(BATCH_SIZE)

  # Repeating Data, so that cardinality if dataset becomes infinte
  if training:
      ds = ds.repeat()

  # prefetching allows later images to be prepared while the current image is being processed
  ds = ds.prefetch(buffer_size=BATCH_SIZE // 2)
  return ds

train_ds = dataset_object(train_cache, training=True)
val_ds = dataset_object(val_cache, training=False)

**Visualize the Data**

In [None]:
lowres, highres = next(iter(train_ds))

# High Resolution Images
plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(highres[i].numpy().astype("uint8"))
    plt.title(highres[i].shape)
    plt.axis("off")

# Low Resolution Images
plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(lowres[i].numpy().astype("uint8"))
    plt.title(lowres[i].shape)
    plt.axis("off")

# **Model**

**1. Create the Model**

In [None]:
model = arcs.EDSR_Vanilla(num_filters=64, num_of_residual_blocks=16)
model._name = ("EDSR_Vanilla_fit50")

**2. Compile the Model**

In [None]:
# Scegliamo quale Loss e quale Optimizer assegnare al Modello

my_loss = "mae"
my_opt = Adam(
    learning_rate=schedules.PiecewiseConstantDecay(
            boundaries=[5000], values=[0.0001, (0.0001 - 0.00005)]
    )
)

model.compile(
    loss=my_loss,
    optimizer=my_opt,
    metrics=[us.PSNR_metric, us.SSIM_metric]
)

In [None]:
# Eseguiamo il Plot del Modello
model.summary()
plot_model(model, show_shapes=True, rankdir="LR")

**3. Fit the Model**

In [None]:
# Define TensorBoard Callbacks
tb_callback = tf_keras.callbacks.TensorBoard(log_dir=("logs/" + str(model._name)), histogram_freq=1)

In [None]:

# Fit
model.fit(train_ds, epochs=EPOCHS, steps_per_epoch=200, validation_data=val_ds, callbacks=[tb_callback])

# **References**

[1] Image Super-Resolution Using EDSR and WDSR - "https://sumittagadiya.medium.com/image-super-resolution-using-edsr-and-wdsr-f4de0b00e039"
[2] Enhanced Deep Residual Networks for single-image super-resolution - "https://keras.io/examples/vision/edsr/"