In [1]:
from typing import Callable

import flax.linen as nn
import jax.numpy as jnp
import jax.random as random

from aevb.core import AEVB

In [2]:
# Generative Model and Recognition Feature Extractor --------------------
class GenModel(nn.Module):
    @nn.compact
    def __call__(self, x, train: bool = False):
        x = nn.Dense(features=128)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(784)(x)
        return x


class RecModel(nn.Module):
    latent_dim: int

    @nn.compact
    def __call__(self, x, train: bool = False):
        x = nn.Dense(features=512)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=64)(x)
        mu = nn.Dense(features=self.latent_dim)(x)
        logvar = nn.Dense(features=self.latent_dim)(x)
        sigma = jnp.exp(logvar * 0.5)
        return mu, sigma

In [3]:
# Create AEVB inference engine
latent_dim = 4
gen_model = GenModel()
rec_model = RecModel(latent_dim)
import optax
optimizer = optax.adam(1e-3)

engine = AEVB(
    latent_dim=latent_dim,
    generative_model=gen_model,
    recognition_model=rec_model,
    optimizer=optimizer,
    n_samples=15,
    nn_lib="flax",
)