In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
import numpy as np
import optax
from datasets import load_dataset
from jax.random import PRNGKey, split

from aevb.core import AEVB


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
eqx.Mo

In [11]:
@eqx.nn.make_with_state
class RecModel(eqx.Module):

    latent_dim: int
    layers: list
    projection_layers: list

    def __init__(self, key: PRNGKey, latent_dim: int):

        keys = random.split(key, 6)
        self.latent_dim = latent_dim

        self.layers = [
            eqx.nn.Linear(in_features=784, out_features=512, key=keys[0]),
            eqx.nn.BatchNorm(input_size=512, axis_name="batch"),
            jax.nn.relu,
            eqx.nn.Linear(in_features=512, out_features=256, key=keys[1]),
            jax.nn.relu,
            eqx.nn.Linear(in_features=256, out_features=128, key=keys[2]),
            jax.nn.relu,
            eqx.nn.Linear(in_features=128, out_features=64, key=keys[3]),
        ]
        self.projection_layers = [
            eqx.nn.Linear(in_features=64, out_features=self.latent_dim, key=keys[4]),
            eqx.nn.Linear(in_features=64, out_features=self.latent_dim, key=keys[5]),
        ]

    def __call__(self, x, state):
        for layer in self.layers:
            if isinstance(layer, eqx.nn._batch_norm.BatchNorm):
                x, state = layer(x, state)
            else:
                x = layer(x)

        mu = self.projection_layers[0](x)
        logvar = self.projection_layers[1](x)
        sigma = jnp.exp(logvar * 0.5)
        return (mu, sigma), state

In [12]:
model = RecModel(random.key(0), 4)


_model, _ = model
_model.__annotations__.keys()

dict_keys(['latent_dim', 'layers', 'projection_layers'])

In [13]:
type(model[0]), type(model[1])

(__main__.RecModel, equinox.nn._stateful.State)