In [None]:
import tensorflow as tf
from tensorflow import keras
from matplotlib import pyplot as plt
import numpy as np
import os

plt.rcParams["figure.figsize"] = (7, 7)

## Define tensorflow pipeline

In [None]:
def load_img(file):
    # load and process the image
    img = tf.io.read_file(file)
    img = tf.io.decode_jpeg(img)
    img = tf.image.convert_image_dtype(img, tf.float32)

    return img

def upsample(img, size):
    # upscale using bicubic interpolation
    img = tf.image.resize(
        img,
        (size, size),
        method=tf.image.ResizeMethod.BICUBIC,
        preserve_aspect_ratio=False,
        antialias=False,
    )

    # clip overflowing values after interpolation
    img = tf.clip_by_value(img, 0.0, 1.0)

    return img


def pipeline(filename):
    # load both low and high res images
    Y_img = load_img(Y_source + filename)
    X_img = load_img(X_source + filename)

    return Y_img, X_img

def PSNR(img, truth, max_val=1):
    return tf.image.psnr(img, truth, max_val)

## Define source directories, and load names of image files

In [None]:
Y_source = '../input/oxfordpet196/images-upsampled-196/'  # low-res images
X_source = '../input/oxfordpet196/images-cropped-196/'  # ground truth high-res images
model_save_loc = './SRResNet-SubPixel/saved-model'

n_samples = 3000
input_size = 96

In [None]:
# get all the names of img files
filenames = os.listdir(Y_source)
rand_filenames = np.random.choice(filenames, n_samples, replace=False)
filenames_dataset = tf.data.Dataset.from_tensor_slices(rand_filenames)

## Create image dataset using pipeline, shuffle it, split it, and batch it

In [None]:
# get dataset of images
dataset = filenames_dataset.map(pipeline)
dataset = dataset.shuffle(len(dataset)//4)

# train-val-test split
dataset_size = len(dataset)
train_size = int(0.85 * dataset_size)
val_size = int(0.05 * dataset_size)
test_size = int(0.1 * dataset_size)

dataset_train = dataset.take(train_size)
dataset_test = dataset.skip(train_size)
dataset_val = dataset_test.skip(test_size)
dataset_test = dataset_test.take(test_size)

# batching
dataset_train = dataset_train.batch(32)
dataset_val = dataset_val.batch(32)
dataset_test = dataset_test.batch(32)


## Define the model, or import a previously trained one

In [None]:
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_PSNR', patience=15, mode='max', restore_best_weights=True
)

checkpoint = keras.callbacks.ModelCheckpoint(
    model_save_loc, monitor='val_PSNR', save_best_only=True,
    mode='max', save_freq='epoch',
    initial_value_threshold=23.0
)

callbacks = [early_stopping, checkpoint]

In [None]:
from tensorflow.keras import layers


def res_net_block(x, filters, filter_size):
    x_skip = x

    # Layer 1
    x = layers.Convolution2D(filters, filter_size, padding='same', activation='relu')(x)
    # Layer 2
    x = layers.Convolution2D(filters, filter_size, padding='same')(x)
    # Add Residue
    x = layers.Add()([x, x_skip])
    x = layers.Activation('relu')(x)

    return x


def conv_block(x, filters, filter_size):
    x = layers.Convolution2D(filters, filter_size, padding='same')(x)
    x = layers.Activation('relu')(x)

    return x

In [None]:
# Basic ResNet Model.
# Requires inputs to already be upsampled to final size, using some kind of interpolation.

inputs = keras.Input(shape=(input_size, input_size, 3))
x = conv_block(inputs, 16, 9)
# x = res_net_block(x, 16, 3)
x = res_net_block(x, 16, 3)
x = res_net_block(x, 16, 3)
x = conv_block(x, 32, 3)
# x = res_net_block(x, 32, 3)
x = res_net_block(x, 32, 3)
x = res_net_block(x, 32, 3)
outputs = layers.Convolution2D(3, 3, activation='linear', padding='same')(x)

model = keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mae', metrics=[PSNR])

