- Create and train an ensemble of neural networks efficiently using `jax.vmap` and PyTorch's functional batching tools
- Compare vectorized neural network training in JAX/Flax and PyTorch with functional APIs
- Use a simple regression task with synthetic data
- Build minimal MLP models in both Flax (JAX) and PyTorch
- Illustrate how to initialize, structure, and parallelize computation for multiple models
- Highlight similarities and differences in workflow between JAX and PyTorch functional approaches


In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import torch
import torch.nn as tnn
import torch.func as tfunc

# Data (Same for both)
x_in = np.linspace(-1, 1, 100)[:, None].astype(np.float32)
y_true = 2 * x_in + 1 + np.random.randn(100, 1).astype(np.float32) * 0.1

x_jax, y_jax = jnp.array(x_in), jnp.array(y_true)
x_torch, y_torch = torch.from_numpy(x_in), torch.from_numpy(y_true)

#### Define the model in Flax

In [None]:
# 1. Model
class FlaxMLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Dense(1)(x)

# 2. State Creator (Single Model)
def create_state(rng):
    model = FlaxMLP()
    params = model.init(rng, jnp.ones((1, 1)))['params']
    tx = optax.adam(0.01)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# 3. Vectorized Creation (5 Models)
keys = jax.random.split(jax.random.PRNGKey(0), 5)
ensemble_state = jax.vmap(create_state)(keys)

print(f"Flax Params Shape: {ensemble_state.params['Dense_0']['kernel'].shape}")
# Output: (5, 1, 1) -> (Ensemble, Input, Output)

#### Define the training step in Flax

In [None]:
# 1. Single Step Logic
def train_step(state, x, y):
    def loss_fn(params):
        pred = state.apply_fn({'params': params}, x)
        return jnp.mean((pred - y) ** 2)
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads), loss

# 2. Vectorize it
# Map over state (axis 0), broadcast data (None)
# `in_axes` specifies how the inputs are mapped across the leading axis during `vmap`.
# Here, `state` has shape (ensemble, ...), so we set in_axes=0 for it to map across models,
# while `x` and `y` are broadcasted (same batch for all models), so we use None for those.
ensemble_step = jax.jit(jax.vmap(train_step, in_axes=(0, None, None)))

# 3. Run
for i in range(100):
    ensemble_state, losses = ensemble_step(ensemble_state, x_jax, y_jax)

print(f"Loss: ", losses)

#### Define the model in functional Pytorch

PyTorch objects (self.layer) are stateful. 

To vmap them, we must "extract" the state and turn the model into a pure function 
using `torch.func.functional_call`.

In [None]:
# 1. Define Model
class TorchMLP(tnn.Module):
    def __init__(self):
        super().__init__()
        self.layer = tnn.Linear(1, 1)
    def forward(self, x):
        return self.layer(x)

# 2. Instantiate 5 Models (List of objects)
models = [TorchMLP() for _ in range(5)]

# 3. Stack Parameters
# We must manually extract weights from objects and stack them
params, buffers = tfunc.stack_module_state(models)

# Now 'params' is a dictionary of stacked tensors!
# params['layer.weight'].shape -> (5, 1, 1)

#### Define the training step in functional Pytorch

In [None]:
# 1. Define Pure Function
# We need a function: f(params, x) -> loss
def compute_loss(params, buffers, x, y):
    # functional_call(model, params, input) temporarily "inserts" params
    pred = tfunc.functional_call(models[0], (params, buffers), (x,))
    return torch.mean((pred - y) ** 2)

# 2. Vectorize the Function
# PyTorch uses 'in_dims', not 'in_axes'
# Params (0), Buffers (0), X (None/Broadcast), Y (None/Broadcast)
vmap_loss = torch.vmap(compute_loss, in_dims=(0, 0, None, None))

In [None]:
# Simple SGD logic
learning_rate = 0.01

for i in range(100):
    # 1. Create a Vectorized Gradient Function
    # Why argnums=0? Because 'compute_loss' takes parameters in the order:
    # (params, buffers, x, y), so to get gradients w.r.t. 'params' we set argnums=0.
    grad_fn = torch.func.grad(compute_loss, argnums=0)
    
    # Vectorize the gradient computation
    # Again: Use in_dims=(0, 0, None, None)
    vectorized_grad_fn = torch.vmap(grad_fn, in_dims=(0, 0, None, None))
    
    # 2. Compute Gradients
    grads = vectorized_grad_fn(params, buffers, x_torch, y_torch)
    
    # 3. Manual Update (SGD)
    # Update tensor in-place.
    # Note: params is a dictionary of stacked tensors.
    with torch.no_grad():
        for key in params:
            params[key] -= learning_rate * grads[key]

print("PyTorch Ensemble trained.")
# Verify shapes
print(f"Ensemble Param Shape: {params['layer.weight'].shape}") 
# Expected: torch.Size([5, 1, 1])

# âš¡ Cheat Sheet: JAX vs. PyTorch Functional API

While PyTorch 2.0+ has introduced functional transforms (`torch.func`) that mimic JAX, the API syntax differs slightly.

| Concept | **JAX** (`jax`) | **PyTorch** (`torch.func`) |
| :--- | :--- | :--- |
| **Vectorization** | `jax.vmap(func)` | `torch.vmap(func)` |
| **Input Axes Argument** | `in_axes=(0, None)` | `in_dims=(0, None)` |
| **Gradient Transform** | `jax.grad(func)` | `torch.func.grad(func)` |
| **Value & Gradient** | `jax.value_and_grad(func)` | No direct equivalent (call both or use `grad_and_value` if avail) |
| **Extracting Params** | `params = model.init(...)` | `params, buff = torch.func.stack_module_state(models)` |
| **Running the Model** | `model.apply(params, x)` | `torch.func.functional_call(model, params, x)` |
| **Randomness** | Explicit `key` passed to function | Implicit global state (or `torch.func` specific RNG handling) |

**Key Takeaway:**
* **JAX** functions are pure by default, so `vmap` and `grad` are native.
* **PyTorch** models are objects, so we use `functional_call` and `stack_module_state` to "force" them to behave like pure functions for `vmap`.