<a href="https://colab.research.google.com/github/Leonaxi/Machine-Learning-in-Finance/blob/main/553_project_2_lsmc_with_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Due date: October 21 2022 

# Description
  In this problem we will use apply the LSMC method to price American put options. Specifically, we will replicate the result in the first row, 6th column of Table 1 in [Longstaff and Schwartz 2001](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf)

  

*  Read the introduction of the [paper](https://www.anderson.ucla.edu/documents/areas/fac/finance/least_squares.pdf).
*   We will price an american put option as described in page 126 of the aforementioned article. Read paragraphs 1 and 2 of page 126
* As we saw in class, one of the ways we can use linear regression to fit nonlinear functions is to use polynomial features. A common choice in many applications is to use the so called ``Chebyshev polynomials''. Chebyshev polynomials are defined recursively by:

\begin{equation}
T_0(x) = 1\\
T_1(x) = x\\
T_{n + 1}(x)  =  2 x T_n(x) - T_{n - 1}(x)\\
\end{equation}


# Part 1
The code below simulates the evolution of a stock price that follows a geometric brownian motion. Write a JAX version of that code. You are not allowed to use functions from other libraries. For this part, the "simulate"
function does not need to be jit compiled. As we will see, jit compiling a funciton with for loops may introduce some complications
 

In [None]:
# numpy version: dont run if want to apply jax version
#import numpy as np

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

# 1. def simulate function: 
# 1) set random seed; 
def simulate():
  np.random.seed(0)

# 2) define stock price change function;
  def step(S):
    dZ = np.random.normal(size=S.size) * np.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S  # return后终止进程，退出函数并且返回值，返回到function中（不会打印出来）需要intend, 在函数里面，之后可以调用函数；若这里改成print,则没有返回值，（但会打印出来结果）之后不能调用函数

#3) create empty list S, initialized S to S0
  S0 = np.ones(20000)
  S = S0
  S_list = [] #创建一个空的list

# 4)literate to add Sn into empty list S
  for t in range(m):
    S = step(S)
    S_list.append(S) #空的list一步一步加入S= step(S, subkeys[i])
    #print(S_list)
  
# 5) ？？
  S_array = np.stack(S_list)
  return S_array # return for simulate function，这个return是在similate function里面的

In [None]:
# jax version
import jax
import jax.numpy as jnp
from jax import random
from jax import lax

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

@jax.jit
def simulate():
  seed = 100
  key = jax.random.PRNGKey(seed)
  subkeys = random.split(key, m) ##### # of variables in random seed = subkeys


  def step(S,subkey):
    dZ = jax.random.normal(subkey, (S.size,)) * jnp.sqrt(dt) # S, subkey都是independent var: jax.random.normal(key=keyarray, shape)
    dS = r * S  * dt + σ  * S  * dZ
    S = S + dS
    return S # return S 这个function，等于上面一系列跑下来


  S0 = jnp.ones(20000)
  S = S0
  S_list = [] #创建一个空的list

  for i in range(m):
    subkey = subkeys[i]
    S = step(S,subkey)
    S_list.append(S) #空的list一步一步加入S= step(S, subkeys[i])

  
  S_array = jnp.stack(S_list)
  return S_array


print(simulate())

[[1.0017724  1.0004742  1.0034137  ... 1.0005076  1.0015863  1.0125836 ]
 [1.0007677  0.99627924 1.0043089  ... 1.0027115  1.0032715  1.0199088 ]
 [0.99718064 1.005305   1.0090697  ... 1.0071484  1.0027707  1.0259789 ]
 ...
 [1.0061305  1.0749997  1.0193487  ... 0.98685    0.9719698  1.0534912 ]
 [1.003955   1.0690762  1.0210334  ... 0.9911981  0.97647256 1.0592455 ]
 [1.0042301  1.0759467  1.0205411  ... 0.98978984 0.9772807  1.0638406 ]]


# Part 2
Write a jit compiled version of the simulate function. You may want to check out the function jax.lax.scan.


In [None]:
import jax
import jax.numpy as jnp
from jax import random
from jax import lax

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100


@jax.jit  #It's called Just-In-Time (JIT) compilation and it's basically just caching some code that you use often so that it runs faster.
def simulate():
    seed = 100
    key = jax.random.PRNGKey(seed)
    subkeys = random.split(key, m)  #####

    def step(S, subkey):
        dZ = jax.random.normal(subkey, (S.size,)) * jnp.sqrt(dt)
        dS = r * S * dt + σ * S * dZ
        S = S + dS
        return S

    S0 = jnp.ones(20000)
    
    def f(carry, x):
        return step(carry, x), step(carry, x)


    final, result = jax.lax.scan(f, S0, subkeys)
    return result

simulate()


DeviceArray([[1.0017724 , 1.0004742 , 1.0034137 , ..., 1.0005075 ,
              1.0015863 , 1.0125835 ],
             [1.0007677 , 0.99627924, 1.0043089 , ..., 1.0027114 ,
              1.0032715 , 1.0199087 ],
             [0.99718064, 1.005305  , 1.0090697 , ..., 1.0071483 ,
              1.0027707 , 1.0259788 ],
             ...,
             [1.0061305 , 1.0749997 , 1.0193487 , ..., 0.9868499 ,
              0.9719698 , 1.0534912 ],
             [1.003955  , 1.0690762 , 1.0210334 , ..., 0.991198  ,
              0.97647256, 1.0592455 ],
             [1.0042301 , 1.0759467 , 1.0205411 , ..., 0.9897897 ,
              0.9772807 , 1.0638406 ]], dtype=float32)

# Part 3
The code below computes the price of an American Put option using Least Squares Monte Carlo (LSMC). Write a JAX version of that code. You are not allowed to use functions from other libraries. Your "compute_price" function must be jit compiled.

In [None]:
# numpy version
# import numpy as np

Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# Construct polynomial features of order up to k using the
# recursive formulation
def chebyshev_basis(x, k):
    B = [np.ones(len(x)), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)

    return np.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S):
    dB = np.sqrt(Δt) * np.random.normal(size=S.size)
    S_tp1 = S + r * S * Δt + σ * S * dB
    return S_tp1


