<a href="https://colab.research.google.com/github/Hadrien-Cornier/cool-nn-stuff/blob/main/jax_vmap.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## What does axes do in vmap ?

vmap is an accelerator for jax functions that automatically distributes execution with very small overhead


In [5]:
import jax
import jax.numpy as jnp
import time

# Define a simple function that adds two numbers
def add(x, y):
    return x + y

# Function to generate arrays of a given size
def generate_arrays(size):
    x = jnp.arange(size)
    y = jnp.arange(size, size * 2)
    return x, y

# Function to measure execution time
def measure_time(func, *args):
    start = time.time()
    result = func(*args)
    end = time.time()
    return result, end - start

# Main function to compare execution times
def compare_execution_times(size):
    x, y = generate_arrays(size)

    # Without vmap
    def add_without_vmap(x, y):
        return jnp.array([add(xi, yi) for xi, yi in zip(x, y)])

    result_no_vmap, time_no_vmap = measure_time(add_without_vmap, x, y)
    print(f"{result_no_vmap=}")
    print(f"{time_no_vmap=} seconds")

    # With vmap
    vmap_add = jax.vmap(add, in_axes=(0, 0))
    result_vmap, time_vmap = measure_time(vmap_add, x, y)
    print(f"{result_vmap=}")
    print(f"{time_vmap=} seconds")

# Example usage
size = 10000
compare_execution_times(size)

result_no_vmap=Array([10000, 10002, 10004, ..., 29994, 29996, 29998], dtype=int32)
time_no_vmap=1.3610410690307617 seconds
result_vmap=Array([10000, 10002, 10004, ..., 29994, 29996, 29998], dtype=int32)
time_vmap=0.002285480499267578 seconds



vmap has a parameter called "axes"


### Gpt-4O
In JAX, the `in_axes` parameter in the `vmap` function specifies which axes of the input arguments should be mapped over. This allows you to vectorize a function by applying it across slices of input arrays along specified axes. Here's how it works:

- **`in_axes` Parameter**: This parameter determines which axis of each input argument should be mapped over. It can be set to an integer, a tuple, or a container of integers/tuples to specify the mapping for multiple inputs.

- **Usage**: If `in_axes` is set to an integer, it indicates the axis of the input array to map over. If set to `None`, the input is broadcasted (i.e., not mapped). For example, `in_axes=(0, None)` means the first argument will be mapped over its first axis, while the second argument will not be mapped.

- **Example**: Consider a function `f(x, y)` that you want to apply to a batch of inputs. Using `vmap(f, in_axes=(0, 0))` will map the function over the first axis of both `x` and `y`, applying `f` to each pair of slices `(x[i], y[i])`.

This parameter is crucial for enabling parallel computation over batches of data, making it highly efficient for operations on large datasets, especially when using accelerators like GPUs and TPUs[2][4][5].

Citations:
[1] https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html
[2] https://flax.readthedocs.io/en/v0.5.3/_autosummary/flax.linen.vmap.html
[3] https://pyimagesearch.com/2023/02/27/learning-jax-in-2023-part-2-jaxs-power-tools-grad-jit-vmap-and-pmap/
[4] https://stackoverflow.com/questions/70564419/in-axes-keyword-in-jaxs-vmap
[5] https://discuss.pennylane.ai/t/facing-issues-with-jax-jitting-the-optimization-loop/4274
[6] https://github.com/google/jax/issues/7465
[7] https://github.com/google/jax/issues/15041
[8] https://ppl-ai-file-upload.s3.amazonaws.com/web/direct-files/12491979/c2f73992-0c38-45d2-8e57-4741083de218/product_of_betas-2-.ipynb

In [20]:
import jax
import jax.numpy as jnp

# Define a simple function that adds two numbers
def add(x, y):
    return x + y

# Example 1: Vectorizing addition over a single axis
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])

print("Example 1: Vectorizing addition over a single axis")
print("Input arrays:")
print("x:", x)
print("y:", y)

# Vectorize the add function to apply it across the first axis of x and y
vmap_add = jax.vmap(add, in_axes=(0, 0))
result1 = vmap_add(x, y)
print("Called function: jax.vmap(add, in_axes=(0, 0))(x, y)")
print("in_axes: (0, 0) - Mapping over first axis of both x and y")
print("Result 1:", result1)  # Output: [5, 7, 9]
print()

# Example 2: Vectorizing with broadcasting
z = 10

print("Example 2: Vectorizing with broadcasting")
print("Input array:")
print("x:", x)
print("Broadcast scalar z:", z)

# Vectorize the add function to map over the first axis of x and broadcast z
vmap_add_broadcast = jax.vmap(add, in_axes=(0, None))
result2 = vmap_add_broadcast(x, z)
print("Called function: jax.vmap(add, in_axes=(0, None))(x, z)")
print("in_axes: (0, None) - Mapping over first axis of x and broadcasting z")
print("Result 2:", result2)  # Output: [11, 12, 13]
print()

# Example 3: Multi-dimensional mapping
a = jnp.array([[1, 2], [3, 4]])
b = jnp.array([[5, 6], [7, 8]])

print("Example 3: Multi-dimensional mapping")
print("Input arrays:")
print("a:", a)
print("b:", b)

# Vectorize the add function to map over both dimensions
vmap_add_multi = jax.vmap(add, in_axes=(1, 1))
result3 = vmap_add_multi(a, b)
print("Called function: jax.vmap(add, in_axes=(1, 1))(a, b)")
print("in_axes: (1, 1) - Outer vmap maps over second axis, inner vmap maps over first axis")
print("Result 3:", result3)  # Output: [[ 6,  8], [10, 12]]

Example 1: Vectorizing addition over a single axis
Input arrays:
x: [1 2 3]
y: [4 5 6]
Called function: jax.vmap(add, in_axes=(0, 0))(x, y)
in_axes: (0, 0) - Mapping over first axis of both x and y
Result 1: [5 7 9]

Example 2: Vectorizing with broadcasting
Input array:
x: [1 2 3]
Broadcast scalar z: 10
Called function: jax.vmap(add, in_axes=(0, None))(x, z)
in_axes: (0, None) - Mapping over first axis of x and broadcasting z
Result 2: [11 12 13]

Example 3: Multi-dimensional mapping
Input arrays:
a: [[1 2]
 [3 4]]
b: [[5 6]
 [7 8]]
Called function: jax.vmap(add, in_axes=(1, 1))(a, b)
in_axes: (1, 1) - Outer vmap maps over second axis, inner vmap maps over first axis
Result 3: [[ 6 10]
 [ 8 12]]


In [17]:
print(f"{jax.vmap(add, in_axes=(0, 0))(a, b) =}")

jax.vmap(add, in_axes=(0, 0))(a, b) =Array([[ 6,  8],
       [10, 12]], dtype=int32)


In [18]:
print(f"{jax.vmap(add, in_axes=(1, 1))(a, b) =}")

jax.vmap(add, in_axes=(1, 1))(a, b) =Array([[ 6, 10],
       [ 8, 12]], dtype=int32)
