In [275]:
import jax
import time
import jax.numpy as jnp
import matplotlib.pyplot as plt
import quantecon as qe

In [276]:
%config InlineBackend.figure_format = 'svg'

In [277]:
jax.config.update("jax_enable_x64", True)

In [278]:
def create_model_jax(
    R=1.01,
    beta=0.98,
    gamma=2,
    w_min=0.01,
    w_max=5,
    w_size=150,
    rho=0.9,
    nu=0.1,
    y_size=100,
):
    w_grid = jnp.linspace(w_min, w_max, w_size)
    mc = qe.tauchen(n=y_size, rho=rho, sigma=nu)
    y_grid, Q = jnp.exp(mc.state_values), jnp.array(mc.P)
    params = (beta, R, gamma)
    sizes = (w_size, y_size)
    arrays = (w_grid, y_grid, Q)
    return params, sizes, arrays

In [279]:
def u(c, params):
    beta, R, gamma = params
    return c ** (1 - gamma) / (1 - gamma)


u = jax.jit(u)

In [280]:
def B_vmap(v, params, sizes, arrays, i, j, ip):
    """
    The right-hand side of the Bellman equation before maximization, which takes
    the form

        B(w, y, w′) = u(Rw + y - w′) + β Σ_y′ v(w′, y′) Q(y, y′)

    The indices are (i, j, ip) -> (w, y, w′).
    """
    beta, R, gamma = params
    w_grid, y_grid, Q = arrays
    w, y, wp = w_grid[i], y_grid[j], w_grid[ip]
    c = R * w + y - wp
    EV = jnp.sum(v[ip, :] * Q[j, :])
    ans = jnp.where(c > 0, u(c, params) + beta * EV, -jnp.inf)
    return ans

In [281]:
B_1 = jax.vmap(B_vmap, in_axes=(None, None, None, None, None, None, 0))
B_2 = jax.vmap(B_1, in_axes=(None, None, None, None, None, 0, None))
B_vmap = jax.vmap(B_2, in_axes=(None, None, None, None, 0, None, None))

JAX 的 `vmap` 默认行为是：**每次 vmap 都会在结果的最前面（axis 0）插入一个新的维度。**

我们来倒推一下它是怎么形成 $(i, j, ip)$ 的：

### 1. 核心规则：后进先出（Last in, First out）
*   **最内层的 vmap**（处理谁）：对应的维度会变成**最里面（最后面）的维度**。
*   **最外层的 vmap**（处理谁）：对应的维度会变成**最外面（最前面）的维度**。

---

### 2. 这里的演变过程

#### 第一步：最内层 `B_1` (处理 `ip`)
*   **任务**：针对 `ip` 向量化。
*   **动作**：算出一排结果。
*   **此时形状**：`(N_ip,)`
*   **对应维度**：因为只有一层，它既是第 0 维也是最后一维。

#### 第二步：中间层 `B_2` (处理 `j`)
*   **任务**：针对 `j` 向量化。
*   **动作**：JAX 会说：“好，我针对每一个 `j` 都运行一遍 B_1（得到一个 `N_ip` 的条），然后把这些条**堆叠**起来。”
*   **堆叠位置**：最前面 (Axis 0)。
*   **此时形状**：`(N_j, N_ip)`
    *   `N_j` 占据了 Axis 0。
    *   原来的 `N_ip` 被挤到了 Axis 1。

#### 第三步：最外层 `B_vmap` (处理 `i`)
*   **任务**：针对 `i` 向量化。
*   **动作**：JAX 会说：“好，我针对每一个 `i` 都运行一遍 B_2（得到一个 `N_j x N_ip` 的矩阵），然后把这些矩阵**堆叠**起来。”
*   **堆叠位置**：最前面 (Axis 0)。
*   **此时形状**：`(N_i, N_j, N_ip)`
    *   `N_i` 占据了 Axis 0。
    *   原来的 `N_j` 被挤到了 Axis 1。
    *   原来的 `N_ip` 被挤到了 Axis 2。

---

### 3. 如果搞反了会怎样？

假设你想保持 `(i, j, ip)` 的输出形状，但你把 `vmap` 顺序写反了，比如最外层处理 `ip`，最内层处理 `i`：

1.  内层处理 `i` $\rightarrow$ 输出 `(N_i)`
2.  中层处理 `j` $\rightarrow$ 输出 `(N_j, N_i)`
3.  外层处理 `ip` $\rightarrow$ 输出 `(N_ip, N_j, N_i)`

**结果**：你的输出形状变成了 `(ip, j, i)`。
**后果**：当你后面调用 `max(axis=2)` 时，你消掉的就不是“下一期资产选择 `ip`”了，而是“当前资产 `i`”，这完全破坏了 Bellman 方程的物理意义。

### 总结

