# <center> JAX

In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import jax

In [3]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [11]:
size = 4096
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

1.16 s ± 76.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%timeit x@x.T

355 ms ± 5.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%timeit jnp.matmul(x, x.T)

356 ms ± 7.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

368 ms ± 6.14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
x = np.random.normal(size=(size, size)).astype(np.float32)
x = jax.device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

346 ms ± 7.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## JIT

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.26 ms ± 91.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

428 µs ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## GRAD

In [None]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [None]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

-0.035325598


# VMAP():
basically to do for loops with different values for an input

In [None]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [None]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
21.8 µs ± 192 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
out = vmap_batched_apply_matrix(batched_x)

In [None]:
jnp.allclose(apply_matrix(batched_x[0]), out[0], atol=1e-5)

DeviceArray(True, dtype=bool)

In [3]:
class A:
    def __init__(self):
        print('Initializing: class A')

    def sub_method(self, b):
        print('Printing from class A:', b)


class B(A):
    def __init__(self):
        print('Initializing: class B')
        super().__init__()

    def sub_method(self, b):
        print('Printing from class B:', b)
        super().sub_method(b + 1)


class C(B):
    def __init__(self):
        super().__init__()
        print('Initializing: class C')

    def sub_method(self, b):
        print('Printing from class C:', b)
        super().sub_method(b + 1)

In [4]:
c = C()

Initializing: class B
Initializing: class A
Initializing: class C
