Notebook overview:
- Demonstrates 1D convolution using JAX.
- Shows basic loop-based convolution for a single vector.
- Extends to batched inputs with manual batching (loop over batch).
- Implements manual vectorization across batch dimension.
- Compares manual approaches to automatic batching with `jax.vmap`.
- Uses JAX's NumPy-compatible API for array operations.



https://docs.jax.dev/en/latest/automatic-vectorization.html

#### Manual convolution

In [None]:
import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

#### Batched convolution in a loop

In [None]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)

#### Automatic batching with `jax.vmap`

In [None]:
auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws)

#### JIT compilation

In [None]:
jitted_batch_convolve = jax.jit(auto_batch_convolve)

jitted_batch_convolve(xs, ws)