In [None]:
!pip install --upgrade flax
!pip install --upgrade orbax-checkpoint
!pip install --upgrade jax
!pip install --upgrade dm_pix
!pip install --upgrade treescope

import jax
import flax
import optax
import cv2
import glob
import treescope
import dm_pix as pix
import numpy as np
import jax.numpy as jnp
from flax import nnx

# Settings
filters = 64
blocks = 8
kernel_size = (3, 3)

class ResBlock(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv0 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.conv1 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.conv2 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)

    def __call__(self, input):
        x = nnx.relu(self.conv0(input))
        x = nnx.relu(self.conv1(x))
        x = self.conv2(x)
        return x + input

class ArtCNN(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv0 = nnx.Conv(1, filters, kernel_size=kernel_size, rngs=rngs)
        self.res_blocks = [ResBlock(rngs=rngs) for _ in range(blocks)]
        self.conv1 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.feats_conv = nnx.Conv(filters, 4, kernel_size=kernel_size, rngs=rngs)

    def __call__(self, input):
        conv0 = self.conv0(input)
        x = conv0
        for block in self.res_blocks:
            x = block(x)
        conv1 = self.conv1(x)
        features = self.feats_conv(conv1 + conv0)
        output = jnp.clip(pix.depth_to_space(features, 2), 0.0, 1.0)
        return output

model = ArtCNN(rngs=nnx.Rngs(0))
treescope.basic_interactive_setup(autovisualize_arrays=True)
nnx.display(model)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
!rm -rf /content/HR
!rm -rf /content/LR

In [None]:
filelist = sorted(glob.glob('/content/HR/*.png'))
train_ref = []
train_in = []

for myFile in filelist:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    train_ref.append(image)
    image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT)
    train_in.append(image)

train_ref = np.array(train_ref).astype(np.float32) / 255.0
train_ref = np.clip(train_ref, 0.0, 1.0)
train_ref = np.expand_dims(train_ref, axis=-1)
print(train_ref.shape)
train_ref = jnp.array(train_ref)

train_in = np.array(train_in).astype(np.float32) / 255.0
train_in = np.clip(train_in, 0.0, 1.0)
train_in = np.expand_dims(train_in, axis=-1)
print(train_in.shape)
train_in = jnp.array(train_in)

In [None]:
optimizer = nnx.Optimizer(model, optax.adamw(0.000025))
batch_size = 8
num_epochs = 5
metrics_history = {'train_loss': []}

def create_batches(x, y, batch_size):
    num_samples = x.shape[0]
    for i in range(0, num_samples, batch_size):
        x_batch = x[i:i + batch_size]
        y_batch = y[i:i + batch_size]
        yield x_batch, y_batch

@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        loss = jnp.mean(jnp.absolute(y - y_pred))
        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    return loss

# Training Loop
for epoch in range(num_epochs):
    print(f"\nStarting epoch {epoch + 1} / {num_epochs}")
    epoch_loss = []  # Track loss for this epoch

    for step, (x, y) in enumerate(create_batches(train_in, train_ref, batch_size)):
        loss = train_step(model, optimizer, x, y)

        metrics_history['train_loss'].append(loss)
        epoch_loss.append(loss)

        if step % 50 == 0:  # Print every 50 steps
            print(f"[train] epoch: {epoch + 1}, step: {step}, batch loss: {loss:.6f}")

    # Print epoch summary
    epoch_avg_loss = jnp.mean(jnp.array(epoch_loss))
    print(f"Epoch {epoch + 1} complete. Average loss: {epoch_avg_loss:.6f}")
    save_checkpoint()

In [None]:
from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np

def save_checkpoint():
    ckpt_dir = ocp.test_utils.erase_and_create_empty('/content/checkpoints/')

    _, state = nnx.split(model)
    print("Creating Checkpoint")
    checkpointer = ocp.StandardCheckpointer()
    checkpointer.save(ckpt_dir / 'state', state)
    print("Sending Checkpoint to Google Drive")
    !zip -r /content/checkpoints.zip /content/checkpoints/*
    !cp /content/checkpoints.zip /content/drive/MyDrive/tmp/checkpoints.zip

save_checkpoint()

In [None]:
from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np

checkpointer = ocp.StandardCheckpointer()

abstract_model = nnx.eval_shape(lambda: ArtCNN(rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
print('The abstract NNX state (all leaves are abstract arrays):')
nnx.display(abstract_state)

state_restored = checkpointer.restore('/content/checkpoints/state', abstract_state)
print('NNX State restored: ')
nnx.display(state_restored)

model = nnx.merge(graphdef, state_restored)

In [None]:
# Make a single prediction
input = cv2.imread('/content/downscaled.png', cv2.IMREAD_COLOR)
input = cv2.cvtColor(input, cv2.COLOR_BGR2GRAY, 0)
input = np.array(input).astype(np.float32) / 255.0
input = np.clip(input, 0.0, 1.0)
input = np.expand_dims(input, axis=0)
input = np.expand_dims(input, axis=-1)

pred = model(jnp.array(input))
pred = np.clip(np.array(pred), 0.0, 1.0)
pred = np.squeeze(pred)
pred = pred * 255.0
pred = np.squeeze((np.around(pred)).astype(np.uint8))

cv2.imwrite('/content/prediction_jax.png', pred)

In [None]:
!pip install --upgrade tf2onnx
!pip install --upgrade onnx
!pip install --upgrade onnxsim

import jax
import jax.numpy as jnp
import tf2onnx
import onnx
import tensorflow as tf
from jax.experimental import jax2tf

# Convert JAX function to TensorFlow function
tf_function = tf.function(
    jax2tf.convert(model, polymorphic_shapes=["(1, h, w, 1)"], enable_xla=False),
    input_signature=[tf.TensorSpec([1, None, None, 1], dtype=tf.float32, name="input")]
)

# Export TensorFlow function to ONNX
onnx_model_path = "ArtCNN_R8F64_Flax.onnx"

# Note: input_signature must match the input signature used for tf.function()
onnx_model, _ = tf2onnx.convert.from_function(
    tf_function,
    input_signature=[tf.TensorSpec([1, None, None, 1], dtype=tf.float32, name="input")],
    output_path="ArtCNN_R8F64_Flax.onnx",  # File path for the ONNX model
    inputs_as_nchw=['input'],  # Specify input tensor name as 'input' (ensure NCHW format)
    outputs_as_nchw=['output'],  # Specify output tensor name as 'depth_to_space' (NCHW format)
    opset=17  # Use opset 13 to avoid issues with unsupported operations
)

!onnxsim ArtCNN_R8F64_Flax.onnx ArtCNN_R8F64_Flax.onnx