# Description
  In this problem we will use apply the Least Squares Monte Carlo method to price American put options.

  
One of the ways we can use linear regression to fit nonlinear functions is to use polynomial features. A common choice in many applications is to use the so called ``Chebyshev polynomials''. Chebyshev polynomials are defined recursively by:

\begin{equation}
T_0(x) = 1\\
T_1(x) = x\\
T_{n + 1}(x)  =  2 x T_n(x) - T_{n - 1}(x)\\
\end{equation}


In [None]:
import jax.numpy as jnp
import jax.random as random

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100


def simulate():
  key = random.PRNGKey(0)

  def step(S, key):
    key, subkey = random.split(key)
    dZ = random.normal(subkey, shape=S.shape) * jnp.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S, key

  S0 = np.ones(20000)
  S = S0
  S_list = []

  for t in range(m):
    S, key = step(S, key)
    S_list.append(S)

  S_array = jnp.stack(S_list)
  return S_array


### Jit compiled version of the simulate function

In [None]:
import jax
import jax.numpy as jnp
import jax.random as random

# Parameters
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

@jax.jit
def simulate():
    key = random.PRNGKey(0)

    def step(S, key):
        key, subkey = random.split(key)
        dZ = random.normal(subkey, shape=S.shape) * jnp.sqrt(dt)
        dS = r * S * dt + σ * S * dZ
        S = S + dS
        return S, key

    # Initial state
    S0 = jnp.ones(20000)  # Initial prices for 20,000 assets

    # Define the scan function to iterate over time steps
    def scan_fn(S, key):
        S, key = step(S, key)
        return S, S

    # Run the simulation using lax.scan
    S_array, _ = jax.lax.scan(scan_fn, S0, jnp.arange(m))
    return S_array


### The code below computes the price of an American Put option using Least Squares Monte Carlo (LSMC).

In [None]:
import jax
import jax.numpy as jnp
from jax import jit, vmap
from functools import partial

Spot, σ, K, r = 36, 0.2, 40, 0.06
n, m, T, order = 100000, 50, 1, 12
Δt = T / m

@partial(jit, static_argnums=(1,))
def chebyshev_basis(x, k):
    B = [jnp.ones_like(x), x]
    for n in range(2, k):
        B.append(2 * x * B[-1] - B[-2])
    return jnp.column_stack(B)

@jit
def scale(x):
    xmin, xmax = x.min(), x.max()
    a = 2 / (xmax - xmin)
    return a * x + (1 - a * xmax)

@jit
def step(S, key):
    dB = jnp.sqrt(Δt) * jax.random.normal(key, S.shape)
    return S * (1 + r * Δt + σ * dB)

@jit
def payoff_put(S):
    return jnp.maximum(K - S, 0.)

@jit
def compute_price():
    key = jax.random.PRNGKey(0)
    S = Spot * jnp.ones(n)

    keys = jax.random.split(key, m)
    S = jax.lax.scan(lambda s, k: (step(s, k), s), S, keys)[1]

    discount = jnp.exp(-r * Δt)
    discounted_future_cashflows = payoff_put(S[-1]) * discount

    def backward_step(carry, S_t):
        X = chebyshev_basis(scale(S_t), order)
        Θ = jnp.linalg.solve(X.T @ X, X.T @ carry)
        value_if_wait = X @ Θ
        value_if_exercise = payoff_put(S_t)
        exercise = value_if_exercise >= value_if_wait
        return jnp.where(exercise, value_if_exercise, carry) * discount, None

    discounted_future_cashflows, _ = jax.lax.scan(
        backward_step, discounted_future_cashflows, S[-2::-1])

    return discounted_future_cashflows.mean()

print(compute_price())

4.463316
