# Training a simple MLP with FLAX

In [2]:
# install flax if haven't done so
!pip install -q flax

In [9]:
# a toy example using flax
import numpy as np

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import optax # optimization library for JAX from deepmind

In [21]:
# create a dense layer
model = nn.Dense(features=5)

# Psedo random number generation
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
print(key, key1, key2, '\n')

# initialize input
x = jax.random.normal(key1, (10,))
params = model.init(key2, x)
print("x: {}\n\nparams: {}\n".format(x, params.keys()))

jax.tree_util.tree_map(lambda x: x.shape, params)

[0 0] [4146024105  967050713] [2718843009 1272950319] 

x: [-2.6105583   0.03385283  1.0863333  -1.4802988   0.48895672  1.062516
  0.54174834  0.0170228   0.2722685   0.30522448]

params: frozen_dict_keys(['params'])



FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [24]:
y = model.apply(params, x)
print(y)