In [1]:
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, PReLU, ZeroPadding2D, Add
from tensorflow.keras.callbacks import Callback, EarlyStopping
from keras import backend as K
from keras.callbacks import ModelCheckpoint
import numpy as np
import os
import cv2
from PIL import Image

In [2]:
print(tf.__version__)

2.12.0


### Global variables

In [3]:
TRAIN_DIR = '/content/drive/MyDrive/SRCNN/Train'
TEST_DIR = '/content/drive/MyDrive/SRCNN/Test/Set5'
VAL_DIR = '/content/drive/MyDrive/SRCNN/Test/Set14'
SCALE = 3
scales = [2,3,4]
SIZE_INPUT = 33
SIZE_LABEL = 33
STRIDE = 14

### Dataset

In [4]:
def crop(image, factor):
    # ensure no remainder while scaling

    size = image.shape
    size -= np.mod(size, factor)
    image = image[:size[0], :size[1]]
    return image

def generator(data_path):
    files = os.listdir(data_path)
    def img_gen():
        for file in files:
            path = os.path.join(data_path, file)
            img_arr = np.array(Image.open(path), dtype=np.uint8)
            img_arr = np.array(cv2.cvtColor(np.array(img_arr), cv2.COLOR_RGB2YCrCb), dtype=float)
            y_channel = img_arr[:,:,0]
            for scale in scales:
                img_label = tf.expand_dims(crop(y_channel, scale), axis=-1)
                h, w, _ = img_label.shape

                down_sample = tf.image.resize(img_label, (int(h/scale), int(w/scale)), method='bicubic')
                img_input = tf.image.resize(down_sample, (h, w), method='bicubic')
                
                for x in range(0, h-SIZE_INPUT+1, STRIDE):
                    for y in range(0, w-SIZE_INPUT+1, STRIDE):
                        sub_img_input = img_input[x:x+SIZE_INPUT, y:y+SIZE_INPUT]
                        sub_img_label = img_label[x:x+SIZE_INPUT, y:y+SIZE_INPUT]

                        yield (sub_img_input, sub_img_label)
    return img_gen

def create_dataset(path):
    dataset = tf.data.Dataset.from_generator(
        generator(path),
        output_signature=(
            tf.TensorSpec(shape=(SIZE_INPUT, SIZE_INPUT, 1),),
            tf.TensorSpec(shape=(SIZE_LABEL, SIZE_LABEL, 1),)
        )
    )
    return dataset

In [6]:
train = create_dataset(TRAIN_DIR).shuffle(30000).batch(64)
val = create_dataset(VAL_DIR).shuffle(1000).batch(64)
# test = create_dataset(TEST_DIR).shuffle(1000).batch(256)

In [7]:
def PSNR(y_true, y_pred):
	max_pixel = 255.0
	return 10.0 * np.log10((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) 

### Model

In [8]:
class VDSR(tf.keras.Model):
    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Gradient clipping
        lr = self.optimizer.lr
        clip_value = 0.5/lr
        clipped_grads = [tf.clip_by_value(grad, -clip_value, clip_value) for grad in gradients]
        # Update weights
        self.optimizer.apply_gradients(zip(clipped_grads, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

In [9]:
def create_model():
    input = Input(shape = (SIZE_INPUT, SIZE_INPUT, 1))
    x = Conv2D(64, 3, kernel_initializer='he_normal', padding='same', activation='relu')(input)
    for i in range(19):
        x = Conv2D(64, 3, kernel_initializer='he_normal', padding='same', activation='relu')(x)
    x = Conv2D(1, 3, kernel_initializer='he_normal', padding='same', activation='relu')(x)
    output = Add()([input, x])
    model = VDSR(inputs = input, outputs = output, name = "VDSR") 
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate = 0.1),
        loss="mse",
    )
    return model

In [15]:
m = create_model()
m.summary()

Model: "VDSR"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 33, 33, 1)]  0           []                               
                                                                                                  
 conv2d_21 (Conv2D)             (None, 33, 33, 64)   640         ['input_2[0][0]']                
                                                                                                  
 conv2d_22 (Conv2D)             (None, 33, 33, 64)   36928       ['conv2d_21[0][0]']              
                                                                                                  
 conv2d_23 (Conv2D)             (None, 33, 33, 64)   36928       ['conv2d_22[0][0]']              
                                                                                               

In [17]:
# save the model with min val_loss
file_path = "/content/drive/MyDrive/SRCNN/saved_model/vdsr.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

In [18]:
# This function keeps the initial learning rate for the first ten epochs
# and decreases it exponentially after that.
def scheduler(epoch, lr):
    if epoch < 20:
        return lr
    elif epoch % 20 == 0:
        return lr*0.1
    return lr
lr_decay = tf.keras.callbacks.LearningRateScheduler(scheduler)

In [None]:
m.fit(
    train,
    epochs = 80,
    validation_data = val,
    callbacks = [
        # EarlyStopping(monitor = 'val_loss', patience=5),
        checkpoint,
        # lr_decay
    ]
)

### Test results

In [None]:
def merge(path_ori, srcnn, scale):
    img_arr = np.array(Image.open(path_ori), dtype=np.uint8)
    img_arr = np.array(cv2.cvtColor(np.array(img_arr), cv2.COLOR_RGB2YCrCb), dtype=float)
    y_channel = img_arr[:,:,0]
    img_ori = tf.expand_dims(crop(y_channel, scale), axis=-1)
    h, w, _ = img_ori.shape

    down_sample = tf.image.resize(img_ori, (int(h/scale), int(w/scale)), method='bicubic')

    img_input = tf.image.resize(down_sample, (h, w), method='bicubic')
    patches = []
    col = 0
    for x in range(0, h-SIZE_INPUT+1, SIZE_LABEL):
        row = 0
        for y in range(0, w-SIZE_INPUT+1, SIZE_LABEL):
            sub_img_input = img_input[x:x+SIZE_INPUT, y:y+SIZE_INPUT]
            patches.append(sub_img_input)
            row += 1
        col += 1
    img_in = tf.convert_to_tensor(patches)

    res = srcnn.predict(img_in)
    img_sr = np.zeros((col*SIZE_LABEL, row*SIZE_LABEL, 1))
    print(img_sr.shape)
    for i in range(len(res)):
        r = i % row
        c = i // row
        img_sr[c*SIZE_LABEL:(c+1)*SIZE_LABEL, r*SIZE_LABEL:(r+1)*SIZE_LABEL, :] = res[i]

    return img_ori, img_input, img_sr