def payoff_put(S):
    return np.maximum(K - S, 0.)


# LSMC algorithm
def compute_price():
    np.random.seed(0)
    S0 = Spot * np.ones(n)
    S = [S0]

    for t in range(m):
        S_tp1 = step(S[t])
        S.append(S_tp1)

    discount = np.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount

    # Proceed recursively
    for i in range(m - 1):
        X = chebyshev_basis(scale(S[-2 - i]), order)
        Y = discounted_future_cashflows

        Θ = np.linalg.solve(X.T @ X, X.T @ Y)
        value_if_wait = X @ Θ
        value_if_exercise = payoff_put(S[-2 - i])
        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * np.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows)

    return discounted_future_cashflows.mean()


print(compute_price())
# test = compute_price(order, Spot, σ, K, r)


In [None]:
Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simualted paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12   # Polynmial order
Δt = T / m  # interval between two exercise dates


# Construct polynomial features of order up to k using the
# recursive formulation
def chebyshev_basis(x, k):
    B = [jnp.ones(len(x)), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)

    return jnp.column_stack(B)


# scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b


# simulates one step of the stock price evolution
def step(S, subkey):
    dB = jnp.sqrt(Δt) * jax.random.normal(subkey, (S.size,))
    S_tp1 = S + r * S * Δt + σ * S * dB
    return S_tp1


def payoff_put(S):
    return jnp.maximum(K - S, 0.)


# LSMC algorithm
@jax.jit
def compute_price():
    seed = 100
    key = jax.random.PRNGKey(seed)
    subkeys = random.split(key, m)

    S0 = Spot * np.ones(n)
    S = [S0]

    def f(carry, x):
       return step(carry, x), step(carry, x)

    final, result = jax.lax.scan(f, S0, subkeys)
    #return result
    S = jnp.append(S0,result)
    S = S.reshape(m+1,n)
    
    
    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])  # 最後一期的是否執行的現金流量
    discounted_future_cashflows = value_if_exercise * discount

    # Proceed recursively
    for i in range(m - 1):
        X = chebyshev_basis(scale(S[-2 - i]), order) # 12次方,S[-2-i]為從倒數第二期的股價開始做回歸
        Y = discounted_future_cashflows

        Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
        value_if_wait = X @ Θ # 得到P(S)
        value_if_exercise = payoff_put(S[-2 - i]) # 此價格為不exercise的價格
        exercise = value_if_exercise >= value_if_wait # 若於當期執行
        discounted_future_cashflows = discount * jnp.where(
            exercise,
            value_if_exercise, 
            discounted_future_cashflows)  # true, 表示有執行

    return discounted_future_cashflows.mean()



print(compute_price())
# test = compute_price(order, Spot, σ, K, r)

4.4752336
