In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
import cv2
import tensorflow as tf

# Settings
filters = 96
blocks = 16
act = "silu"
kernel_size = 3

# For the JAX Backend
# class DepthToSpace(keras.layers.Layer):
#     def __init__(self, block_size):
#         super().__init__()
#         self.block_size = block_size

#     def call(self, input):
#         batch, height, width, depth = keras.ops.shape(input)
#         depth = depth // (self.block_size**2)
#         x = keras.ops.reshape(input, [batch, height, width, self.block_size, self.block_size, depth])
#         x = keras.ops.transpose(x, [0, 1, 3, 2, 4, 5])
#         x = keras.ops.reshape(x, [batch, height * self.block_size, width * self.block_size, depth])
#         if not training:
#             x = keras.ops.clip(x, 0.0, 1.0)
#         return x

class DepthToSpace(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, x, training=None):
        x = tf.nn.depth_to_space(x, 2)
        if not training:
            x = keras.ops.clip(x, 0.0, 1.0)
        return x

def res_block(input, filters=filters, kernel_size=kernel_size):
    x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same', activation=act)(input)
    x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same', activation=act)(x)
    x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
    x = keras.layers.Add()([x, input])
    return x

# Build the model:
inputs = keras.layers.Input(shape=(None,None,1))
conv0 = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(inputs)

x = conv0
for _ in range(blocks):
    x = res_block(x)

# Feature Fusion
conv1 = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
mix = keras.layers.Add()([conv1, conv0])

# Upsampler
features = keras.layers.Conv2D(filters=4, kernel_size=kernel_size, padding='same')(mix)
outputs = DepthToSpace()(features)

# Defining the model
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
keras.utils.plot_model(model, show_shapes=True)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/tmp/r16f96.keras /content/r16f96.keras
!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
# Data Augmentation
import cv2
import numpy as np
import glob
import os
from pathlib import Path
from tqdm import tqdm

rotations = [0, 90, 180, 270]

filelist = sorted(glob.glob('/content/HR/*.png'))

for myFile in tqdm(filelist):
    img = cv2.imread(myFile, cv2.IMREAD_UNCHANGED)

    if img is None:
        print(f"Error reading image: {myFile}")
        continue

    for rotation in rotations:
        rotated_img = np.rot90(img, rotation // 90)
        cv2.imwrite("/content/HR/" + str(Path(myFile).stem) + str(rotation) + ".png", rotated_img)

    flipped_img = cv2.flip(img, 1)

    for rotation in rotations:
        rotated_flipped_img = np.rot90(flipped_img, rotation // 90)
        cv2.imwrite("/content/HR/" + str(Path(myFile).stem) + str(rotation) + "f.png", rotated_flipped_img)

    os.remove(myFile)

In [None]:
import tensorflow as tf
import glob
import cv2
import random

def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

def data_generator(filelist, batch_size):
    n_samples = len(filelist)

    while True:
        random.shuffle(filelist)

        for i in range(0, n_samples, batch_size):
            batch_files = filelist[i:min(i + batch_size, n_samples)]
            train_ref_batch = []
            train_in_batch = []

            for file in batch_files:
                image = cv2.imread(file, cv2.IMREAD_COLOR)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # Luma
                image = image.astype(np.float32) / 255.0
                image = np.clip(image, 0.0, 1.0)

                ref_image = np.expand_dims(image.copy(), axis=-1) # Luma
                train_ref_batch.append(ref_image)

                in_image = cv2.resize(image.copy(), None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT) # Box downscale
                in_image = np.clip(in_image, 0.0, 1.0)
                in_image = np.expand_dims(in_image, axis=-1) # Luma
                train_in_batch.append(in_image)

            train_ref_batch = np.array(train_ref_batch)
            train_in_batch = np.array(train_in_batch)

            yield train_in_batch, train_ref_batch


filelist = sorted(glob.glob('/content/HR/*.png'))
batch_size = 8
steps_per_epoch = len(filelist) // batch_size
train_generator = data_generator(filelist, batch_size)

model.compile(optimizer=keras.optimizers.AdamW(learning_rate=0.000025), loss=keras.losses.MeanAbsoluteError(), metrics=[psnr_metric])
model.load_weights("/content/r16f96.keras", skip_mismatch=True)

In [None]:
history = model.fit(train_generator, steps_per_epoch=steps_per_epoch, epochs=10, verbose=1)

In [None]:
# Make a single prediction
input = cv2.imread('/content/downscaled.png', cv2.IMREAD_COLOR)
input = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY, 0)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
input = np.expand_dims(input, axis=0)
input = np.expand_dims(input, axis=-1)

pred = model.predict(input)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)

In [None]:
# Make a single RGB prediction
input = cv2.imread('/content/aoko.png', cv2.IMREAD_COLOR)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
(input_b, input_g, input_r) = cv2.split(input)
input_b = np.expand_dims(input_b, axis=0)
input_g = np.expand_dims(input_g, axis=0)
input_r = np.expand_dims(input_r, axis=0)
input_b = np.expand_dims(input_b, axis=-1)
input_g = np.expand_dims(input_g, axis=-1)
input_r = np.expand_dims(input_r, axis=-1)

pred_b = model.predict(input_b)
pred_g = model.predict(input_g)
pred_r = model.predict(input_r)
pred = np.stack((pred_b, pred_g, pred_r), axis=-1)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)

In [None]:
!pip install --upgrade onnxsim
!pip install tf2onnx
!pip install onnx

import tensorflow as tf
import tf2onnx
import onnx

input_signature = [tf.TensorSpec([1, None, None, 1], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(model=model, input_signature=input_signature, inputs_as_nchw=['input'], outputs_as_nchw=['depth_to_space'])
onnx.save(onnx_model, "/content/r16f96_dev.onnx")
!onnxsim r16f96_dev.onnx r16f96_dev.onnx

model.save('/content/r16f96.keras')
!cp /content/r16f96.keras /content/drive/MyDrive/tmp/r16f96.keras
!cp /content/r16f96.onnx /content/drive/MyDrive/tmp/r16f96.onnx