In [None]:
import jax.numpy as jnp
from flax import linen as nn
from jax import vmap, random, pmap, grad
from jax.lax import pmean

from typing import Any, Callable, Sequence, Optional, Union, Dict
from flax.training import train_state
from flax import struct
import optax

from functools import partial
from flax import jax_utils

In [10]:
class TrainState(train_state.TrainState):
    encode_fn: Callable = struct.field(pytree_node=False)
    decode_fn: Callable = struct.field(pytree_node=False)

In [4]:
class GaussianMLP(nn.Module):
    num_layers: int=2
    hidden_dim: int=64
    output_dim: int=1
    activation: Callable=nn.gelu

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = nn.Dense(self.hidden_dim)(x)
            x = self.activation(x)
        mu = nn.Dense(self.output_dim)(x)
        logsigma = nn.Dense(self.output_dim)(x)
        return mu, logsigma

In [None]:
class MLP(nn.Module):
    num_layers: int=2
    hidden_dim: int=64
    output_dim: int=1
    activation: Callable=nn.gelu

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = nn.Dense(self.hidden_dim)(x)
            x = self.activation(x)
        x = nn.Dense(self.output_dim)(x)
        return x
    
class GaussianMLP(nn.Module):
    num_layers: int=2
    hidden_dim: int=64
    output_dim: int=1
    activation: Callable=nn.gelu

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = nn.Dense(self.hidden_dim)(x)
            x = self.activation(x)
        mu = nn.Dense(self.output_dim)(x)
        logsigma = nn.Dense(self.output_dim)(x)
        return mu, logsigma
    
class MlpEncoder(nn.Module):
    latent_dim: int=8
    num_layers: int=2
    hidden_dim: int=64
    activation: Callable=nn.gelu

    @nn.compact
    def __call__(self, x, eps):
        mu, logsigma = GaussianMLP(self.num_layers, 
                                   self.hidden_dim,
                                   self.latent_dim,
                                   self.activation)(x)
        z = mu + eps*jnp.sqrt(jnp.exp(logsigma))
        kl_loss = 0.5*jnp.sum(jnp.exp(logsigma) + mu**2 - 1.0 - logsigma, axis=-1)
        return z, kl_loss

In [8]:
class VAE(nn.Module):
    encoder: nn.Module
    decoder: nn.Module

    @nn.compact
    def __call__(self, x, eps):
        z, _ = self.encoder(x, eps)
        x = self.decoder(z)
        return x

    def _encode(self, x, eps):
        z, kl_loss = self.encoder(x, eps)
        return z, kl_loss

    def _decode(self, z):
        x = self.decoder(z)
        return x

In [None]:
class VariationalAutoencoder:
    def __init__(self, config):

        # Define architecture
        encoder = GaussianMLP(
            num_layers=config.encoder.num_layers,
            hidden_dim=config.encoder.hidden_dim,
            output_dim=config.encoder.output_dim,
            activation=config.encoder.activation,
        )
        decoder = MLP(
            num_layers=config.decoder.num_layers,
            hidden_dim=config.decoder.hidden_dim,
            output_dim=config.decoder.output_dim,
            activation=config.decoder.activation,
        )
        arch = VAE(encoder, decoder)

        # Initialize params
        x = jnp.ones(config.input_dim)
        eps = jnp.ones(config.eps_dim)
        key = random.PRNGKey(config.seed)
        params = arch.init(key, x, eps)

        # Vectorized functions across a mini-batch
        apply_fn = vmap(arch.apply, in_axes=(None,0,0))
        encode_fn = vmap(lambda params, x, eps: arch.apply(params, x, eps, method=arch._encode), in_axes=(None,0,0))
        decode_fn = vmap(lambda params, z: arch.apply(params, z, method=arch._decode), in_axes=(None,0))

        # Optimizer
        lr = optax.exponential_decay(
            init_value=config.optimizer.learning_rate,
            transition_steps=config.optimizer.decay_steps,
            decay_rate=config.optimizer.decay_rate
        )
        tx = optax.adam(
            learning_rate=lr, 
            b1=config.optimizer.beta1, 
            b2=config.optimizer.beta2,
            eps=config.optimizer.eps
        )

        # Create state
        state = TrainState.create(
            apply_fn=apply_fn,
            params=params,
            tx=tx,
            encode_fn=encode_fn,
            decode_fn=decode_fn
        )

        # Replicate state across devices
        self.state = jax_utils.replicate(state) 
        self.beta = config.beta

    # Computes KL loss across a mini-batch
    def kl_loss(self, params, x, eps):
        _, loss = self.state.encode_fn(params, x, eps)
        return jnp.mean(loss)

    # Computes reconstruction loss across a mini-batch for a single MC sample
    def recon_loss(self, params, x, eps):
        outputs = self.state.apply_fn(params, x, eps)
        loss = jnp.mean((x-outputs)**2)
        return loss
    
    # Computes total loss across a mini-batch for multiple MC samples
    def loss(self, params, batch):
        x, eps = batch
        kl_loss = vmap(self.kl_loss, in_axes=(None,None,0))(params, x, eps)
        recon_loss = vmap(self.recon_loss, in_axes=(None,None,0))(params, x, eps)
        kl_loss = jnp.mean(kl_loss)
        recon_loss = jnp.mean(recon_loss)
        loss = self.beta*kl_loss + recon_loss
        return loss
    
    @partial(pmap, axis_name='num_devices', static_broadcasted_argnums=(0,))
    def eval_losses(self, params, batch):
        x, eps = batch
        kl_loss = vmap(self.kl_loss, in_axes=(None,None,0))(params, x, eps)
        recon_loss = vmap(self.recon_loss, in_axes=(None,None,0))(params, x, eps)
        kl_loss = jnp.mean(kl_loss)
        recon_loss = jnp.mean(recon_loss)
        return kl_loss, recon_loss

    # Define a compiled update step
    @partial(pmap, axis_name='num_devices', static_broadcasted_argnums=(0,))
    def step(self, state, batch):
        grads = grad(self.loss)(state.params, batch)
        grads = pmean(grads, 'num_devices')
        state = state.apply_gradients(grads=grads)
        return state