In [1]:
import jax.numpy as jnp


In [2]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


In [3]:
from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

7.33 ms ± 504 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()

1.99 ms ± 415 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
