In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
import cv2
import tensorflow as tf

# Settings
filters = 64
blocks = 8
act = "relu"
kernel_size = 3

# For the JAX Backend
# class DepthToSpace(keras.layers.Layer):
#     def __init__(self, block_size):
#         super().__init__()
#         self.block_size = block_size

#     def call(self, input):
#         batch, height, width, depth = keras.ops.shape(input)
#         depth = depth // (self.block_size**2)
#         x = keras.ops.reshape(input, [batch, height, width, self.block_size, self.block_size, depth])
#         x = keras.ops.transpose(x, [0, 1, 3, 2, 4, 5])
#         x = keras.ops.reshape(x, [batch, height * self.block_size, width * self.block_size, depth])
#         x = keras.ops.clip(x, 0.0, 1.0)
#         return x

class DepthToSpace(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, x):
        x = tf.nn.depth_to_space(x, 2)
        x = keras.ops.clip(x, 0.0, 1.0)
        return x

def res_block(input, filters=filters, kernel_size=kernel_size):
    x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same', activation=act)(input)
    x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same', activation=act)(x)
    x = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
    x = keras.layers.Add()([x, input])
    return x

# Build the model:
inputs = keras.layers.Input(shape=(None,None,1))
conv0 = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(inputs)

x = conv0
for _ in range(blocks):
    x = res_block(x)

# Feature Fusion
conv1 = keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding='same')(x)
mix = keras.layers.Add()([conv1, conv0])

# Upsampler
features = keras.layers.Conv2D(filters=4, kernel_size=kernel_size, padding='same')(mix)
outputs = DepthToSpace()(features)

# Defining the model
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
keras.utils.plot_model(model, show_shapes=True)
# model.load_weights("/content/r8f64.keras", skip_mismatch=True)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
# !cp /content/drive/MyDrive/tmp/r8f64.keras /content/r8f64.keras
!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
# Single Dataset Gray
import glob
import cv2

filelist = sorted(glob.glob('/content/HR/*.png'))
train_ref = []
train_in = []

for myFile in filelist:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    train_ref.append(image)
    image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT)
    # image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT) # for 4x
    train_in.append(image)

train_ref = np.array(train_ref).astype(np.float32) / 255.0
train_ref = np.clip(train_ref, 0.0, 1.0)
train_ref = np.expand_dims(train_ref, axis=-1)
print(train_ref.shape)

train_in = np.array(train_in).astype(np.float32) / 255.0
train_in = np.clip(train_in, 0.0, 1.0)
train_in = np.expand_dims(train_in, axis=-1)
print(train_in.shape)

In [None]:
# Separate HR and LR
import glob
import cv2

filelist1 = sorted(glob.glob('/content/LR/*.png'))
train_in = []
for myFile in filelist1:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    # image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT) # for 4x
    train_in.append(image)
train_in = np.array(train_in).astype(np.float32) / 255.0
train_in = np.clip(train_in, 0.0, 1.0)
train_in = np.expand_dims(train_in, axis=-1)
print(train_in.shape)

filelist2 = sorted(glob.glob('/content/HR/*.png'))
train_ref = []
for myFile in filelist2:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    train_ref.append(image)
train_ref = np.array(train_ref).astype(np.float32) / 255.0
train_ref = np.clip(train_ref, 0.0, 1.0)
train_ref = np.expand_dims(train_ref, axis=-1)
print(train_ref.shape)

In [None]:
# Train the model
model.compile(optimizer=keras.optimizers.AdamW(learning_rate=0.000025), loss=keras.losses.MeanAbsoluteError())
history = model.fit(train_in, train_ref, epochs=100, batch_size=16, verbose=1)

In [None]:
# Make a single prediction
input = cv2.imread('/content/downscaled.png', cv2.IMREAD_COLOR)
input = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY, 0)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
input = np.expand_dims(input, axis=0)
input = np.expand_dims(input, axis=-1)

pred = model.predict(input)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)

In [None]:
# Make a single RGB prediction
input = cv2.imread('/content/aoko.png', cv2.IMREAD_COLOR)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
(input_b, input_g, input_r) = cv2.split(input)
input_b = np.expand_dims(input_b, axis=0)
input_g = np.expand_dims(input_g, axis=0)
input_r = np.expand_dims(input_r, axis=0)
input_b = np.expand_dims(input_b, axis=-1)
input_g = np.expand_dims(input_g, axis=-1)
input_r = np.expand_dims(input_r, axis=-1)

pred_b = model.predict(input_b)
pred_g = model.predict(input_g)
pred_r = model.predict(input_r)
pred = np.stack((pred_b, pred_g, pred_r), axis=-1)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)

In [None]:
!pip install tf2onnx
!pip install onnx

import tensorflow as tf
import tf2onnx
import onnx

input_signature = [tf.TensorSpec([1, None, None, 1], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(model=model, input_signature=input_signature, inputs_as_nchw=['input'], outputs_as_nchw=['depth_to_space'])
onnx.save(onnx_model, "/content/r8f64.onnx")

model.save('/content/r8f64.keras')
!cp /content/r8f64.keras /content/drive/MyDrive/tmp/r8f64.keras
!cp /content/r8f64.onnx /content/drive/MyDrive/tmp/r8f64.onnx