In [1]:
import jax

In [2]:
jax.devices()

[CudaDevice(id=0)]

In [3]:
from flax import nnx
import optax
import jax.numpy as jnp


In [4]:
class Encoder(nnx.Module):
    """VAE Encoder"""
    def __init__(self, latent_dim: int, image_flatten_dim: int, internal_dim:int, rngs: nnx.Rngs):
        self.dense_layer = nnx.Linear(in_features=image_flatten_dim, out_features=internal_dim, rngs=rngs)
        self.mean_layer = nnx.Linear(in_features=internal_dim, out_features=latent_dim, rngs=rngs)
        self.logvar_layer = nnx.Linear(in_features=internal_dim, out_features=latent_dim, rngs=rngs)
    def __call__(self, x: jax.Array):
        x = self.dense_layer(x)
        x = jax.nn.relu(x)
        means = self.mean_layer(x)
        logvars = self.logvar_layer(x)
        return (means, logvars)
        

In [5]:
class Decoder(nnx.Module):
    """VAE Decoder"""
    def __init__(self, latent_dim: int, image_flatten_dim: int, internal_dim:int, rngs: nnx.Rngs):
        self.dense_layer1 = nnx.Linear(in_features=latent_dim, out_features=internal_dim, rngs=rngs)
        self.dense_layer2 = nnx.Linear(in_features=internal_dim, out_features=image_flatten_dim, rngs=rngs)
    def __call__(self, x: jax.Array):
        x = self.dense_layer1(x)
        x = jax.nn.relu(x)
        x = self.dense_layer2(x)
        return x

In [6]:
class VAE(nnx.Module):
    """VAE together"""
    def __init__(self, latent_dim:int, image_flatten_dim:int, internal_dim:int, rngs: nnx.Rngs):
        self.encoder = Encoder(latent_dim=latent_dim, image_flatten_dim=image_flatten_dim, internal_dim=internal_dim, rngs=rngs)
        self.decoder = Decoder(latent_dim=latent_dim, image_flatten_dim=image_flatten_dim, internal_dim=internal_dim, rngs=rngs)
        self.random_for_reparam = rngs.reparam()
    def __call__(self, x: jax.Array):
        mean, logvar = self.encoder(x)
        x = self.reparam(mean, logvar)
        x = self.decoder(x)
        return x, mean, logvar

    def reparam(self, mean: jax.Array, logvars: jax.Array):
        std = jnp.exp(logvars/2.0)
        eps = jax.random.normal(self.random_for_reparam, logvars.shape)
        return mean + eps*std
        

In [7]:
vae = VAE(16, 32*32, 512, nnx.Rngs(0))
y = vae(jnp.ones((32*32)))
y

2024-10-07 14:41:20.767516: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


(Array([-0.45001116,  0.34558707,  0.28287855, ..., -1.2844074 ,
         0.22410434, -0.27409077], dtype=float32),
 Array([ 0.5681723 , -0.06791615, -0.24735603, -0.39458978, -0.9860513 ,
         0.74274313, -0.13057835, -0.5456586 , -1.2630452 ,  0.4183552 ,
        -0.09809393, -1.2089942 , -1.4128064 , -0.9893104 ,  0.6430859 ,
        -0.7321053 ], dtype=float32),
 Array([ 0.6757376 , -1.2812955 ,  0.6903671 , -0.23399292,  0.22029123,
        -1.1534731 , -0.55614877, -1.0910658 , -0.0902662 ,  0.5463139 ,
        -0.0121367 , -1.4974947 ,  0.03824741, -0.61413294, -1.1627892 ,
        -0.24757725], dtype=float32))

In [14]:
import datasets
from PIL import Image
from IPython.display import display

In [10]:
dataset = datasets.load_dataset("ylecun/mnist")

README.md:   0%|          | 0.00/6.97k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/15.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.60M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [68]:
fl = dataset["train"]["image"][0]

In [74]:
jnp.array([jnp.array([1,2]), jnp.array([3,4])])

Array([[1, 2],
       [3, 4]], dtype=int32)

In [79]:
iter_data = dataset["train"].map(lambda x: {"x": jnp.array([jnp.array(i) for i in x["image"]])}, remove_columns=dataset["train"].column_names, batched=True, batch_size=32)

Map:   0%|          | 0/60000 [00:00<?, ? examples/s]

# Training

In [18]:
def kl_divergence(mean, logvar):
  return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

In [42]:
def reconstruct_loss(logits, labels):
    logits = jax.nn.log_sigmoid(logits)
    return -jnp.sum(
        labels * logits + (1.0 - labels) * jnp.log(-jnp.expm1(logits))
    )

In [91]:
model = VAE(10, 28*28, 256, nnx.Rngs(0))
tx = optax.adam(0.001)
state = nnx.Optimizer(model, tx)

In [104]:
def build_loss_fn(model):
    def loss_fn(data):
        logits, mean, logvar = model(data)
        divergence = kl_divergence(mean, logvar)
        reconstruct = reconstruct_loss(logits, data)
        print('d', divergence, 'r', reconstruct)
        return divergence + reconstruct
    return loss_fn

In [109]:
loss_fn = build_loss_fn(model)
loss_fn(jnp.array(jnp.array(iter_data['x'][0])!=0, dtype=jnp.float32).reshape(-1, 28*28))

d 0.89737016 r 573.34796


Array(574.24536, dtype=float32)

In [None]:
grad = nnx.grad(loss_fn)(jnp.array(jnp.array(iter_data['x'][0])!=0, dtype=jnp.float32).reshape(-1, 28*28))
grad

In [111]:
state.update(grad)

TypeError: unsupported operand type(s) for *: 'float' and 'State'