In [3]:
import jax
from jax import random, numpy as jnp
import flax
from flax import linen as nn

In [2]:
model = nn.Dense(features=5)

print(model)

Dense(
    # attributes
    features = 5
    use_bias = True
    dtype = None
    param_dtype = float32
    precision = None
    kernel_init = init
    bias_init = zeros
    dot_general = None
    dot_general_cls = None
)


In [5]:
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,))  # dummy input
params = model.init(key2, x)

jax.tree_util.tree_map(lambda x: x.shape, params)

{'params': {'bias': (5,), 'kernel': (10, 5)}}

In [7]:
result = model.apply(params, x)

print(type(result))
result

<class 'jaxlib.xla_extension.ArrayImpl'>


Array([-1.3721199 ,  0.611315  ,  0.64428365,  2.2192967 , -1.1271119 ],      dtype=float32)

In [8]:
from icecream import ic

# set problem dimensions
n_samples = 20
x_dim = 10
y_dim = 5

key = random.key(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))

true_params = flax.core.freeze({"params": {"bias": b, "kernel": W}})


# Generate Samples with noise

key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = (
    jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise, (n_samples, y_dim))
)

ic(x_samples.shape)
ic(y_samples.shape)

ic| x_samples.shape: (20, 10)
ic| y_samples.shape: (20, 5)


(20, 5)

In [9]:
@jax.jit
def mse(params, x_batched, y_batched):
    # define the squared error loss
    def squared_error(x, y):
        pred = model.apply(params, x)
        return jnp.inner(y - pred, y - pred) / 2

    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

In [10]:
learning_rate = 0.3
mse(true_params, x_samples, y_samples)

Array(0.0236398, dtype=float32)

In [11]:
loss_grad_fn = jax.value_and_grad(mse)


@jax.jit
def update_params(params, learning_rate, grads):
    params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
    return params


for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    params = update_params(params, learning_rate, grads)  # type: ignore
    if i % 10 == 0:
        print(f"Step {i+1}, Loss {loss_val}")

Step 1, Loss 35.343875885009766
Step 11, Loss 0.5143467783927917
Step 21, Loss 0.11384160816669464
Step 31, Loss 0.039326731115579605
Step 41, Loss 0.01991620473563671
Step 51, Loss 0.014209133572876453
Step 61, Loss 0.012425646185874939
Step 71, Loss 0.01185037661343813
Step 81, Loss 0.011661776341497898
Step 91, Loss 0.011599412187933922
Step 101, Loss 0.01157870702445507


In [12]:
import optax

tx = optax.adam(learning_rate=learning_rate)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)

for i in range(101):
    loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    if i % 10 == 0:
        print(f"Step {i+1}, Loss {loss_val}")

Step 1, Loss 0.011577614583075047
Step 11, Loss 0.26143109798431396
Step 21, Loss 0.07674869149923325
Step 31, Loss 0.03644001856446266
Step 41, Loss 0.022012466564774513
Step 51, Loss 0.016178492456674576
Step 61, Loss 0.013002862222492695
Step 71, Loss 0.01202611718326807
Step 81, Loss 0.011764513328671455
Step 91, Loss 0.011646024882793427
Step 101, Loss 0.011585516855120659


In [13]:
from flax import serialization

bytes_output = serialization.to_bytes(params)
dict_output = serialization.to_state_dict(params)

ic(dict_output)
ic(bytes_output)

