# How to use Jax tools to accelerate your program?

Jax provides many powerful tools for user to optimize their programs by Parallelize, JIT, Multi-thread on CPU, GPU usage & TPU usage. I will introduce some simple cases to make you understand how to use it.  

## # JIT: parallel for matrix operations  

In [1]:
import jax.numpy as jnp 
from jax import random, jit

In [2]:
key = random.PRNGKey(42)

X = random.normal(key, (1000000, 200))
W = random.normal(key, (200, 100))

def sum(x, w):
    return jnp.sum(x @ w, axis=1)

In [3]:
import time

s = time.time()
res = sum(X, W)
print(f'time without jit: {time.time() - s} s')

sum = jit(sum)  # JIT

s = time.time()
res_jit_1 = sum(X, W)
print(f'time with jit No.1: {time.time() - s} s')

s = time.time()
res_jit_2 = sum(X, W)
print(f'time with jit No.2: {time.time() - s} s')

s = time.time()
res_jit_3 = sum(X, W)
print(f'time with jit No.3: {time.time() - s} s')

assert (res == res_jit_1).all() and (res == res_jit_2).all() and (res == res_jit_3).all()

time without jit: 0.08606386184692383 s
time with jit No.1: 0.12407374382019043 s
time with jit No.2: 0.00022339820861816406 s
time with jit No.3: 0.00013685226440429688 s


## # JIT: match small scale casese

This means the static parameters you used in funtion can be contained by a small countable set. Because Jax will not recompile if a static parameter is compiled once. For example, a function `mult(X: jnp.ndarray, Y: jnp.ndarray, name: str)`, where `name` is a static parameter. We call this function in the order of: 

|   name   | same name's Nth call|
|  :----:  |  :----:             |
|1|1|
|1|2|
|2|1|
|2|2|
|1|3|
|3|1|
|3|2|
|2|3|
|3|3|

- condition 1: do not recompile
    - (1, 1) == (1, 3) >> (1, 2)
    - (2, 1) == (2, 3) >> (2, 2)
    - (3, 1) == (3, 3) >> (3, 2)
- condition 2: do recompile
    - (1, 1) >> (1, 3) == (1, 2)
    - (2, 1) >> (2, 3) == (2, 2)
    - (3, 1) >> (3, 3) == (3, 2)

In [4]:
def mult(X, Y, name='1'):
    if name == '1':
        return 1 * X @ Y
    elif name == '2':
        return 2 * X @ Y
    elif name == '3':
        return 3 * X @ Y
    else:
        raise ValueError('please input one of "1", "2", "3"')

In [5]:
import time

def shiyan2(name, nth_call):
    s = time.time()
    _ = mult(X, W, f'{name}')
    print(f'({name}, {nth_call}): {time.time() - s} s')
   

s = time.time()
res1 = mult(X, W, '1')
print(f'1: time without jit: {time.time() - s} s')
res2 = mult(X, W, '2')
print(f'2: time without jit: {time.time() - s} s')
res3 = mult(X, W, '3')
print(f'3: time without jit: {time.time() - s} s')

mult = jit(mult, static_argnames='name')  # JIT

shiyan2(1, 1)
shiyan2(1, 2)
shiyan2(2, 1)
shiyan2(2, 2)
shiyan2(1, 3)
shiyan2(3, 1)
shiyan2(3, 2)
shiyan2(2, 3)
shiyan2(3, 3)

1: time without jit: 0.012926101684570312 s
2: time without jit: 0.013566732406616211 s
3: time without jit: 0.013772249221801758 s
(1, 1): 0.012732982635498047 s
(1, 2): 2.8848648071289062e-05 s
(2, 1): 0.02155590057373047 s
(2, 2): 2.574920654296875e-05 s
(1, 3): 1.7642974853515625e-05 s
(3, 1): 0.018445253372192383 s
(3, 2): 3.409385681152344e-05 s
(2, 3): 1.1920928955078125e-05 s
(3, 3): 6.198883056640625e-06 s


In [6]:
def func1():

    def func2():
        pass

    return func2

for _ in range(5):
    print(f'f1: {func1}, f2: {func1()}')

f1: <function func1 at 0x70a8500979c0>, f2: <function func1.<locals>.func2 at 0x70a850097d80>
f1: <function func1 at 0x70a8500979c0>, f2: <function func1.<locals>.func2 at 0x70a850097e20>
f1: <function func1 at 0x70a8500979c0>, f2: <function func1.<locals>.func2 at 0x70a850097ce0>
f1: <function func1 at 0x70a8500979c0>, f2: <function func1.<locals>.func2 at 0x70a850097ce0>
f1: <function func1 at 0x70a8500979c0>, f2: <function func1.<locals>.func2 at 0x70a850097ce0>


## # jax.vmap: vectorize operations

### 1. Batch Level

Do operation along shape '0' (batch size dim). This case can be represented as: 

