In [6]:
import jax
import jax.numpy as jnp
import jax.lax

# Define input data
x = jnp.array(2.0)

# Define your different functions as a JAX-compatible tuple
functions_tuple = (
    lambda x: x ** 2,
    lambda x: jnp.sin(x),
    lambda x: jnp.log(x + 1)
)

# Use jax.lax.switch inside lax.map to apply the correct function
def apply_function_with_switch(i, x):
    return jax.lax.switch(i, functions_tuple, x)

# Use lax.map to apply functions in parallel over their indices
results = jax.lax.map(lambda i: apply_function_with_switch(i, x), jnp.arange(len(functions_tuple)))

# JIT the function
jit_apply_function_with_switch = jax.jit(apply_function_with_switch)

# Evaluate with JIT
print(f"Results of JIT-compiled lax.map with switch: {results}")

Results of JIT-compiled lax.map with switch: [4.        0.9092974 1.0986123]


In [10]:
import jax
import jax.numpy as jnp
import jax.lax
import timeit

# Define different functions
def func1(x):
    return x ** 2

def func2(x):
    return jnp.sin(x)

def func3(x):
    return jnp.log(x + 1)

# Functions in tuple form for JAX
functions_tuple = (
    lambda x: x ** 2,
    lambda x: jnp.sin(x),
    lambda x: jnp.log(x + 1)
)

# Input data
x = jnp.array(2.0)

# Manual loop in Python
def manual_loop(functions, x):
    results = []
    for f in functions:
        results.append(f(x))
    return results

# Use jax.lax.switch inside lax.map to apply the correct function
def apply_function_with_switch(i, x):
    return jax.lax.switch(i, functions_tuple, x)

# Use lax.map to apply functions in parallel over their indices
results = jax.lax.map(lambda i: apply_function_with_switch(i, x), jnp.arange(len(functions_tuple)))

# JIT the function
jit_apply_function_with_switch = jax.jit(apply_function_with_switch)

# Timeit setup
%timeit manual_loop([func1, func2, func3], x)
%timeit jax.lax.map(lambda i: jit_apply_function_with_switch(i, x), jnp.arange(len(functions_tuple)))

34.2 µs ± 207 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
17.6 ms ± 278 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
import jax
import jax.numpy as jnp
import jax.lax
import timeit

# Define different functions
def func1(x):
    return x ** 2

def func2(x):
    return jnp.sin(x)

def func3(x):
    return jnp.log(x + 1)

# Functions in tuple form for JAX
functions_tuple = (
    lambda x: x ** 2,
    lambda x: jnp.sin(x),
    lambda x: jnp.log(x + 1)
)

# Input data
x = jnp.array(2.0)

# Manual loop in Python
def manual_loop(functions, x):
    results = []
    for f in functions:
        results.append(f(x))
    return results

# JAX approach with lax.map and switch
def jax_map_switch(functions_tuple, x):
    def apply_function_with_switch(i, x):
        return jax.lax.switch(i, functions_tuple, x)
    return jax.lax.map(lambda i: apply_function_with_switch(i, x), jnp.arange(len(functions_tuple)))

# JIT the jax function
jit_jax_map_switch = jax.jit(jax_map_switch)

# Timeit setup
manual_loop_time = timeit.timeit(lambda: manual_loop([func1, func2, func3], x), number=1000)
jax_map_switch_time = timeit.timeit(lambda: jit_jax_map_switch(functions_tuple, x), number=1000)

# Print the results
print(f"Manual loop time: {manual_loop_time}")
print(f"JAX map+switch with JIT time: {jax_map_switch_time}")

TypeError: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute

In [12]:
import jax
import jax.numpy as jnp
import jax.lax
import timeit

# Define different functions
def func1(x):
    return x ** 2

def func2(x):
    return jnp.sin(x)

def func3(x):
    return jnp.log(x + 1)

# Input data
x = jnp.array(2.0)

# Manual loop in Python
def manual_loop(functions, x):
    results = []
    for f in functions:
        results.append(f(x))
    return results

# JAX approach with lax.map and switch
def jax_map_switch(x):
    # Define the functions inside the JAX context
    def apply_function_with_switch(i, x):
        return jax.lax.switch(i, [lambda x: x ** 2, lambda x: jnp.sin(x), lambda x: jnp.log(x + 1)], x)
    
    # Use lax.map to apply the function indices (0, 1, 2)
    return jax.lax.map(lambda i: apply_function_with_switch(i, x), jnp.arange(3))

# JIT the JAX function
jit_jax_map_switch = jax.jit(jax_map_switch)

# Timeit setup
manual_loop_time = timeit.timeit(lambda: manual_loop([func1, func2, func3], x), number=1000)
jax_map_switch_time = timeit.timeit(lambda: jit_jax_map_switch(x), number=1000)

# Print the results
print(f"Manual loop time: {manual_loop_time}")
print(f"JAX map+switch with JIT time: {jax_map_switch_time}")

Manual loop time: 0.07952899998053908
JAX map+switch with JIT time: 0.02338075003353879


In [18]:
import jax
import jax.numpy as jnp
import jax.lax
import timeit
from functools import partial
from jax import jit

# Define different functions
def func1(x):
    return x ** 2

def func2(x):
    return jnp.sin(x)

def func3(x):
    return jnp.log(x + 1)

# Input data
x = jnp.array(2.0)

# Manual loop in Python
def manual_loop(functions, x):
    results = []
    for f in functions:
        results.append(f(x))
    return results

# JAX approach with lax.map and switch
def jax_map_switch(x):
    # Define the functions inside the JAX context
    def apply_function_with_switch(i, x):
        return jax.lax.switch(i, [lambda x: x ** 2, lambda x: jnp.sin(x), lambda x: jnp.log(x + 1)], x)
    
    # Use lax.map to apply the function indices (0, 1, 2)
    return jax.lax.map(lambda i: apply_function_with_switch(i, x), jnp.arange(3))

# JIT the JAX function
jit_jax_map_switch = jax.jit(jax_map_switch)

# JIT the manual loop
@partial(jit, static_argnums=(0,))
def jit_manual_loop(functions, x):
    results = []
    for f in functions:
        results.append(f(x))
    return results

# Timeit setup for both JIT versions
jit_manual_loop_time = timeit.timeit(lambda: jit_manual_loop((func1, func2, func3), x), number=1000)
jax_map_switch_time = timeit.timeit(lambda: jit_jax_map_switch(x), number=1000)

# Print the results
print(f"Manual loop time (JIT): {jit_manual_loop_time}")
print(f"JAX map+switch with JIT time: {jax_map_switch_time}")

Manual loop time (JIT): 0.025044000009074807
JAX map+switch with JIT time: 0.05962416698457673
