<a href="https://colab.research.google.com/github/afairley/ColaboratoryNotebooks/blob/main/ManagingParametersAndState.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#From https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/state_params.html
# and also https://github.com/google/flax/blob/main/docs/guides/flax_fundamentals/state_params.rst?plain=1#L59
import flax
from flax import linen as nn
from jax import random
import jax.numpy as jnp
import jax
import optax

# Create some fake data and run only for one epoch for testing.
dummy_input = jnp.ones((3, 4))
num_epochs = 1

class BiasAdderWithRunningMean(nn.Module):
  momentum: float = 0.9

  @nn.compact
  def __call__(self, x):
    is_initialized = self.has_variable('batch_stats','mean')
    mean = self.variable('batch_stats', 'mean', jnp.zeros,
                         x.shape[1:])
    bias = self.param('bias', lambda rng, shape: jnp.zeros(shape),
                      x.shape[1:])
    if is_initialized:
      mean.value =  (self.momentum * mean.value) + \
       (1.0 - self.momentum) * jnp.mean(x, axis=0, keepdims=True)
    return mean.value + bias
def update_step(apply_fn, x, opt_state, params, state):
    def loss(params):
      y, updated_state = apply_fn({'params': params, **state},
                                  x, mutable=list(state.keys()))
      lossValue = ((x -y) ** 2).sum()
      return lossValue, updated_state
    (l, updated_state), grads = jax.value_and_grad(
        loss, has_aux=True)(params)
    updates, opt_state = tx.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return opt_state, params, updated_state

model = BiasAdderWithRunningMean()
variables = model.init(random.key(0), dummy_input)
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)

for _ in range(num_epochs):
  opt_state, params, state = update_step(
      model.apply, dummy_input, opt_state, params, state)
