In [1]:
import jax
import jax.numpy as jnp

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

def simulate():
    # Set random seed for reproducibility
    key = jax.random.PRNGKey(0)

    def step(S, key):
        key, subkey = jax.random.split(key)
        dZ = jax.random.normal(subkey, shape=S.shape) * jnp.sqrt(dt)
        dS = r * S * dt + σ * S * dZ
        S = S + dS
        return S, key

    S0 = jnp.ones(20000)
    S = S0
    S_list = []

    key = jax.random.PRNGKey(0)
    for t in range(m):
        S, key = step(S, key)
        S_list.append(S)

    S_array = jnp.stack(S_list)
    return S_array

# Run the simulation
S_array = simulate()


In [2]:
import jax
import jax.numpy as jnp
from jax import jit, lax

# Data
σ = 0.04
r = 0.01
K = 35

# Design choice
dt = 0.01
m = 100

def simulate():
    # Set random seed for reproducibility
    key = jax.random.PRNGKey(0)

    def step(carry, _):
        S, key = carry
        key, subkey = jax.random.split(key)
        dZ = jax.random.normal(subkey, shape=S.shape) * jnp.sqrt(dt)
        dS = r * S * dt + σ * S * dZ
        S = S + dS
        return (S, key), S

    S0 = jnp.ones(20000)
    key = jax.random.PRNGKey(0)
    (final_S, _), S_list = lax.scan(step, (S0, key), None, length=m)
    S_array = jnp.stack(S_list)
    return S_array

# JIT compile the simulate function
simulate_jit = jit(simulate)

# Run the JIT compiled simulation
S_array = simulate_jit()


In [3]:
import jax
import jax.numpy as jnp
from jax import jit, lax

# Data
Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simulated paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12  # Polynomial order (reduced to avoid overfitting)
Δt = T / m  # interval between two exercise dates

# Construct polynomial features of order up to k using the recursive formulation
def chebyshev_basis(x, k):
    B = [jnp.ones(len(x)), x]
    for n in range(2, k):
        Bn = 2 * x * B[n - 1] - B[n - 2]
        B.append(Bn)
    return jnp.column_stack(B)

# scales x to be in the interval (-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b

# simulates one step of the stock price evolution
def step(S, key):
    key, subkey = jax.random.split(key)
    dB = jnp.sqrt(Δt) * jax.random.normal(subkey, shape=S.shape)
    S_tp1 = S * jnp.exp((r - 0.5 * σ**2) * Δt + σ * dB)  # Adjusted to use geometric Brownian motion
    return S_tp1, key

def payoff_put(S):
    return jnp.maximum(K - S, 0.)

# LSMC algorithm
@jit
def compute_price():
    key = jax.random.PRNGKey(42) 
    S0 = Spot * jnp.ones(n)
    S = [S0]

    for t in range(m):
        S_tp1, key = step(S[t], key)
        S.append(S_tp1)

    discount = jnp.exp(-r * Δt)

    # Very last date
    value_if_exercise = payoff_put(S[-1])
    discounted_future_cashflows = value_if_exercise * discount

    # Proceed recursively
    for i in range(m - 1):
        X = chebyshev_basis(scale(S[-2 - i]), order)
        Y = discounted_future_cashflows

        # Adding L2 regularization to avoid overfitting
        reg = 1e-6 * jnp.eye(X.shape[1])
        Θ = jnp.linalg.solve(X.T @ X + reg, X.T @ Y)
        value_if_wait = X @ Θ
        value_if_exercise = payoff_put(S[-2 - i])
        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * jnp.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows)

    # Apply a final discounting step to improve precision
    return discounted_future_cashflows.mean() * jnp.exp(-r * Δt)

# Run the JIT compiled LSMC pricing
price = compute_price()
print(price)


4.4707355


In [4]:
import jax
import jax.numpy as jnp
from jax import jit, lax
 
# Data
Spot = 36   # stock price
σ = 0.2     # stock volatility
K = 40      # strike price
r = 0.06    # risk free rate
n = 100000  # Number of simulated paths
m = 50      # number of exercise dates
T = 1       # maturity
order = 12  # Polynomial order
Δt = T / m  # interval between two exercise dates
 
# Construct polynomial features of order up to k using the recursive formulation
def chebyshev_basis(x, k):
    B = [jnp.ones_like(x), x]
    for n in range(2, k):
        Bn = 2 * x * B[-1] - B[-2]
        B.append(Bn)
    return jnp.column_stack(B)
 
# scales x to be in the interval (-1, 1)
def scale(x):
    xmin = x.min()
    xmax = x.max()
    a = 2 / (xmax - xmin)
    b = 1 - a * xmax
    return a * x + b
 
# simulates one step of the stock price evolution
def step(S, key):
    dB = jnp.sqrt(Δt) * jax.random.normal(key, shape=S.shape)
    S_tp1 = S + r * S * Δt + σ * S * dB
    return S_tp1
 
# LSMC algorithm enhanced with jax.lax.scan for better performance
@jit
def compute_price():
    key = jax.random.PRNGKey(0)
    S0 = Spot * jnp.ones(n)
 
    # Using jax.lax.scan to simulate the stock paths
    def body_fn(S, x):
        return step(S, x), S
 
    keys = jax.random.split(key, m)
    _, S_list = lax.scan(body_fn, S0, keys)
    
    discount = jnp.exp(-r * Δt)
 
    # Processing the very last date
    value_if_exercise = jnp.maximum(K - S_list[-1], 0.)
    discounted_future_cashflows = value_if_exercise * discount
 
    # Recursively calculate the option price
    for i in range(m - 1):
        X = chebyshev_basis(scale(S_list[-2 - i]), order)
        Y = discounted_future_cashflows
 
        Θ = jnp.linalg.solve(X.T @ X, X.T @ Y)
        value_if_wait = X @ Θ
        value_if_exercise = jnp.maximum(K - S_list[-2 - i], 0.)
        exercise = value_if_exercise >= value_if_wait
        discounted_future_cashflows = discount * jnp.where(
            exercise,
            value_if_exercise,
            discounted_future_cashflows)
 
    return discounted_future_cashflows.mean()
 
# Run the JIT compiled LSMC pricing
price = compute_price()
print(price)

4.4637995
