In [None]:
pip install optax dm-haiku

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 29.0 MB/s 
[?25hCollecting dm-haiku
  Downloading dm_haiku-0.0.8-py3-none-any.whl (350 kB)
[K     |████████████████████████████████| 350 kB 58.2 MB/s 
Collecting chex>=0.0.4
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[K     |████████████████████████████████| 85 kB 4.0 MB/s 
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Installing collected packages: jmp, chex, optax, dm-haiku
Successfully installed chex-0.1.5 dm-haiku-0.0.8 jmp-0.0.2 optax-0.1.3


Due date: October 21 2022 

# Description
  In this problem we will use apply the LSMC method to price American put options. Specifically, we will replicate the result in the first row, 6th column of Table 1 in [Longstaff and Schwartz 2001](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf)

  

*  Read the introduction of the [paper](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf).
*   We will price an american put option as described in page 126 of the aforementioned article. Read paragraphs 1 and 2 of page 126
* As we saw in class, one of the ways we can use linear regression to fit nonlinear functions is to use polynomial features. A common choice in many applications is to use the so called ``Chebyshev polynomials''. Chebyshev polynomials are defined recursively by:

\begin{equation}
T_0(x) = 1\\
T_1(x) = x\\
T_{n + 1}(x)  =  2 x T_n(x) - T_{n - 1}(x)\\
\end{equation}


# Part 1
The code below simulates the evolution of a stock price that follows a geometric brownian motion. Write a JAX version of that code. You are not allowed to use functions from other libraries. For this part, the "simulate"
function does not need to be jit compiled. As we will see, jit compiling a funciton with for loops may introduce some complications
 

In [None]:
import jax
import jax.numpy as jnp
import optax
from jax import random
from jax import device_put

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100


def simulate():
  key = random.PRNGKey(0)

  def step(S):
    dz=random.normal(key,(S.size,)) * jnp.sqrt(dt)
    dz=device_put(dz)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S

  S0 = jnp.ones(20000)
  S = S0
  S_list = []

  for t in range(m):
    S = step(S)
    S_list.append(S)

  S_array = jax.numpy.stack(S_list)
  return S_array




# Part 2
Write a jit compiled version of the simulate function. You may want to check out the function jax.lax.scan.


In [None]:
import jax
import jax.numpy as jnp
from jax import random
from jax import jit

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

def simulate(S):
  return  S+ r * S  * dt + σ  * S  * random.normal(key,(S.size,)) * jnp.sqrt(dt)
key = random.PRNGKey(0)
S0 = jnp.ones((20000,20000)) 
S = S0
fast_s = jax.jit(simulate)


# Part 3
The code below computes the price of an American Put option using Least Squares Monte Carlo (LSMC). Write a JAX version of that code. You are not allowed to use functions from other libraries. Your "compute_price" function must be jit compiled.

In [None]:
import jax
import jax.numpy as jnp
from jax import random
from jax import jit
from jax import device_put

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# Construct polynomial features of order up to k using the
# recursive formulation
def chebyshev_basis(x, k):
    B = [jnp.ones(len(x)), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)

    return jnp.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S):
    key = random.PRNGKey(0)
    dB = jnp.sqrt(Δt) * random.normal(key,(S.size, ))
    S_tp1 = S + r * S * Δt + σ * S * dB
    return S_tp1


def payoff_put(S):
    return jax.lax.max(K - S, 0.)


# LSMC algorithm
def compute_price():
    S0 = Spot * jnp.ones(n)
    S = [S0]

    for t in range(m):
        S_tp1 = step(S[t])
        S.append(S_tp1)

    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount

    # Proceed recursively
    for i in range(m - 1):
        X = chebyshev_basis(scale(S[-2 - i]), order)
        Y = discounted_future_cashflows

        Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
        value_if_wait = X @ Θ
        value_if_exercise = payoff_put(S[-2 - i])
        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * jnp.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows)

    return discounted_future_cashflows.mean()


print(compute_price())
# test = compute_price(order, Spot, σ, K, r)


import jax
from jax import random
import jax.numpy as jnp

import numpy as np


Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# # Construct polynomial features of order up to k using the
# # recursive formulation
# @jax.jit
def chebyshev_basis(x):

  def cb_step(B2,x):
    Bn = 2 * x * B2[1] - B2[0]
    B2_new = jnp.vstack((B2[1],Bn))
    return B2_new, Bn
 
  B0 = jnp.array([jnp.ones(len(x)), x])
  z = jnp.zeros((order-2,len(x)))
  xs = z.at[:].add(x)
  B = jnp.transpose(jnp.vstack((B0,jax.lax.scan(cb_step, B0, xs)[1])))

  return B


# scales x to be in the interval(-1, 1)
# @jax.jit
def scale(x):
  xmin = x.min()
  xmax = x.max()
  a = 2 / (xmax - xmin)
  b = 1 - a * xmax
  return a * x + b


# @jax.jit
def step(S,rd):
  dB = jnp.sqrt(Δt) * rd
  S_tp1 = S + r * S * Δt + σ * S * dB
  return S_tp1,S_tp1


# @jax.jit
def payoff_put(S):
  return jnp.maximum(K - S, 0.)


# LSMC algorithm
# @jax.jit
def compute_price():

  key=random.PRNGKey(2)
  S0 = Spot * jnp.ones(n)
  rd = random.normal(key,shape=(m,n))
  S1 = jax.lax.scan(step,S0,rd)[1]
  S = jnp.vstack((S0,S1))

  discount = jnp.exp(-r * Δt)

  # Very last date
  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount

  # Proceed 
  def proceed(dfc,s):
    X = chebyshev_basis(scale(s))
    Y = dfc

    Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
    value_if_wait = X @ Θ
    value_if_exercise = payoff_put(s)
    exercise = value_if_exercise >= value_if_wait
    dfc = discount * jnp.where(
        exercise,
        value_if_exercise,
        dfc)
    return dfc, dfc

  # P = jax.lax.scan(proceed,discounted_future_cashflows,S[-2:-m-1:-1])[0]
  P = jax.lax.scan(proceed,discounted_future_cashflows,S[-2:-m-1:-1])[0]
  print(P[10:30])
  print(P.mean())

###################################################
  value_if_exercise = payoff_put(S[-1])
  discounted_future_cashflows = value_if_exercise * discount

  # Proceed recursively
  for i in range(m - 1):
      # X = chebyshev_basis(scale(S[-2 - i]))
      # Y = discounted_future_cashflows

      # Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
      # value_if_wait = X @ Θ
      # value_if_exercise = payoff_put(S[-2 - i])
      # exercise = value_if_exercise >= value_if_wait
      # discounted_future_cashflows = discount * np.where(
      #     exercise,
      #     value_if_exercise,
      #     discounted_future_cashflows)
      
      discounted_future_cashflows,_=proceed(discounted_future_cashflows,S[-2 - i])

  print(discounted_future_cashflows[10:30])
  print(discounted_future_cashflows.mean())
  print((discounted_future_cashflows-P).sum())
  
####################################################


  return P.mean()


print(compute_price())


10.791309
