# `nnx` using `vmap` to create multiple models

In case of `linen`, you can you use the following initialaztion of several instaces:
```python
def make(rng):
  m = my_module.init(rng, dummy_input)
  return ...

rngs = jax.random.split(jax.random.PRNGKey(0), num=5)
models = jax.vmap(make)(rngs)
```

You can use nnx.split_rngs to automatically split the Rngs before going into nnx.vmap.

```python
@nnx.split_rngs(splits=5)
@nnx.vmap
def make_model(rngs):
  return nnx.Linear(2, 3, rngs=rngs)

model = make_model(nnx.Rngs(0))

print(model)
```

In [44]:
import jax
from flax import nnx
import jax.numpy as jnp

In [45]:
@nnx.split_rngs(splits=5)
@nnx.vmap
def make_model(rngs):
  return nnx.Linear(2, 3, rngs=rngs)

model = make_model(nnx.Rngs(0))

print(model)

[38;2;79;201;177mLinear[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 45 (180 B)[0m
  [38;2;156;220;254mbias[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 15 (60 B)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArray[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m5[0m, [38;2;182;207;169m3[0m[38;2;255;213;3m)[0m, [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0mdtype('float32')[38;2;255;213;3m)[0m
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mbias_init[0m[38;2;212;212;212m=[0m<function zeros at 0x7b4862d41080>,
  [38;2;156;220;254mdot_general[0m[38;2;212;212;212m=[0m<function dot_general at 0x7b486374cd60>,
  [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0m[38;2;86;156;214mNone[0m,
  [38;2;156;220;254min_features[0m[38;2;212;212;212m=[0m[38;2;182;207;169m2[0m,
  [38;2;156;220;254mkernel[0m[3

In [47]:
print("Model parameter shapes:")
print(f"Kernel shape: {model.kernel.value.shape}")
print(f"Bias shape: {model.bias.value.shape}\n")

Model parameter shapes:
Kernel shape: (5, 2, 3)
Bias shape: (5, 3)



In [56]:
num_devices = 1
per_device_batch = 4
num_heads = 2
din = 16
dout = 32

@nnx.split_rngs(splits=num_devices)
@nnx.pmap(in_axes=0, out_axes=0)  # device dim
@nnx.split_rngs(splits=per_device_batch)
@nnx.vmap(in_axes=0, out_axes=0)  # batch dim
@nnx.split_rngs(splits=num_heads)
@nnx.vmap(in_axes=0, out_axes=1)  # head dim
def make_model(rngs):
  return nnx.Linear(din, dout, rngs=rngs)

model = make_model(nnx.Rngs(0))
print(model)

[38;2;79;201;177mLinear[0m[38;2;255;213;3m([0m[38;2;105;105;105m # Param: 4,352 (17.4 KB)[0m
  [38;2;156;220;254mbias[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 256 (1.0 KB)[0m
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0m[38;2;79;201;177mArray[0m[38;2;255;213;3m([0m[38;2;156;220;254mshape[0m[38;2;212;212;212m=[0m[38;2;255;213;3m([0m[38;2;182;207;169m1[0m, [38;2;182;207;169m4[0m, [38;2;182;207;169m32[0m, [38;2;182;207;169m2[0m[38;2;255;213;3m)[0m, [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0mdtype('float32')[38;2;255;213;3m)[0m
  [38;2;255;213;3m)[0m,
  [38;2;156;220;254mbias_init[0m[38;2;212;212;212m=[0m<function zeros at 0x7b4862d41080>,
  [38;2;156;220;254mdot_general[0m[38;2;212;212;212m=[0m<function dot_general at 0x7b486374cd60>,
  [38;2;156;220;254mdtype[0m[38;2;212;212;212m=[0m[38;2;86;156;214mNone[0m,
  [38;2;156;220;254min_features[0m[38;2;212;212;212m=[0

In [57]:
print("Model parameter shapes:")
print(f"Kernel shape: {model.kernel.value.shape}")
print(f"Bias shape: {model.bias.value.shape}\n")

Model parameter shapes:
Kernel shape: (1, 4, 16, 2, 32)
Bias shape: (1, 4, 32, 2)



Notice vectorized number of paramers for two model above 45 vs 4,352 parameters. It is important what do you want exactly to vectorize. Maybe just shared parameters for each GPU.

In [50]:
# Configuration
batch = 2
num_heads = 3
head_dim = 4  # input dimension per head
dout = 5      # output dimension
seq_len = 6

# Initialize vectorized model
@nnx.split_rngs(splits=batch)
@nnx.vmap(in_axes=0)  # batch dimension
@nnx.split_rngs(splits=num_heads)
@nnx.vmap(in_axes=0)  # head dimension
def make_model(rngs):
    # Each Linear layer processes head_dim -> dout
    return nnx.Linear(head_dim, dout, rngs=rngs)

model = make_model(nnx.Rngs(0))

# Create sample input (batch, num_heads, seq_len, head_dim)
x = jnp.ones((batch, num_heads, seq_len, head_dim))

# Forward pass with vmap
def forward(model, x):
    @nnx.vmap(in_axes=(0, 0))  # batch dimension
    @nnx.vmap(in_axes=(0, 0))  # head dimension
    def forward_single(model, x):
        # x shape: (seq_len, head_dim)
        # model.kernel shape: (head_dim, dout)
        return model(x)

    return forward_single(model, x)

# Run and print shapes
print("Model parameter shapes:")
print(f"Kernel shape: {model.kernel.value.shape}")  # Should be (batch, num_heads, head_dim, dout)
print(f"Bias shape: {model.bias.value.shape}\n")   # Should be (batch, num_heads, dout)

print("Input/Output shapes:")
print(f"Input shape: {x.shape}")
output = forward(model, x)
print(f"Output shape: {output.shape}")

Model parameter shapes:
Kernel shape: (2, 3, 4, 5)
Bias shape: (2, 3, 5)

Input/Output shapes:
Input shape: (2, 3, 6, 4)
Output shape: (2, 3, 6, 5)
