In [1]:
# Dont run this unless u want to disable GPU

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


In [2]:
import jax
import jax.numpy as jnp

In [4]:
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [5]:
x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

2022-09-11 21:50:00.377117: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


2.38 ms ± 209 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Reason of using `block_until_ready` is here https://jax.readthedocs.io/en/latest/async_dispatch.html

In [6]:
@jax.jit
def selu_jit(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [7]:
y = jnp.arange(1000000)
%timeit selu_jit(y).block_until_ready()

488 µs ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [8]:
# It can be done like this too

selu_jit_new_way = jax.jit(selu)

In [9]:
z = jnp.arange(1000000)

%timeit selu_jit_new_way(z).block_until_ready()

496 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


You can see the performance gains here