In [None]:
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers

AUTOTUNE = tf.data.AUTOTUNE

In [None]:
print(f"Tensorflow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")
print(tf.config.list_physical_devices('GPU'))

# Download the DIV2K Dataset

In [None]:
# Load training and validation datasets
train, train_info = tfds.load('div2k/bicubic_x8', split='train', as_supervised=True, with_info=True)
val, val_info = tfds.load('div2k/bicubic_x8', split='validation', as_supervised=True, with_info=True)

train_cache = train.cache()
val_cache = val.cache()

# Explore the DIV2K Dataset

In [None]:
# Print number of examples in the training and validation sets
print(f"Number of training examples: {train_info.splits['train'].num_examples}")
print(f"Number of validation examples: {val_info.splits['validation'].num_examples}")

In [None]:
import matplotlib.pyplot as plt

# Lấy 6 cặp ảnh từ tập train
samples = list(train_cache.take(6).as_numpy_iterator())

plt.figure(figsize=(20, 38))  

for i, (lowres, highres) in enumerate(samples):
    # Hiển thị ảnh Low-Resolution (LR)
    ax = plt.subplot(6, 2, 2 * i + 1)
    plt.imshow(lowres.astype("uint8"))
    plt.title(f"LR {lowres.shape}", fontsize=14)
    plt.axis("on")

    # Hiển thị ảnh High-Resolution (HR)
    ax = plt.subplot(6, 2, 2 * i + 2)
    plt.imshow(highres.astype("uint8"))
    plt.title(f"HR {highres.shape}", fontsize=14)
    plt.axis("on")

plt.tight_layout()
plt.show()


# Image Augmentation

In [None]:
def flip_left_right(lowres_img, highres_img):
    """Flips Images to left and right."""

    rn = tf.random.uniform(shape=(), maxval=1)
    return tf.cond(
        rn < 0.5,
        lambda: (lowres_img, highres_img),
        lambda: (
            tf.image.flip_left_right(lowres_img),
            tf.image.flip_left_right(highres_img),
        ),
    )


def random_rotate(lowres_img, highres_img):
    """Rotates Images by 90 degrees."""

    rn = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
    
    return tf.image.rot90(lowres_img, rn), tf.image.rot90(highres_img, rn)


def random_crop(lowres_img, highres_img, hr_crop_size=224, scale=8):
    # Crop images.
    
    lowres_crop_size = hr_crop_size // scale  
    lowres_img_shape = tf.shape(lowres_img)[:2]  

    lowres_width = tf.random.uniform(
        shape=(), maxval=lowres_img_shape[1] - lowres_crop_size + 1, dtype=tf.int32
    )
    lowres_height = tf.random.uniform(
        shape=(), maxval=lowres_img_shape[0] - lowres_crop_size + 1, dtype=tf.int32
    )

    highres_width = lowres_width * scale
    highres_height = lowres_height * scale

    lowres_img_cropped = lowres_img[
        lowres_height : lowres_height + lowres_crop_size,
        lowres_width : lowres_width + lowres_crop_size,
    ]  
    highres_img_cropped = highres_img[
        highres_height : highres_height + hr_crop_size,
        highres_width : highres_width + hr_crop_size,
    ]  

    return lowres_img_cropped, highres_img_cropped

# Create the TensorFlow Dataset

In [None]:
batch_size = 10

def dataset_object(dataset_cache, training=True):

    ds = dataset_cache
    ds = ds.map(
        lambda lowres, highres: random_crop(lowres, highres, hr_crop_size = 224, scale=8),
        num_parallel_calls=AUTOTUNE,
    )

    if training:
        ds = ds.map(random_rotate, num_parallel_calls=AUTOTUNE)
        ds = ds.map(flip_left_right, num_parallel_calls=AUTOTUNE)
        
    # Batching Data
    ds = ds.batch(batch_size)

    if training:
        ds = ds.repeat()
        
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds

In [None]:
# NOTE: Turned off caching earlier
train_ds = dataset_object(train_cache, training=True)
val_ds = dataset_object(val_cache, training=False)

# Visualize the Data

Let's take a look at some high-res image patches, and their corresponding low-res patches.  Because we use the same figure size within matplotlib.plt, we can compare them at the same size and see how pixelated the low-res versions are.

In [None]:
lowres, highres = next(iter(train_ds))

# High Resolution Images
plt.figure(figsize=(20, 20))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(highres[i].numpy().astype("uint8"))
    plt.title(highres[i].shape)
    plt.axis("off")

# Low Resolution Images
plt.figure(figsize=(20, 20))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(lowres[i].numpy().astype("uint8"))
    plt.title(lowres[i].shape)
    plt.axis("off")

# Construct the Model

In [None]:
# Residual Block
def ResBlock(inputs):
    x = layers.Conv2D(64, 3, padding="same", activation="relu")(inputs)
    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.Add()([inputs, x])
    return x

