In [None]:
import jax
import flax
import optax
import cv2
import glob
import dm_pix as pix
import numpy as np
import jax.numpy as jnp
import random
import time
from tqdm import tqdm
from flax import linen as nn
from flax import serialization

# Settings
filters = 64
blocks = 8
kernel_size = (3, 3)

class ResBlock(nn.Module):
    @nn.compact
    def __call__(self, input):
        x = nn.Conv(features=filters, kernel_size=kernel_size)(input)
        x = nn.silu(x)
        x = nn.Conv(features=filters, kernel_size=kernel_size)(x)
        x = nn.silu(x)
        x = nn.Conv(features=filters, kernel_size=kernel_size)(x)
        x = x + input
        return x

class ArtCNN(nn.Module):
    @nn.compact
    def __call__(self, input):
        conv0 = nn.Conv(features=filters, kernel_size=kernel_size)(input)
        x = conv0
        for _ in range(blocks):
            x = ResBlock()(x)
        conv1 = nn.Conv(features=filters, kernel_size=kernel_size)(x)
        x = conv1 + conv0
        x = nn.Conv(features=4, kernel_size=kernel_size)(x)
        x = pix.depth_to_space(x, 2)
        x = jnp.clip(x, 0.0, 1.0)
        return x

model = ArtCNN()
print(model.tabulate(jax.random.key(0), jnp.ones((1, 128, 128, 1)), compute_flops=True, compute_vjp_flops=True))

key = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 128, 128, 1))
variables = model.init(key, dummy_input)
params = variables['params']

In [2]:
with open("artcnn_params.msgpack", "rb") as f:
    params = serialization.from_bytes(params, f.read())

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
!rm -rf /content/HR
!rm -rf /content/LR

In [None]:
def data_generator(filelist, batch_size):
    n_samples = len(filelist)

    while True:
        random.shuffle(filelist)

        for i in range(0, n_samples, batch_size):
            batch_files = filelist[i:min(i + batch_size, n_samples)]
            train_ref_batch = []
            train_in_batch = []

            for file in batch_files:
                image = cv2.imread(file, cv2.IMREAD_COLOR)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # Luma
                image = image.astype(np.float32) / 255.0
                image = np.clip(image, 0.0, 1.0)

                ref_image = np.expand_dims(image.copy(), axis=-1) # Luma
                train_ref_batch.append(ref_image)

                in_image = cv2.resize(image.copy(), None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT) # Box downscale
                in_image = np.clip(in_image, 0.0, 1.0)
                in_image = np.expand_dims(in_image, axis=-1) # Luma
                train_in_batch.append(in_image)

            train_ref_batch = np.array(train_ref_batch)
            train_in_batch = np.array(train_in_batch)

            yield train_in_batch, train_ref_batch

def loss_fn(pred, target):
    return jnp.mean(jnp.absolute(target - pred))

def forward(params, input, target):
    pred = model.apply({'params': params}, input)
    loss = loss_fn(pred, target)
    return loss

@jax.jit
def train_step(params, opt_state, input, target):
    loss, grads = jax.value_and_grad(forward)(params, input, target)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

learning_rate = 0.000025
optimizer = optax.adamw(learning_rate)
opt_state = optimizer.init(params)
epochs = 5
batch_size = 8
filelist = sorted(glob.glob('/content/HR/*.png'))
steps_per_epoch = len(filelist) // batch_size
train_generator = data_generator(filelist, batch_size)

In [None]:
for epoch in range(epochs):
    start_time = time.time()
    epoch_loss = 0.0

    with tqdm(total=steps_per_epoch, desc=f"Epoch {epoch + 1}/{epochs}", unit="step") as pbar:
        for step in range(steps_per_epoch):
            batch_in, batch_ref = next(train_generator)
            params, opt_state, loss = train_step(params, opt_state, batch_in, batch_ref)
            epoch_loss += float(loss)

            pbar.set_postfix(loss=float(loss))
            pbar.update(1)

    avg_epoch_loss = epoch_loss / steps_per_epoch
    elapsed = time.time() - start_time
    print(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_epoch_loss:.6f}, Time = {elapsed:.2f}s")

In [9]:
with open("artcnn_params.msgpack", "wb") as f:
    f.write(serialization.to_bytes(params))

In [None]:
# Make a single prediction
input = cv2.imread('/content/downscaled.png', cv2.IMREAD_COLOR)
input = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY, 0)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
input = np.expand_dims(input, axis=0)
input = np.expand_dims(input, axis=-1)
input = jnp.array(input)

pred = model.apply({'params': params}, input)
pred = np.array(pred)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)