In [1]:
import jax
import jax.numpy as jnp


## Jax 基本操作

```
pip install -U jax
pip install -U flax
```

jnp操作类似np，PyTree类似dict，不过加上了自动微分和 JIT 编译功能（不过@jit只能用于静态的func！,否则参考‘有状态计算’）

与TF基本通用(Deepmind)

###  1. 求导 & 求偏导 ----- jax.grad 获取func()对应的偏导数方程
dict数据求微分：
```
def fn(dictA):
    X = dictA['X']

grad(fn)(dictA)
```

https://jax.ac.cn/en/latest/automatic-differentiation.html#differentiating-with-respect-to-nested-lists-tuples-and-dicts

In [2]:
x = 2.0  
y = 3.0

def func(x, y):
    return x**2 + y**3

df_dx = jax.grad(func, argnums = 0)  ## 对 index=0 的参数求偏导即  d_func/d_x = 2x       Default: argnums = 0
df_dy = jax.grad(func, argnums = 1)  ## 对 index=1 的参数求偏导即  d_func/d_y = 3yy

func(x, y), df_dx(x,y) , df_dy(x,y)    ## input dtype 有要求，详见输入 2 而不是 2.0 后获得的报错

(31.0,
 Array(4., dtype=float32, weak_type=True),
 Array(27., dtype=float32, weak_type=True))

In [3]:
jax.grad(func, argnums = (0,1))(x,y)    ## 同时获得多个参数的偏导: argnums = (0,1,...)

(Array(4., dtype=float32, weak_type=True),
 Array(27., dtype=float32, weak_type=True))

In [4]:
jax.value_and_grad(func, argnums = (0,1))(x,y)     ## 同时获得 value_and_grad

(Array(31., dtype=float32, weak_type=True),
 (Array(4., dtype=float32, weak_type=True),
  Array(27., dtype=float32, weak_type=True)))

### 2. 批量输入 ----- jax.vmap 向量化
多设备并行：pmap

In [5]:
x = 2.0
y = jnp.array([3.0] * 5)

def func(x, y):
    return x**2 + y**3

vfunc = jax.vmap(func, in_axes=(None,0))           ## Default: 对输入的axes=0进行向量化处理；  如果输入为标量设置None

vfunc(x,y)

Array([31., 31., 31., 31., 31.], dtype=float32)

In [6]:
## 批量获得 Grad:   不能直接 jax.grad(vfunc, argnums = (0,1))(x,y) 因为 Gradient only defined for scalar-output functions

x = 2.0
y = jnp.array([3.0] * 5)

@jax.jit
def func(x, y):
    return x**2 + y**3

grad_f = jax.grad(func, argnums = (0,1))
vgrad_f = jax.vmap(grad_f, in_axes=(None,0))

vgrad_f(x,y)

(Array([4., 4., 4., 4., 4.], dtype=float32, weak_type=True),
 Array([27., 27., 27., 27., 27.], dtype=float32))

### 3. 获取伪随机数 --- 重用key会导致相同的输出，因此是pseudo
https://jax.ac.cn/en/latest/key-concepts.html#pseudorandom-numbers

In [7]:
key = jax.random.PRNGKey(0)
subkeyA, subkeyB, subkeyC = jax.random.split(key, 3)   ##  split PRNG table  --- 确定性的函数，见下方cell

for i in range(2):                                     ## 重用 key 会导致相同的输出!!!! 
    print(jax.random.normal(subkeyA))

1.1188384
1.1188384


In [8]:
for i in range(2):                                     ## 重用 key 会导致相同的输出!!!! 
    print(  jax.random.split(key, 3)  )

[[2467461003  428148500]
 [3186719485 3840466878]
 [2562233961 1946702221]]
[[2467461003  428148500]
 [3186719485 3840466878]
 [2562233961 1946702221]]


In [9]:
## 正确用法------- del old_key 后再重新 jax.random.split 获取 key

def get_pseudo_random(old_key):
    new_key, subkey = jax.random.split(old_key)
    val = jax.random.normal(subkey)
    return new_key, val

print('Test 1 ----------')
key = jax.random.PRNGKey(0)
for i in range(3):
    key, val = get_pseudo_random(key)
    print(val)


print('Test 2 ----------')                    ## 2次Test一模一样本质就是load随机数table，因此说是pseudo
key = jax.random.PRNGKey(0)
for i in range(3):
    key, val = get_pseudo_random(key)
    print(val)

Test 1 ----------
-1.2515389
-0.58665055
0.48648307
Test 2 ----------
-1.2515389
-0.58665055
0.48648307


### 4. 设备控制
https://jax.ac.cn/en/latest/sharded-computation.html#sharded-computation

**TODO：并行编程入门**  https://jax.ac.cn/en/latest/sharded-computation.html

In [10]:
device = jax.devices()[0]  ## or 'cpu' 'gpu'
device

CpuDevice(id=0)

In [11]:
x = jnp.array([3.0] * 5)
x_on_device = jax.device_put(x, device=device)   ## maybe put/get from a gpu
x_get_from_d = jax.device_get(x_on_device)

In [12]:
x.devices(), x.sharding

({CpuDevice(id=0)},
 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host))

In [13]:
with jax.default_device(device):
    pass                               ## action/data happen on that device

### 5. 仅当函数静态时可func用@Jit or jit(f, static_argnames='x')
不能使用的情况：计算依赖了input的dim，使用/更新全局参数（例如，optimizor就不可以加@jit）

此时可用 **lax**.cond/scan/fori_loop/...，关于什么情况下可微分：https://jax.ac.cn/en/latest/control-flow.html


In [14]:
def func(carry, x):
    carry = carry + x
    return carry, (x, carry) 

final_carry, (x_history, carry_history) = jax.lax.scan(func, 0, jnp.array([1,2,3,4]))   ## input of func:  carry=0, x = [..]
final_carry, (x_history, carry_history)

(Array(10, dtype=int32),
 (Array([1, 2, 3, 4], dtype=int32), Array([ 1,  3,  6, 10], dtype=int32)))

In [15]:
def func(x):
    return jax.lax.cond(x>0, lambda x: 1,  lambda x: -1, x)

func(9),print(), func(-9)




(Array(1, dtype=int32, weak_type=True),
 None,
 Array(-1, dtype=int32, weak_type=True))

### 6. 其它
1. 当使用 jax.jit() 转换函数时，print() 函数只会打印抽象追踪器值, 需使用jax.debug.print(x) 打印实际的运行时值
2. 可用 jax.debug.breakpoint() 假如断点
3. jax.jacfwd/jacrev [计算Jacobian/Hessian矩阵](https://jax.ac.cn/en/latest/advanced-autodiff.html#jacobians-and-hessians-using-jax-jacfwd-and-jax-jacrev)