ic| dict_output: {'params': {'bias': Array([-1.4555763 , -2.027799  ,  2.0790977 ,  1.2186146 , -0.99809825],      dtype=float32),
                             'kernel': Array([[ 1.0098811 ,  0.18934365,  0.04455001, -0.92802244,  0.34784022],
                        [ 1.7298453 ,  0.987937  ,  1.1640465 ,  1.1006079 , -0.10653906],
                        [-1.2029463 ,  0.28635174,  1.415598  ,  0.11870932, -1.3141485 ],
                        [-1.1941485 , -0.189585  ,  0.03413848,  1.3169427 ,  0.08060391],
                        [ 0.13852438,  1.3713042 , -1.3187189 ,  0.53152686, -2.2404995 ],
                        [ 0.56294   ,  0.81223136,  0.31752002,  0.5345511 ,  0.90500396],
                        [-0.37926036,  1.7410394 ,  1.079029  , -0.5039834 ,  0.9283064 ],
                        [ 0.97064894, -1.3153405 ,  0.33681473,  0.8099342 , -1.2018456 ],
                        [ 1.019431  , -0.6202477 ,  1.0818828 , -1.8389741 , -0.45804927],
                        [-0.

b'\x81\xa6params\x82\xa4bias\xc7!\x01\x93\x91\x05\xa7float32\xc4\x14SP\xba\xbfu\xc7\x01\xc0\xf0\x0f\x05@\x90\xfb\x9b?^\x83\x7f\xbf\xa6kernel\xc7\xd6\x01\x93\x92\n\x05\xa7float32\xc4\xc8\xc9C\x81?M\xe3A>\x11z6=\xe1\x92m\xbf\x1d\x18\xb2>\x92k\xdd?p\xe9|?z\xff\x94?\xb8\xe0\x8c?&1\xda\xbd%\xfa\x99\xbf\xb2\x9c\x92>Q2\xb5?\xdf\x1d\xf3=\x056\xa8\xbf\xdc\xd9\x98\xbf\x92"B\xbe\xcb\xd4\x0b=\x94\x91\xa8?\xaa\x13\xa5=V\xd9\r>\xe5\x86\xaf?\xc8\xcb\xa8\xbf%\x12\x08?Xd\x0f\xc0\xd6\x1c\x10?e\xeeO?\xfc\x91\xa2>W\xd8\x08?W\xaeg?j.\xc2\xbea\xda\xde?\x9f\x1d\x8a?\x0e\x05\x01\xbf}\xa5m?s|x?\x14]\xa8\xbf\xfbr\xac>\xd9WO?\x14\xd6\x99\xbf\xb7|\x82?\x8e\xc8\x1e\xbf#{\x8a?\x81c\xeb\xbfo\x85\xea\xbe{\xc6$\xbfH\xd0\xe9>P\x03\x91\xbf|u/\xbf/T,>'

In [14]:
serialization.from_bytes(params, bytes_output)

{'params': {'bias': array([-1.4555763 , -2.027799  ,  2.0790977 ,  1.2186146 , -0.99809825],
        dtype=float32),
  'kernel': array([[ 1.0098811 ,  0.18934365,  0.04455001, -0.92802244,  0.34784022],
         [ 1.7298453 ,  0.987937  ,  1.1640465 ,  1.1006079 , -0.10653906],
         [-1.2029463 ,  0.28635174,  1.415598  ,  0.11870932, -1.3141485 ],
         [-1.1941485 , -0.189585  ,  0.03413848,  1.3169427 ,  0.08060391],
         [ 0.13852438,  1.3713042 , -1.3187189 ,  0.53152686, -2.2404995 ],
         [ 0.56294   ,  0.81223136,  0.31752002,  0.5345511 ,  0.90500396],
         [-0.37926036,  1.7410394 ,  1.079029  , -0.5039834 ,  0.9283064 ],
         [ 0.97064894, -1.3153405 ,  0.33681473,  0.8099342 , -1.2018456 ],
         [ 1.019431  , -0.6202477 ,  1.0818828 , -1.8389741 , -0.45804927],
         [-0.6436536 ,  0.45666718, -1.1329136 , -0.6853864 ,  0.16828988]],
        dtype=float32)}}

In [16]:
from typing import Any, Callable, Sequence


class ExplicitMLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        # we automatically know what to do with lists, dicts of submodules
        self.layers = [nn.Dense(feat) for feat in self.features]
        # for single submodules, we would just write:
        # self.layer1 = nn.Dense(feat1)

    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate(self.layers):
            x = lyr(x)
            if i != len(self.layers) - 1:
                x = nn.relu(x)
        return x


key1, key2 = random.split(random.key(0), 2)
x = random.uniform(key1, (4, 4))

model = ExplicitMLP(features=[3, 4, 5])
params = model.init(key2, x)
y = model.apply(params, x)

print(
    "initialized parameter shapes:\n",
    jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)),
)
print("output:\n", y)

initialized parameter shapes:
 {'params': {'layers_0': {'bias': (3,), 'kernel': (4, 3)}, 'layers_1': {'bias': (4,), 'kernel': (3, 4)}, 'layers_2': {'bias': (5,), 'kernel': (4, 5)}}}
output:
 [[ 0.          0.          0.          0.          0.        ]
 [ 0.00723789 -0.00810346 -0.02550935  0.02151712 -0.01261239]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]


For demonstration purposes, we’ll implement a simplified but similar mechanism to batch normalization: we’ll store running averages and subtract those to the input at training time. For proper batchnorm, you should use (and look at) the implementation [here](https://github.com/google/flax/blob/main/flax/linen/normalization.py).

In [17]:
class BiasAdderWithRunningMean(nn.Module):
    decay: float = 0.99

    @nn.compact
    def __call__(self, x):
        # easy pattern to detect if we're initializing via empty variable tree
        is_initialized = self.has_variable("batch_stats", "mean")
        ra_mean = self.variable(
            "batch_stats", "mean", lambda s: jnp.zeros(s), x.shape[1:]
        )
        bias = self.param("bias", lambda rng, shape: jnp.zeros(shape), x.shape[1:])
        if is_initialized:
            ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(
                x, axis=0, keepdims=True
            )

        return x - ra_mean.value + bias


key1, key2 = random.split(random.key(0), 2)
x = jnp.ones((10, 5))
model = BiasAdderWithRunningMean()

variables = model.init(key1, x)
print("initialized variables:\n", variables)

y, updated_state = model.apply(variables, x, mutable=["batch_stats"])
print("updated state:\n", updated_state)

initialized variables:
 {'batch_stats': {'mean': Array([0., 0., 0., 0., 0.], dtype=float32)}, 'params': {'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}


In [18]:
for val in [1.0, 2.0, 3.0]:
    x = val * jnp.ones((10, 5))
    y, updated_state = model.apply(variables, x, mutable=["batch_stats"])
    old_state, params = flax.core.pop(variables, "params")
    variables = flax.core.freeze({"params": params, **updated_state})
    print("updated state:\n", updated_state)  # Shows only the mutable part

updated state:
 {'batch_stats': {'mean': Array([[0.01, 0.01, 0.01, 0.01, 0.01]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.0299, 0.0299, 0.0299, 0.0299, 0.0299]], dtype=float32)}}
updated state:
 {'batch_stats': {'mean': Array([[0.059601, 0.059601, 0.059601, 0.059601, 0.059601]], dtype=float32)}}
