In [None]:
!pip install tensorflow-addons
!pip install tensorflow==2.11.0

In [None]:
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Input
from tensorflow.keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
import tensorflow_addons as tfa
import numpy as np
import os
import cv2
from PIL import Image

In [None]:
print(tf.__version__)

2.11.0


### Global variables

In [None]:
TRAIN_DIR = '/content/drive/MyDrive/SRCNN/Train'
TEST_DIR = '/content/drive/MyDrive/SRCNN/Test/Set5'
VAL_DIR = '/content/drive/MyDrive/SRCNN/Test/Set14'
SCALE = 3
SIZE_INPUT = 33
SIZE_LABEL = 21 # test_stride always equal size_label for patches to be consective
PAD = int(abs(SIZE_INPUT - SIZE_LABEL)/2)
STRIDE = 14

### Dataset

In [None]:
def crop(image, factor):
    # ensure no remainder while scaling

    size = image.shape
    size -= np.mod(size, factor)
    image = image[:size[0], :size[1]]
    return image

def generator(data_path):
    files = os.listdir(data_path)
    def img_gen():
        for file in files:
            path = os.path.join(data_path, file)
            img_arr = np.array(Image.open(path), dtype=np.uint8)
            img_arr = np.array(cv2.cvtColor(np.array(img_arr), cv2.COLOR_RGB2YCrCb), dtype=float)
            y_channel = img_arr[:,:,0]
            img_label = tf.expand_dims(crop(y_channel, SCALE), axis=-1)
            h, w, _ = img_label.shape

            down_sample = tf.image.resize(img_label, (int(h/SCALE), int(w/SCALE)), method='bicubic')
            img_input = tf.image.resize(down_sample, (h, w), method='bicubic')
            
            for x in range(0, h-SIZE_INPUT+1, STRIDE):
                for y in range(0, w-SIZE_INPUT+1, STRIDE):
                    sub_img_input = img_input[x:x+SIZE_INPUT, y:y+SIZE_INPUT]
                    sub_img_label = img_label[x+PAD:x+PAD+SIZE_LABEL, y+PAD:y+PAD+SIZE_LABEL]

                    yield (sub_img_input, sub_img_label)
    return img_gen

def create_dataset(path):
    dataset = tf.data.Dataset.from_generator(
        generator(path),
        output_signature=(
            tf.TensorSpec(shape=(SIZE_INPUT, SIZE_INPUT, 1),),
            tf.TensorSpec(shape=(SIZE_LABEL, SIZE_LABEL, 1),)
        )
    )
    return dataset

In [None]:
def merge(path_ori, srcnn, scale):
    img_arr = np.array(Image.open(path_ori), dtype=np.uint8)
    img_arr = np.array(cv2.cvtColor(np.array(img_arr), cv2.COLOR_RGB2YCrCb), dtype=float)
    y_channel = img_arr[:,:,0]
    img_ori = tf.expand_dims(crop(y_channel, scale), axis=-1)
    h, w, _ = img_ori.shape

    down_sample = tf.image.resize(img_ori, (int(h/scale), int(w/scale)), method='bicubic')

    img_input = tf.image.resize(down_sample, (h, w), method='bicubic')
    patches = []
    col = 0
    for x in range(0, h-SIZE_INPUT+1, SIZE_LABEL):
        row = 0
        for y in range(0, w-SIZE_INPUT+1, SIZE_LABEL):
            sub_img_input = img_input[x:x+SIZE_INPUT, y:y+SIZE_INPUT]
            patches.append(sub_img_input)
            row += 1
        col += 1
    img_in = tf.convert_to_tensor(patches)

    res = srcnn.predict(img_in)
    img_sr = np.zeros((col*SIZE_LABEL, row*SIZE_LABEL, 1))
    for i in range(len(res)):
        r = i % row
        c = i // row
        img_sr[c*SIZE_LABEL:(c+1)*SIZE_LABEL, r*SIZE_LABEL:(r+1)*SIZE_LABEL, :] = res[i]

    return img_ori, img_input, img_sr

In [None]:
train = create_dataset(TRAIN_DIR).shuffle(30000).batch(256)
val = create_dataset(VAL_DIR).shuffle(1000).batch(256)
# test = create_dataset(TEST_DIR).shuffle(1000).batch(256)

### Model

In [None]:
def create_model(f1=9, f2=1, f3=5, n1=64, n2=32):
    model = tf.keras.Sequential([
        Input(shape = (SIZE_INPUT, SIZE_INPUT, 1)),
        Conv2D(n1, f1, activation='relu', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.001)),
        Conv2D(n2, f2, activation='relu', kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.001)),
        Conv2D(1, f3, kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.001))
    ])

    optimizers = [
        tf.keras.optimizers.SGD(learning_rate = 1e-4),
        tf.keras.optimizers.SGD(learning_rate = 1e-5)
    ]
    optimizers_and_layers = [(optimizers[0], model.layers[:2]), (optimizers[1], model.layers[2])]
    optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers)
    model.compile(
        optimizer=optimizer,
        # optimizer=tf.keras.optimizers.SGD(learning_rate = 1e-4),
        loss="mse")
    return model

In [None]:
m = create_model(f1=9, f2=1, f3=5, n1=64, n2=32)
m.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 25, 25, 64)        5248      
                                                                 
 conv2d_1 (Conv2D)           (None, 25, 25, 32)        2080      
                                                                 
 conv2d_2 (Conv2D)           (None, 21, 21, 1)         801       
                                                                 
Total params: 8,129
Trainable params: 8,129
Non-trainable params: 0
_________________________________________________________________


In [None]:
# https://www.mathworks.com/matlabcentral/answers/127891-x-0-0-1-10-what-s-going-on-really

In [None]:
# save the model with min val_loss
file_path = "/content/drive/MyDrive/SRCNN/saved_model/srcnn3.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

In [None]:
m.fit(
    train,
    epochs = 800,
    validation_data = val,
    callbacks = [
        EarlyStopping(monitor = 'val_loss', patience=5),
        checkpoint,]
)

### Test results

In [None]:
def merge(path_ori, srcnn, scale):
    img_arr = np.array(Image.open(path_ori), dtype=np.uint8)
    img_arr = np.array(cv2.cvtColor(np.array(img_arr), cv2.COLOR_RGB2YCrCb), dtype=float)
    y_channel = img_arr[:,:,0]
    img_ori = tf.expand_dims(crop(y_channel, scale), axis=-1)
    h, w, _ = img_ori.shape

    down_sample = tf.image.resize(img_ori, (int(h/scale), int(w/scale)), method='bicubic')

    img_input = tf.image.resize(down_sample, (h, w), method='bicubic')
    patches = []
    col = 0
    for x in range(0, h-SIZE_INPUT+1, SIZE_LABEL):
        row = 0
        for y in range(0, w-SIZE_INPUT+1, SIZE_LABEL):
            sub_img_input = img_input[x:x+SIZE_INPUT, y:y+SIZE_INPUT]
            patches.append(sub_img_input)
            row += 1
        col += 1
    img_in = tf.convert_to_tensor(patches)

    res = srcnn.predict(img_in)
    img_sr = np.zeros((col*SIZE_LABEL, row*SIZE_LABEL, 1))
    print(img_sr.shape)
    for i in range(len(res)):
        r = i % row
        c = i // row
        img_sr[c*SIZE_LABEL:(c+1)*SIZE_LABEL, r*SIZE_LABEL:(r+1)*SIZE_LABEL, :] = res[i]

    return img_ori, img_input, img_sr