**S01P01_quickstart.ipynb**

Arz

2024 APR 03 (WED)

reference:
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap

# multiplying matrices

In [2]:
key = jax.random.key(0)
x = jax.random.normal(key, (5,))

In [3]:
print(x)

[ 0.18784384 -1.2833426  -0.2710917   1.2490593   0.24447003]


In [4]:
x = jax.random.normal(key, (3, 2))

In [5]:
print(x)

[[ 0.18784384 -1.2833426 ]
 [ 0.6494181   1.2490593 ]
 [ 0.24447003 -0.11744965]]


## test on JAX array ##

In [6]:
size = 3000
x = jax.random.normal(key, (size, size), dtype=jnp.float32)

%timeit jnp.dot(x, x.T).block_until_ready()

13 ms ± 152 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## test on NumPy array ##

In [7]:
import numpy as np

In [8]:
size = 3000
x = np.random.normal(size=(size, size)).astype(np.float32)

%timeit jnp.dot(x, x.T).block_until_ready()

46 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


slower than JAX array because of data transfers to the GPU.

one can use JAX's device_put.

In [9]:
from jax import device_put

In [10]:
size = 3000
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)

%timeit jnp.dot(x, x.T).block_until_ready()

13.9 ms ± 285 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# program transformations

## jit()

In [11]:
def selu(x, alpha=1.07, lamda=1.05):
    return lamda*jnp.where(x > 0, x, alpha*jnp.exp(x) - alpha)

In [12]:
x = jax.random.normal(key, (1000000,))

### without jit()

In [13]:
%timeit selu(x).block_until_ready()

1.6 ms ± 77.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### with jit()

In [14]:
selu_jit = jit(selu)

%timeit selu_jit(x).block_until_ready()

363 µs ± 60.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## grad()

In [15]:
def sum_logistic(x):
    return jnp.sum(1.0/(1.0 + jnp.exp(-x)))

In [16]:
def sum_sin(x):
    return jnp.sum(jnp.sin(x))

In [17]:
function = sum_sin

In [18]:
x = jnp.arange(jnp.pi)

### using grad()

In [19]:
derivative_function = grad(function)

In [20]:
print(derivative_function(x))

[ 1.         0.5403023 -0.4161468 -0.9899925]


### using manual finite difference

In [21]:
print(jnp.eye(len(x)))

[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]


In [22]:
def finite_difference_central_ord1(f, x):
    h = 1E-3
    return jnp.array([(f(x + h*v) - f(x - h*v))/(2*h)
                     for v in jnp.eye(len(x))])

In [23]:
print(finite_difference_central_ord1(function, x))

[ 1.0000466   0.54031605 -0.4161596  -0.9899139 ]


### composing grad() and jit()

In [24]:
y = grad(jit(grad(function)))(jnp.pi)

print(y)  # case for sum_sin: sin->cos->sin, we get sin(pi) == 0

8.742278e-08


In [25]:
y = grad(jit(grad(jit(grad(function)))))(jnp.pi)

print(y)  # case for sum_sin: sin->cos->-sin->-cos, we get cos(pi) == 1

1.0


### advanced

- jax.vjp(): reverse-mode vector-Jacobian products
- jax.jvp(): forward-mode Jacobian-vector products

In [26]:
# def hessian_(f):
#     return jit(jax.jvp(jax.vjp(f)))  # seems wrong

In [27]:
from jax import jacfwd, jacrev

In [28]:
print(jit(jacfwd(function))(x))
print(jit(jacrev(function))(x))

[ 1.         0.5403023 -0.4161468 -0.9899925]
[ 1.         0.5403023 -0.4161468 -0.9899925]


In [29]:
def hessian(f):
    return jit(jacfwd(jacrev(f)))

In [30]:
y = hessian(function)(x)

print(y)

[[-0.         -0.         -0.         -0.        ]
 [-0.         -0.841471   -0.         -0.        ]
 [-0.         -0.         -0.90929747 -0.        ]
 [-0.         -0.         -0.         -0.14112   ]]


## vmap()

In [31]:
v_batch = jax.random.normal(key, (3, 2))
print(v_batch)

for v in v_batch:
    print(v)

[[ 0.18784384 -1.2833426 ]
 [ 0.6494181   1.2490593 ]
 [ 0.24447003 -0.11744965]]
[ 0.18784384 -1.2833426 ]
[0.6494181 1.2490593]
[ 0.24447003 -0.11744965]


In [32]:
def left_A_multiply(v):
    return jnp.dot(A, v)

In [33]:
A = jax.random.normal(key, (150, 100))

In [34]:
v_batch = jax.random.normal(key, (10, 100))

### naive multiplication

for loop

In [35]:
def naive_batch_left_A_multiply(v_batch):
    return jnp.stack([left_A_multiply(v) for v in v_batch])

In [36]:
%timeit naive_batch_left_A_multiply(v_batch).block_until_ready()

4.01 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [37]:
@jit
def naive_batch_left_A_multiply(v_batch):
    return jnp.stack([left_A_multiply(v) for v in v_batch])

In [38]:
%timeit naive_batch_left_A_multiply(v_batch).block_until_ready()

314 µs ± 29 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### multiplication using linear algebra

matrix multiplication

In [39]:
def mathy_batch_left_A_multiply(v_batch):
    return jnp.dot(v_batch, A.T)

In [40]:
%timeit mathy_batch_left_A_multiply(v_batch).block_until_ready()

490 µs ± 90.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [41]:
@jit
def mathy_batch_left_A_multiply(v_batch):
    return jnp.dot(v_batch, A.T)

In [42]:
%timeit mathy_batch_left_A_multiply(v_batch).block_until_ready()

96.6 µs ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### multiplication using vectorizing map

JAX computational

In [43]:
def vmap_batch_left_A_multiply(v_batch):
    return vmap(left_A_multiply)(v_batch)

In [44]:
%timeit vmap_batch_left_A_multiply(v_batch).block_until_ready()

976 µs ± 412 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [45]:
@jit
def vmap_batch_left_A_multiply(v_batch):
    return vmap(left_A_multiply)(v_batch)

In [46]:
%timeit vmap_batch_left_A_multiply(v_batch).block_until_ready()

195 µs ± 35.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
