In [None]:
!pip install tensorflow-addons
!pip install tensorflow==2.11.0

In [None]:
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, PReLU, ZeroPadding2D
from tensorflow.keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
import tensorflow_addons as tfa
import numpy as np
import os
import cv2
from PIL import Image


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [None]:
print(tf.__version__)

2.11.0


### Global variables

In [None]:
TRAIN_DIR = '/content/drive/MyDrive/SRCNN/Train'
TEST_DIR = '/content/drive/MyDrive/SRCNN/Test/Set5'
VAL_DIR = '/content/drive/MyDrive/SRCNN/Test/Set14'
SCALE = 3
SIZE_INPUT = 7
SIZE_LABEL = 21
STRIDE = 5

### Dataset

In [None]:
def augmentation():
    all_files = os.listdir(TRAIN_DIR)    
    for file in all_files:
        img = cv2.imread(os.path.join(TRAIN_DIR, file))
        for scale in [0.9, 0.8, 0.7, 0.6]:
            h, w, _ = img.shape
            new_img = tf.image.resize(img, (int(h*scale), int(w*scale)), method='bicubic', preserve_aspect_ratio=True)
            new_img = np.array(new_img)
            for rotate in [90, 180, 270]:
                new_img = cv2.rotate(new_img, cv2.ROTATE_90_CLOCKWISE)
                path = file[:-3]+'0'+str(int(scale*10))+'.'+str(rotate)+'.bmp'
                cv2.imwrite(path, new_img)

In [None]:
def crop(image, factor):
    # ensure no remainder while scaling

    size = image.shape
    size -= np.array(np.mod(size, factor), dtype=np.uint8)
    image = image[:size[0], :size[1], :]
    return image

def generator(data_path):
    files = os.listdir(data_path)
    def img_gen():
        for file in files:
            path = os.path.join(data_path, file)
            img_arr = np.array(Image.open(path), dtype=np.uint8)
            img_arr = np.array(cv2.cvtColor(np.array(img_arr), cv2.COLOR_RGB2YCrCb), dtype=float)
            img_label = tf.expand_dims(crop(img_arr, SCALE)[:,:,0], axis=-1)
            h, w, _ = img_label.shape

            down_sample = tf.image.resize(img_label, (int(h/SCALE), int(w/SCALE)), method='bicubic')

            h, w, _ = down_sample.shape
            
            xlabel = 0   
            for x in range(0, h-SIZE_INPUT+1, STRIDE):
                ylabel = 0
                for y in range(0, w-SIZE_INPUT+1, STRIDE):
                    sub_img_input = down_sample[x:x+SIZE_INPUT, y:y+SIZE_INPUT]
                    sub_img_label = img_label[xlabel:xlabel+SIZE_LABEL, ylabel:ylabel+SIZE_LABEL]
                    ylabel += SCALE * STRIDE

                    yield (sub_img_input, sub_img_label)

                xlabel += SCALE * STRIDE

    return img_gen

def create_dataset(path):
    dataset = tf.data.Dataset.from_generator(
        generator(path),
        output_signature=(
            tf.TensorSpec(shape=(SIZE_INPUT, SIZE_INPUT, 1),),
            tf.TensorSpec(shape=(SIZE_LABEL, SIZE_LABEL, 1),)
        )
    )
    return dataset

In [None]:
def PSNR(y_true, y_pred):
	max_pixel = 255.0
	return 10.0 * tf_log10((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true)))) 

In [None]:
augmentation()

In [None]:
train = create_dataset(TRAIN_DIR).shuffle(600000).batch(256)
val = create_dataset(VAL_DIR).shuffle(1000).batch(256)
# test = create_dataset(TEST_DIR).shuffle(1000).batch(256)

### Model

In [None]:
def create_model(d=56, s=12, m=4): # 5-1-3
    input = Input(shape = (SIZE_INPUT, SIZE_INPUT, 1))
    x = Conv2D(d, 5, kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.001), padding='same')(input)
    x = PReLU()(x)
    x = Conv2D(s, 1, kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.001), padding='same')(x)
    x = PReLU()(x)
    for i in range(m):
        x = Conv2D(s, 3, kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.001), padding='same')(x)
    x = PReLU()(x)
    x = Conv2D(d, 1, kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.001), padding='same')(x)
    x = PReLU()(x)
    output = Conv2DTranspose(1, 9, strides=SCALE, padding='same')(x)
    model = tf.keras.Model(inputs = input,
                           outputs = output)

    
    return model