# Upsampling Block
def Upsampling(inputs, factor=2, **kwargs):
    x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(inputs)
    x = tf.nn.depth_to_space(x, block_size=factor)

    # Repeat
    x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(x)
    x = tf.nn.depth_to_space(x, block_size=factor)

    x = layers.Conv2D(64 * (factor ** 2), 3, padding="same", **kwargs)(x)
    x = tf.nn.depth_to_space(x, block_size=factor)

    return x


def make_model(num_filters, num_of_residual_blocks):
    input_layer = layers.Input(shape=(None, None, 3))
    
    x = layers.Rescaling(scale=1.0 / 255)(input_layer)
    
    x = x_new = layers.Conv2D(num_filters, 3, padding="same")(x)

    for _ in range(num_of_residual_blocks):
        x_new = ResBlock(x_new)

    x_new = layers.Conv2D(num_filters, 3, padding="same")(x_new)
    x = layers.Add()([x, x_new])

    x = Upsampling(x)
    
    x = layers.Conv2D(3, 3, padding="same")(x)

    output_layer = layers.Rescaling(scale=255)(x)
    
    return keras.Model(input_layer, output_layer)

model = make_model(num_filters=64, num_of_residual_blocks=16)

# Training Configuration

In [None]:
# Optimizer & Learning Rate Scheduling
optim_edsr = keras.optimizers.Adam(learning_rate=1e-4)
lr_scheduler = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)

## Define Custom PSNR Metric

As mentioned earlier, PSNR is a common metric to use for image super-resolution.  TensorFlow already offers a function under its `tf.image` module.

In [None]:
def PSNR(super_resolution, high_resolution):
    """Compute the peak signal-to-noise ratio, measures quality of image."""
    
    psnr_value = tf.reduce_mean(tf.image.psnr(high_resolution, super_resolution, max_val=255))
    return psnr_value

Although PSNR will be our main metric, we'll use **Mean Absolute Error** (L1 Loss) as our loss function.  In theory, L2 loss (**MSE**) would minimize the PSNR, multiple papers have found empirically that using L1 loss instead results in more stable convergence and better overall results, so we'll do that here.

> 🤔 **Hmmm** Remember "compressed sensing" that I mentioned in the beginning of this notebook?  It's all about L1 techniques.  Coincidence!?  😏

In [None]:
# Compiling model with loss as mean absolute error (L1 Loss) and PSNR as metric
model.compile(optimizer=optim_edsr, loss="mae", metrics=[PSNR])

In [None]:
# Checkpoint the best model
best_weights_checkpoint_path="best-model.weights.h5"

save_best_cb = keras.callbacks.ModelCheckpoint(
    filepath=best_weights_checkpoint_path,
    monitor="val_loss",
    save_best_only=True,
    save_weights_only=True,
    save_freq="epoch",
)

# Train!

Because we used the `.repeat()` method on our TensorFlow dataset, the data will act like a generator and create new data infinitely.  So `.fit()` won't understand what an epoch is, since the dataset is infinite.  So, we'll need to tell it how many training steps (batches) of data we consider to be an epoch.

In [None]:
for lr, hr in train_ds.take(1):
    sr = model(lr)
    print("Low-Resolution Input:", lr.shape)
    print("High-Resolution Ground Truth:", hr.shape)
    print("Super-Resolution Output:", sr.shape)


In [None]:
history = model.fit(
    train_ds, 
    epochs=20,
    steps_per_epoch=400, 
    validation_data=val_ds,
    callbacks=[save_best_cb, lr_scheduler]
)

## Load Best Training Weights

In [None]:
model.load_weights(best_weights_checkpoint_path)

In [None]:
model.save("best-model.h5")

In [None]:
def plot_results(lowres, preds):
    """
    Displays low resolution image and super resolution image side-by-side.
    """
    plt.figure(figsize=(24, 14))
    plt.subplot(132), plt.imshow(lowres), plt.title("Low resolution")
    plt.subplot(133), plt.imshow(preds), plt.title("Prediction")
    plt.show()

## Upscale Helper Function

In [None]:
def upscale_image(lowres):
    """Takes (H, W, C) image and returns (4H, 4W, C) image."""
    
    model_inputs = tf.expand_dims(lowres, axis=0) 
    
    SR = model(model_inputs, training=False)
    
    SR = tf.clip_by_value(SR, 0, 255)
    SR = tf.round(SR)
    SR = tf.cast(SR, tf.uint8)

    SR = tf.squeeze(SR, axis=0)
    
    return SR

## Upscale Eye Candy 😮

In [None]:
for lowres, _ in val.take(8):
    
    # lowres = tf.image.random_crop(lowres, (150, 150, 3))  
    
    SR = upscale_image(lowres)
    
    plot_results(lowres, SR)