*   你想让哪个变量出现在 Tensor 的**最后面**（最内层维度），你就必须在代码的最里层（**第一个 vmap**）去处理它。
*   你想让哪个变量出现在 Tensor 的**最前面**（最外层维度），你就必须在代码的最外层（**最后一个 vmap**）去处理它。


In [282]:
def B(v, params, sizes, arrays):
    w_size, y_size = sizes
    w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
    return B_vmap(v, params, sizes, arrays, w_indices, y_indices, w_indices)


B = jax.jit(B, static_argnums=(2,))

In [283]:
def T(v, params, sizes, arrays):
    return jnp.max(B(v, params, sizes, arrays), axis=-1)


T = jax.jit(T, static_argnums=(2,))

In [284]:
def get_greedy(v, params, sizes, arrays):
    return jnp.argmax(B(v, params, sizes, arrays), axis=-1)


get_greedy = jax.jit(get_greedy, static_argnums=(2,))

Given current policy $\sigma$, we define reward $r_\sigma$ as:
$$
r_\sigma := r(w,y,\sigma(w,y))
$$
and in this case, it's the utility from consumption given this policy $\sigma$.

In [285]:
def _get_r_sigma(sigma, params, sizes, arrays, i, j):
    beta, R, gamma = params
    w_grid, y_grid, Q = arrays

    # compute r_sigma [i, j]
    w, y, wp = w_grid[i], y_grid[j], w_grid[sigma[i, j]]
    c = R * w + y - wp
    r_sigma = u(c, params)
    return r_sigma


r_1 = jax.vmap(
    _get_r_sigma,
    in_axes=(None, None, None, None, None, 0),
)
r_sigma_vmap = jax.vmap(
    r_1,
    in_axes=(None, None, None, None, 0, None),
)

In [286]:
def r_sigma(sigma, params, sizes, arrays):
    w_size, y_size = sizes
    w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
    return r_sigma_vmap(sigma, params, sizes, arrays, w_indices, y_indices)


r_sigma = jax.jit(r_sigma, static_argnums=(2,))

In [287]:
def _T_sigma(v, sigma, params, sizes, arrays, i, j):
    beta, R, gamma = params
    w_grid, y_grid, Q = arrays

    r_sigma = _get_r_sigma(sigma, params, sizes, arrays, i, j)
    EV = jnp.sum(v[sigma[i, j], :] * Q[j, :])
    return r_sigma + beta * EV


T_1 = jax.vmap(
    _T_sigma,
    in_axes=(None, None, None, None, None, None, 0),
)

T_sigma_vmap = jax.vmap(
    T_1,
    in_axes=(None, None, None, None, None, 0, None),
)

In [288]:
def T_sigma(v, sigma, params, sizes, arrays):
    w_size, y_size = sizes
    w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
    return T_sigma_vmap(v, sigma, params, sizes, arrays, w_indices, y_indices)


T_sigma = jax.jit(T_sigma, static_argnums=(3,))

Now we need some extra math:
- we want the value $v_\sigma$ given $\sigma$
- we know it satisfies the Bellman equation:
$$
v_\sigma (w,y) = r_\sigma (w,y) + \beta \mathbb{E}_{y'|y}[v_\sigma(\sigma(w,y),y')]
$$

to solve this:
+ we define $L_\sigma$ as:
$$
(L_\sigma v)(w,y) = v_\sigma (w,y) -\beta \mathbb{E}_{y'|y}[v_\sigma(\sigma(w,y),y')]
$$
+ so it satisfies:
$$
(L_\sigma v)(w,y) = r_\sigma (w,y) 
$$
+ then we have $L_\sigma v_\sigma = r_\sigma$, that is, 
$$
v_\sigma = L_\sigma^{-1} r_\sigma
$$

Another note on this PI:
+ given this fixed policy $\sigma$, there's no $\max$ operator here.
+ this is purely a linear equation

define $v \in \mathbb{R}^{n\times 1}$, where $n$ is the number of the states in total. Then we can rewrite in matrix form:
$$
v = r + \beta P_\sigma v
$$
so we basically have:
$$
(I -\beta P_\sigma )v = r
$$
that is $L_\sigma = I - \beta P_\sigma$.

In [289]:
def _L_sigma(v, sigma, params, sizes, arrays, i, j):
    beta, R, gamma = params
    w_grid, y_grid, Q = arrays

    ans = v[i, j] - beta * jnp.sum(v[sigma[i, j], :] * Q[j, :])
    return ans


L_1 = jax.vmap(_L_sigma, in_axes=(None, None, None, None, None, None, 0))
L_sigma_vmap = jax.vmap(
    L_1,
    in_axes=(None, None, None, None, None, 0, None),
)

In [290]:
def L_sigma(v, sigma, params, sizes, arrays):
    w_size, y_size = sizes
    w_indices, y_indices = jnp.arange(w_size), jnp.arange(y_size)
    return L_sigma_vmap(v, sigma, params, sizes, arrays, w_indices, y_indices)


