In [None]:
#!pip install --upgrade keras
#!pip install --upgrade tensorflow
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np
import cv2
import tensorflow as tf

# For JAX backend
# class DepthToSpace(keras.layers.Layer):
#     def __init__(self, block_size):
#         super().__init__()
#         self.block_size = block_size

#     def call(self, input):
#         batch, height, width, depth = keras.ops.shape(input)
#         depth = depth // (self.block_size**2)
#         x = keras.ops.reshape(input, [batch, height, width, self.block_size, self.block_size, depth])
#         x = keras.ops.transpose(x, [0, 1, 3, 2, 4, 5])
#         x = keras.ops.reshape(x, [batch, height * self.block_size, width * self.block_size, depth])
#         x = keras.activations.relu(x, max_value=1.0)
#         return x

class DepthToSpace(keras.layers.Layer):
    def __init__(self):
        super().__init__()

    def call(self, x):
        x = tf.nn.depth_to_space(x, 2)
        x = tf.clip_by_value(x, 0.0, 1.0)
        return x

# Settings
filters = 32

# Build the model:
inputs = keras.layers.Input(shape=(None,None,1))
conv0 = keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same')(inputs)

# Internal convolutions
conv1 = keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu')(conv0)
conv2 = keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu')(conv1)
conv3 = keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu')(conv2)
conv4 = keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu')(conv3)

# Feature Fusion
mix_global = keras.layers.Conv2D(filters=filters, kernel_size=3, padding='same')(conv4)
add_global = keras.layers.Add()([mix_global, conv0])

# Upsampler
features = keras.layers.Conv2D(filters=4, kernel_size=3, padding='same')(add_global)
outputs = DepthToSpace()(features)

# Defining the model
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()
keras.utils.plot_model(model, show_shapes=True)
# model.load_weights("/content/c4f32.keras", skip_mismatch=True)

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/tmp/c4f32.keras /content/c4f32.keras
!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
# Single Dataset Gray
import glob
import cv2

filelist = sorted(glob.glob('/content/HR/*.png'))
train_ref = []
train_in = []

for myFile in filelist:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    image = np.array(image).astype(np.float32) / 255.0
    image = np.clip(image, 0.0, 1.0)
    train_ref.append(image.copy())
    train_in = cv2.resize(image.copy(), None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT)
    train_in = np.clip(train_in, 0.0, 1.0)
    train_in.append(image)

train_ref = np.array(train_ref)
train_ref = np.expand_dims(train_ref, axis=-1)
print(train_ref.shape)

train_in = np.array(train_in)
train_in = np.expand_dims(train_in, axis=-1)
print(train_in.shape)

In [None]:
# Separate HR and LR
import glob
import cv2

filelist1 = sorted(glob.glob('/content/LR/*.png'))
train_in = []
for myFile in filelist1:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    train_in.append(image)
train_in = np.array(train_in).astype(np.float32) / 255.0
train_in = np.clip(train_in, 0.0, 1.0)
train_in = np.expand_dims(train_in, axis=-1)
print(train_in.shape)

filelist2 = sorted(glob.glob('/content/HR/*.png'))
train_ref = []
for myFile in filelist2:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    train_ref.append(image)
train_ref = np.array(train_ref).astype(np.float32) / 255.0
train_ref = np.clip(train_ref, 0.0, 1.0)
train_ref = np.expand_dims(train_ref, axis=-1)
print(train_ref.shape)

In [None]:
# Train the model
model.compile(optimizer=keras.optimizers.AdamW(learning_rate=0.00005), loss=keras.losses.MeanAbsoluteError())
history = model.fit(train_in, train_ref, epochs=500, batch_size=8, verbose=1)

In [None]:
# Make a single prediction
input = cv2.imread('/content/downscaled.png', cv2.IMREAD_COLOR)
input = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY, 0)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
input = np.expand_dims(input, axis=0)
input = np.expand_dims(input, axis=-1)

pred = model.predict(input)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)

In [None]:
# Make a single RGB prediction
input = cv2.imread('/content/aoko.png', cv2.IMREAD_COLOR)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
(input_b, input_g, input_r) = cv2.split(input)
input_b = np.expand_dims(input_b, axis=0)
input_g = np.expand_dims(input_g, axis=0)
input_r = np.expand_dims(input_r, axis=0)
input_b = np.expand_dims(input_b, axis=-1)
input_g = np.expand_dims(input_g, axis=-1)
input_r = np.expand_dims(input_r, axis=-1)

pred_b = model.predict(input_b)
pred_g = model.predict(input_g)
pred_r = model.predict(input_r)
pred = np.stack((pred_b, pred_g, pred_r), axis=-1)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)

