In [1]:
import jax
import jax.numpy as jnp
import jax.random as random

from aevb.config import Config
from aevb.nets_flax import FlaxMLPEncoder, FlaxMLPDecoder
from aevb.nets_eqx import EqxMLPEncoder, EqxMLPDecoder

In [2]:
config = Config.from_yaml('./configs.yaml', override=["mnist", "small"])
config.prettyprint()


Config:
seed:            0            (int)
dataset:                      (str)
nnlib:           flax         (str)
jax.gpus:        [0]          (ints)
run.steps:       10000        (int)
run.eval_every:  100          (int)
data_shape:      [28, 28]     (ints)
latent_dim:      4            (int)
gen_hidden:      [128, 784]   (ints)
rec_hidden:      [128, 64]    (ints)
act:             relu         (str)
init.init:       normal       (str)
init.params:     [0.0, 0.01]  (floats)
opt.opt:         adagrad      (str)
opt.lr:          0.1          (float)


In [4]:
from typing import Iterable
from math import prod


activations = {
    "relu": jax.nn.relu
}


# FLAX

def gen_rec_flax_mlps(config):
    data_shape = config['data_shape']
    if isinstance(data_shape, Iterable):
        out_dim = prod(data_shape)
    else:
        out_dim = data_shape

    latent_dim = config['latent_dim']
    activation = activations[config['act']]
    
    gen_hidden = config['gen_hidden']
    rec_hidden = config['rec_hidden']
    
    gen_model = FlaxMLPDecoder(out_dim, gen_hidden, activation)
    rec_model = FlaxMLPEncoder(latent_dim, rec_hidden, activation)
    return gen_model, rec_model

gen_model, rec_model = gen_rec_flax_mlps(config)

x = jnp.ones((1, 784))
rec_params = rec_model.init(random.key(1), x)
rec_model.apply(rec_params, x)

z = jnp.ones((1, config['latent_dim']))
gen_params = gen_model.init(random.key(1), z)
gen_model.apply(gen_params, z)


# EQUINOX

def gen_rec_eqx_mlps(config):
    data_shape = config['data_shape']

    if isinstance(data_shape, Iterable):
        data_dim = prod(data_shape)
        in_dim, out_dim = [data_dim] * 2
    else:
        in_dim, out_dim = [data_shape] * 2

    latent_dim = config['latent_dim']
    activation = activations[config['act']]
    
    gen_hidden = config['gen_hidden']
    rec_hidden = config['rec_hidden']

    gen_model = EqxMLPDecoder(out_dim, latent_dim, gen_hidden, activation)
    rec_model = EqxMLPEncoder(in_dim, latent_dim, rec_hidden, activation)
    
    return gen_model, rec_model


gen_model, rec_model = gen_rec_eqx_mlps(config)
rec_model(random.key(1), x)
gen_model(random.key(1), z)
print('passed')


passed
