In [1]:
# imports
import jax
from flax import nnx
import optax
import jax.numpy as jnp
import dataclasses

from typing import List

In [2]:
jax.devices()

[CudaDevice(id=0)]

# Configuration
This model has more parameters, let's create a special class with configuration parameters

# Configuration
This model has more parameters, let's create a special class with configuration parameters

In [31]:
@dataclasses.dataclass
class Config:
    embedded_size:int = 64

    # Image-related parameters
    image_channels:int = 3
    image_width:int = 256
    image_height:int = 256

    # Architecture related parameters
    num_residual_hiddens:int = 128
    num_residual_input:int = 32
    num_residual_layers_encoder:int = 2

    num_residual_layers_decoder:int = 2
    # Dataset-related parameters
    dataset_name:str = 'bitmind/ffhq-256_training_faces'
    batch_size:int = 32
    

In [4]:
class ResBlock(nnx.Module):
    """One block of the encoder with a residual connection"""
    def __init__(self, config:Config, rngs: nnx.Rngs):
        super().__init__()
        self.config= config
        self.conv1 = nnx.Conv(in_features=config.num_residual_input, out_features=config.num_residual_hiddens, kernel_size=(3,3), strides=1, padding='SAME', rngs=rngs)
        self.conv2 = nnx.Conv(in_features=config.num_residual_hiddens, out_features=config.num_residual_input, kernel_size=(1,1), strides=1, padding='SAME', rngs=rngs)
    def __call__(self, inputs:jax.Array):
        out = self.conv1(inputs)
        out = jax.nn.relu(out)
        out = self.conv2(out)
        out = jax.nn.relu(out)
        return out+inputs
    

In [33]:
rngs = nnx.Rngs(0)
config = Config()
res_block = ResBlock(config, rngs)

