In [23]:
import jax
import flax.linen as nn
from jax import random
import jax.numpy as jnp
import functools

In [24]:
dense1 = nn.Dense(features=1)
dense1_params = dense1.init(random.PRNGKey(0), jnp.ones((1)))

In [25]:
dense1_params

FrozenDict({
    params: {
        kernel: DeviceArray([[-1.3452858]], dtype=float32),
        bias: DeviceArray([0.], dtype=float32),
    },
})

In [33]:
dense1_params["params"]["kernel"]

DeviceArray([[-1.3452858]], dtype=float32)

In [26]:
@functools.partial(jax.vmap, in_axes=(None, 0))
def params_identity(params, batch):
    return params

In [27]:
batch = jnp.array([[0]]*4)
batch

DeviceArray([[0],
             [0],
             [0],
             [0]], dtype=int32)

In [37]:
batched_params = params_identity(dense1_params, batch)
batched_params["params"]["kernel"]

DeviceArray([[[-1.3452858]],

             [[-1.3452858]],

             [[-1.3452858]],

             [[-1.3452858]]], dtype=float32)

In [51]:
batched_params

FrozenDict({
    params: {
        bias: DeviceArray([[0.],
                     [0.],
                     [0.],
                     [0.]], dtype=float32),
        kernel: DeviceArray([[[-1.3452858]],
        
                     [[-1.3452858]],
        
                     [[-1.3452858]],
        
                     [[-1.3452858]]], dtype=float32),
    },
})

In [43]:
def print_params_leaves(batched_params):
      return jax.tree_map(
            lambda leaf: print(f'leaf: {type(leaf)}'),batched_params)

In [44]:
result = print_params_leaves(batched_params)

leaf: <class 'jaxlib.xla_extension.DeviceArray'>
leaf: <class 'jaxlib.xla_extension.DeviceArray'>


In [50]:
type(result)
result

FrozenDict({
    params: {
        bias: None,
        kernel: None,
    },
})

In [45]:
array = jnp.array([1,2,3])
array.mean()


DeviceArray(2., dtype=float32)

In [47]:
def params_mean(batched_params):
      return jax.tree_map(
            lambda leaf: leaf.mean(),
            batched_params)

In [48]:
params_mean(batched_params)

FrozenDict({
    params: {
        bias: DeviceArray(0., dtype=float32),
        kernel: DeviceArray(-1.3452858, dtype=float32),
    },
})