In [2]:
# ruff: noqa
from flax import linen as nn  # Linen API
import flax
import jax
import jax.numpy as jnp
from jax import random

# We create one dense layer(model) instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

In [17]:
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,))  # Dummy input data
params = model.init(key2, x)  # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params)  # Checking output shapes
# 10 rows (training examples) and 5 columns (features)

{'params': {'bias': (5,), 'kernel': (10, 5)}}

To conduct a forward pass with the model with a given set of parameters

In [18]:
model.apply(params, x)

Array([-1.471482  , -0.16962777,  0.12994334,  0.5224348 , -1.0919424 ],      dtype=float32)

In [19]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = (
    jnp.dot(x_samples, W)
    + b
    + 0.1 * random.normal(key_noise, (n_samples, y_dim))
)
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [20]:
# Same as JAX version but using model.apply().
@jax.jit
def mse(params, x_batched, y_batched):
    # Define the squared loss for a single pair (x,y)
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y - pred, y - pred) / 2.0

    # Vectorize the previous to compute the average of the loss on all samples.
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [21]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)


@jax.jit
def update_params(params, learning_rate, grads):
    params = jax.tree_util.tree_map(
        lambda p, g: p - learning_rate * g, params, grads
    )
    return params


for i in range(101):
    # Perform one gradient update.
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)
    if i % 10 == 0:
        print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.023887426
Loss step 0:  29.167133
Loss step 10:  0.6585776
Loss step 20:  0.2847663
Loss step 30:  0.15943417
Loss step 40:  0.097042
Loss step 50:  0.06356495
Loss step 60:  0.04459983
Loss step 70:  0.03334882
Loss step 80:  0.026425159
Loss step 90:  0.022047115
Loss step 100:  0.019225018


### Using Optax:

In [22]:
import optax

tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

In [23]:
for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print('Loss step {}: '.format(i), loss_val)

Loss step 0:  0.019003797
Loss step 10:  0.20683758
Loss step 20:  0.05855435
Loss step 30:  0.035229336
Loss step 40:  0.02006433
Loss step 50:  0.016343607
Loss step 60:  0.014753598
Loss step 70:  0.014125167
Loss step 80:  0.013897793
Loss step 90:  0.013803817
Loss step 100:  0.013771151


If you want to save your optimized parameters

In [24]:
from flax import serialization

bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)
print('Dict output')
print(dict_output)
print('Bytes output')
print(bytes_output)

Dict output
{'params': {'bias': Array([-2.4378996 , -2.07345   ,  0.18866651, -0.34174547, -0.77294886],      dtype=float32), 'kernel': Array([[ 0.99648714, -0.84069055, -0.7420563 , -1.1751673 , -0.88212633],
       [ 0.5849409 ,  0.6981158 , -1.0097519 ,  1.6940992 , -1.9293189 ],
       [-1.2534724 ,  0.19594063, -1.1317803 ,  0.23212159,  1.7384187 ],
       [ 0.54630667,  0.62694055, -0.4886673 ,  0.9249355 ,  0.3472473 ],
       [-0.6366119 , -0.96825767,  0.77167565,  1.0612171 ,  0.9169658 ],
       [-0.46407035,  0.8665426 ,  1.6763489 , -2.526709  ,  0.4433026 ],
       [ 1.8060067 , -1.2587421 , -0.57777286,  2.2216108 ,  0.76148087],
       [-0.21333523, -1.6236352 , -0.8071298 , -2.3925495 ,  1.5795202 ],
       [ 1.4050099 ,  0.34845546,  0.02730806,  1.1249974 , -0.21001515],
       [-0.30623946,  0.7411661 , -0.3360593 ,  0.46161816,  0.19004533]],      dtype=float32)}}
Bytes output
b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14\x8c\x06\x1c\xc0h\xb

Retrieving optimized parameters

In [25]:
serialization.from_bytes(params, bytes_output)

{'params': {'bias': array([-2.4378996 , -2.07345   ,  0.18866651, -0.34174547, -0.77294886],
        dtype=float32),
  'kernel': array([[ 0.99648714, -0.84069055, -0.7420563 , -1.1751673 , -0.88212633],
         [ 0.5849409 ,  0.6981158 , -1.0097519 ,  1.6940992 , -1.9293189 ],
         [-1.2534724 ,  0.19594063, -1.1317803 ,  0.23212159,  1.7384187 ],
         [ 0.54630667,  0.62694055, -0.4886673 ,  0.9249355 ,  0.3472473 ],
         [-0.6366119 , -0.96825767,  0.77167565,  1.0612171 ,  0.9169658 ],
         [-0.46407035,  0.8665426 ,  1.6763489 , -2.526709  ,  0.4433026 ],
         [ 1.8060067 , -1.2587421 , -0.57777286,  2.2216108 ,  0.76148087],
         [-0.21333523, -1.6236352 , -0.8071298 , -2.3925495 ,  1.5795202 ],
         [ 1.4050099 ,  0.34845546,  0.02730806,  1.1249974 , -0.21001515],
         [-0.30623946,  0.7411661 , -0.3360593 ,  0.46161816,  0.19004533]],
        dtype=float32)}}

