In [17]:
"""
A basic example creating two MultiLayer Perceptrons
"""
from flax import linen as nn
import optax
import jax
import jax.numpy as jnp

class MLP(nn.Module):        # Create a Flax Module dataclass
    out_dims : int
    
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(128)(x) # Create inline flax module submodules
        x = nn.relu(x)
        x = nn.Dense(self.out_dims)(x)   # shape inference
        return x
    
model = MLP(out_dims=10)                 # instantiate the MLP model
x = jnp.ones((4, 28, 28, 1))            # generate random data
variables = model.init(jax.random.PRNGKey(42), x)    # initialize weights
y = model.apply(variables, x)            # make a forward pass

In [20]:
print(y)

[[-0.05794942 -0.15916863 -0.9622741  -0.57689476 -0.21577388 -0.34986165
   0.24366494  1.4783851  -0.4489202   0.03279138]
 [-0.05794942 -0.15916863 -0.9622741  -0.57689476 -0.21577388 -0.34986165
   0.24366494  1.4783851  -0.4489202   0.03279138]
 [-0.05794942 -0.15916863 -0.9622741  -0.57689476 -0.21577388 -0.34986165
   0.24366494  1.4783851  -0.4489202   0.03279138]
 [-0.05794942 -0.15916863 -0.9622741  -0.57689476 -0.21577388 -0.34986165
   0.24366494  1.4783851  -0.4489202   0.03279138]]


In [27]:
class TwoMLP(nn.Module):
    out_dims : int
    
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = MLP(out_dims=20)(x)
        x = MLP(out_dims=self.out_dims)(x)
        return x

In [28]:
model = TwoMLP(out_dims=6)
x = jnp.ones((4, 28, 28, 1))            # generate random data
variables = model.init(jax.random.PRNGKey(42), x)    # initialize weights
y = model.apply(variables, x)
y

Array([[-0.00130665, -0.19587332,  0.11442596, -0.10978554, -0.08749157,
         0.07744953],
       [-0.00130665, -0.19587332,  0.11442596, -0.10978554, -0.08749157,
         0.07744953],
       [-0.00130665, -0.19587332,  0.11442596, -0.10978554, -0.08749157,
         0.07744953],
       [-0.00130665, -0.19587332,  0.11442596, -0.10978554, -0.08749157,
         0.07744953]], dtype=float32)