In [None]:
# ResNet model with learnable upsampling.
# Upsamples image using Transposed Convolution

inputs = keras.Input(shape=(input_size, input_size, 3))
x = conv_block(inputs, 16, 9)
x = res_net_block(x, 16, 3)
x = res_net_block(x, 16, 3)
x = conv_block(x, 32, 3)
x = res_net_block(x, 32, 3)
x = res_net_block(x, 32, 3)
outputs = layers.Convolution2D(3, 3, activation='linear', padding='same')(x)

model = keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mae', metrics=[PSNR])

In [None]:
# CNN model with learnable upsampling.
# Upsamples image using Sub-pixel convolution

inputs = keras.Input(shape=(input_size, input_size, 3))
x = layers.Conv2DTranspose(8, kernel_size=3, activation='relu')(inputs)  # increase 96x96 -> 98x98
x = conv_block(x, 16, 9)

x_skip = x
x = res_net_block(x, 16, 3)
x = res_net_block(x, 16, 3)
x = res_net_block(x, 16, 3)

x = conv_block(x, 32, 5)
x = res_net_block(x, 32, 3)
x = res_net_block(x, 32, 3)
x = res_net_block(x, 32, 3)

x = conv_block(x, 16, 5)
x = layers.Add()([x, x_skip])
x = conv_block(x, 32, 5)

x = layers.Convolution2D(3 * (2 ** 2), 5, activation='relu', padding='same')(x)  # subpixel conv
x = tf.nn.depth_to_space(x, 2)
outputs = layers.Convolution2D(3, 3, activation='linear', padding='same')(x)

model = keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='mae', metrics=[PSNR])

In [None]:
# load previously trained model to continue training
model = keras.models.load_model(model_save_loc, custom_objects={'res_net_block': res_net_block, 'PSNR': PSNR})

In [None]:
model.summary()
hists = []

## Train the model, view PSNR performance trend, and run some test examples 

In [None]:
hist = model.fit(dataset_train, epochs=100,callbacks=callbacks, validation_data=dataset_val)
hists.append(hist)

In [None]:
hist_train, hist_val = [], []

# join together PSNR data from all epochs
for hist in hists:
    hist_train += hist.history['PSNR']
    hist_val += hist.history['val_PSNR']

# and display them
plt.plot(hist_train, label='train PSNR')
plt.plot(hist_val, label='val PSNR')

plt.legend()
plt.show()

In [None]:
from tensorflow.python.ops.numpy_ops import np_config

np_config.enable_numpy_behavior()

def show(img, title=''):
    plt.axis('off')
    plt.imshow(img)
    plt.title(title)
    plt.show()

In [None]:
# See the SRResNet's performace (Bicubic vs Model vs Ground Truth) on a few test examples
for Y, X in dataset_test.take(5):
    pred_mat = model.predict(Y)[0]
    Y = Y[0]
    X = X[0]
    
    show(Y.numpy(), title=f'Bicubic | PSNR:{round(PSNR(Y, X).numpy(),2)} dB')
    show(pred_mat, title=f'SRResNet | PSNR:{round(PSNR(pred_mat, X).numpy(),2)} dB')
    show(X.numpy(), title='Ground Truth')

In [None]:
# See the SRResNet-SubPix' performace (Bicubic vs Model vs Ground Truth) on a few test examples
for Y, X in dataset_test.take(5):
    pred_mat = model.predict(Y)[0]
    Y = Y[0]
    X = X[0]
    
    # upsample Y using bicupic interpolation 
    Y = upsample(Y, 196)
    
    show(Y.numpy(), title=f'Bicubic | PSNR:{round(PSNR(Y, X).numpy(),2)} dB')
    show(pred_mat, title=f'SRResNet-SubPix | PSNR:{round(PSNR(pred_mat, X).numpy(),2)} dB')
    show(X.numpy(), title='Ground Truth')

In [None]:
# get test accuracy
_, acc = model.evaluate(dataset_test, verbose=0)
print(f'Test PSNR: {round(acc,2)} dB')

## Save the model

In [None]:
model.save(model_save_loc)