Given:   

$$
f: (shape_1), (shape_2), \dots \rightarrow (shape_3)
$$  

What to do:    

$$
f_{new}: (B, shape_1), (B, shape_2), \dots \rightarrow (B, shape_3)
$$

where $B$ is batch size, $n$ is number of parameters for $f$.

In [7]:
from jax import vmap

In [8]:
X = jnp.ones((100, 4, 3))
Y = jnp.ones((100, 3, 2))

def mult4loop(x, y):
    res = jnp.zeros((x.shape[0], x.shape[1], y.shape[2]))
    for i, (yy, xx) in enumerate(zip(x, y)):
        res = res.at[i].set(
            yy @ xx
        )

    return res


def mult(x, y):
    return x @ y

s = time.time()
res = mult4loop(X, Y)
print(f'time without vmap: {time.time() - s}')

s = time.time()
res1 = vmap(mult, in_axes=(0, 0))(X, Y)  # use vmap
print(f'time with vmap, without JIT: {time.time() - s}')

jit_vmap_mult = jit(vmap(mult, in_axes=(0, 0)))  # use JIT

s = time.time()
res21 = jit_vmap_mult(X, Y)
print(f'time with vmap, with JIT No.1: {time.time() - s}')

s = time.time()
res22 = jit_vmap_mult(X, Y)
print(f'time with vmap, with JIT No.2: {time.time() - s}')

assert (res == res1).all() and (res == res21).all() and (res == res22).all()

time without vmap: 1.746204137802124
time with vmap, without JIT: 0.023752450942993164
time with vmap, with JIT No.1: 0.01141047477722168
time with vmap, with JIT No.2: 6.389617919921875e-05


### 2. Outer Product

This problem can be represented as:  

Given:   

$$
f: (shape_1), (shape_2), \dots \rightarrow (shape_3)
$$

What to do:  

$$
f_{new}: (B_1, shape_1), (B_2, shape_2), \dots \rightarrow (B_1, B_2, \dots, shape_3)
$$

In [9]:
X = jnp.ones((27, 5, 4))
Y = jnp.ones((28, 4, 3))
Z = jnp.ones((29, 3, 2))
# X, Y, Z -> (27, 28, 29, 5, 2)

def mult4loop(x, y, z):
    res = jnp.zeros((x.shape[0], y.shape[0], z.shape[0], x.shape[1], z.shape[2]))
    for ix, xx in enumerate(x):
        for iy, yy in enumerate(y):
            for iz, zz in enumerate(z):
                res = res.at[ix, iy, iz].set(
                    xx @ yy @ zz
                )

    return res

def mult(x, y, z):
    return x @ y @ z

mult_vmap = vmap(
                vmap(
                    vmap(
                        mult, in_axes=(None, None, 0)
                    ), in_axes=(None, 0, None),
                ), in_axes=(0, None, None)
            )

jit_vmap_mult = jit(mult_vmap)

s = time.time()
res = mult4loop(X, Y, Z)
print(f'time without vmap: {time.time() - s}')

s = time.time()
res1 = mult_vmap(X, Y, Z)  # use vmap
print(f'time with vmap, without JIT: {time.time() - s}')

s = time.time()
res21 = jit_vmap_mult(X, Y, Z)
print(f'time with vmap, with JIT No.1: {time.time() - s}')

s = time.time()
res22 = jit_vmap_mult(X, Y, Z)
print(f'time with vmap, with JIT No.2: {time.time() - s}')

assert (res == res1).all() and (res == res21).all() and (res == res22).all()

time without vmap: 9.08353042602539
time with vmap, without JIT: 0.038193702697753906
time with vmap, with JIT No.1: 0.01571202278137207
time with vmap, with JIT No.2: 8.0108642578125e-05


### 3. Mixture

This is about how to convert a loop into vmap function. I think you can learn something from last two cases, that is if you want to make two loop run parallely, you can set `in_axis=(0, 0)`, and if you want to make them have the order of loop in and out, you can use `in_axis=(None, 0)` & `in_axis=(0, None)`, where former is inner loop, latter is outra loop. 

In [10]:
X = jnp.ones((7, 5, 4))
Y = jnp.ones((8, 4, 3))
Z = jnp.ones((10, 3, 2))
R = jnp.ones((10, 2, 2))
S = jnp.ones((12, 2, 2))
T = jnp.ones((11, 2, 2))
U = jnp.ones((11, 2, 2))
V = jnp.ones((11, 2, 2))

