In [None]:
import jax
import flax
import optax
import cv2
import glob
import treescope
import dm_pix as pix
import numpy as np
import jax.numpy as jnp
import orbax.checkpoint as ocp
import random
import time
from tqdm import tqdm
from flax import nnx

# Settings
filters = 64
blocks = 8
kernel_size = (3, 3)

class DepthToSpace(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.block_size = 2

    def __call__(self, input):
        x = pix.depth_to_space(input, self.block_size)
        return x

class ResBlock(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv0 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.conv1 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.conv2 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)

    def __call__(self, input):
        x = nnx.silu(self.conv0(input))
        x = nnx.silu(self.conv1(x))
        x = self.conv2(x)
        return x + input

class ArtCNN(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv0 = nnx.Conv(1, filters, kernel_size=kernel_size, rngs=rngs)
        self.res_blocks = [ResBlock(rngs=rngs) for _ in range(blocks)]
        self.conv1 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.feats_conv = nnx.Conv(filters, 4, kernel_size=kernel_size, rngs=rngs)
        self.depth_to_space = DepthToSpace(rngs=rngs)

    def __call__(self, input):
        conv0 = self.conv0(input)
        x = conv0
        for block in self.res_blocks:
            x = block(x)
        conv1 = self.conv1(x)
        x = conv1 + conv0
        x = self.feats_conv(x)
        x = self.depth_to_space(x)
        x = jnp.clip(x, 0.0, 1.0)
        return x

model = ArtCNN(rngs=nnx.Rngs(0))
treescope.basic_interactive_setup(autovisualize_arrays=True)
nnx.display(model)

In [None]:
abstract_model = nnx.eval_shape(lambda: ArtCNN(rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)

checkpointer = ocp.StandardCheckpointer()
state_restored = checkpointer.restore('/content/checkpoints/state', abstract_state)
nnx.display(state_restored)

model = nnx.merge(graphdef, state_restored)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
!rm -rf /content/HR
!rm -rf /content/LR

In [7]:
def data_generator(filelist, batch_size):
    n_samples = len(filelist)

    while True:
        random.shuffle(filelist)

        for i in range(0, n_samples, batch_size):
            batch_files = filelist[i:min(i + batch_size, n_samples)]
            train_ref_batch = []
            train_in_batch = []

            for file in batch_files:
                image = cv2.imread(file, cv2.IMREAD_GRAYSCALE)
                image = image.astype(np.float32) / 255.0
                image = np.clip(image, 0.0, 1.0)

                ref_image = np.expand_dims(image.copy(), axis=-1) # Luma
                train_ref_batch.append(ref_image)

                in_image = cv2.resize(image.copy(), None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT) # Box downscale
                in_image = np.clip(in_image, 0.0, 1.0)
                in_image = np.expand_dims(in_image, axis=-1) # Luma
                train_in_batch.append(in_image)

            train_ref_batch = np.array(train_ref_batch)
            train_in_batch = np.array(train_in_batch)

            yield train_in_batch, train_ref_batch

def loss_fn(pred, target):
    return jnp.mean(jnp.absolute(target - pred))

def forward(model, input, target):
    pred = model(input)
    loss = loss_fn(pred, target)
    return loss

@nnx.jit
def train_step(model, optimizer, input, target):
    loss, grads = nnx.value_and_grad(forward)(model, input, target)
    optimizer.update(model, grads)
    return loss

learning_rate = 0.0001
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate), wrt=nnx.Param)
epochs = 5
batch_size = 8
filelist = sorted(glob.glob('/content/HR/*.png'))
steps_per_epoch = len(filelist) // batch_size
train_generator = data_generator(filelist, batch_size)

In [None]:
for epoch in range(epochs):
    start_time = time.time()
    epoch_loss = 0.0

    with tqdm(total=steps_per_epoch, desc=f"Epoch {epoch + 1}/{epochs}", unit="step") as pbar:
        for step in range(steps_per_epoch):
            batch_in, batch_ref = next(train_generator)
            loss = train_step(model, optimizer, batch_in, batch_ref)
            epoch_loss += float(loss)

            pbar.set_postfix(loss=float(loss))
            pbar.update(1)

    avg_epoch_loss = epoch_loss / steps_per_epoch
    elapsed = time.time() - start_time
    print(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_epoch_loss:.6f}, Time = {elapsed:.2f}s")

In [None]:
_, state = nnx.split(model)
nnx.display(state)

checkpointer = ocp.StandardCheckpointer()
checkpointer.save('/content/checkpoints/state', state)

In [None]:
# Make a single prediction
input = cv2.imread('/content/downscaled.png', cv2.IMREAD_GRAYSCALE)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
input = np.expand_dims(input, axis=0)
input = np.expand_dims(input, axis=-1)
input = jnp.array(input)

pred = model(input)
pred = np.array(pred)
pred = np.clip(pred, 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction.png', pred)