In [39]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax  # for optimizers
from flax.training import train_state
import numpy as np

In [None]:
class GRUModel(nn.Module):
    hidden_size: int
    output_size: int

    @nn.compact
    def __call__(self, x):
        # x: [batch_size, time_steps, input_size]
        batch_size, time_steps, _ = x.shape

        # Specify the number of features for the GRU cell
        gru_cell = nn.GRUCell(features=self.hidden_size)

        # Initial hidden state
        h = jnp.zeros((batch_size, self.hidden_size))

        outputs = []

        for t in range(time_steps):
            h, _ = gru_cell(h, x[:, t])  # Apply GRU cell at each time step
            outputs.append(h)

        # Stack outputs: [time_steps, batch_size, hidden_size] → [batch_size, time_steps, hidden_size]
        outputs = jnp.stack(outputs, axis=1)

        # Final projection (e.g., for classification or regression)
        logits = nn.Dense(self.output_size)(outputs)

        return logits

In [47]:
class TrainState(train_state.TrainState):
    pass

In [48]:
def cross_entropy_loss(logits, labels):
    one_hot = jax.nn.one_hot(labels, logits.shape[-1])
    return optax.softmax_cross_entropy(logits, one_hot).mean()


@jax.jit
def train_step(state, batch):
    x, y = batch

    def loss_fn(params):
        logits = GRUModel(hidden_size=64, output_size=10).apply(
            {'params': params}, x
        )
        loss = cross_entropy_loss(logits[:, -1], y)  # last time step
        return loss

    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)

In [49]:
# Dummy data
x = np.random.randn(32, 10, 16).astype(
    jnp.float32
)  # [batch, time, input_size]
y = np.random.randint(0, 10, (32,))

model = GRUModel(hidden_size=64, output_size=10)
params = model.init(jax.random.PRNGKey(0), x)['params']

state = TrainState.create(
    apply_fn=model.apply, params=params, tx=optax.adam(1e-3)
)

# Train step
state = train_step(state, (x, y))

# Part 2 Creating Custom Models

In [4]:
from typing import Sequence
class Brolos(nn.Module):
    num_neurons_per_layer: Sequence[int]

    def setup(self):
        self.layers = [nn.Dense(n) for n in self.num_neurons_per_layer]
    
    def __call__(self, x):
        activation = x
        for i, layer in enumerate(self.layers):
            activation = layer(activation)
            if i < len(self.layers) - 1:
                activation = nn.relu(activation)
        return activation
x_key, init_key = random.split(random.key(0))

model = Brolos(num_neurons_per_layer=[16, 8, 1])
x = random.uniform(x_key, (4, 4))
params = model.init(init_key, x)
y = model.apply(params, x)
print(jax.tree_util.tree_map(lambda x: x.shape, params))
print(f"Output: {y}")

{'params': {'layers_0': {'bias': (16,), 'kernel': (4, 16)}, 'layers_1': {'bias': (8,), 'kernel': (16, 8)}, 'layers_2': {'bias': (1,), 'kernel': (8, 1)}}}
Output: [[-0.12992406]
 [-0.02715186]
 [-0.10914332]
 [-0.13801154]]


In [6]:
# instead of using setup, we can use the @nn.compact decorator
class BrolosCompact(nn.Module):
    num_neurons_per_layer: Sequence[int]

    @nn.compact
    def __call__(self, x):
        activation = x
        for i, n in enumerate(self.num_neurons_per_layer):
            activation = nn.Dense(n)(activation)
            if i < len(self.num_neurons_per_layer) - 1:
                activation = nn.relu()
        return activation
    
model = Brolos(num_neurons_per_layer=[16, 8, 1])
x = random.uniform(x_key, (4, 4))
params = model.init(init_key, x)
y = model.apply(params, x)
print(jax.tree_util.tree_map(lambda x: x.shape, params))
print(f"Output: {y}")

{'params': {'layers_0': {'bias': (16,), 'kernel': (4, 16)}, 'layers_1': {'bias': (8,), 'kernel': (16, 8)}, 'layers_2': {'bias': (1,), 'kernel': (8, 1)}}}
Output: [[-0.12992406]
 [-0.02715186]
 [-0.10914332]
 [-0.13801154]]


In [None]:
from typing import Callable
from pprint import pprint

class MyDenseImp(nn.Module):
    num_neurons: int
    weight_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros

    @nn.compact
    def __call__(self, x):
        # name that appears in the dict
        weight = self.param('weight', self.weight_init, # initialization function
                             (x.shape[-1], self.num_neurons))
        bias = self.param('bias', self.bias_init, (self.num_neurons,))
        return jnp.dot(x, weight) + bias
    
x_key, init_key = random.split(random.key(0))

model = MyDenseImp(num_neurons=3)
x = random.uniform(x_key, (4, 4))
params = model.init(init_key, x)
y = model.apply(params, x)
print(jax.tree_util.tree_map(lambda x: x.shape, params))
pprint(f"Output: {y}")

{'params': {'bias': (3,), 'weight': (4, 3)}}


AttributeError: 'function' object has no attribute 'pprint'