In [6]:
class Encoder(nnx.Module):
    """VQVAE encoder"""
    def __init__(self, config: Config, rngs: nnx.Rngs):
        super().__init__()
        self.config = config
        self.res_layers = [ResBlock(config, rngs) for _ in range(self.config.num_residual_layers_encoder)]
        # Three convolutions that reduce resolution by 2 and increase internal 
        self.initial_conv1 = nnx.Conv(in_features=config.image_channels, out_features=config.num_residual_input//2, kernel_size=(4,4), strides=(2, 2), padding='SAME', rngs=rngs)
        self.initial_conv2 = nnx.Conv(in_features=config.num_residual_input//2, out_features=config.num_residual_input, kernel_size=(4,4), strides=(2, 2), padding='SAME', rngs=rngs)
        
        
    def __call__(self, input: jax.Array):
        print('Input:', input.shape)
        x = jax.nn.relu(self.initial_conv1(input))
        print('E1 ', x.shape)
        x = jax.nn.relu(self.initial_conv2(x))
        print('E2 ', x.shape)
        for i, l in enumerate(self.res_layers):
            x = l(x)
            print('R', i, x.shape)
        # Flatten output into one vector
        x = x.flatten()
        print('O', x.shape)
        return x

In [7]:
rngs = nnx.Rngs(0)
config = Config()
encoder = Encoder(config, rngs)

In [8]:
class Decoder(nnx.Module):
    """ ... """
    def __init__(self, config: Config,  rngs: nnx.Rngs):
        super().__init__()
        self.config = config
        self.res_layers = [ResBlock(config, rngs) for _ in range(self.config.num_residual_layers_decoder)]
        self.conv_transpose1 = nnx.ConvTranspose(in_features=config.num_residual_input,out_features=config.num_residual_input//2, kernel_size=(4, 4), strides=2, padding="SAME", rngs=rngs)
        self.conv_transpose2 = nnx.ConvTranspose(in_features=config.num_residual_input//2, kernel_size=(4, 4),out_features=config.image_channels, strides=2, padding="SAME", rngs=rngs)
    def __call__(self, x: jax.Array):
        # Unflatten the input
        x = x.reshape((64, 64, -1))
        for l in self.res_layers:
            x = l(x)
        print('1:', x.shape)
        x = self.conv_transpose1(x)
        print('2:', x.shape)
        x = jax.nn.relu(x)
        x = self.conv_transpose2(x)
        return x

In [9]:
rngs = nnx.Rngs(0)
config = Config()
decoder = Decoder(config, rngs)

# Datasets

In [10]:
import datasets
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from IPython.display import display

In [12]:
dataset = datasets.load_dataset(config.dataset_name, 'base_transforms', streaming=True)


In [13]:
dataset

IterableDatasetDict({
    train: IterableDataset({
        features: ['image', 'original_index', 'landmark', 'mask'],
        n_shards: 14
    })
})

In [36]:
batched_data =  dataset['train'].batch(config.batch_size).map(lambda x: {"x": jnp.array([jnp.array(i) for i in x["image"]])}, remove_columns=dataset["train"].column_names, batched=True, batch_size=32)

In [39]:
for i in batched_data:
    image = i['x'] #Image.frombytes('RGB', data=i['x'].tobytes(), size=(256,256))
    break
image

(32, 256, 256, 3)

In [40]:
class EncoderDecoder(nnx.Module):
    def __init__(self, config:Config,  rngs: nnx.Rngs):
        super().__init__()
        self.config = config
        self.encoder = Encoder(config, rngs)
        self.decoder = Decoder(config, rngs)
    def __call__(self, in_image: jax.Array) -> List[jax.Array]:
        encoded = self.encoder(in_image)
        # TODO: add quantization
        decoded = self.decoder(encoded)
        return (decoded, encoded)
        

In [41]:
model = EncoderDecoder(config, rngs)

In [43]:
jax.vmap(model)(image)

Input: (256, 256, 3)
E1  (128, 128, 16)
E2  (64, 64, 32)
R 0 (64, 64, 32)
R 1 (64, 64, 32)
O (131072,)
1: (64, 64, 32)
2: (128, 128, 16)


(Array([[[[-3.56530094e+00, -3.03041697e-01,  8.68887901e-01],
          [ 5.74235773e+00,  2.10232735e+00, -6.66455555e+00],
          [ 7.78432846e+00, -2.96319580e+00,  2.71894455e+00],
          ...,
          [ 4.14463997e+00, -1.26242695e+01, -8.13631248e+00],
          [ 7.70458317e+00, -1.51355820e+01,  1.22178793e+01],
          [ 8.67223644e+00,  1.83337259e+00, -1.94448161e+00]],
 
         [[ 1.43042631e+01,  1.92804050e+00,  2.29953337e+00],
          [-9.73361206e+00,  1.16252680e+01, -1.11803761e+01],
          [ 3.14358845e+01,  7.59307146e-01, -1.25363064e+01],
          ...,
          [-2.16495895e+01,  6.62523746e+00, -1.47581291e+00],
          [ 2.95003510e+01,  1.37895527e+01, -6.79598856e+00],
          [-2.03991394e+01,  2.95774937e+00, -1.13983774e+01]],
 
         [[-6.03686142e+00, -3.52874017e+00,  1.42041245e+01],
          [-1.10348177e+01, -4.97590542e+00,  3.22436738e+00],
          [ 5.96610451e+00, -2.07357788e+01, -1.20878220e-01],
          ...,
    

# Training

In [45]:
def image_loss(orig, restored):
    return jax.numpy.mean((orig-restored)**2)

In [46]:
tx = optax.adam(0.001)
optimizer = nnx.Optimizer(model, tx)

In [55]:
def loss_fn(model, data):
    restored, _ = jax.vmap(model)(data)
    loss = image_loss(data, restored)
    return loss

In [56]:
@nnx.jit
def train_step(model: EncoderDecoder, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, data):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model, data)
    optimizer.update(grads)
    metrics.update(loss=loss)

In [57]:
metrics = nnx.MultiMetric(
  loss=nnx.metrics.Average('loss'),
)

In [None]:
for step, batch in enumerate(batched_data):
    data = batch['x']
    train_step(model, optimizer, metrics, data)
    if step%100 == 0:
        print([(m, v) for m,v in metrics.compute().items()])

[('loss', Array(15087.221, dtype=float32))]
[('loss', Array(1794.9474, dtype=float32))]
[('loss', Array(1019.8223, dtype=float32))]
[('loss', Array(731.798, dtype=float32))]
[('loss', Array(578.14636, dtype=float32))]
[('loss', Array(481.87878, dtype=float32))]
[('loss', Array(414.52884, dtype=float32))]
