In [None]:
!pip install flax
!pip install dm_pix
!pip install treescope

import jax
import flax
import optax
import cv2
import glob
import dm_pix as pix
import numpy as np
import jax.numpy as jnp
from flax import nnx

# Settings
filters = 64
blocks = 8
kernel_size = (3, 3)

class ResBlock(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv0 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.conv1 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.conv2 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)

    def __call__(self, input):
        x = nnx.relu(self.conv0(input))
        x = nnx.relu(self.conv1(x))
        x = self.conv2(x)
        return x + input

class ArtCNN(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv0 = nnx.Conv(1, filters, kernel_size=kernel_size, rngs=rngs)
        self.res_blocks = [ResBlock(rngs=rngs) for _ in range(blocks)]
        self.conv1 = nnx.Conv(filters, filters, kernel_size=kernel_size, rngs=rngs)
        self.feats_conv = nnx.Conv(filters, 4, kernel_size=kernel_size, rngs=rngs)

    def __call__(self, x):
        conv0 = self.conv0(x)
        x = conv0
        for block in self.res_blocks:
            x = block(x)
        conv1 = self.conv1(x)
        features = self.feats_conv(conv1 + conv0)
        outputs = jnp.clip(pix.depth_to_space(features, 2), 0.0, 1.0)
        return outputs

model = ArtCNN(rngs=nnx.Rngs(0))
# nnx.display(model)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp /content/drive/MyDrive/Datasets/Anime_Train_HR.zip /content/HR1.zip
!cp /content/drive/MyDrive/Datasets/Digital_Art_Train_HR.zip /content/HR2.zip
!unzip /content/HR1.zip
!unzip /content/HR2.zip

In [None]:
!rm -rf /content/HR
!rm -rf /content/LR

In [None]:
filelist = sorted(glob.glob('/content/HR/*.png'))
train_ref = []
train_in = []

for myFile in filelist:
    image = cv2.imread(myFile, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY, 0)
    train_ref.append(image)
    image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_LINEAR_EXACT)
    train_in.append(image)

train_ref = np.array(train_ref).astype(np.float32) / 255.0
train_ref = np.clip(train_ref, 0.0, 1.0)
train_ref = np.expand_dims(train_ref, axis=-1)
print(train_ref.shape)
train_ref = jnp.array(train_ref)

train_in = np.array(train_in).astype(np.float32) / 255.0
train_in = np.clip(train_in, 0.0, 1.0)
train_in = np.expand_dims(train_in, axis=-1)
print(train_in.shape)
train_in = jnp.array(train_in)

In [None]:
optimizer = nnx.Optimizer(model, optax.adamw(0.0001))
batch_size = 16
num_epochs = 50
metrics_history = {'train_loss': []}

def create_batches(x, y, batch_size):
    num_samples = x.shape[0]
    for i in range(0, num_samples, batch_size):
        x_batch = x[i:i + batch_size]
        y_batch = y[i:i + batch_size]
        yield x_batch, y_batch

@nnx.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        y_pred = model(x)
        loss = jnp.mean(jnp.absolute(y - y_pred))
        return loss

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)
    return loss

# Training Loop
for epoch in range(num_epochs):
    print(f"\nStarting epoch {epoch + 1} / {num_epochs}")
    epoch_loss = []  # Track loss for this epoch

    for step, (x, y) in enumerate(create_batches(train_in, train_ref, batch_size)):
        loss = train_step(model, optimizer, x, y)

        metrics_history['train_loss'].append(loss)
        epoch_loss.append(loss)

        if step % 10 == 0:  # Print every 10 steps
            print(f"[train] epoch: {epoch + 1}, step: {step}, batch loss: {loss:.4f}")

    # Print epoch summary
    epoch_avg_loss = jnp.mean(jnp.array(epoch_loss))
    print(f"Epoch {epoch + 1} complete. Average loss: {epoch_avg_loss:.4f}")