From a3419397f9d3f4fe3f273496dfcd6747af4a4792 Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Sun, 3 Aug 2025 15:56:59 +0900 Subject: [PATCH 1/4] misc --- lectures/inventory_ssd.py | 404 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 404 insertions(+) create mode 100644 lectures/inventory_ssd.py diff --git a/lectures/inventory_ssd.py b/lectures/inventory_ssd.py new file mode 100644 index 00000000..a45d737c --- /dev/null +++ b/lectures/inventory_ssd.py @@ -0,0 +1,404 @@ +# --- +# jupyter: +# jupytext: +# default_lexer: ipython3 +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.17.2 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Inventory Management Model +# +# ```{include} _admonition/gpu.md +# ``` +# +# +# This lecture provides a JAX implementation of a model in [Dynamic Programming](https://dp.quantecon.org/). +# +# In addition to JAX and Anaconda, this lecture will need the following libraries: + +# %% tags=["hide-output"] +# !pip install --upgrade quantecon + +# %% [markdown] +# ## A model with constant discounting +# +# +# We study a firm where a manager tries to maximize shareholder value. +# +# To simplify the problem, we assume that the firm only sells one product. +# +# Letting $\pi_t$ be profits at time $t$ and $r > 0$ be the interest rate, the value of the firm is +# +# $$ +# V_0 = \sum_{t \geq 0} \beta^t \pi_t +# \qquad +# \text{ where } +# \quad \beta := \frac{1}{1+r}. +# $$ +# +# Suppose the firm faces exogenous demand process $(D_t)_{t \geq 0}$. +# +# We assume $(D_t)_{t \geq 0}$ is IID with common distribution $\phi \in (Z_+)$. +# +# Inventory $(X_t)_{t \geq 0}$ of the product obeys +# +# $$ +# X_{t+1} = f(X_t, D_{t+1}, A_t) +# \qquad +# \text{where} +# \quad +# f(x,a,d) := (x - d)\vee 0 + a. +# $$ +# +# The term $A_t$ is units of stock ordered this period, which take one period to +# arrive. +# +# We assume that the firm can store at most $K$ items at one time. +# +# Profits are given by +# +# $$ +# \pi_t := X_t \wedge D_{t+1} - c A_t - \kappa 1\{A_t > 0\}. +# $$ +# +# We take the minimum of current stock and demand because orders in excess of +# inventory are assumed to be lost rather than back-filled. +# +# Here $c$ is unit product cost and $\kappa$ is a fixed cost of ordering inventory. +# +# +# We can map our inventory problem into a dynamic program with state space +# $X := \{0, \ldots, K\}$ and action space $A := X$. +# +# The feasible correspondence $\Gamma$ is +# +# $$ +# \Gamma(x) := \{0, \ldots, K - x\}, +# $$ +# +# which represents the set of feasible orders when the current inventory +# state is $x$. +# +# The reward function is expected current profits, or +# +# $$ +# r(x, a) := \sum_{d \geq 0} (x \wedge d) \phi(d) +# - c a - \kappa 1\{a > 0\}. +# $$ +# +# The stochastic kernel (i.e., state-transition probabilities) from the set of feasible state-action pairs is +# +# $$ +# P(x, a, x') := P\{ f(x, a, D) = x' \} +# \qquad \text{when} \quad +# D \sim \phi. +# $$ +# +# When discounting is constant, the Bellman equation takes the form +# +# ```{math} +# :label: inventory_ssd_v1 +# v(x) +# = \max_{a \in \Gamma(x)} \left\{ +# r(x, a) +# + \beta +# \sum_{d \geq 0} v(f(x, a, d)) \phi(d) +# \right\} +# ``` +# +# ## Time varing discount rates +# +# We wish to consider a more sophisticated model with time-varying discounting. +# +# This time variation accommodates non-constant interest rates. +# +# To this end, we replace the constant $\beta$ in +# {eq}`inventory_ssd_v1` with a stochastic process $(\beta_t)$ where +# +# * $\beta_t = 1/(1+r_t)$ and +# * $r_t$ is the interest rate at time $t$ +# +# We suppose that the dynamics can be expressed as $\beta_t = \beta(Z_t)$, where the exogenous process $(Z_t)_{t \geq 0}$ is a Markov chain +# on $Z$ with Markov matrix $Q$. +# +# After relabeling inventory $X_t$ as $Y_t$ and $x$ as $y$, the Bellman equation becomes +# +# $$ +# v(y, z) = \max_{a \in \Gamma(x)} B((y, z), a, v) +# $$ +# +# where +# +# ```{math} +# :label: inventory_ssd_b1 +# B((y, z), a, v) +# = +# r(y, a) +# + \beta(z) +# \sum_{d, \, z'} v(f(y, a, d), z') \phi(d) Q(z, z'). +# ``` +# +# We set +# +# $$ +# R(y, a, y') +# := P\{f(y, a, d) = y'\} \quad \text{when} \quad D \sim \phi, +# $$ +# +# Now $R(y, a, y')$ is the probability of realizing next period inventory level +# $y'$ when the current level is $y$ and the action is $a$. +# +# Hence we can rewrite {eq}`inventory_ssd_b1` as +# +# $$ +# B((y, z), a, v) +# = r(y, a) +# + \beta(z) +# \sum_{y', z'} v(y', z') Q(z, z') R(y, a, y') . +# $$ +# +# Let's begin with the following imports + +# %% +import quantecon as qe +import jax +import jax.numpy as jnp +import numpy as np +import matplotlib.pyplot as plt +from time import time +from functools import partial +from typing import NamedTuple + +# %% [markdown] +# Let's check the GPU we are running + +# %% +# !nvidia-smi + +# %% [markdown] +# We will use 64 bit floats with JAX in order to increase the precision. + +# %% +jax.config.update("jax_enable_x64", True) + +# %% [markdown] +# Let's define a model to represent the inventory management. + +# %% +# NamedTuple Model +class Model(NamedTuple): + z_values: jnp.ndarray # Exogenous shock values + Q: jnp.ndarray # Exogenous shock probabilities + x_values: jnp.ndarray # Inventory values + d_values: jnp.ndarray # Demand values for summation + ϕ_values: jnp.ndarray # Demand probabilities + p: float # Demand parameter + c: float = 0.2 # Unit cost + κ: float = 0.8 # Fixed cost + +# %% +def create_sdd_inventory_model( + ρ: float = 0.98, # Exogenous state autocorrelation parameter + ν: float = 0.002, # Exogenous state volatility parameter + n_z: int = 100, # Exogenous state discretization size + b: float = 0.97, # Exogenous state offset + K: int = 100, # Max inventory + D_MAX: int = 101, # Demand upper bound for summation + p: float = 0.6 + ) -> Model: + # Demand + def demand_pdf(p, d): + return (1 - p)**d * p + d_values = jnp.arange(D_MAX) + ϕ_values = demand_pdf(p, d_values) + # Exogenous state process + mc = qe.tauchen(n_z, ρ, ν) + z_values, Q = map(jnp.array, (mc.state_values + b, mc.P)) + # Endogenous state + x_values = jnp.arange(K + 1) # 0, 1, ..., K + return Model( + z_values=z_values, Q=Q, + x_values=x_values, d_values=d_values, ϕ_values=ϕ_values, + p=p + ) + + +# %% [markdown] +# Here's the function `B` on the right-hand side of the Bellman equation. + +# %% +@jax.jit +def B(x, z_idx, a, v, model): + """ + The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). + """ + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z = z_values[z_idx] + revenue = jnp.sum(jnp.minimum(x, d_values) * ϕ_values) + profit = revenue - c * a - κ * (a > 0) + v_R = jnp.sum(v[jnp.maximum(x - d_values, 0) + a].T * ϕ_values, axis=1) + cv = jnp.sum(v_R * Q[z_idx]) + return profit + z * cv + + +# %% [markdown] +# We need to vectorize this function so that we can use it efficiently in JAX. +# +# We apply a sequence of `vmap` operations to vectorize appropriately in each +# argument. + +# %% +B_vec_a = jax.vmap(B, in_axes=(None, None, 0, None, None)) + + +# %% +@jax.jit +def B2(x, z_idx, v, model): + """ + The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). + """ + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + a_values = x_values # Set of possible order sizes + max_x = len(x_values) - 1 + res = B_vec_a(x, z_idx, a_values, v, model) + return jnp.where(a_values <= max_x - x, res, -jnp.inf) + + +# %% +B2_vec_z = jax.vmap(B2, in_axes=(None, 0, None, None)) +B2_vec_z_x = jax.vmap(B2_vec_z, in_axes=(0, None, None, None)) + + +# %% [markdown] +# Next we define the Bellman operator. + +# %% +@jax.jit +def T(v, model): + """The Bellman operator.""" + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z_indices = jnp.arange(len(z_values)) + res = B2_vec_z_x(x_values, z_indices, v, model) + return jnp.max(res, axis=2) + + +# %% [markdown] +# The following function computes a v-greedy policy. + +# %% +@jax.jit +def get_greedy(v, model): + """Get a v-greedy policy. Returns a zero-based array.""" + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z_indices = jnp.arange(len(z_values)) + res = B2_vec_z_x(x_values, z_indices, v, model) + return jnp.argmax(res, axis=2) + + +# %% [markdown] +# Here's code to solve the model using value function iteration. + +# %% +def solve_inventory_model(v_init, model, max_iter=10_000, tol=1e-6): + """Use successive_approx to get v_star and then compute greedy.""" + + def update(state): + error, i, v = state + new_v = T(v, model) + new_error = jnp.max(jnp.abs(new_v - v)) + new_i = i + 1 + return new_error, new_i, new_v + + def test(state): + error, i, v = state + return (i < max_iter) & (error > tol) + + i, error = 0, tol + 1 + initial_state = error, i, v_init + final_state = jax.lax.while_loop(test, update, initial_state) + error, i, v_star = final_state + σ_star = get_greedy(v_star, model) + return v_star, σ_star + + +# %% [markdown] +# Now let's create an instance and solve it. + +# %% +model = create_sdd_inventory_model() +z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model +n_z = len(z_values) +n_x = len(x_values) +v_init = jnp.zeros((n_x, n_z), dtype=float) + +# %% +start = time() +v_star, σ_star = solve_inventory_model(v_init, model) +jax_time_with_compile = time() - start +print("Jax compile plus execution time = ", jax_time_with_compile) + +# %% [markdown] +# Let's run again to get rid of the compile time. + +# %% +start = time() +v_star, σ_star = solve_inventory_model(v_init, model) +jax_time_without_compile = time() - start +print("Jax execution time = ", jax_time_without_compile) + +# %% +Q = np.array(Q) +z_values = np.array(z_values) +z_mc = qe.MarkovChain(Q, z_values) + + +# %% +def sim_inventories(ts_length, X_init=0): + """Simulate given the optimal policy.""" + global p, z_mc + z_idx = z_mc.simulate_indices(ts_length, init=1) + X = np.zeros(ts_length, dtype=np.int32) + X[0] = X_init + rand = np.random.default_rng().geometric(p=p, size=ts_length-1) - 1 + for t in range(ts_length-1): + X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], z_idx[t]] + return X, z_values[z_idx] + + +# %% +def plot_ts(ts_length=400, fontsize=10): + X, Z = sim_inventories(ts_length) + fig, axes = plt.subplots(2, 1, figsize=(9, 5.5)) + + ax = axes[0] + ax.plot(X, label=r"$X_t$", alpha=0.7) + ax.set_xlabel(r"$t$", fontsize=fontsize) + ax.set_ylabel("inventory", fontsize=fontsize) + ax.legend(fontsize=fontsize, frameon=False) + ax.set_ylim(0, np.max(X)+3) + + # calculate interest rate from discount factors + r = (1 / Z) - 1 + + ax = axes[1] + ax.plot(r, label=r"$r_t$", alpha=0.7) + ax.set_xlabel(r"$t$", fontsize=fontsize) + ax.set_ylabel("interest rate", fontsize=fontsize) + ax.legend(fontsize=fontsize, frameon=False) + + plt.tight_layout() + plt.show() + + +# %% +# plot_ts() + + From 9c78fc732b84346ebfec5a466aa4a3818de606be Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Wed, 6 Aug 2025 05:57:58 +0900 Subject: [PATCH 2/4] misc --- lectures/inventory_ssd.md | 298 ++++++++++------------------ lectures/inventory_ssd.py | 404 -------------------------------------- 2 files changed, 107 insertions(+), 595 deletions(-) delete mode 100644 lectures/inventory_ssd.py diff --git a/lectures/inventory_ssd.md b/lectures/inventory_ssd.md index 7ad465e9..93dabf09 100644 --- a/lectures/inventory_ssd.md +++ b/lectures/inventory_ssd.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.17.2 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -172,10 +172,9 @@ import jax import jax.numpy as jnp import numpy as np import matplotlib.pyplot as plt -from collections import namedtuple -import numba -from numba import prange from time import time +from functools import partial +from typing import NamedTuple ``` Let's check the GPU we are running @@ -194,77 +193,75 @@ Let's define a model to represent the inventory management. ```{code-cell} ipython3 # NamedTuple Model -Model = namedtuple("Model", ("c", "κ", "p", "z_vals", "Q")) +class Model(NamedTuple): + z_values: jnp.ndarray # Exogenous shock values + Q: jnp.ndarray # Exogenous shock probabilities + x_values: jnp.ndarray # Inventory values + d_values: jnp.ndarray # Demand values for summation + ϕ_values: jnp.ndarray # Demand probabilities + p: float # Demand parameter + c: float = 0.2 # Unit cost + κ: float = 0.8 # Fixed cost ``` -We need the following successive approximation function. - -```{code-cell} ipython3 -def successive_approx(T, # Operator (callable) - x_0, # Initial condition - tolerance=1e-6, # Error tolerance - max_iter=10_000, # Max iteration bound - print_step=25, # Print at multiples - verbose=False): - x = x_0 - error = tolerance + 1 - k = 1 - while error > tolerance and k <= max_iter: - x_new = T(x) - error = jnp.max(jnp.abs(x_new - x)) - if verbose and k % print_step == 0: - print(f"Completed iteration {k} with error {error}.") - x = x_new - k += 1 - if error > tolerance: - print(f"Warning: Iteration hit upper bound {max_iter}.") - elif verbose: - print(f"Terminated successfully in {k} iterations.") - return x -``` - -```{code-cell} ipython3 -@jax.jit -def demand_pdf(p, d): - return (1 - p)**d * p -``` - -```{code-cell} ipython3 -K = 100 -D_MAX = 101 -``` - -Let's define a function to create an inventory model using the given parameters. - ```{code-cell} ipython3 def create_sdd_inventory_model( - ρ=0.98, ν=0.002, n_z=100, b=0.97, # Z state parameters - c=0.2, κ=0.8, p=0.6, # firm and demand parameters - use_jax=True): + ρ: float = 0.98, # Exogenous state autocorrelation parameter + ν: float = 0.002, # Exogenous state volatility parameter + n_z: int = 10, # Exogenous state discretization size + b: float = 0.97, # Exogenous state offset + K: int = 100, # Max inventory + D_MAX: int = 101, # Demand upper bound for summation + p: float = 0.6 + ) -> Model: + # Demand + def demand_pdf(p, d): + return (1 - p)**d * p + d_values = jnp.arange(D_MAX) + ϕ_values = demand_pdf(p, d_values) + # Exogenous state process mc = qe.tauchen(n_z, ρ, ν) - z_vals, Q = mc.state_values + b, mc.P - if use_jax: - z_vals, Q = map(jnp.array, (z_vals, Q)) - return Model(c=c, κ=κ, p=p, z_vals=z_vals, Q=Q) + z_values, Q = map(jnp.array, (mc.state_values + b, mc.P)) + # Endogenous state + x_values = jnp.arange(K + 1) # 0, 1, ..., K + return Model( + z_values=z_values, Q=Q, + x_values=x_values, d_values=d_values, ϕ_values=ϕ_values, + p=p + ) ``` Here's the function `B` on the right-hand side of the Bellman equation. ```{code-cell} ipython3 @jax.jit -def B(x, i_z, a, v, model): +def B(x, z_idx, v, model): """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). + Take z_idx and convert it to z. Then compute + + B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′) + + for all possible choices of a. + """ - c, κ, p, z_vals, Q = model - z = z_vals[i_z] - d_vals = jnp.arange(D_MAX) - ϕ_vals = demand_pdf(p, d_vals) - revenue = jnp.sum(jnp.minimum(x, d_vals)*ϕ_vals) - profit = revenue - c * a - κ * (a > 0) - v_R = jnp.sum(v[jnp.maximum(x - d_vals, 0) + a].T * ϕ_vals, axis=1) - cv = jnp.sum(v_R*Q[i_z]) - return profit + z * cv + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z = z_values[z_idx] + + def _B(a): + """ + Returns r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′) for each a. + + """ + revenue = jnp.sum(jnp.minimum(x, d_values) * ϕ_values) + profit = revenue - c * a - κ * (a > 0) + v_R = jnp.sum(v[jnp.maximum(x - d_values, 0) + a].T * ϕ_values, axis=1) + cv = jnp.sum(v_R * Q[z_idx]) + return profit + z * cv + + a_values = x_values # Set of possible order sizes + B_values = jax.vmap(_B)(a_values) + max_x = len(x_values) - 1 + return jnp.where(a_values <= max_x - x, B_values, -jnp.inf) ``` We need to vectorize this function so that we can use it efficiently in JAX. @@ -273,24 +270,12 @@ We apply a sequence of `vmap` operations to vectorize appropriately in each argument. ```{code-cell} ipython3 -B_vec_a = jax.vmap(B, in_axes=(None, None, 0, None, None)) -``` -```{code-cell} ipython3 -@jax.jit -def B2(x, i_z, v, model): - """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). - """ - c, κ, p, z_vals, Q = model - a_vals = jnp.arange(K) - res = B_vec_a(x, i_z, a_vals, v, model) - return jnp.where(a_vals < K - x + 1, res, -jnp.inf) ``` ```{code-cell} ipython3 -B2_vec_z = jax.vmap(B2, in_axes=(None, 0, None, None)) -B2_vec_z_x = jax.vmap(B2_vec_z, in_axes=(0, None, None, None)) +B = jax.vmap(B, in_axes=(None, 0, None, None)) +B = jax.vmap(B, in_axes=(0, None, None, None)) ``` Next we define the Bellman operator. @@ -299,10 +284,9 @@ Next we define the Bellman operator. @jax.jit def T(v, model): """The Bellman operator.""" - c, κ, p, z_vals, Q = model - i_z_range = jnp.arange(len(z_vals)) - x_range = jnp.arange(K + 1) - res = B2_vec_z_x(x_range, i_z_range, v, model) + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z_indices = jnp.arange(len(z_values)) + res = B(x_values, z_indices, v, model) return jnp.max(res, axis=2) ``` @@ -312,19 +296,34 @@ The following function computes a v-greedy policy. @jax.jit def get_greedy(v, model): """Get a v-greedy policy. Returns a zero-based array.""" - c, κ, p, z_vals, Q = model - i_z_range = jnp.arange(len(z_vals)) - x_range = jnp.arange(K + 1) - res = B2_vec_z_x(x_range, i_z_range, v, model) + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model + z_indices = jnp.arange(len(z_values)) + res = B(x_values, z_indices, v, model) return jnp.argmax(res, axis=2) ``` Here's code to solve the model using value function iteration. ```{code-cell} ipython3 -def solve_inventory_model(v_init, model): +@jax.jit +def solve_inventory_model(v_init, model, max_iter=10_000, tol=1e-6): """Use successive_approx to get v_star and then compute greedy.""" - v_star = successive_approx(lambda v: T(v, model), v_init, verbose=True) + + def update(state): + error, i, v = state + new_v = T(v, model) + new_error = jnp.max(jnp.abs(new_v - v)) + new_i = i + 1 + return new_error, new_i, new_v + + def test(state): + error, i, v = state + return (i < max_iter) & (error > tol) + + i, error = 0, tol + 1 + initial_state = error, i, v_init + final_state = jax.lax.while_loop(test, update, initial_state) + error, i, v_star = final_state σ_star = get_greedy(v_star, model) return v_star, σ_star ``` @@ -333,14 +332,16 @@ Now let's create an instance and solve it. ```{code-cell} ipython3 model = create_sdd_inventory_model() -c, κ, p, z_vals, Q = model -n_z = len(z_vals) -v_init = jnp.zeros((K + 1, n_z), dtype=float) +z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model +n_z = len(z_values) +n_x = len(x_values) +v_init = jnp.zeros((n_x, n_z), dtype=float) ``` ```{code-cell} ipython3 start = time() v_star, σ_star = solve_inventory_model(v_init, model) +v_star.block_until_ready() # Pause until execution finishes jax_time_with_compile = time() - start print("Jax compile plus execution time = ", jax_time_with_compile) ``` @@ -350,27 +351,38 @@ Let's run again to get rid of the compile time. ```{code-cell} ipython3 start = time() v_star, σ_star = solve_inventory_model(v_init, model) +v_star.block_until_ready() # Pause until execution finishes jax_time_without_compile = time() - start print("Jax execution time = ", jax_time_without_compile) ``` +Now let's do a simulation. + +We'll begin by converting back to NumPy arrays for convenience + ```{code-cell} ipython3 -z_mc = qe.MarkovChain(Q, z_vals) +Q = np.array(Q) +z_values = np.array(z_values) +z_mc = qe.MarkovChain(Q, z_values) ``` +Here's code to simulate inventories + ```{code-cell} ipython3 def sim_inventories(ts_length, X_init=0): """Simulate given the optimal policy.""" global p, z_mc - i_z = z_mc.simulate_indices(ts_length, init=1) + z_idx = z_mc.simulate_indices(ts_length, init=1) X = np.zeros(ts_length, dtype=np.int32) X[0] = X_init rand = np.random.default_rng().geometric(p=p, size=ts_length-1) - 1 for t in range(ts_length-1): - X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], i_z[t]] - return X, z_vals[i_z] + X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], z_idx[t]] + return X, z_values[z_idx] ``` +Here's code to generate a plot. + ```{code-cell} ipython3 def plot_ts(ts_length=400, fontsize=10): X, Z = sim_inventories(ts_length) @@ -396,104 +408,8 @@ def plot_ts(ts_length=400, fontsize=10): plt.show() ``` -```{code-cell} ipython3 -plot_ts() -``` - -## Numba implementation - - -Let's try the same operations in Numba in order to compare the speed. +Let's take a look. ```{code-cell} ipython3 -@numba.njit -def demand_pdf_numba(p, d): - return (1 - p)**d * p - -@numba.njit -def B_numba(x, i_z, a, v, model): - """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). - """ - c, κ, p, z_vals, Q = model - z = z_vals[i_z] - d_vals = np.arange(D_MAX) - ϕ_vals = demand_pdf_numba(p, d_vals) - revenue = np.sum(np.minimum(x, d_vals)*ϕ_vals) - profit = revenue - c * a - κ * (a > 0) - v_R = np.sum(v[np.maximum(x - d_vals, 0) + a].T * ϕ_vals, axis=1) - cv = np.sum(v_R*Q[i_z]) - return profit + z * cv - - -@numba.njit(parallel=True) -def T_numba(v, model): - """The Bellman operator.""" - c, κ, p, z_vals, Q = model - new_v = np.empty_like(v) - for i_z in prange(len(z_vals)): - for x in prange(K+1): - v_1 = np.array([B_numba(x, i_z, a, v, model) - for a in range(K-x+1)]) - new_v[x, i_z] = np.max(v_1) - return new_v - - -@numba.njit(parallel=True) -def get_greedy_numba(v, model): - """Get a v-greedy policy. Returns a zero-based array.""" - c, κ, p, z_vals, Q = model - n_z = len(z_vals) - σ_star = np.zeros((K+1, n_z), dtype=np.int32) - for i_z in prange(n_z): - for x in range(K+1): - v_1 = np.array([B_numba(x, i_z, a, v, model) - for a in range(K-x+1)]) - σ_star[x, i_z] = np.argmax(v_1) - return σ_star - - - -def solve_inventory_model_numba(v_init, model): - """Use successive_approx to get v_star and then compute greedy.""" - v_star = successive_approx(lambda v: T_numba(v, model), v_init, verbose=True) - σ_star = get_greedy_numba(v_star, model) - return v_star, σ_star -``` - -```{code-cell} ipython3 -model = create_sdd_inventory_model(use_jax=False) -c, κ, p, z_vals, Q = model -n_z = len(z_vals) -v_init = np.zeros((K + 1, n_z), dtype=float) -``` - -```{code-cell} ipython3 -start = time() -v_star_numba, σ_star_numba = solve_inventory_model_numba(v_init, model) -numba_time_with_compile = time() - start -print("Numba compile plus execution time = ", numba_time_with_compile) -``` - -Let's run again to eliminate the compile time. - -```{code-cell} ipython3 -start = time() -v_star_numba, σ_star_numba = solve_inventory_model_numba(v_init, model) -numba_time_without_compile = time() - start -print("Numba execution time = ", numba_time_without_compile) -``` - -Let's verify that the Numba and JAX implementations converge to the same solution. - -```{code-cell} ipython3 -np.allclose(v_star_numba, v_star) -``` - -Here's the speed comparison. - -```{code-cell} ipython3 -print("JAX vectorized implementation is " - f"{numba_time_without_compile/jax_time_without_compile} faster " - "than Numba's parallel implementation") +plot_ts() ``` diff --git a/lectures/inventory_ssd.py b/lectures/inventory_ssd.py deleted file mode 100644 index a45d737c..00000000 --- a/lectures/inventory_ssd.py +++ /dev/null @@ -1,404 +0,0 @@ -# --- -# jupyter: -# jupytext: -# default_lexer: ipython3 -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.17.2 -# kernelspec: -# display_name: Python 3 (ipykernel) -# language: python -# name: python3 -# --- - -# %% [markdown] -# # Inventory Management Model -# -# ```{include} _admonition/gpu.md -# ``` -# -# -# This lecture provides a JAX implementation of a model in [Dynamic Programming](https://dp.quantecon.org/). -# -# In addition to JAX and Anaconda, this lecture will need the following libraries: - -# %% tags=["hide-output"] -# !pip install --upgrade quantecon - -# %% [markdown] -# ## A model with constant discounting -# -# -# We study a firm where a manager tries to maximize shareholder value. -# -# To simplify the problem, we assume that the firm only sells one product. -# -# Letting $\pi_t$ be profits at time $t$ and $r > 0$ be the interest rate, the value of the firm is -# -# $$ -# V_0 = \sum_{t \geq 0} \beta^t \pi_t -# \qquad -# \text{ where } -# \quad \beta := \frac{1}{1+r}. -# $$ -# -# Suppose the firm faces exogenous demand process $(D_t)_{t \geq 0}$. -# -# We assume $(D_t)_{t \geq 0}$ is IID with common distribution $\phi \in (Z_+)$. -# -# Inventory $(X_t)_{t \geq 0}$ of the product obeys -# -# $$ -# X_{t+1} = f(X_t, D_{t+1}, A_t) -# \qquad -# \text{where} -# \quad -# f(x,a,d) := (x - d)\vee 0 + a. -# $$ -# -# The term $A_t$ is units of stock ordered this period, which take one period to -# arrive. -# -# We assume that the firm can store at most $K$ items at one time. -# -# Profits are given by -# -# $$ -# \pi_t := X_t \wedge D_{t+1} - c A_t - \kappa 1\{A_t > 0\}. -# $$ -# -# We take the minimum of current stock and demand because orders in excess of -# inventory are assumed to be lost rather than back-filled. -# -# Here $c$ is unit product cost and $\kappa$ is a fixed cost of ordering inventory. -# -# -# We can map our inventory problem into a dynamic program with state space -# $X := \{0, \ldots, K\}$ and action space $A := X$. -# -# The feasible correspondence $\Gamma$ is -# -# $$ -# \Gamma(x) := \{0, \ldots, K - x\}, -# $$ -# -# which represents the set of feasible orders when the current inventory -# state is $x$. -# -# The reward function is expected current profits, or -# -# $$ -# r(x, a) := \sum_{d \geq 0} (x \wedge d) \phi(d) -# - c a - \kappa 1\{a > 0\}. -# $$ -# -# The stochastic kernel (i.e., state-transition probabilities) from the set of feasible state-action pairs is -# -# $$ -# P(x, a, x') := P\{ f(x, a, D) = x' \} -# \qquad \text{when} \quad -# D \sim \phi. -# $$ -# -# When discounting is constant, the Bellman equation takes the form -# -# ```{math} -# :label: inventory_ssd_v1 -# v(x) -# = \max_{a \in \Gamma(x)} \left\{ -# r(x, a) -# + \beta -# \sum_{d \geq 0} v(f(x, a, d)) \phi(d) -# \right\} -# ``` -# -# ## Time varing discount rates -# -# We wish to consider a more sophisticated model with time-varying discounting. -# -# This time variation accommodates non-constant interest rates. -# -# To this end, we replace the constant $\beta$ in -# {eq}`inventory_ssd_v1` with a stochastic process $(\beta_t)$ where -# -# * $\beta_t = 1/(1+r_t)$ and -# * $r_t$ is the interest rate at time $t$ -# -# We suppose that the dynamics can be expressed as $\beta_t = \beta(Z_t)$, where the exogenous process $(Z_t)_{t \geq 0}$ is a Markov chain -# on $Z$ with Markov matrix $Q$. -# -# After relabeling inventory $X_t$ as $Y_t$ and $x$ as $y$, the Bellman equation becomes -# -# $$ -# v(y, z) = \max_{a \in \Gamma(x)} B((y, z), a, v) -# $$ -# -# where -# -# ```{math} -# :label: inventory_ssd_b1 -# B((y, z), a, v) -# = -# r(y, a) -# + \beta(z) -# \sum_{d, \, z'} v(f(y, a, d), z') \phi(d) Q(z, z'). -# ``` -# -# We set -# -# $$ -# R(y, a, y') -# := P\{f(y, a, d) = y'\} \quad \text{when} \quad D \sim \phi, -# $$ -# -# Now $R(y, a, y')$ is the probability of realizing next period inventory level -# $y'$ when the current level is $y$ and the action is $a$. -# -# Hence we can rewrite {eq}`inventory_ssd_b1` as -# -# $$ -# B((y, z), a, v) -# = r(y, a) -# + \beta(z) -# \sum_{y', z'} v(y', z') Q(z, z') R(y, a, y') . -# $$ -# -# Let's begin with the following imports - -# %% -import quantecon as qe -import jax -import jax.numpy as jnp -import numpy as np -import matplotlib.pyplot as plt -from time import time -from functools import partial -from typing import NamedTuple - -# %% [markdown] -# Let's check the GPU we are running - -# %% -# !nvidia-smi - -# %% [markdown] -# We will use 64 bit floats with JAX in order to increase the precision. - -# %% -jax.config.update("jax_enable_x64", True) - -# %% [markdown] -# Let's define a model to represent the inventory management. - -# %% -# NamedTuple Model -class Model(NamedTuple): - z_values: jnp.ndarray # Exogenous shock values - Q: jnp.ndarray # Exogenous shock probabilities - x_values: jnp.ndarray # Inventory values - d_values: jnp.ndarray # Demand values for summation - ϕ_values: jnp.ndarray # Demand probabilities - p: float # Demand parameter - c: float = 0.2 # Unit cost - κ: float = 0.8 # Fixed cost - -# %% -def create_sdd_inventory_model( - ρ: float = 0.98, # Exogenous state autocorrelation parameter - ν: float = 0.002, # Exogenous state volatility parameter - n_z: int = 100, # Exogenous state discretization size - b: float = 0.97, # Exogenous state offset - K: int = 100, # Max inventory - D_MAX: int = 101, # Demand upper bound for summation - p: float = 0.6 - ) -> Model: - # Demand - def demand_pdf(p, d): - return (1 - p)**d * p - d_values = jnp.arange(D_MAX) - ϕ_values = demand_pdf(p, d_values) - # Exogenous state process - mc = qe.tauchen(n_z, ρ, ν) - z_values, Q = map(jnp.array, (mc.state_values + b, mc.P)) - # Endogenous state - x_values = jnp.arange(K + 1) # 0, 1, ..., K - return Model( - z_values=z_values, Q=Q, - x_values=x_values, d_values=d_values, ϕ_values=ϕ_values, - p=p - ) - - -# %% [markdown] -# Here's the function `B` on the right-hand side of the Bellman equation. - -# %% -@jax.jit -def B(x, z_idx, a, v, model): - """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). - """ - z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model - z = z_values[z_idx] - revenue = jnp.sum(jnp.minimum(x, d_values) * ϕ_values) - profit = revenue - c * a - κ * (a > 0) - v_R = jnp.sum(v[jnp.maximum(x - d_values, 0) + a].T * ϕ_values, axis=1) - cv = jnp.sum(v_R * Q[z_idx]) - return profit + z * cv - - -# %% [markdown] -# We need to vectorize this function so that we can use it efficiently in JAX. -# -# We apply a sequence of `vmap` operations to vectorize appropriately in each -# argument. - -# %% -B_vec_a = jax.vmap(B, in_axes=(None, None, 0, None, None)) - - -# %% -@jax.jit -def B2(x, z_idx, v, model): - """ - The function B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′). - """ - z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model - a_values = x_values # Set of possible order sizes - max_x = len(x_values) - 1 - res = B_vec_a(x, z_idx, a_values, v, model) - return jnp.where(a_values <= max_x - x, res, -jnp.inf) - - -# %% -B2_vec_z = jax.vmap(B2, in_axes=(None, 0, None, None)) -B2_vec_z_x = jax.vmap(B2_vec_z, in_axes=(0, None, None, None)) - - -# %% [markdown] -# Next we define the Bellman operator. - -# %% -@jax.jit -def T(v, model): - """The Bellman operator.""" - z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model - z_indices = jnp.arange(len(z_values)) - res = B2_vec_z_x(x_values, z_indices, v, model) - return jnp.max(res, axis=2) - - -# %% [markdown] -# The following function computes a v-greedy policy. - -# %% -@jax.jit -def get_greedy(v, model): - """Get a v-greedy policy. Returns a zero-based array.""" - z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model - z_indices = jnp.arange(len(z_values)) - res = B2_vec_z_x(x_values, z_indices, v, model) - return jnp.argmax(res, axis=2) - - -# %% [markdown] -# Here's code to solve the model using value function iteration. - -# %% -def solve_inventory_model(v_init, model, max_iter=10_000, tol=1e-6): - """Use successive_approx to get v_star and then compute greedy.""" - - def update(state): - error, i, v = state - new_v = T(v, model) - new_error = jnp.max(jnp.abs(new_v - v)) - new_i = i + 1 - return new_error, new_i, new_v - - def test(state): - error, i, v = state - return (i < max_iter) & (error > tol) - - i, error = 0, tol + 1 - initial_state = error, i, v_init - final_state = jax.lax.while_loop(test, update, initial_state) - error, i, v_star = final_state - σ_star = get_greedy(v_star, model) - return v_star, σ_star - - -# %% [markdown] -# Now let's create an instance and solve it. - -# %% -model = create_sdd_inventory_model() -z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model -n_z = len(z_values) -n_x = len(x_values) -v_init = jnp.zeros((n_x, n_z), dtype=float) - -# %% -start = time() -v_star, σ_star = solve_inventory_model(v_init, model) -jax_time_with_compile = time() - start -print("Jax compile plus execution time = ", jax_time_with_compile) - -# %% [markdown] -# Let's run again to get rid of the compile time. - -# %% -start = time() -v_star, σ_star = solve_inventory_model(v_init, model) -jax_time_without_compile = time() - start -print("Jax execution time = ", jax_time_without_compile) - -# %% -Q = np.array(Q) -z_values = np.array(z_values) -z_mc = qe.MarkovChain(Q, z_values) - - -# %% -def sim_inventories(ts_length, X_init=0): - """Simulate given the optimal policy.""" - global p, z_mc - z_idx = z_mc.simulate_indices(ts_length, init=1) - X = np.zeros(ts_length, dtype=np.int32) - X[0] = X_init - rand = np.random.default_rng().geometric(p=p, size=ts_length-1) - 1 - for t in range(ts_length-1): - X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], z_idx[t]] - return X, z_values[z_idx] - - -# %% -def plot_ts(ts_length=400, fontsize=10): - X, Z = sim_inventories(ts_length) - fig, axes = plt.subplots(2, 1, figsize=(9, 5.5)) - - ax = axes[0] - ax.plot(X, label=r"$X_t$", alpha=0.7) - ax.set_xlabel(r"$t$", fontsize=fontsize) - ax.set_ylabel("inventory", fontsize=fontsize) - ax.legend(fontsize=fontsize, frameon=False) - ax.set_ylim(0, np.max(X)+3) - - # calculate interest rate from discount factors - r = (1 / Z) - 1 - - ax = axes[1] - ax.plot(r, label=r"$r_t$", alpha=0.7) - ax.set_xlabel(r"$t$", fontsize=fontsize) - ax.set_ylabel("interest rate", fontsize=fontsize) - ax.legend(fontsize=fontsize, frameon=False) - - plt.tight_layout() - plt.show() - - -# %% -# plot_ts() - - From d1dd920dc7a4b59f06e5290806184f0e9aa09fcb Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Fri, 8 Aug 2025 13:53:42 +1000 Subject: [PATCH 3/4] minor updates --- lectures/inventory_ssd.md | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/lectures/inventory_ssd.md b/lectures/inventory_ssd.md index 93dabf09..e984a0e0 100644 --- a/lectures/inventory_ssd.md +++ b/lectures/inventory_ssd.md @@ -4,7 +4,8 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.17.2 + jupytext_version: 1.17.1 + formats: md:myst,ipynb kernelspec: display_name: Python 3 (ipykernel) language: python @@ -208,22 +209,27 @@ class Model(NamedTuple): def create_sdd_inventory_model( ρ: float = 0.98, # Exogenous state autocorrelation parameter ν: float = 0.002, # Exogenous state volatility parameter - n_z: int = 10, # Exogenous state discretization size + n_z: int = 10, # Exogenous state discretization size b: float = 0.97, # Exogenous state offset K: int = 100, # Max inventory D_MAX: int = 101, # Demand upper bound for summation p: float = 0.6 ) -> Model: + # Demand def demand_pdf(p, d): return (1 - p)**d * p + d_values = jnp.arange(D_MAX) ϕ_values = demand_pdf(p, d_values) + # Exogenous state process mc = qe.tauchen(n_z, ρ, ν) z_values, Q = map(jnp.array, (mc.state_values + b, mc.P)) + # Endogenous state x_values = jnp.arange(K + 1) # 0, 1, ..., K + return Model( z_values=z_values, Q=Q, x_values=x_values, d_values=d_values, ϕ_values=ϕ_values, @@ -242,15 +248,14 @@ def B(x, z_idx, v, model): B(x, z, a, v) = r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′) for all possible choices of a. - """ + z_values, Q, x_values, d_values, ϕ_values, p, c, κ = model z = z_values[z_idx] def _B(a): """ Returns r(x, a) + β(z) Σ_x′ v(x′) P(x, a, x′) for each a. - """ revenue = jnp.sum(jnp.minimum(x, d_values) * ϕ_values) profit = revenue - c * a - κ * (a > 0) @@ -261,6 +266,7 @@ def B(x, z_idx, v, model): a_values = x_values # Set of possible order sizes B_values = jax.vmap(_B)(a_values) max_x = len(x_values) - 1 + return jnp.where(a_values <= max_x - x, B_values, -jnp.inf) ``` @@ -269,10 +275,6 @@ We need to vectorize this function so that we can use it efficiently in JAX. We apply a sequence of `vmap` operations to vectorize appropriately in each argument. -```{code-cell} ipython3 - -``` - ```{code-cell} ipython3 B = jax.vmap(B, in_axes=(None, 0, None, None)) B = jax.vmap(B, in_axes=(0, None, None, None)) @@ -341,9 +343,12 @@ v_init = jnp.zeros((n_x, n_z), dtype=float) ```{code-cell} ipython3 start = time() v_star, σ_star = solve_inventory_model(v_init, model) -v_star.block_until_ready() # Pause until execution finishes + +# Pause until execution finishes +jax.tree_util.tree_map(lambda x: x.block_until_ready(), (v_star, σ_star)) + jax_time_with_compile = time() - start -print("Jax compile plus execution time = ", jax_time_with_compile) +print(f"compile plus execution time = {jax_time_with_compile * 1000:.6f} ms") ``` Let's run again to get rid of the compile time. @@ -351,9 +356,12 @@ Let's run again to get rid of the compile time. ```{code-cell} ipython3 start = time() v_star, σ_star = solve_inventory_model(v_init, model) -v_star.block_until_ready() # Pause until execution finishes + +# Pause until execution finishes +jax.tree_util.tree_map(lambda x: x.block_until_ready(), (v_star, σ_star)) + jax_time_without_compile = time() - start -print("Jax execution time = ", jax_time_without_compile) +print(f"execution time = {jax_time_without_compile * 1000:.6f} ms") ``` Now let's do a simulation. @@ -372,12 +380,15 @@ Here's code to simulate inventories def sim_inventories(ts_length, X_init=0): """Simulate given the optimal policy.""" global p, z_mc + z_idx = z_mc.simulate_indices(ts_length, init=1) X = np.zeros(ts_length, dtype=np.int32) X[0] = X_init rand = np.random.default_rng().geometric(p=p, size=ts_length-1) - 1 + for t in range(ts_length-1): X[t+1] = np.maximum(X[t] - rand[t], 0) + σ_star[X[t], z_idx[t]] + return X, z_values[z_idx] ``` From 292b69c094face0e1b7e66acc4eb33204b299cee Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Mon, 11 Aug 2025 11:53:06 +1000 Subject: [PATCH 4/4] minor update --- lectures/inventory_ssd.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lectures/inventory_ssd.md b/lectures/inventory_ssd.md index e984a0e0..985c99cc 100644 --- a/lectures/inventory_ssd.md +++ b/lectures/inventory_ssd.md @@ -146,7 +146,7 @@ where \sum_{d, \, z'} v(f(y, a, d), z') \phi(d) Q(z, z'). ``` -We set +We set $\beta(z) := z$ and $$ R(y, a, y')