In [None]:
#Generate fragment shader
import numpy as np

def generate_shader_code(current_layer, previous_layer, channels_in, channels_out):
    passes_in = int(np.ceil(channels_in / 4))
    passes_out = int(np.ceil(channels_out / 4))

    if previous_layer.name == "input_layer":
        previous_layer.name = "LUMA"

    shader_code = ""
    for pass_idx in range(passes_out):
        if any(layer_name in current_layer.name for layer_name in ["conv2d_1", "conv2d_2", "conv2d_3", "conv2d_4"]):
            shader_code += f"//!DESC ArtCNN C4F{filters} ({current_layer.name.title().replace('_', '-')}-ReLU)\n" #-{pass_idx}
        elif "conv2d" in current_layer.name:
            shader_code += f"//!DESC ArtCNN C4F{filters} ({current_layer.name.title().replace('_', '-')})\n"
        else:
            shader_code += f"//!DESC ArtCNN C4F{filters} ({current_layer.name.title().replace('_', '-')})\n"

        shader_code += f"//!HOOK LUMA\n"

        if previous_layer.name == "LUMA":
            shader_code += f"//!BIND {previous_layer.name}\n"
        elif "add" in previous_layer.name:
            for i in range(passes_in):
                shader_code += f"//!BIND conv2d_{i}\n"
                shader_code += f"//!BIND conv2d_5_{i}\n"
        elif "conv2d" in current_layer.name:
            for i in range(passes_in):
                shader_code += f"//!BIND {previous_layer.name}_{i}\n"
        elif "depth" in current_layer.name:
            shader_code += f"//!BIND {previous_layer.name}_0\n"

        if "depth" in current_layer.name:
            shader_code += f"//!WIDTH LUMA.w 2.0 *\n"
            shader_code += f"//!HEIGHT LUMA.h 2.0 *\n"
        else:
            shader_code += f"//!SAVE {current_layer.name}_{pass_idx}\n"
            shader_code += f"//!WIDTH LUMA.w\n"
            shader_code += f"//!HEIGHT LUMA.h\n"

        shader_code += f"//!COMPONENTS 4\n"
        shader_code += f"//!WHEN OUTPUT.w LUMA.w / 1.3 > OUTPUT.h LUMA.h / 1.3 > *\n\n"
        shader_code += "vec4 hook() {\n"

        if "conv2d" in current_layer.name:
            biases = current_layer.get_weights()[1][pass_idx*4:(pass_idx+1)*4]
            biases_str = ", ".join(str(w) for w in biases.flatten())
            shader_code += f"    vec4 result = vec4({biases_str});\n"

            for z in range(passes_in):
                for y in range(-1, 2):
                    for x in range(-1, 2):
                        weights = current_layer.get_weights()[0][y+1, x+1, z*4:(z+1)*4, pass_idx*4:(pass_idx+1)*4]
                        weights_str = ", ".join(str(w) for w in weights.flatten())

                        if weights_str:
                            if previous_layer.name == "LUMA":
                                shader_code += f"    result += vec4({weights_str}) * {previous_layer.name}_texOff(vec2({x}, {y})).x;\n"
                            elif "add" in previous_layer.name:
                                shader_code += f"    result += mat4({weights_str}) * (conv2d_5_{z}_texOff(vec2({x}, {y})) + conv2d_{z}_texOff(vec2({x}, {y})));\n"
                            else:
                                shader_code += f"    result += mat4({weights_str}) * {previous_layer.name}_{z}_texOff(vec2({x}, {y}));\n"

            if any(layer_name in current_layer.name for layer_name in ["conv2d_1", "conv2d_2", "conv2d_3", "conv2d_4"]):
                shader_code += "    return max(result, vec4(0.0));\n"
            else:
                shader_code += "    return result;\n"

        elif "depth" in current_layer.name:
            shader_code += f"    vec4 result = vec4(0.0, 0.0, 0.0, 1.0);\n"
            shader_code += f"    vec2 f0 = fract({previous_layer.name}_0_pos * {previous_layer.name}_0_size);\n"
            shader_code += f"    ivec2 i0 = ivec2(f0 * vec2(2.0));\n"
            shader_code += f"    result.x = {previous_layer.name}_0_tex((vec2(0.5) - f0) * {previous_layer.name}_0_pt + {previous_layer.name}_0_pos)[i0.y * 2 + i0.x];\n"
            shader_code += f"    return clamp(result, 0.0, 1.0);\n"

        shader_code += "}\n\n"
    return shader_code

