In [1]:
# imports
import jax
from flax import nnx
import optax
import jax.numpy as jnp
import dataclasses

from typing import List

In [2]:
jax.devices()

[CudaDevice(id=0)]

# Configuration
This model has more parameters, let's create a special class with configuration parameters

# Configuration
This model has more parameters, let's create a special class with configuration parameters

In [65]:
@dataclasses.dataclass
class Config:
    embedded_size:int = 64

    # Image-related parameters
    image_channels:int = 3
    image_width:int = 256
    image_height:int = 256

    # Architecture related parameters
    num_residual_hiddens:int = 128
    num_residual_input:int = 32
    num_residual_layers_encoder:int = 2

    num_residual_layers_decoder:int = 2
    # Dataset-related parameters
    dataset_name:str = 'bitmind/ffhq-256_training_faces'
    

In [66]:
class ResBlock(nnx.Module):
    """One block of the encoder with a residual connection"""
    def __init__(self, config:Config, rngs: nnx.Rngs):
        super().__init__()
        self.config= config
        self.conv1 = nnx.Conv(in_features=config.num_residual_input, out_features=config.num_residual_hiddens, kernel_size=(3,3), strides=1, padding='SAME', rngs=rngs)
        self.conv2 = nnx.Conv(in_features=config.num_residual_hiddens, out_features=config.num_residual_input, kernel_size=(1,1), strides=1, padding='SAME', rngs=rngs)
    def __call__(self, inputs:jax.Array):
        out = self.conv1(inputs)
        out = jax.nn.relu(out)
        out = self.conv2(out)
        out = jax.nn.relu(out)
        return out+inputs
    

In [13]:
rngs = nnx.Rngs(0)
config = Config()
res_block = ResBlock(config, rngs)

In [29]:
test_image=jnp.ones((config.image_height, config.image_width, 3))
test_image.shape, res_block(test_image).shape            

((256, 256, 3), (256, 256, 3))

In [79]:
class Encoder(nnx.Module):
    """VQVAE encoder"""
    def __init__(self, config: Config, rngs: nnx.Rngs):
        super().__init__()
        self.config = config
        self.res_layers = [ResBlock(config, rngs) for _ in range(self.config.num_residual_layers_encoder)]
        # Three convolutions that reduce resolution by 2 and increase internal 
        self.initial_conv1 = nnx.Conv(in_features=config.image_channels, out_features=config.num_residual_input//2, kernel_size=(4,4), strides=(2, 2), padding='SAME', rngs=rngs)
        self.initial_conv2 = nnx.Conv(in_features=config.num_residual_input//2, out_features=config.num_residual_input, kernel_size=(4,4), strides=(2, 2), padding='SAME', rngs=rngs)
        
        
    def __call__(self, input: jax.Array):
        print('Input:', input.shape)
        x = jax.nn.relu(self.initial_conv1(input))
        print('E1 ', x.shape)
        x = jax.nn.relu(self.initial_conv2(x))
        print('E2 ', x.shape)
        for i, l in enumerate(self.res_layers):
            x = l(x)
            print('R', i, x.shape)
        # Flatten output into one vector
        x = x.flatten()
        print('O', x.shape)
        return x

In [80]:
rngs = nnx.Rngs(0)
config = Config()
encoder = Encoder(config, rngs)

In [77]:
class Decoder(nnx.Module):
    """ ... """
    def __init__(self, config: Config,  rngs: nnx.Rngs):
        super().__init__()
        self.config = config
        self.conv1 = nnx.Conv(in_features=1, out_features=config.num_residual_input, kernel_size=(3,3), strides=1, padding='SAME', rngs=rngs)
        self.res_layers = [ResBlock(config, rngs) for _ in range(self.config.num_residual_layers_decoder)]
        self.conv_transpose1 = nnx.ConvTranspose(in_features=config.num_residual_input,out_features=config.num_residual_input//2, kernel_size=(4, 4), strides=2, padding=1, rngs=rngs)
        self.conv_transpose2 = nnx.ConvTranspose(in_features=config.num_residual_input//2, kernel_size=(4, 4),out_features=config.image_channels, strides=2, padding=1, rngs=rngs)
    def __call__(self, x: jax.Array):
        # 
        x = self.conv1(x)
        for l in self.res_layers:
            x = l(x)
        print('1:', x.shape)
        x = self.conv_transpose1(x)
        print('2:', x.shape)
        x = jax.nn.relu(x)
        x = self.conv_transpose2(x)
        return x

SyntaxError: expected argument value expression (3664749262.py, line 9)

In [41]:
rngs = nnx.Rngs(0)
config = Config()
decoder = Decoder(config, rngs)

# Datasets

In [20]:
import datasets
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from IPython.display import display

In [26]:
dataset = datasets.load_dataset(config.dataset_name, 'base_transforms')
images = dataset['train']['image']

In [27]:
for d in images:
    print(jnp.array(d).shape)
    break

(256, 256, 3)


In [81]:
encoded=encoder(jnp.array(images[0]))

Input: (256, 256, 3)
E1  (128, 128, 16)
E2  (64, 64, 32)
R 0 (64, 64, 32)
R 1 (64, 64, 32)
O (131072,)


In [42]:
decoder(encoded).shape

1: (256, 256, 3)
2: (510, 510, 1)


(1018, 1018, 3)

# Training

In [None]:
def image_loss(orig, restored):
    return jax.numpy.mean((org-restored)**2)