# Image Super-Resolution using an Efficient Sub-Pixel CNN (ESPCN)

<img src="https://i.imgur.com/Wsnp5mR.png" width=1000/>

- [source paper](https://arxiv.org/abs/1609.05158)
- [reference source](https://keras.io/examples/vision/super_resolution_sub_pixel/)

In [None]:
import os
import math
import cv2
import random
import numpy as np

import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.preprocessing.image import img_to_array

from IPython.display import display

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes

import PIL

# Load data: BSDS500 dataset

In [None]:
dataset_url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url, fname="BSR", untar=True)
root_dir = os.path.join(data_dir, "BSDS500/data")

In [None]:
crop_size = 300
upscale_factor = 3
input_size = crop_size // upscale_factor
batch_size = 8

In [None]:
dataset = os.path.join(root_dir, 'images', 'test')
test_img_paths = sorted(
    [os.path.join(dataset, file)
     for file in os.listdir(dataset)
     if '.jpg' in file]
)

# Creat Datasets, Crop and resize images

In [None]:
def process_input(inputs, input_size, upscale_factor):
    return tf.image.resize(inputs, [input_size, input_size], method="area")


def data_generater(dataset):
    datalist = [file
                for file in os.listdir(os.path.join(root_dir,
                                                    'images',
                                                    bytes.decode(dataset)))
                if '.jpg' in file]
    random.shuffle(datalist)
    for file in datalist:
        image = cv2.imread(os.path.join(root_dir,
                                        'images',
                                        bytes.decode(dataset),
                                        file))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (crop_size, crop_size))
        image = image / 255.0
        yield process_input(image, input_size, upscale_factor), image

In [None]:
train_ds = tf.data.Dataset.from_generator(
    data_generater,
    output_signature=(tf.TensorSpec(shape=(None, None, 3),
                                    dtype=tf.float32),
                      tf.TensorSpec(shape=(None, None, 3),
                                    dtype=tf.float32)),
    args=['train']
)

valid_ds = tf.data.Dataset.from_generator(
    data_generater,
    output_signature=(tf.TensorSpec(shape=(None, None, 3),
                                    dtype=tf.float32),
                      tf.TensorSpec(shape=(None, None, 3),
                                    dtype=tf.float32)),
    args=['val']
)

In [None]:
train_ds = train_ds.batch(batch_size).prefetch(buffer_size=32)
valid_ds = valid_ds.batch(batch_size).prefetch(buffer_size=32)

In [None]:
for batch in train_ds.take(1):
    for img in batch[0]:
        print(img.shape)
        display(array_to_img(img))
    for img in batch[1]:
        print(img.shape)
        display(array_to_img(img))

# Build a model

In [None]:
def get_model(upscale_factor=3, channels=3):
    inputs = tf.keras.Input(shape=(None, None, channels))
    x = tf.keras.layers.Conv2D(64, (5, 5),
                               activation='relu', padding='same')(inputs)
    x = tf.keras.layers.Conv2D(64, (3, 3),
                               activation='relu', padding='same')(x)
    x = tf.keras.layers.Conv2D(32, (3, 3),
                               activation='relu', padding='same')(x)
    x = tf.keras.layers.Conv2D(channels * (upscale_factor ** 2), (3, 3),
                               activation='relu', padding='same')(x)
    outputs = tf.nn.depth_to_space(x, upscale_factor)
    model = tf.keras.Model(inputs, outputs)
    return model


model = get_model(upscale_factor=upscale_factor, channels=3)

# Define callbacks to monitor training

