In [1]:
import os
import jax 
import numpy as np
from jax import random
import jax.numpy as jnp
from jax import make_jaxpr
from functools import partial
from jax import grad, jit, vmap
import matplotlib.pyplot as plt
from jax.config import config
from scipy.integrate import odeint
config.update("jax_enable_x64", True) # enable 64-bit precision
config.update("jax_platform_name", "cpu") # "gpu"  uncomment this line to run on CPU
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' 

In [2]:
USE_TPU = False
if USE_TPU:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()

In [3]:
plt.rcParams['axes.labelsize'] = 16
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['axes.xmargin'] = 0


In [4]:
def whereis(x):
    print(x.device_buffer.device())

In [5]:
x = jnp.array([1, 2, 3])
whereis(x)

TFRT_CPU_0


### [VMAP](https://jiayiwu.me/blog/2021/04/05/learning-about-jax-axes-in-vmap.html)

- [YouTube](https://www.youtube.com/watch?v=W1vfBDFLm7Q)

In [6]:
a = jnp.array(([1, 3], [2, 4]))
b = jnp.array(([11, 3], [12, 40]))
a, b

(Array([[1, 3],
        [2, 4]], dtype=int64),
 Array([[11,  3],
        [12, 40]], dtype=int64))

In [7]:
jnp.add(a, b)

Array([[12,  6],
       [14, 44]], dtype=int64)

In [8]:
vmap(jnp.add, in_axes=(0,0), out_axes=0)(a, b)
# row of a + corresponding row of b, then stack the results by rows
# same as np.add(a, b)

Array([[12,  6],
       [14, 44]], dtype=int64)

In [9]:
vmap(jnp.add, in_axes=(0,0), out_axes=1)(a, b)
# row of a + corresponding row of b, then stack the results by columns
# same as np.add(a, b).T

Array([[12, 14],
       [ 6, 44]], dtype=int64)

In [10]:
vmap(jnp.add, in_axes=(0,1), out_axes=0)(a, b)
# row of a + column of b, then stack the results by rows
# same as np.add(a, b.T)

Array([[12, 15],
       [ 5, 44]], dtype=int64)

In [11]:
vmap(jnp.add, in_axes=(0,1), out_axes=1)(a, b)
# row of a + column of b, then stack the results by columns

Array([[12,  5],
       [15, 44]], dtype=int64)

In [12]:
vmap(jnp.add, in_axes=(1,0), out_axes=0)(a, b)
# column of a + row of b, then stack the results by rows
# same as np.add(a.T, b)

Array([[12,  5],
       [15, 44]], dtype=int64)

In [13]:
a = np.arange(6).reshape(2, 3) + 10
b = 2 
c = np.arange(6).reshape(2, 3) + 20

def f(a, b, c):
    return a + b + c

print(a, b, c, sep='\n')

v_f = jit(vmap(f, in_axes=(0, None, 0)))

make_jaxpr(v_f)(a, b, c)

[[10 11 12]
 [13 14 15]]
2
[[20 21 22]
 [23 24 25]]


{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i64[2,3][39m b[35m:i64[][39m c[35m:i64[2,3][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:i64[2,3][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; e[35m:i64[2,3][39m f[35m:i64[][39m g[35m:i64[2,3][39m. [34m[22m[1mlet
          [39m[22m[22mh[35m:i64[][39m = convert_element_type[new_dtype=int64 weak_type=False] f
          i[35m:i64[2,3][39m = add e h
          j[35m:i64[2,3][39m = add i g
        [34m[22m[1min [39m[22m[22m(j,) }
      name=f
    ] a b c
  [34m[22m[1min [39m[22m[22m(d,) }

In [14]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [15]:
from jax import pmap 

p_f = pmap(f, in_axes=(0, None, 0))
a = np.random.random(size=(8,3))
b = 4
c = np.random.random(size=(8,3))

p_f(a, b, c)

Array([[5.27834241, 5.42649338, 4.53975413],
       [5.12853116, 5.2062678 , 4.79251556],
       [5.04806234, 5.3456616 , 4.34332798],
       [5.22036097, 4.24449012, 5.0497359 ],
       [4.89529815, 5.02340113, 5.30752538],
       [5.11011971, 4.65142249, 4.76652823],
       [4.7130902 , 5.70628288, 4.85989081],
       [4.77093606, 5.79740456, 5.79388623]], dtype=float64)

- composition

In [16]:
def f(a, b, c):
    return a @ b + c

a = np.random.random(size=(2,3))
b = np.random.random(size=(3,4))
c = np.random.random(size=(2,4))

f(a, b, c).shape

(2, 4)

In [17]:
v_f = vmap(f)

a = np.random.random(size=(5, 2, 3))
b = np.random.random(size=(5, 3, 4))
c = np.random.random(size=(5, 2, 4))

v_f(a, b, c).shape

(5, 2, 4)

In [18]:
make_jaxpr(v_f)(a, b, c)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[5,2,3][39m b[35m:f64[5,3,4][39m c[35m:f64[5,2,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f64[5,2,4][39m = dot_general[
      dimension_numbers=(((2,), (1,)), ((0,), (0,)))
      precision=None
      preferred_element_type=None
    ] a b
    e[35m:f64[5,2,4][39m = add d c
  [34m[22m[1min [39m[22m[22m(e,) }

In [19]:
jv_f = jit(vmap(f))

a = np.random.random(size=(5, 2, 3))
b = np.random.random(size=(5, 3, 4))
c = np.random.random(size=(5, 2, 4))

print(v_f(a, b, c).shape)
make_jaxpr(jv_f)(a, b, c)

(5, 2, 4)


{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[5,2,3][39m b[35m:f64[5,3,4][39m c[35m:f64[5,2,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f64[5,2,4][39m = xla_call[
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; e[35m:f64[5,2,3][39m f[35m:f64[5,3,4][39m g[35m:f64[5,2,4][39m. [34m[22m[1mlet
          [39m[22m[22mh[35m:f64[5,2,4][39m = dot_general[
            dimension_numbers=(((2,), (1,)), ((0,), (0,)))
            precision=None
            preferred_element_type=None
          ] e f
          i[35m:f64[5,2,4][39m = add h g
        [34m[22m[1min [39m[22m[22m(i,) }
      name=f
    ] a b c
  [34m[22m[1min [39m[22m[22m(d,) }

In [20]:
pv_f = pmap(v_f)

a = np.random.random(size=(8, 5, 2, 3))
b = np.random.random(size=(8, 5, 3, 4))
c = np.random.random(size=(8, 5, 2, 4))

print(v_f(a, b, c).shape)
make_jaxpr(pv_f)(a, b, c)

(8, 5, 2, 4)


{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[8,5,2,3][39m b[35m:f64[8,5,3,4][39m c[35m:f64[8,5,2,4][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f64[8,5,2,4][39m = xla_pmap[
      axis_name=<axis 0x7fefb028adc0>
      axis_size=8
      backend=None
      call_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; e[35m:f64[5,2,3][39m f[35m:f64[5,3,4][39m g[35m:f64[5,2,4][39m. [34m[22m[1mlet
          [39m[22m[22mh[35m:f64[5,2,4][39m = dot_general[
            dimension_numbers=(((2,), (1,)), ((0,), (0,)))
            precision=None
            preferred_element_type=None
          ] e f
          i[35m:f64[5,2,4][39m = add h g
        [34m[22m[1min [39m[22m[22m(i,) }
      devices=None
      donated_invars=(False, False, False)
      global_arg_shapes=(None, None, None)
      global_axis_size=None
      in_axes=(0, 0, 0)
      name=f
      out_axes=(0,)
    ] a b c
  [34m[22m[1min [39m[22m[22m(d,) }

In [21]:
key = random.PRNGKey(0)
print(jax.random.uniform(key, (3,)))
print(jax.random.uniform(key, (3,)))

[0.57450053 0.09968609 0.74196595]
[0.57450053 0.09968609 0.74196595]


- RANDOM

In [22]:
key = jax.random.PRNGKey(124)
for _ in range(5):
    key, subkey = jax.random.split(key)
    print(jax.random.uniform(subkey))


0.6773641507104642
0.0482590299181076
0.42465834799117474
0.0919238766178283
0.7752573371832019
