**S01P06_non_array_inputs.ipynb**

Arz

2024 APR 05 (FRI)

reference:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# non-array inputs

## NumPy: permitted

In [4]:
np.sum([1, 2, 3])

6

## JAX: not permitted

passing lists or tuples to traced functions can lead to silent performance degradation that might otherwise be difficult to detect.

In [5]:
# jnp.sum([1, 2 ,3])  # forbidden: throws error

a permissive version of this function can be designed but you will see that the performance is actually degraded. 

In [6]:
def permissive_sum(x):
    return jnp.sum(jnp.array(x))

In [7]:
permissive_sum([1, 2, 3])

Array(6, dtype=int32)

result is correct, but,

In [8]:
from jax import make_jaxpr

In [9]:
make_jaxpr(permissive_sum)([1, 2, 3])

{ lambda ; a:i32[] b:i32[] c:i32[]. let
    d:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    e:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] c
    g:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] d
    h:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] e
    i:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] f
    j:i32[3] = concatenate[dimension=0] g h i
    k:i32[] = reduce_sum[axes=(0,)] j
  in (k,) }

each element is being traced, leading to a performance degradation.

### solution: just directly pass JAX array

In [10]:
jnp.sum(jnp.array([1, 2, 3]))

Array(6, dtype=int32)

In [11]:
make_jaxpr(jnp.sum)(jnp.array([1, 2, 3]))

{ lambda ; a:i32[3]. let b:i32[] = reduce_sum[axes=(0,)] a in (b,) }