################################################################################
filters = model.layers[1].filters
shader_code = """// MIT License

// Copyright (c) 2024 Joao Chrisostomo

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:

// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

"""
for i in range(1, len(model.layers)):
    if model.layers[i].name == "conv2d":
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], 1, filters)
    elif "conv2d_6" in model.layers[i].name:
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], filters, 4)
    elif "conv2d" in model.layers[i].name:
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], filters, filters)
    elif "depth_to_space" in model.layers[i].name:
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], 4, 1)

print(shader_code)
with open("fragment.glsl", mode="w") as f:
    f.write(shader_code)

In [None]:
#Generate compute shader
import math
import numpy as np

def find_rect(area):
    s = int(math.sqrt(area))
    for w in range(s, 0, -1):
        if area % w == 0:
            h = area // w
            return int(max(w, h)), int(min(w, h))

def get_tile_off(i, w):
    x = i % w
    y = i // w
    return x, y

def generate_shader_code(current_layer, previous_layer, channels_in, channels_out):
    if previous_layer.name == "input_layer":
        previous_layer.name = "LUMA"

    threads_w, threads_h = 12, 16
    tile_w, tile_h = find_rect(np.ceil(channels_out / 4.0))
    tile_in_w, tile_in_h = find_rect(np.ceil(channels_in / 4.0))

    shader_code = ""
    if any(layer_name in current_layer.name for layer_name in ["conv2d_1", "conv2d_2", "conv2d_3", "conv2d_4"]):
        shader_code += f"//!DESC ArtCNN C4F{filters} ({current_layer.name.title().replace('_', '-')}-ReLU)\n"
    else:
        shader_code += f"//!DESC ArtCNN C4F{filters} ({current_layer.name.title().replace('_', '-')})\n"
    shader_code += f"//!COMPUTE {threads_w * tile_w} {threads_h * tile_h} {threads_w} {threads_h}\n"
    shader_code += f"//!HOOK LUMA\n"

    if previous_layer.name == "LUMA":
        shader_code += f"//!BIND {previous_layer.name}\n"
    elif "add" in previous_layer.name:
        shader_code += f"//!BIND conv2d\n"
        shader_code += f"//!BIND conv2d_5\n"
    elif "conv2d" in current_layer.name:
        shader_code += f"//!BIND {previous_layer.name}\n"
    elif "depth" in current_layer.name:
        shader_code += f"//!BIND {previous_layer.name}\n"

    if "depth" in current_layer.name:
        shader_code += f"//!WIDTH LUMA.w 2.0 *\n"
        shader_code += f"//!HEIGHT LUMA.h 2.0 *\n"
    else:
        shader_code += f"//!SAVE {current_layer.name}\n"
        shader_code += f"//!WIDTH LUMA.w {float(tile_w)} *\n"
        shader_code += f"//!HEIGHT LUMA.h {float(tile_h)} *\n"

    shader_code += f"//!COMPONENTS 4\n"
    shader_code += f"//!WHEN OUTPUT.w LUMA.w / 1.3 > OUTPUT.h LUMA.h / 1.3 > *"
    shader_code += """
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable
#ifdef GL_EXT_shader_explicit_arithmetic_types_float16
#	define V4 f16vec4
#	define M4 f16mat4
#	define F float16_t
#else
#	define V4 vec4
#	define M4 mat4
#	define F float
#endif

"""

    if "conv2d" in current_layer.name:
        storage = "V4"
        weights_storage = "M4"
        load_suffix = ""
        if previous_layer.name == "LUMA":
            storage = "F"
            weights_storage = "V4"
            load_suffix = ".x"

        shader_code += "const ivec2 ksize = ivec2(3, 3);\n"
        shader_code += "const ivec2 offset = ksize / 2;\n"
        shader_code += "const ivec2 wg_size = ivec2(gl_WorkGroupSize);\n"
        shader_code += "const ivec2 isize = wg_size + ksize - 1;\n"
        shader_code += f"shared {storage} inp[{tile_in_w * tile_in_h}][isize.y][isize.x];\n"

        shader_code += "void hook() {\n"
        shader_code += "    const uvec2 local_xy = gl_LocalInvocationID.xy;\n"
        shader_code += "    ivec2 base = ivec2(gl_WorkGroupID) * wg_size;\n"
        shader_code += "    for (uint y = local_xy.y; y < isize.y; y += wg_size.y) {\n"
        shader_code += "        for (uint x = local_xy.x; x < isize.x; x += wg_size.x) {\n"
        shader_code += f"            const ivec2 input_base = (base + ivec2(x,y) - offset) * ivec2({tile_in_w}, {tile_in_h});\n"
        for z in range(tile_in_w * tile_in_h):
            x, y = get_tile_off(z, tile_in_w)
            if "add" in previous_layer.name:
                shader_code += f"            inp[{z}][y][x] = {storage}(conv2d_5_mul * texelFetch(conv2d_5_raw, input_base + ivec2({x}, {y}), 0) + conv2d_mul * texelFetch(conv2d_raw, input_base + ivec2({x}, {y}), 0){load_suffix});\n"
            else:
                shader_code += f"            inp[{z}][y][x] = {storage}({previous_layer.name}_mul * texelFetch({previous_layer.name}_raw, input_base + ivec2({x}, {y}), 0){load_suffix});\n"

        shader_code += "        }\n"
        shader_code += "    }\n"
        shader_code += "\n    barrier();\n"

        for pass_idx in range(tile_w * tile_h):
            biases = current_layer.get_weights()[1][pass_idx*4:(pass_idx+1)*4]
            biases_str = ", ".join(str(w) for w in biases.flatten())
            shader_code += f"    V4 result{pass_idx} = V4({biases_str});\n"

        for z in range(tile_in_w * tile_in_h):
            for y in range(0, 3):
                for x in range(0, 3):
                    shader_code += f"    const {storage} inp_{z}_{x}_{y} = inp[{z}][local_xy.y + {y}][local_xy.x + {x}];\n"
            for pass_idx in range(tile_w * tile_h):
                for y in range(0, 3):
                    for x in range(0, 3):
                        weights = current_layer.get_weights()[0][y, x, z*4:(z+1)*4, pass_idx*4:(pass_idx+1)*4]
                        weights_str = ", ".join(str(w) for w in weights.flatten())
                        if weights_str:
                            shader_code += f"    result{pass_idx} += {weights_storage}({weights_str}) * inp_{z}_{x}_{y};\n"

        shader_code += f"    const ivec2 output_base = ivec2(gl_GlobalInvocationID) * ivec2({tile_w}, {tile_h});\n"
        for pass_idx in range(tile_w * tile_h):
            x, y = get_tile_off(pass_idx, tile_w)
            if any(layer_name in current_layer.name for layer_name in ["conv2d_1", "conv2d_2", "conv2d_3", "conv2d_4"]):
                shader_code += f"    imageStore(out_image, output_base + ivec2({x}, {y}), max(result{pass_idx}, {storage}(0.0)));\n"
            else:
                shader_code += f"    imageStore(out_image, output_base + ivec2({x}, {y}), result{pass_idx});\n"

    elif "depth" in current_layer.name:
        shader_code += "void hook() {\n"
        shader_code += f"    vec4 result = vec4(0.0, 0.0, 0.0, 1.0);\n"
        shader_code += f"    vec2 f0 = fract({previous_layer.name}_pos * {previous_layer.name}_size);\n"
        shader_code += f"    ivec2 i0 = ivec2(f0 * vec2(2.0));\n"
        shader_code += f"    result.x = {previous_layer.name}_tex((vec2(0.5) - f0) * {previous_layer.name}_pt + {previous_layer.name}_pos)[i0.y * 2 + i0.x];\n"
        shader_code += f"    imageStore(out_image, ivec2(gl_GlobalInvocationID), clamp(result, 0.0, 1.0));\n"
    shader_code += "}\n\n"

    return shader_code

