In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
import optax

class Encoder(nn.Module):
 latents: int


 @nn.compact
 def __call__(self, x):
   x = nn.Dense(500, name='fc1')(x)
   x = nn.relu(x)
   mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
   logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
   return mean_x, logvar_x

In [None]:
class Decoder(nn.Module):

 @nn.compact
 def __call__(self, z):
   z = nn.Dense(500, name='fc1')(z)
   z = nn.relu(z)
   z = nn.Dense(784, name='fc2')(z)
   return z

In [None]:
class VAE(nn.Module):
 latents: int = 20

 def setup(self):
   self.encoder = Encoder(self.latents)
   self.decoder = Decoder()

 def __call__(self, x, z_rng):
   mean, logvar = self.encoder(x)
   z = reparameterize(z_rng, mean, logvar)
   recon_x = self.decoder(z)
   return recon_x, mean, logvar

def reparameterize(rng, mean, logvar):
 std = jnp.exp(0.5 * logvar)
 eps = random.normal(rng, logvar.shape)
 return mean + eps * std

def model():
 return VAE(latents=LATENTS)

In [None]:
@jax.vmap
def kl_divergence(mean, logvar):
 return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
 logits = nn.log_sigmoid(logits)
 return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))


@jax.jit
def train_step(params, opt_state, batch, rng):
    def loss_fn(params):
        recon_x, mean, logvar = model().apply({'params': params}, batch, rng)
        bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
        kld_loss = kl_divergence(mean, logvar).mean()
        return bce_loss + kld_loss

    grads = jax.grad(loss_fn)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

In [None]:
# Key creation for random number generation
rng = random.PRNGKey(0)
rng, key = random.split(rng)

# Constants
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
LATENTS = 128
STEPS_PER_EPOCH = 50000 // BATCH_SIZE

init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
init_params = model().init(key, init_data, rng)['params']

optimizer = optax.adam(learning_rate=LEARNING_RATE)
opt_state = optimizer.init(init_params)

# Assuming train_ds is defined correctly to yield batches

for epoch in range(NUM_EPOCHS):
    for _ in range(STEPS_PER_EPOCH):
        batch = next(train_ds)
        rng, z_rng = random.split(rng)
        init_params, opt_state = train_step(init_params, opt_state, batch, z_rng)
    print(f'Epoch {epoch}, Opt_state: {opt_state}')


[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
        1.66956976e-03,  2.97446605e-02, -2.86155821e-17, -4.75970469e-03,
        2.77911010e-03, -3.08944117e-02,  0.00000000e+00,  0.00000000e+00,
        2.31014448e-03,  5.36582358e-02, -2.89061330e-02,  7.11585744e-05,
       -4.26229509e-03,  0.00000000e+00, -1.10686086e-02,  0.00000000e+00,
        7.87703693e-02, -1.95624884e-02,  0.00000000e+00, -1.05861314e-02,
       -8.19674805e-02, -4.96000126e-02,  2.53555889e-04,  0.00000000e+00,
        1.43997753e-02, -5.89821208e-03,  3.10101779e-03,  7.92386010e-03,
        1.79021936e-02, -1.06453821e-02, -4.26795147e-02, -7.04591069e-03,
       -1.06925487e-04, -1.67909043e-03, -4.25595877e-04,  1.78391906e-03,
        0.00000000e+00,  1.04138628e-02,  3.94990221e-02, -3.33353411e-03,
        6.78041996e-03, -5.18530898e-04, -3.56943440e-03,  2.29178858e-03,
        0.00000000e+00,  5.29542789e-02,  0.00000000e+00, -1.39920665e-02,
        

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

tf.config.experimental.set_visible_devices([], 'GPU')

def prepare_image(x):
 x = tf.cast(x['image'], tf.float32)
 x = tf.reshape(x, (-1,))
 return x

ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)[0])

Downloading and preparing dataset 104.68 MiB (download: 104.68 MiB, generated: Unknown size, total: 104.68 MiB) to /root/tensorflow_datasets/binarized_mnist/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/binarized_mnist/1.0.0.incompleteXLQPJ9/binarized_mnist-train.tfrecord*...:…

Generating validation examples...:   0%|          | 0/10000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/binarized_mnist/1.0.0.incompleteXLQPJ9/binarized_mnist-validation.tfrecord…

Generating test examples...:   0%|          | 0/10000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/binarized_mnist/1.0.0.incompleteXLQPJ9/binarized_mnist-test.tfrecord*...: …

Dataset binarized_mnist downloaded and prepared to /root/tensorflow_datasets/binarized_mnist/1.0.0. Subsequent calls will reuse this data.
