In [47]:
from functools import partial

import jax.numpy as jnp
import jax

jax.config.update("jax_enable_x64", True)


In [48]:
# For pricing european options, the Heston model has a semi-closed form solution using the characteristic function (Heston, 73).

In [90]:
def heston_characteristic_function(phi, S0, v0, T, r, q, kappa, theta, sigma, rho, lamda, j):
    # based on (Heston, 73) equation (17)
    if j == 1:
        u = 0.5
        b = kappa + lamda - rho * sigma
    else:
        u = -0.5
        b = kappa + lamda
    
    a = kappa * theta
    d = jnp.sqrt((rho * sigma * phi * 1j - b)**2 - sigma**2 * (2 * u * phi * 1j - phi**2))
    g = (b - rho * sigma * phi * 1j + d) / (b - rho * sigma * phi * 1j - d)
    C = (r - q) * phi * 1j * T + (a / sigma**2) \
            * ((b - rho * sigma * phi * 1j + d) * T - 2 * jnp.log((1 - g * jnp.exp(d * T))/(1 - g)))
    D = (b - rho * sigma * phi * 1j + d) / sigma**2 * ((1 - jnp.exp(d * T)) / (1 - g * jnp.exp(d * T)))
    
    return jnp.exp(C + D * v0 + 1j * phi * jnp.log(S0))

def heston_probability(S0, v0, K, r, q, T, kappa, theta, sigma, rho, lmbda, j):
    # based on (Heston, 73) equation (18)
    charac_fn = lambda phi: heston_characteristic_function(phi, S0, v0, T, r, q, kappa, theta, sigma, rho, lmbda, j) 
    integrand = lambda phi: jnp.real(jnp.exp(-1j * phi * jnp.log(K)) * charac_fn(phi) / (1j * phi))    

    x = jnp.linspace(1e-7, 100, 100)
    y = integrand(x)
    integral = jax.scipy.integrate.trapezoid(y, x)
    return 0.5 + (1 / jnp.pi) * integral
    
def heston_european_option_call_price(S0, v0, K, T, r, q, kappa, theta, sigma, rho, lamda):
    # based on (Heston, 73) equation (10)
    p1 = heston_probability(S0, v0, K, r, q, T, kappa, theta, sigma, rho, lamda, 1)
    p2 = heston_probability(S0, v0, K, r, q, T, kappa, theta, sigma, rho, lamda, 2)
    return S0 * jnp.exp(-q*T) * p1 - K * jnp.exp(-r*T) * p2


In [91]:
# Parameters
# based on Case 3 of: https://papers.ssrn.com/sol3/papers.cfm?abstract_id=1718102
T = 1         # maturity
S0 = 100.0      # spot price
K = 100.0       # strike price
r = 0.00     # risk-free interest rate
q = 0.00     # dividend rate
v0 = 0.09    # initial variance
rho = -0.3   # correlation between Brownian motions
kappa = 1.0    # mean reversion rate
theta = 0.09 # Long term mean of variance
sigma = 1.0  # volatility of volatility
lamda = 0.0    # market price of volatility risk

# Option values
option_price = heston_european_option_call_price(S0, v0, K, T, r, q, kappa, theta, sigma, rho, lmbda)

print(f"European call option price: {option_price:.6}")

European call option price: 9.77379


In [92]:
# Option greeks
dS = jax.grad(partial(heston_european_option_call_price, K=K, T=T, r=r, q=q, kappa=kappa, theta=theta, sigma=sigma, rho=rho, lamda=lamda))(S0, v0)
dV = jax.grad(partial(heston_european_option_call_price, K=K, T=T, r=r, q=q, kappa=kappa, theta=theta, sigma=sigma, rho=rho, lamda=lamda), 1)(S0, v0)
print(f"dS: {dS:.6}")
print(f"dV: {dV:.6}")

dS: 0.608057
dV: 39.4946


In [93]:
# Option second-order greeks
dSdS = jax.hessian(partial(heston_european_option_call_price, K=K, T=T, r=r, q=q, kappa=kappa, theta=theta, sigma=sigma, rho=rho, lamda=lamda))(S0, v0)
dVdV = jax.hessian(partial(heston_european_option_call_price, K=K, T=T, r=r, q=q, kappa=kappa, theta=theta, sigma=sigma, rho=rho, lamda=lamda), 1)(S0, v0)
print(f"dSdS: {dSdS:.6}")
print(f"dVdV: {dVdV:.6}")

dSdS: 0.0215111
dVdV: -109.28
