Part 1

In [1]:
import jax
import jax.numpy as np
from jax import random

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100


def step(S, key):
    dZ = random.normal(key, shape=S.shape) * np.sqrt(dt)
    dS = r * S * dt + σ * S * dZ
    S = S + dS
    return S

def simulate():
    key = random.PRNGKey(0)
    S0 = np.ones(20000)
    S_list = []

    for t in range(m):
        key, subkey = random.split(key)
        S = step(S0, subkey)
        S_list.append(S)
        S0 = S

    S_array = np.stack(S_list)
    return S_array

D:\anaconda\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
D:\anaconda\lib\site-packages\numpy\.libs\libopenblas.PYQHXLVVQ7VESDPUVUADXEVJOBGHJPAY.gfortran-win_amd64.dll


In [2]:
simulate()

DeviceArray([[0.9990368 , 0.99845755, 1.0009075 , ..., 0.9949413 ,
              0.9968931 , 0.99635816],
             [0.99916214, 0.99233675, 0.9988377 , ..., 0.9934248 ,
              0.9962843 , 1.0000815 ],
             [1.0038999 , 0.99782735, 0.9933294 , ..., 0.9843731 ,
              0.99664205, 0.99447113],
             ...,
             [0.98081917, 1.0411566 , 0.98559856, ..., 1.0382012 ,
              1.0340495 , 0.9722904 ],
             [0.9760958 , 1.0460585 , 0.9811195 , ..., 1.0345007 ,
              1.0348959 , 0.977704  ],
             [0.97404456, 1.0439519 , 0.9716244 , ..., 1.0376757 ,
              1.034841  , 0.974258  ]], dtype=float32)

Part 2

In [4]:
import jax
import jax.numpy as np
from jax import random,jit

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100


def step(S, key):
    dZ = random.normal(key, shape=S.shape) * np.sqrt(dt)
    dS = r * S * dt + σ * S * dZ
    S = S + dS
    return S

def simulate():
    key = random.PRNGKey(0)
    S0 = np.ones(20000)
    S_list = []

    for t in range(m):
        key, subkey = random.split(key)
        S = step(S0, subkey)
        S_list.append(S)
        S0 = S

    S_array = np.stack(S_list)
    return S_array

simulate_jit = jit(simulate)

In [5]:
simulate_jit()

DeviceArray([[0.9990368 , 0.99845755, 1.0009075 , ..., 0.9949413 ,
              0.9968931 , 0.99635816],
             [0.99916214, 0.99233675, 0.9988377 , ..., 0.9934248 ,
              0.9962843 , 1.0000815 ],
             [1.0038999 , 0.99782735, 0.9933294 , ..., 0.9843731 ,
              0.99664205, 0.99447113],
             ...,
             [0.98081917, 1.0411566 , 0.98559856, ..., 1.0382012 ,
              1.0340495 , 0.9722904 ],
             [0.9760958 , 1.0460585 , 0.9811195 , ..., 1.0345007 ,
              1.0348959 , 0.977704  ],
             [0.97404456, 1.0439519 , 0.9716244 , ..., 1.0376757 ,
              1.034841  , 0.974258  ]], dtype=float32)

Part 3

In [7]:
import jax
import jax.numpy as jnp
from jax import random

Spot = 36 # stock price
σ = 0.2 # stock volatility
K = 40 # strike price
r = 0.06 # risk free rate
n = 100000 # Number of simulated paths
m = 50 # number of exercise dates
T = 1 # maturity
order = 12 # Polynomial order
Δt = T / m # interval between two exercise dates

# Construct polynomial features of order up to k using the recursive formulation
def chebyshev_basis(x, k):
    B = [jnp.ones_like(x), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)
    return jnp.column_stack(B)

# Scales x to be in the interval(-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b

# Simulates one step of the stock price evolution
def step(S, Δt, key):
    dB = jnp.sqrt(Δt) * random.normal(key, shape=S.shape)
    S_tp1 = S + r * S * Δt + σ * S * dB
    return S_tp1

def payoff_put(S, K):
    return jnp.maximum(K - S, 0.)

# LSMC algorithm
@jax.jit
def compute_price():
    key = random.PRNGKey(0)
    S0 = Spot * jnp.ones(n)
    S = [S0]

    for t in range(m):
        key, subkey = random.split(key)
        S_tp1 = step(S[t], Δt, subkey)
        S.append(S_tp1)

    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1], K)
    discounted_future_cashflows = value_if_exercise * discount

    # Proceed recursively
    for i in range(m - 1):
        X = chebyshev_basis(scale(S[-2 - i]), order)
        Y = discounted_future_cashflows

        Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
        value_if_wait = X @ Θ
        value_if_exercise = payoff_put(S[-2 - i], K)

        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * jnp.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows
        )

    return discounted_future_cashflows.mean()

print(compute_price())


4.475946
