这里安装一下JAX，我使用的指令是：`$ conda install jax cuda-nvcc -c conda-forge -c nvidia`

In [3]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
#%timeit 
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]
11 ms ± 443 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:


def square(x):
    return jnp.square(x)

x = jnp.array([1.0, 2.0, 3.0])

# 使用 jax.jacfwd 计算函数的前向雅可比矩阵（梯度）
grad_fwd = jax.jacfwd(square)
grad = grad_fwd(x)
print("Forward gradient:", grad)

# 使用 jax.jacrev 计算函数的反向雅可比矩阵（梯度）
grad_rev = jax.jacrev(square)
grad = grad_rev(x)
print("Reverse gradient:", grad)


Forward gradient: [[2. 0. 0.]
 [0. 4. 0.]
 [0. 0. 6.]]
Reverse gradient: [[2. 0. 0.]
 [0. 4. 0.]
 [0. 0. 6.]]