L_sigma = jax.jit(L_sigma, static_argnums=(3,))

In [291]:
def get_v_sigma(sigma, params, sizes, arrays):
    r_sigma_val = r_sigma(sigma, params, sizes, arrays)
    _L_sigma = lambda v: L_sigma(v, sigma, params, sizes, arrays)
    ans = jax.scipy.sparse.linalg.bicgstab(_L_sigma, r_sigma_val)
    return ans[0]


get_v_sigma = jax.jit(get_v_sigma, static_argnums=(2,))

In [292]:
def sucessive_approx_jax(T, v_0, tol=1e-8, max_iter=10000):
    def body_fun(k_v_err):
        k, v, err = k_v_err
        v_new = T(v)
        err = jnp.max(jnp.abs(v_new - v))
        return k + 1, v_new, err

    def cond_fun(k_v_err):
        k, v, err = k_v_err
        return jnp.logical_and(k < max_iter, err > tol)

    k, v, err = jax.lax.while_loop(cond_fun, body_fun, (0, v_0, jnp.inf))
    return v


sucessive_approx_jax = jax.jit(sucessive_approx_jax, static_argnums=(0,))

In [293]:
def OPI_update(sigma, v, m, params, sizes, arrays):
    def update(i, v):
        v = T_sigma(v, sigma, params, sizes, arrays)
        return v

    v = jax.lax.fori_loop(0, m, update, v)
    return v


OPI_update = jax.jit(OPI_update, static_argnums=(4,))

In [294]:
def OPI_loop(
    params,
    sizes,
    arrays,
    m=50,
    tol=1e-8,
    max_iter=10_000,
):
    v_0 = jnp.zeros(sizes)

    def cond_fun(k_v_err):
        k, v, err = k_v_err
        return jnp.logical_and(k < max_iter, err > tol)

    def body_fun(k_v_err):
        k, v, err = k_v_err
        sigma = get_greedy(v, params, sizes, arrays)
        v_new = OPI_update(sigma, v, m, params, sizes, arrays)
        err = jnp.max(jnp.abs(v_new - v))
        return k + 1, v_new, err

    k, v, err = jax.lax.while_loop(
        cond_fun,
        body_fun,
        (0, v_0, jnp.inf),
    )
    return v


OPI_loop = jax.jit(OPI_loop, static_argnums=(1,))

In [295]:
def OPI(params, sizes, arrays, m=20, tol=1e-8, max_iter=10_000):
    sigma = OPI_loop(params, sizes, arrays, m, tol, max_iter)
    return sigma

In [296]:
params, sizes, arrays = create_model_jax()
sigma = OPI(params, sizes, arrays).block_until_ready()

In [297]:
start_time = time.time()
sigma = OPI(params, sizes, arrays).block_until_ready()
end_time = time.time()
OPI_time = end_time - start_time
print(f"OPI time: {OPI_time} seconds")

OPI time: 0.2265169620513916 seconds


In [298]:
def howard_policy_iteration(model, maxiter=250):
    """
    Implements Howard policy iteration (see dp.quantecon.org)
    """
    params, sizes, arrays = model
    σ = jnp.zeros(sizes, dtype=int)
    i, error = 0, 1.0
    while error > 0 and i < maxiter:
        v_σ = get_v_sigma(σ, params, sizes, arrays)
        σ_new = get_greedy(v_σ, params, sizes, arrays)
        error = jnp.max(jnp.abs(σ_new - σ))
        σ = σ_new
        i = i + 1
        if i % 20 == 0:
            print(f"Concluded loop {i} with error {error}.")
    return σ

In [299]:
sigma = howard_policy_iteration(model=(params, sizes, arrays)).block_until_ready()

Concluded loop 20 with error 1.
Concluded loop 40 with error 1.
Concluded loop 60 with error 1.
Concluded loop 80 with error 1.
Concluded loop 100 with error 1.
Concluded loop 120 with error 1.
Concluded loop 140 with error 1.
Concluded loop 160 with error 1.
Concluded loop 180 with error 1.
Concluded loop 200 with error 1.
Concluded loop 220 with error 1.
Concluded loop 240 with error 1.


In [300]:
start_time = time.time()
sigma = howard_policy_iteration(model=(params, sizes, arrays)).block_until_ready()
end_time = time.time()
OPI_time = end_time - start_time
print(f"HPI time: {OPI_time} seconds")

Concluded loop 20 with error 1.
Concluded loop 40 with error 1.
Concluded loop 60 with error 1.
Concluded loop 80 with error 1.
Concluded loop 100 with error 1.
Concluded loop 120 with error 1.
Concluded loop 140 with error 1.
Concluded loop 160 with error 1.
Concluded loop 180 with error 1.
Concluded loop 200 with error 1.
Concluded loop 220 with error 1.
Concluded loop 240 with error 1.
HPI time: 3.2087228298187256 seconds