################################################################################
filters = model.layers[1].filters
shader_code = """// MIT License

// Copyright (c) 2024 Joao Chrisostomo, Kacper Michajłow

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:

// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.

// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

"""
for i in range(1, len(model.layers)):
    if model.layers[i].name == "conv2d":
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], 1, filters)
    elif model.layers[i].name == "conv2d_6":
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], filters, 4)
    elif "conv2d_" in model.layers[i].name:
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], filters, filters)
    elif model.layers[i].name == "depth_to_space":
        shader_code += generate_shader_code(model.layers[i], model.layers[i - 1], 4, 1)

print(shader_code)
with open("compute.glsl", mode="w", encoding="utf-8") as f:
    f.write(shader_code)

In [None]:
!pip install tf2onnx
!pip install onnx

import tensorflow as tf
import tf2onnx
import onnx

input_signature = [tf.TensorSpec([1, None, None, 1], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(model=model, input_signature=input_signature, inputs_as_nchw=['input'], outputs_as_nchw=['depth_to_space'])
onnx.save(onnx_model, "/content/c4f32_dft.onnx")

model.save('/content/c4f32.keras')
!cp /content/c4f32.keras /content/drive/MyDrive/tmp/c4f32.keras
!cp /content/c4f32.onnx /content/drive/MyDrive/tmp/c4f32.onnx