Due date: October 9 2023

# Description
  In this problem we will use apply the LSMC method to price American put options. Specifically, we will replicate the result in the first row, 6th column of Table 1 in [Longstaff and Schwartz 2001](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf)

  

*  Read the introduction of the [paper](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf).
*   We will price an american put option as described in page 126 of the aforementioned article. Read paragraphs 1 and 2 of page 126
* As we saw in class, one of the ways we can use linear regression to fit nonlinear functions is to use polynomial features. A common choice in many applications is to use the so called ``Chebyshev polynomials''. Chebyshev polynomials are defined recursively by:

\begin{equation}
T_0(x) = 1\\
T_1(x) = x\\
T_{n + 1}(x)  =  2 x T_n(x) - T_{n - 1}(x)\\
\end{equation}


# Part 1
The code below simulates the evolution of a stock price that follows a geometric brownian motion. Write a JAX version of that code. You are not allowed to use functions from other libraries. For this part, the "simulate"
function does not need to be jit compiled. As we will see, jit compiling a funciton with for loops may introduce some complications.





In [None]:
import jax.numpy as jnp
import jax
from jax import lax

In [None]:
%%time


# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

def simulate():
  key = jax.random.PRNGKey(seed = 0)

  @jax.jit
  def step(S):
    nonlocal key
    key, subkey = jax.random.split(key)
    dZ = jax.random.normal(key=subkey, shape=S.shape) * jnp.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S

  S0 = jnp.ones(200000)
  S = S0
  S_list = []
  for t in range(m):
    S = step(S)
    S_list.append(S)

  S_array = jnp.stack(S_list)
  return S_array


CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 8.58 µs


In [None]:
%%time
simulated_paths = simulate()

CPU times: user 198 ms, sys: 8.13 ms, total: 206 ms
Wall time: 195 ms


# Part 2
Write a jit compiled version of the simulate function. You may want to check out the function jax.lax.scan.

In [None]:
# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100
def simulate():
  key = jax.random.PRNGKey(seed = 0)
  key, subkey = jax.random.split(key)
  @jax.jit
  def step(S, dZ):
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S, S

  S0 = jnp.ones(20000)
  S = S0
  RD_normal = jax.random.normal(key=subkey, shape=(m, S.shape[0])) * jnp.sqrt(dt)
  S_list = []
  result_init = jnp.zeros(S.shape)

  _, result = jax.lax.scan(step, S, RD_normal)
  return result


In [None]:
%%time
simulated_paths = simulate()

CPU times: user 101 ms, sys: 2.82 ms, total: 104 ms
Wall time: 98.2 ms


# Part 3
The code below computes the price of an American Put option using Least Squares Monte Carlo (LSMC). Write a JAX version of that code. You are not allowed to use functions from other libraries. Your "compute_price" function must be jit compiled.

In [None]:
Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# Construct polynomial features of order up to k using the
# recursive formulation

def chebyshev_basis(x, k):
    B_init = jnp.ones(len(x)), x
    def B(carry, _):
      b1, b2 = carry
      bn = 2 * x * b2 - b1
      return (b2, bn), bn
    # print("B_init", B_init)
    carry, computed_B = jax.lax.scan(B, B_init, xs=None, length = k-2)
    computed_B = jnp.insert(computed_B, 0, values=x, axis=0)
    computed_B = jnp.insert(computed_B, 0, values=jnp.ones(len(x)), axis=0)
    return computed_B.T

# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S, dZ):
  dS = r * S  * Δt + σ  * S  * dZ
  S = S + dS
  return S, S


def payoff_put(S):
    return jnp.maximum(K - S, 0.)


# LSMC algorithm
@jax.jit
def compute_price():
    key = jax.random.PRNGKey(seed = 1)
    key, subkey = jax.random.split(key)

    S0 = jnp.ones(n) * Spot
    S = S0
    RD_normal = jax.random.normal(key=subkey, shape=(m, S.shape[0])) * jnp.sqrt(Δt)
    S_list = []
    result_init = jnp.zeros(S.shape)

    _, S = jax.lax.scan(step, S, RD_normal)

    discount = jnp.exp(-r * Δt)
    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount
    def iteration_replacement(Y, S):
      X = chebyshev_basis(scale(S), order)
      Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
      value_if_wait = X @ Θ
      value_if_exercise = payoff_put(S)
      exercise = value_if_exercise >= value_if_wait
      new_Y = discount * jnp.where(
            exercise,
            value_if_exercise,
            Y)
      return new_Y, new_Y
    Y = discounted_future_cashflows
    discounted_future_cashflows, _ = jax.lax.scan(iteration_replacement, Y, S[:-1], reverse=True)

    return discounted_future_cashflows.mean()

In [None]:
%%time
print(compute_price())

4.465737
CPU times: user 30.9 ms, sys: 971 µs, total: 31.9 ms
Wall time: 33.5 ms
