In [4]:
import jax
import jax.numpy as jnp
from flax import nnx
import optax

In [5]:
class Model(nnx.Module):
  def __init__(self, din, dout, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, 64, rngs=rngs)
    self.linear2 = nnx.Linear(64, dout, rngs=rngs)

  def __call__(self, x):
    x = self.linear1(x)
    x = nnx.relu(x)
    x = self.linear2(x)
    return x

In [6]:
observationSize = 168
stackSize = 4

key = jax.random.key(0)
key, subkey = jax.random.split(key)
rngs = nnx.Rngs(subkey)
model = Model(observationSize, 16, rngs)

key, subkey = jax.random.split(key)
observationStack = jax.random.normal(subkey, (stackSize, observationSize))

denseObservationStack = jax.vmap(model)(observationStack)

print(denseObservationStack)








[[-0.6652626  -0.53902525 -0.09297132  0.49937177 -0.6603005   1.0518895
  -0.13109544 -0.6628766  -1.1291803  -0.11417593  1.3707572   1.0709394
   0.6735817   1.64906     0.34760642 -1.6613114 ]
 [ 0.99269426 -0.7585932   1.7138944   0.24401504  0.07184923 -0.1491603
   0.33781812  0.30689996  0.08213931 -0.40384504  0.03291181  0.40766206
  -0.5710783  -0.12206234  0.00263701  0.3509707 ]
 [ 0.00677812 -1.365413    0.56654704 -0.5230862   0.09809178  0.488294
  -0.0625405  -0.6276747   0.7260537   1.1030865  -0.08457478  0.572353
   0.46608803  0.6361121   0.9714385  -0.8715981 ]
 [-1.7621522  -0.19059354  0.29289463 -0.9432181  -0.91601753  0.7529558
   0.3522218   1.2541811  -0.38666964  0.7492006   0.7282804   0.17798842
  -0.48238662 -1.0095683   0.72713643 -0.8324332 ]]
