In [1]:
!pip install jax

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting jax
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/fd/f2/9dbb75de3058acfd1600cf0839bcce7ea391148c9d2b4fa5f5666e66f09e/jax-0.4.30-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m588.0 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting jaxlib<=0.4.30,>=0.4.27 (from jax)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/f3/1d/2d417a1445d5e696bb44d564c7519d4a6761db4d3e31712620c510ed0127/jaxlib-0.4.30-cp39-cp39-macosx_11_0_arm64.whl (66.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:02[0mm
Installing collected packages: jaxlib, jax
Successfully installed jax-0.4.30 jaxlib-0.4.30

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mno

In [2]:
import jax.numpy as jnp

In [3]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(x)
print(selu(x))

[0. 1. 2. 3. 4.]
[0.        1.05      2.1       3.1499999 4.2      ]


# 🔪 Pure functions

In [4]:
import jax

def f(x):
    print(x)   # ❌ 副作用，JIT 时可能不执行或重复打印
    return x ** 2

jitted_f = jax.jit(f)(3)
print(jitted_f)

Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
9


In [5]:
import numpy as np
from jax import jit
from jax import lax
from jax import random
import jax

In [6]:
def impure_print_side_effect(x):
    print("Executing function")  # This is a side-effect
    return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
# 第二次调用时直接调用编译好的函数，第二次使用同种方式调用（传入参数类型相同）可能不会重复执行相同的部分
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [7]:
g = 0.
def impure_uses_globals(x):
    return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
# 传入新的参数类型会重新编译函数，读取新的全局参数
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


# Just-in-time compilation with jax.jit()

In [8]:
from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

753 µs ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
from jax import jit

# jit() 调用函数，并进行加速
selu_jit = jit(selu)
_ = selu_jit(x)  # 首次调用时编译
%timeit selu_jit(x).block_until_ready() # 第二次调用速度更快【直接使用编译后的函数】

231 µs ± 7.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# Taking derivatives with jax.grad()

In [10]:
from jax import grad # 自动求导

def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))


[0.25       0.19661197 0.10499357]


In [11]:
# verify result
def first_finite_differences(f, x, eps=1E-3): # 导数的定义
    return jnp.array(
        [(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))]
    )

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]


In [12]:
# grad() 和 jit() 函数可以混合使用
# 计算三阶导数
print(jit(grad(jit(grad(jit(grad(sum_logistic))))))(1.0))

-0.035325598


In [13]:
from jax import jacobian
print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


In [14]:
# jax.vjp(f, x) 反向模式 给定输入 x，返回 f(x) 以及一个 反向传播函数（vector-Jacobian product）
# jax.jvp(f, x, v) 正向模式 给定输入 x 和方向向量 v，计算方向导数
# jax.linearize(f, x) 正向模式 返回 (f(x), linear_fn)，linear_fn 可用于多次高效计算方向导数
# jax.jacrev(fun)：使用反向模式获取一阶导数（Gradient）
# jax.jacfwd(...)：再对一阶导数使用正向模式，得到二阶导数（即 Hessian）

from jax import jacfwd, jacrev

def hessian(fun):
    return jit(jacfwd(jacrev(fun))) # return jit(jacrev(jacrev(fun)))
print(hessian(sum_logistic)(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085775 -0.        ]
 [-0.         -0.         -0.07996249]]


# Auto-vectorization with jax.vmap()

In [22]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
    return jnp.dot(mat, x)

In [23]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
229 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
import numpy as np

@jit
def batched_apply_matrix(batched_x):
    return jnp.dot(batched_x, mat.T) # 转化为

# 断言函数，用于比较两个数组是否“几乎相等”
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
7.88 µs ± 180 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [None]:
from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x), vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
11.1 µs ± 143 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