In [None]:
model = create_model()
model.summary()

Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 7, 7, 1)]         0         
                                                                 
 conv2d_21 (Conv2D)          (None, 7, 7, 56)          1456      
                                                                 
 p_re_lu_12 (PReLU)          (None, 7, 7, 56)          2744      
                                                                 
 conv2d_22 (Conv2D)          (None, 7, 7, 12)          684       
                                                                 
 p_re_lu_13 (PReLU)          (None, 7, 7, 12)          588       
                                                                 
 conv2d_23 (Conv2D)          (None, 7, 7, 12)          1308      
                                                                 
 conv2d_24 (Conv2D)          (None, 7, 7, 12)          1308

In [None]:
# save the model with min val_loss
file_path = "/content/drive/MyDrive/SRCNN/saved_model/fsrcnn.h5"
checkpoint = ModelCheckpoint(file_path, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

#### Train on 91

In [None]:
optimizers = [
    tf.keras.optimizers.Adam(learning_rate = 1e-3),
    tf.keras.optimizers.Adam(learning_rate = 1e-4)
]
optimizers_and_layers = [(optimizers[0], model.layers[:-1]), (optimizers[1], model.layers[-1])]
optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers)
model.compile(
    optimizer=optimizer,
    # optimizer=tf.keras.optimizers.SGD(learning_rate = 1e-4),
    loss="mse"
)

In [None]:
model.fit(
    train,
    epochs = 80,
    validation_data = val,
    callbacks = [
        EarlyStopping(monitor = 'val_loss', patience=5),
        checkpoint,]
)

#### Fine tune on 100

In [None]:
optimizers = [
    tf.keras.optimizers.SGD(learning_rate = 1e-6),
    tf.keras.optimizers.SGD(learning_rate = 1e-8)
]
optimizers_and_layers = [(optimizers[0], model.layers[:-1]), (optimizers[1], model.layers[-1])]
optimizer = tfa.optimizers.MultiOptimizer(optimizers_and_layers)
m.compile(
    optimizer=optimizer,
    # optimizer=tf.keras.optimizers.SGD(learning_rate = 1e-4),
    loss="mse"
)
m.fit(
    train,
    epochs = 80000,
    validation_data = val,
    callbacks = [
        EarlyStopping(monitor = 'val_loss', patience=5),
        checkpoint,]
)

### Test results

In [None]:
def merge(path_ori, srcnn, scale):
    img_arr = np.array(Image.open(path_ori), dtype=np.uint8)
    img_arr = np.array(cv2.cvtColor(np.array(img_arr), cv2.COLOR_RGB2YCrCb), dtype=float)
    y_channel = img_arr[:,:,0]
    img_ori = tf.expand_dims(crop(y_channel, scale), axis=-1)
    h, w, _ = img_ori.shape

    down_sample = tf.image.resize(img_ori, (int(h/scale), int(w/scale)), method='bicubic')

    img_input = tf.image.resize(down_sample, (h, w), method='bicubic')
    patches = []
    col = 0
    for x in range(0, h-SIZE_INPUT+1, SIZE_LABEL):
        row = 0
        for y in range(0, w-SIZE_INPUT+1, SIZE_LABEL):
            sub_img_input = img_input[x:x+SIZE_INPUT, y:y+SIZE_INPUT]
            patches.append(sub_img_input)
            row += 1
        col += 1
    img_in = tf.convert_to_tensor(patches)

    res = srcnn.predict(img_in)
    img_sr = np.zeros((col*SIZE_LABEL, row*SIZE_LABEL, 1))
    print(img_sr.shape)
    for i in range(len(res)):
        r = i % row
        c = i // row
        img_sr[c*SIZE_LABEL:(c+1)*SIZE_LABEL, r*SIZE_LABEL:(r+1)*SIZE_LABEL, :] = res[i]

    return img_ori, img_input, img_sr