diff --git a/lectures/inventory_ssd.md b/lectures/inventory_ssd.md index 7ad465e..985c99c 100644 --- a/lectures/inventory_ssd.md +++ b/lectures/inventory_ssd.md @@ -4,7 +4,8 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.17.1 + formats: md:myst,ipynb kernelspec: display_name: Python 3 (ipykernel) language: python @@ -145,7 +146,7 @@ where \sum_{d, \, z'} v(f(y, a, d), z') \phi(d) Q(z, z'). ``` -We set +We set $\beta(z) := z$ and $$ R(y, a, y') @@ -172,10 +173,9 @@ import jax import jax.numpy as jnp import numpy as np import matplotlib.pyplot as plt -from collections import namedtuple -import numba -from numba import prange from time import time +from functools import partial +from typing import NamedTuple ``` Let's check the GPU we are running @@ -194,77 +194,80 @@ Let's define a model to represent the inventory management. ```{code-cell} ipython3 # NamedTuple Model -Model = namedtuple("Model", ("c", "κ", "p", "z_vals", "Q")) +class Model(NamedTuple): + z_values: jnp.ndarray # Exogenous shock values + Q: jnp.ndarray # Exogenous shock probabilities + x_values: jnp.ndarray # Inventory values + d_values: jnp.ndarray # Demand values for summation + ϕ_values: jnp.ndarray # Demand probabilities + p: float # Demand parameter + c: float = 0.2 # Unit cost + κ: float = 0.8 # Fixed cost ``` -We need the following successive approximation function. - -```{code-cell} ipython3 -def successive_approx(T, # Operator (callable) - x_0, # Initial condition - tolerance=1e-6, # Error tolerance - max_iter=10_000, # Max iteration bound - print_step=25, # Print at multiples - verbose=False): - x = x_0 - error = tolerance + 1 - k = 1 - while error > tolerance and k <= max_iter: - x_new = T(x) - error = jnp.max(jnp.abs(x_new - x)) - if verbose and k % print_step == 0: - print(f"Completed iteration {k} with error {error}.") - x = x_new - k += 1 - if error > tolerance: - print(f"Warning: Iteration hit upper bound {max_iter}.") - elif verbose: - print(f"Terminated successfully in {k} iterations.") - return x -``` - -```{code-cell} ipython3 -@jax.jit -def demand_pdf(p, d): - return (1 - p)**d * p -``` - -```{code-cell} ipython3 -K = 100 -D_MAX = 101 -``` - -Let's define a function to create an inventory model using the given parameters. - ```{code-cell} ipython3 def create_sdd_inventory_model( - ρ=0.98, ν=0.002, n_z=100, b=0.97, # Z state parameters - c=0.2, κ=0.8, p=0.6, # firm and demand parameters - use_jax=True): + ρ: float = 0.98, # Exogenous state autocorrelation parameter + ν: float = 0.002, # Exogenous state volatility parameter + n_z: int = 10, # Exogenous state discretization size + b: float = 0.97, # Exogenous state offset + K: int = 100, # Max inventory + D_MAX: int = 101, # Demand upper bound for summation + p: float = 0.6 + ) -> Model: + + # Demand + def demand_pdf(p, d): + return (1 - p)**d * p + + d_values = jnp.arange(D_MAX) + ϕ_values = demand_pdf(p, d_values) + + # Exogenous state process mc = qe.tauchen(n_z, ρ, ν) - z_vals, Q = mc.state_values + b, mc.P - if use_jax: - z_vals, Q = map(jnp.array, (z_vals, Q)) - return Model(c=c, κ=κ, p=p, z_vals=z_vals, Q=Q) + z_values, Q = map(jnp.array, (mc.state_values + b, mc.P)) + + # Endogenous state + x_values = jnp.arange(K + 1) # 0, 1, ..., K + + return Model( + z_values=z_values, Q=Q, + x_values=x_values, d_values=d_values, ϕ_values=ϕ_values, + p=p + ) ``` Here's the function `B` on the right-hand side of the Bellman equation. ```{code-cell} ipython3 @jax.jit -def B(x, i_z, a, v, model): +def B(x, z_idx, v, model): """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). + Take z_idx and convert it to z. Then compute + + B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′) + + for all possible choices of a. """ - c, κ, p, z_vals, Q = model - z = z_vals[i_z] - d_vals = jnp.arange(D_MAX) - ϕ_vals = demand_pdf(p, d_vals) - revenue = jnp.sum(jnp.minimum(x, d_vals)*ϕ_vals) - profit = revenue - c * a - κ * (a > 0) - v_R = jnp.sum(v[jnp.maximum(x - d_vals, 0) + a].T * ϕ_vals, axis=1) - cv = jnp.sum(v_R*Q[i_z]) - return profit + z * cv + + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z = z_values[z_idx] + + def _B(a): + """ + Returns r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′) for each a. + """ + revenue = jnp.sum(jnp.minimum(x, d_values) * ϕ_values) + profit = revenue - c * a - κ * (a > 0) + v_R = jnp.sum(v[jnp.maximum(x - d_values, 0) + a].T * ϕ_values, axis=1) + cv = jnp.sum(v_R * Q[z_idx]) + return profit + z * cv + + a_values = x_values # Set of possible order sizes + B_values = jax.vmap(_B)(a_values) + max_x = len(x_values) - 1 + + return jnp.where(a_values <= max_x - x, B_values, -jnp.inf) ``` We need to vectorize this function so that we can use it efficiently in JAX. @@ -273,24 +276,8 @@ We apply a sequence of `vmap` operations to vectorize appropriately in each argument. ```{code-cell} ipython3 -B_vec_a = jax.vmap(B, in_axes=(None, None, 0, None, None)) -``` - -```{code-cell} ipython3 -@jax.jit -def B2(x, i_z, v, model): - """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). - """ - c, κ, p, z_vals, Q = model - a_vals = jnp.arange(K) - res = B_vec_a(x, i_z, a_vals, v, model) - return jnp.where(a_vals < K - x + 1, res, -jnp.inf) -``` - -```{code-cell} ipython3 -B2_vec_z = jax.vmap(B2, in_axes=(None, 0, None, None)) -B2_vec_z_x = jax.vmap(B2_vec_z, in_axes=(0, None, None, None)) +B = jax.vmap(B, in_axes=(None, 0, None, None)) +B = jax.vmap(B, in_axes=(0, None, None, None)) ``` Next we define the Bellman operator. @@ -299,10 +286,9 @@ Next we define the Bellman operator. @jax.jit def T(v, model): """The Bellman operator.""" - c, κ, p, z_vals, Q = model - i_z_range = jnp.arange(len(z_vals)) - x_range = jnp.arange(K + 1) - res = B2_vec_z_x(x_range, i_z_range, v, model) + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z_indices = jnp.arange(len(z_values)) + res = B(x_values, z_indices, v, model) return jnp.max(res, axis=2) ``` @@ -312,19 +298,34 @@ The following function computes a v-greedy policy. @jax.jit def get_greedy(v, model): """Get a v-greedy policy. Returns a zero-based array.""" - c, κ, p, z_vals, Q = model - i_z_range = jnp.arange(len(z_vals)) - x_range = jnp.arange(K + 1) - res = B2_vec_z_x(x_range, i_z_range, v, model) + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z_indices = jnp.arange(len(z_values)) + res = B(x_values, z_indices, v, model) return jnp.argmax(res, axis=2) ``` Here's code to solve the model using value function iteration. ```{code-cell} ipython3 -def solve_inventory_model(v_init, model): +@jax.jit +def solve_inventory_model(v_init, model, max_iter=10_000, tol=1e-6): """Use successive_approx to get v_star and then compute greedy.""" - v_star = successive_approx(lambda v: T(v, model), v_init, verbose=True) + + def update(state): + error, i, v = state + new_v = T(v, model) + new_error = jnp.max(jnp.abs(new_v - v)) + new_i = i + 1 + return new_error, new_i, new_v + + def test(state): + error, i, v = state + return (i < max_iter) & (error > tol) + + i, error = 0, tol + 1 + initial_state = error, i, v_init + final_state = jax.lax.while_loop(test, update, initial_state) + error, i, v_star = final_state σ_star = get_greedy(v_star, model) return v_star, σ_star ``` @@ -333,16 +334,21 @@ Now let's create an instance and solve it. ```{code-cell} ipython3 model = create_sdd_inventory_model() -c, κ, p, z_vals, Q = model -n_z = len(z_vals) -v_init = jnp.zeros((K + 1, n_z), dtype=float) +z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model +n_z = len(z_values) +n_x = len(x_values) +v_init = jnp.zeros((n_x, n_z), dtype=float) ``` ```{code-cell} ipython3 start = time() v_star, σ_star = solve_inventory_model(v_init, model) + +# Pause until execution finishes +jax.tree_util.tree_map(lambda x: x.block_until_ready(), (v_star, σ_star)) + jax_time_with_compile = time() - start -print("Jax compile plus execution time = ", jax_time_with_compile) +print(f"compile plus execution time = {jax_time_with_compile * 1000:.6f} ms") ``` Let's run again to get rid of the compile time. @@ -350,27 +356,44 @@ Let's run again to get rid of the compile time. ```{code-cell} ipython3 start = time() v_star, σ_star = solve_inventory_model(v_init, model) + +# Pause until execution finishes +jax.tree_util.tree_map(lambda x: x.block_until_ready(), (v_star, σ_star)) + jax_time_without_compile = time() - start -print("Jax execution time = ", jax_time_without_compile) +print(f"execution time = {jax_time_without_compile * 1000:.6f} ms") ``` +Now let's do a simulation. + +We'll begin by converting back to NumPy arrays for convenience + ```{code-cell} ipython3 -z_mc = qe.MarkovChain(Q, z_vals) +Q = np.array(Q) +z_values = np.array(z_values) +z_mc = qe.MarkovChain(Q, z_values) ``` +Here's code to simulate inventories + ```{code-cell} ipython3 def sim_inventories(ts_length, X_init=0): """Simulate given the optimal policy.""" global p, z_mc - i_z = z_mc.simulate_indices(ts_length, init=1) + + z_idx = z_mc.simulate_indices(ts_length, init=1) X = np.zeros(ts_length, dtype=np.int32) X[0] = X_init rand = np.random.default_rng().geometric(p=p, size=ts_length-1) - 1 + for t in range(ts_length-1): - X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], i_z[t]] - return X, z_vals[i_z] + X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], z_idx[t]] + + return X, z_values[z_idx] ``` +Here's code to generate a plot. + ```{code-cell} ipython3 def plot_ts(ts_length=400, fontsize=10): X, Z = sim_inventories(ts_length) @@ -396,104 +419,8 @@ def plot_ts(ts_length=400, fontsize=10): plt.show() ``` -```{code-cell} ipython3 -plot_ts() -``` - -## Numba implementation - - -Let's try the same operations in Numba in order to compare the speed. - -```{code-cell} ipython3 -@numba.njit -def demand_pdf_numba(p, d): - return (1 - p)**d * p - -@numba.njit -def B_numba(x, i_z, a, v, model): - """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). - """ - c, κ, p, z_vals, Q = model - z = z_vals[i_z] - d_vals = np.arange(D_MAX) - ϕ_vals = demand_pdf_numba(p, d_vals) - revenue = np.sum(np.minimum(x, d_vals)*ϕ_vals) - profit = revenue - c * a - κ * (a > 0) - v_R = np.sum(v[np.maximum(x - d_vals, 0) + a].T * ϕ_vals, axis=1) - cv = np.sum(v_R*Q[i_z]) - return profit + z * cv - - -@numba.njit(parallel=True) -def T_numba(v, model): - """The Bellman operator.""" - c, κ, p, z_vals, Q = model - new_v = np.empty_like(v) - for i_z in prange(len(z_vals)): - for x in prange(K+1): - v_1 = np.array([B_numba(x, i_z, a, v, model) - for a in range(K-x+1)]) - new_v[x, i_z] = np.max(v_1) - return new_v - - -@numba.njit(parallel=True) -def get_greedy_numba(v, model): - """Get a v-greedy policy. Returns a zero-based array.""" - c, κ, p, z_vals, Q = model - n_z = len(z_vals) - σ_star = np.zeros((K+1, n_z), dtype=np.int32) - for i_z in prange(n_z): - for x in range(K+1): - v_1 = np.array([B_numba(x, i_z, a, v, model) - for a in range(K-x+1)]) - σ_star[x, i_z] = np.argmax(v_1) - return σ_star - - - -def solve_inventory_model_numba(v_init, model): - """Use successive_approx to get v_star and then compute greedy.""" - v_star = successive_approx(lambda v: T_numba(v, model), v_init, verbose=True) - σ_star = get_greedy_numba(v_star, model) - return v_star, σ_star -``` +Let's take a look. ```{code-cell} ipython3 -model = create_sdd_inventory_model(use_jax=False) -c, κ, p, z_vals, Q = model -n_z = len(z_vals) -v_init = np.zeros((K + 1, n_z), dtype=float) -``` - -```{code-cell} ipython3 -start = time() -v_star_numba, σ_star_numba = solve_inventory_model_numba(v_init, model) -numba_time_with_compile = time() - start -print("Numba compile plus execution time = ", numba_time_with_compile) -``` - -Let's run again to eliminate the compile time. - -```{code-cell} ipython3 -start = time() -v_star_numba, σ_star_numba = solve_inventory_model_numba(v_init, model) -numba_time_without_compile = time() - start -print("Numba execution time = ", numba_time_without_compile) -``` - -Let's verify that the Numba and JAX implementations converge to the same solution. - -```{code-cell} ipython3 -np.allclose(v_star_numba, v_star) -``` - -Here's the speed comparison. - -```{code-cell} ipython3 -print("JAX vectorized implementation is " - f"{numba_time_without_compile/jax_time_without_compile} faster " - "than Numba's parallel implementation") +plot_ts() ```