In [None]:
def plot_results(img, prefix, title):
    """Plot the result with zoom-in area."""
    img_array = img_to_array(img)
    img_array = img_array.astype("float32") / 255.0

    # Create a new figure with a default subplot.
    fig, ax = plt.subplots()
    im = ax.imshow(img_array[::-1], origin="lower")

    plt.title(title)
    # zoom-factor: 2.0, location: upper-left
    axins = zoomed_inset_axes(ax, 2, loc=2)
    axins.imshow(img_array[::-1], origin="lower")

    # Specify the limits.
    x1, x2, y1, y2 = 200, 300, 100, 200
    # Apply the x-limits.
    axins.set_xlim(x1, x2)
    # Apply the y-limits.
    axins.set_ylim(y1, y2)

    plt.yticks(visible=False)
    plt.xticks(visible=False)

    # Make the line.
    mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue")
    plt.savefig(str(prefix) + "-" + title + ".png")
    plt.show()


def get_lowres_image(img, upscale_factor):
    """Return low-resolution image to use as model input."""
    new_img = img.resize((img.size[0] // upscale_factor,
                          img.size[1] // upscale_factor),
                         PIL.Image.Resampling.BICUBIC,)
    return new_img


def upscale_image(model, img):
    """Predict the result based on input image and restore the image as RGB."""

    img = img_to_array(img)
    img = img.astype("float32") / 255.0

    inputs = np.expand_dims(img, axis=0)
    outputs = model.predict(inputs)

    output_img = outputs[0]
    output_img *= 255.0
    output_img = output_img.clip(0, 255)

    return output_img

In [None]:
class ESPCNCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.test_img = get_lowres_image(load_img(test_img_paths[0]),
                                         upscale_factor)

    # Store PSNR value in each epoch.
    def on_epoch_begin(self, epoch, logs=None):
        self.psnr = []

    def on_epoch_end(self, epoch, logs=None):
        print("Mean PSNR for epoch: %.2f" % (np.mean(self.psnr)))
        if epoch % 20 == 0:
            prediction = upscale_image(self.model, self.test_img)
            plot_results(prediction, "epoch-" + str(epoch), "prediction")

    def on_test_batch_end(self, batch, logs=None):
        self.psnr.append(10 * math.log10(1 / logs["loss"]))

In [None]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor="loss",
                                                           patience=10)

checkpoint_filepath = "/tmp/checkpoint"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor="loss",
    mode="min",
    save_best_only=True,
)

model = get_model(upscale_factor=upscale_factor, channels=3)
model.summary()

callbacks = [ESPCNCallback(),
             early_stopping_callback,
             model_checkpoint_callback]

loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Train the model

In [None]:
epochs = 100

model.compile(optimizer=optimizer, loss=loss_fn)

model.fit(train_ds, epochs=epochs, callbacks=callbacks,
          validation_data=valid_ds, verbose=2)

# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)

# Run model prediction and plot the results

In [None]:
total_bicubic_psnr = 0.0
total_test_psnr = 0.0

for index, test_img_path in enumerate(test_img_paths[50:60]):
    img = load_img(test_img_path)
    lowres_input = get_lowres_image(img, upscale_factor)
    w = lowres_input.size[0] * upscale_factor
    h = lowres_input.size[1] * upscale_factor
    highres_img = img.resize((w, h))
    prediction = upscale_image(model, lowres_input)
    lowres_img = lowres_input.resize((w, h))
    lowres_img_arr = img_to_array(lowres_img)
    highres_img_arr = img_to_array(highres_img)
    predict_img_arr = img_to_array(prediction)
    bicubic_psnr = tf.image.psnr(lowres_img_arr, highres_img_arr, max_val=255)
    test_psnr = tf.image.psnr(predict_img_arr, highres_img_arr, max_val=255)

    total_bicubic_psnr += bicubic_psnr
    total_test_psnr += test_psnr

    print("PSNR of low resolution image and high resolution image is %.4f"
          % bicubic_psnr)
    print("PSNR of predict and high resolution is %.4f" % test_psnr)
    plot_results(lowres_img, index, "lowres")
    plot_results(highres_img, index, "highres")
    plot_results(prediction, index, "prediction")

print("Avg. PSNR of lowres images is %.4f" % (total_bicubic_psnr / 10))
print("Avg. PSNR of reconstructions is %.4f" % (total_test_psnr / 10))

In [None]:
model.save('SRCNN_rgb.h5')