def mult4loop(x, y, z, r, s, t, u, v):
    res = jnp.zeros((
        x.shape[0],  # x 
        y.shape[0],  # y
        z.shape[0],  # z, r
        s.shape[0],  # s
        t.shape[0],  # t, u, v
        x.shape[1], v.shape[2]))
    
    for ix, xx in enumerate(x):
        for iy, yy in enumerate(y):
            for iz, (zz, rr) in enumerate(zip(z, r)):
                for iS, ss in enumerate(s):
                    for it, (tt, uu, vv) in enumerate(zip(t, u, v)):
                        res = res.at[ix, iy, iz, iS, it].set(
                            xx @ yy @ zz @ rr @ ss @ tt @ uu @ vv
                        )

    return res

def mult(x, y, z, r, s, t, u, v):
    return x @ y @ z @ r @ s @ t @ u @ v

mult_vmap = vmap(
                vmap(
                    vmap(
                        vmap(
                            vmap(
                                mult, in_axes=(None, None, None, None, None,    0,    0,    0)
                            ), in_axes=       (None, None, None, None,    0, None, None, None)
                        ), in_axes=           (None, None,    0,    0, None, None, None, None)
                    ), in_axes=               (None,    0, None, None, None, None, None, None)
                ), in_axes=                   (   0, None, None, None, None, None, None, None)
            )

jit_vmap_mult = jit(mult_vmap)

s = time.time()
res = mult4loop(X, Y, Z, R, S, T, U, V)
print(f'time without vmap: {time.time() - s}')

s = time.time()
res1 = mult_vmap(X, Y, Z, R, S, T, U, V)  # use vmap
print(f'time with vmap, without JIT: {time.time() - s}')

s = time.time()
res21 = jit_vmap_mult(X, Y, Z, R, S, T, U, V)
print(f'time with vmap, with JIT No.1: {time.time() - s}')

s = time.time()
res22 = jit_vmap_mult(X, Y, Z, R, S, T, U, V)
print(f'time with vmap, with JIT No.2: {time.time() - s}')

assert (res == res1).all() and (res == res21).all() and (res == res22).all()

time without vmap: 55.275911808013916
time with vmap, without JIT: 0.14800119400024414
time with vmap, with JIT No.1: 0.0664219856262207
time with vmap, with JIT No.2: 8.153915405273438e-05


## # jax.tree: powerful dict tool

This tool is very useful while initialing parameters in this framework, see [Initer](plugins/minitorch/initer.py).

In [11]:
from jax import tree

In [12]:
X = {
    'fc:1': {
        'w': 1,
        'b': 1,
    },
    'fc:2': {
        'w': 1,
        'b': 1,
    }
}

Y = {
    'fc:1': {
        'w': 2,
        'b': 2,
    },
    'fc:2': {
        'w': 2,
        'b': 2,
    }
}

res = tree.map(lambda x, y: 2*x + 3*y, X, Y)
import json
print(json.dumps(res, indent=4))

{
    "fc:1": {
        "b": 8,
        "w": 8
    },
    "fc:2": {
        "b": 8,
        "w": 8
    }
}


In [13]:
X = {
    'fc:1': {
        'w': 1,
        'b': 1,
    },
    'fc:2': {
        'w': 1,
        'b': 1,
    }
}

Y = {
    'fc:1': {
        'w': 2,
        'b': 2,
    },
    'fc:2': {
        'w': 2,
        'b': 2,
    }
}

def swap(x, y):
    return y, x

res = tree.map(swap, X, Y)
print(json.dumps(res, indent=4))

{
    "fc:1": {
        "b": [
            2,
            1
        ],
        "w": [
            2,
            1
        ]
    },
    "fc:2": {
        "b": [
            2,
            1
        ],
        "w": [
            2,
            1
        ]
    }
}


if you want to get two tree form last case, you should convert inner without use pytree iterm & then decode it: 

In [14]:
import jax.numpy as jnp

def swap(x, y):
    return jnp.array([y, x])

res = tree.map(swap, X, Y)
print(res)

res1 = tree.map(lambda x: x[0], res)
res2 = tree.map(lambda x: x[1], res)

print(f'res1 is: {res1}')
print(f'res2 is: {res2}')

{'fc:1': {'b': Array([2, 1], dtype=int32), 'w': Array([2, 1], dtype=int32)}, 'fc:2': {'b': Array([2, 1], dtype=int32), 'w': Array([2, 1], dtype=int32)}}
res1 is: {'fc:1': {'b': Array(2, dtype=int32), 'w': Array(2, dtype=int32)}, 'fc:2': {'b': Array(2, dtype=int32), 'w': Array(2, dtype=int32)}}
res2 is: {'fc:1': {'b': Array(1, dtype=int32), 'w': Array(1, dtype=int32)}, 'fc:2': {'b': Array(1, dtype=int32), 'w': Array(1, dtype=int32)}}


## # jax.lax.scan: Iter Functool

see [knn_on_cifar10](https://github.com/HugoPhi/jaxdls/blob/main/knn_cifar10.ipynb) & [lstm cell](https://github.com/HugoPhi/jaxdls/blob/main/plugins/minitorch/nn/JaxOptimized/rnncell.py).