<a href="https://colab.research.google.com/github/QuantEcon/lecture-python.myst/blob/update_markov_asset/lectures/markov_asset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import matplotlib.pyplot as plt
import quantecon as qe
import jax
import jax.numpy as jnp
from jax.numpy.linalg import eigvals, solve
from jax.experimental import checkify
from typing import NamedTuple

In [3]:
class MarkovChain(NamedTuple):
    """
    A class that stores the primitives of a Markov chain.
    Parameters
    ----------
    P : jnp.ndarray
        Transition matrix
    state_values : jnp.ndarray
        The values associated with each state
    """
    P: jax.Array
    state_values: jax.Array


class AssetPriceModel(NamedTuple):
    """
    A class that stores the primitives of the asset pricing model.

    Parameters
    ----------
    mc : MarkovChain
        Contains the transition matrix and set of state values
    G : jax.Array
        The vector form of the function mapping states to growth rates
    β : float
        Discount factor
    γ : float
        Coefficient of risk aversion
    """
    mc: MarkovChain
    G: jax.Array
    β: float
    γ: float


def create_ap_model(g=jnp.exp, β=0.96, γ=2.0):
    """Create an AssetPriceModel class using standard Markov chain."""
    n, ρ, σ = 25, 0.9, 0.02
    qe_mc = qe.tauchen(n, ρ, σ)
    P = jnp.array(qe_mc.P)
    state_values = jnp.array(qe_mc.state_values)
    G = g(state_values)
    mc = MarkovChain(P=P, state_values=state_values)

    return AssetPriceModel(mc=mc, G=G, β=β, γ=γ)


def create_customized_ap_model(mc: MarkovChain, g=jnp.exp, β=0.96, γ=2.0):
    """Create an AssetPriceModel class using a customized Markov chain."""
    G = g(mc.state_values)
    return AssetPriceModel(mc=mc, G=G, β=β, γ=γ)


def test_stability(Q, β):
    """Stability test for a given matrix Q."""
    sr = jnp.max(jnp.abs(eigvals(Q)))
    checkify.check(
        sr < 1 / β,
        "Spectral radius condition failed with radius = {sr}", sr=sr
        )
    return sr


def tree_price(ap):
    """
    Computes the price-dividend ratio of the Lucas tree.

    Parameters
    ----------
    ap: AssetPriceModel
        An instance of AssetPriceModel containing primitives

    Returns
    -------
    v : array_like(float)
        Lucas tree price-dividend ratio
    """
    # Simplify names, set up matrices
    β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G
    J = P * G**(1 - γ)

    # Make sure that a unique solution exists
    test_stability(J, β)

    # Compute v
    n = J.shape[0]
    I = jnp.identity(n)
    Ones = jnp.ones(n)
    v = solve(I - β * J, β * J @ Ones)

    return v

# Wrap the function to be safely jitted
tree_price_jit = jax.jit(checkify.checkify(tree_price))

In [4]:
def consol_price(ap, ζ):
    """
    Computes price of a consol bond with payoff ζ

    Parameters
    ----------
    ap: AssetPriceModel
        An instance of AssetPriceModel containing primitives

    ζ : scalar(float)
        Coupon of the console

    Returns
    -------
    p : array_like(float)
        Console bond prices
    """
    # Simplify names, set up matrices
    β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G
    M = P * G**(- γ)

    # Make sure that a unique solution exists
    test_stability(M, β)

    # Compute price
    n = M.shape[0]
    I = jnp.identity(n)
    Ones = jnp.ones(n)
    p = solve(I - β * M, β * ζ * M @ Ones)

    return p

# Wrap the function to be safely jitted
consol_price_jit = jax.jit(checkify.checkify(consol_price))

In [5]:
def call_option(ap, ζ, p_s, ϵ=1e-7):
    """
    Computes price of a call option on a consol bond.

    Parameters
    ----------
    ap: AssetPriceModel
        An instance of AssetPriceModel containing primitives

    ζ : scalar(float)
        Coupon of the console

    p_s : scalar(float)
        Strike price

    ϵ : scalar(float), optional(default=1e-7)
        Tolerance for infinite horizon problem

    Returns
    -------
    w : array_like(float)
        Infinite horizon call option prices

    """
    # Simplify names, set up matrices
    β, γ, P, G = ap.β, ap.γ, ap.mc.P, ap.G
    M = P * G**(- γ)

    # Make sure that a unique consol price exists
    test_stability(M, β)

    # Compute option price
    p = consol_price(ap, ζ)
    err.throw()
    n = M.shape[0]
    w = jnp.zeros(n)
    error = ϵ + 1

    def step(state):
        w, _ = state
        # Maximize across columns
        w_new = jnp.maximum(β * M @ w, p - p_s)
        # Find maximal difference of each component and update
        error_new = jnp.amax(jnp.abs(w - w_new))
        return (w_new, error_new)

    # Check whether converged
    def cond(state):
        _, error = state
        return error > ϵ

    final_w, _ = jax.lax.while_loop(cond, step, (w, error))

    return final_w

call_option_jit = jax.jit(checkify.checkify(call_option))

In [6]:
ap = create_ap_model(β=0.9)
ζ = 1.0
strike_price = 40

x = ap.mc.state_values
err, p = consol_price_jit(ap, ζ)
err.throw()
err, w = call_option_jit(ap, ζ, strike_price)
err.throw()
