In [None]:
import os
import tensorflow as tf
from glob import glob
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, PReLU, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.utils import Sequence

In [None]:
def FSRCNN(scale=4, d=56, s=12, m=4, input_shape=(None, None, 1)):
    x_input = Input(shape=input_shape)

    x = Conv2D(d, kernel_size=5, padding='same')(x_input)
    x = PReLU(shared_axes=[1, 2])(x)

    x = Conv2D(s, kernel_size=1, padding='same')(x)
    x = PReLU(shared_axes=[1, 2])(x)

    for _ in range(m):
        x = Conv2D(s, kernel_size=3, padding='same')(x)
        x = PReLU(shared_axes=[1, 2])(x)

    x = Conv2D(d, kernel_size=1, padding='same')(x)
    x = PReLU(shared_axes=[1, 2])(x)

    x = Conv2DTranspose(1, kernel_size=9, strides=scale, padding='same')(x)

    return Model(inputs=x_input, outputs=x, name="FSRCNN")

In [None]:
class DIV2KDataset(Sequence):
    def __init__(self, lr_dir, hr_dir, patch_size=48, batch_size=16, scale=4):
        self.lr_files = sorted(glob(os.path.join(lr_dir, '*.png')))
        self.hr_files = sorted(glob(os.path.join(hr_dir, '*.png')))
        self.patch_size = patch_size
        self.batch_size = batch_size
        self.scale = scale

    def __len__(self):
        return max(1, len(self.lr_files) // self.batch_size)

    def __getitem__(self, idx):
        batch_lr = self.lr_files[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_hr = self.hr_files[idx * self.batch_size:(idx + 1) * self.batch_size]

        lr_batch = []
        hr_batch = []

        for lr_path, hr_path in zip(batch_lr, batch_hr):
            hr = tf.io.decode_png(tf.io.read_file(hr_path), channels=1)
            hr = tf.image.convert_image_dtype(hr, tf.float32)

            hr_shape = tf.shape(hr)
            hr_h = hr_shape[0]
            hr_w = hr_shape[1]

            crop_size = self.patch_size * self.scale

            if hr_h < crop_size or hr_w < crop_size:
                print(f"⚠️ Skipping small image: {hr_path}")
                continue

            # Crop from HR
            hr_crop = tf.image.random_crop(hr, [crop_size, crop_size, 1])

            # Downscale to LR
            lr_crop = tf.image.resize(hr_crop, [self.patch_size, self.patch_size], method='bicubic')

            lr_batch.append(lr_crop)
            hr_batch.append(hr_crop)

        return tf.stack(lr_batch), tf.stack(hr_batch)


In [None]:
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

In [None]:
scale = 4
patch_size = 48
batch_size = 16
epochs = 50

In [None]:
train_lr = '../DIV2K/train_LR_X4/DIV2K_train_LR_bicubic/X4'
train_hr = '../DIV2K/valid_HR/DIV2K_valid_HR'
val_lr = '../DIV2K/valid_LR_X4/DIV2K_valid_LR_bicubic/X4'
val_hr = '../DIV2K/train_HR/DIV2K_train_HR'

In [None]:
train_gen = DIV2KDataset(train_lr, train_hr, patch_size, batch_size, scale)
val_gen = DIV2KDataset(val_lr, val_hr, patch_size, batch_size, scale)

In [None]:
model = FSRCNN(scale=scale, input_shape=(None, None, 1))
model.compile(optimizer=Adam(1e-4), loss=MeanSquaredError(), metrics=[psnr_metric])

In [None]:
model.fit(train_gen, validation_data=val_gen, epochs=epochs)

In [None]:
model.save('fsrcnn_DIV2K_x4.h5')
print("Model saved as 'fsrcnn_DIV2K_x4.h5'")