In [1]:
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax.scipy.stats import norm
from jax import grad, jit, vmap
from jax import random as jrandom
import jax

from jax import config
config.update("jax_enable_x64", True)

from typing import Sequence, Tuple, Union, Optional
from jaxtyping import Array, ArrayLike, Float, Int, PyTree, PRNGKeyArray, ScalarLike

import equinox as eqx
import optax
import chex

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

from dataclasses import dataclass, astuple
from functools import partial

import numpy as np

jax.enable_checks = True

jax.devices()

2023-09-05 00:15:21.935670: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


[gpu(id=0)]

In [3]:
class EuropeanPayoff:
    @staticmethod
    def call(maturity_prices: Float[ArrayLike, " n"], strike_prices: Float[ScalarLike, ""]) -> Float[Array, " n"]:
        return jnp.maximum(jnp.subtract(maturity_prices, strike_prices), 0.0)

    @staticmethod
    def put(maturity_prices: Float[ArrayLike, " n"], strike_prices: Float[ScalarLike, ""]) -> Float[Array, " n"]:
        return jnp.maximum(jnp.subtract(strike_prices, maturity_prices), 0.0)

    @staticmethod
    def payoff(spot_path, K, payoff_fn = call):
        spot_maturity = spot_path[..., -1]
        payoff = payoff_fn(spot_maturity, K)
        return payoff

In [19]:
@dataclass
class Range:
    min: float = 0.0
    max: float = 1.0

class HestonSoA(eqx.Module):

    n_states: int
    initial_spot_prices: Float[Array, " n"] # Array of initial spot prices
    initial_vols: Float[Array, " n"] # Array of initial volatility

    K: float = 100.0        # Strike price
    r: float = 0.05         # Risk-free rate
    T: float = 5.00         # n years until expiry, i.e. T2 - T1
    rho: float = -0.3       # Correlation of asset and volatility
    kappa: float = 2.00     # Mean-reversion rate
    theta: float = 0.09     # Long run average volatility
    xi: float = 1.00        # Volatility of volatility


    def __init__(self, key: PRNGKeyArray, n_states: int, spot_range: Range, vol_range: Range):
        self.n_states = n_states
        key, spot_keys, vol_keys = jrandom.split(key, 3)
        # self.initial_spot_prices = jrandom.uniform(spot_keys, shape=(n_states,), minval=spot_range.min, maxval=spot_range.max)
        self.initial_spot_prices = jnp.ones(n_states) * 100
        key, subkey = jrandom.split(key)
        self.initial_vols = jnp.ones(n_states) * 0.09
        # self.initial_vols = jrandom.uniform(vol_keys, shape=(n_states, ), minval=vol_range.min, maxval=vol_range.max)

    def volatility_path(self, vol_draws):
        # n_iter = len(vol_draws)
        n_iter = vol_draws.shape[0]
        dt = self.T / n_iter

        def vol_path_iter(prev_vols_path, vol_draws):
            v_truncated = jnp.maximum(0.0, prev_vols_path)
            prev_path_contribution = prev_vols_path + self.kappa * dt * (self.theta - v_truncated)
            randomness = self.xi * jnp.sqrt(v_truncated * dt) * vol_draws
            current = prev_path_contribution + randomness
            return current, current # use current both for carry and for y

        # carry, vol_path = jax.lax.scan(vol_path_iter, jnp.array(self.v0), vol_draws, length=n_iter)
        carry, vol_path = jax.lax.scan(vol_path_iter, self.initial_vols, vol_draws, length=n_iter)
        print("vol_path: ", vol_path.shape)

        # we now have iteration 1 at position 0. Place initial v0 at the end and rotate into initial slot
        vol_path = jnp.roll(vol_path.at[-1, :].set(self.initial_vols), 1, axis=0)
        return vol_path

    def spot_path(self, spot_draws: Array, vol_path: Array) -> Array:
        n_iter = spot_draws.shape[0]
        dt = self.T / n_iter

        def spot_path_iter(prev_spot_path, iter_pair):
            # print("iter pair ", iter_pair.shape)
            spot_draw = iter_pair[0]
            vol = iter_pair[1]
            v_truncated = jnp.maximum(0.0, vol)
            path_new_spot = prev_spot_path * jnp.exp((self.r - 0.5 * v_truncated) * dt + jnp.sqrt(v_truncated * dt) * spot_draw)
            return path_new_spot, path_new_spot

        print("spot draws.shape ", spot_draws.shape)
        # print("vol path.shape ", vol_path.shape)
        iter_values = jnp.stack((spot_draws, vol_path), axis=1)
        # print("itervalues ", iter_values.shape)
        carry, spot_paths = jax.lax.scan(spot_path_iter, self.initial_spot_prices, iter_values, length=n_iter)

        print("Spot path.shape ", spot_paths.shape)
        # spot_paths = jnp.roll(spot_paths.at[-1, :].set(self.initial_spot_prices), 1)
        spot_paths = jnp.roll(spot_paths.at[-1, :].set(self.initial_spot_prices), 1, axis=0)
        return spot_paths

    def path(self, key: PRNGKeyArray, n_intervals: int = 1000):
        mean = jnp.zeros(2)
        cov = jnp.array([[1.0, self.rho], [self.rho, 1.0]])

        correlated_samples = jrandom.multivariate_normal(key, mean, cov, shape=(n_intervals, self.n_states))
        print(correlated_samples.shape)

        vol_draws = correlated_samples[..., 0]
        spot_draws = correlated_samples[..., 1]
        print(vol_draws.shape)
        print("spot_draws ", spot_draws.shape)

        sample_mean = jnp.mean(vol_draws)
        print(sample_mean)

        vol_path = self.volatility_path(vol_draws)
        spot_path = self.spot_path(spot_draws, vol_path)
        return vol_path, spot_path

    def payoff(self, key: PRNGKeyArray, n_intervals: int = 1000, payoff_fn = EuropeanPayoff.call):
        _, spot_path = self.path(key, n_intervals)
        spot_maturity = spot_path[-1, :]
        # print("matutiry ", spot_maturity)
        payoff = payoff_fn(spot_maturity, self.K)
        # jax.debug.print("payoff {}", payoff)
        return payoff

key = jrandom.PRNGKey(0)
key, subkey = jrandom.split(key)
spot_range = Range(50.0, 150.0)
vol_range = Range(0.01, 0.1)
n_states = 4
hset = HestonSoA(key, n_states, spot_range, vol_range)
samples = hset.path(key)

hset.payoff(key)

(1000, 4, 2)
(1000, 4)
spot_draws  (1000, 4)
-0.002678439312463839
vol_path:  (1000, 4)
spot draws.shape  (1000, 4)
Spot path.shape  (1000, 4)
(1000, 4, 2)
(1000, 4)
spot_draws  (1000, 4)
-0.002678439312463839
vol_path:  (1000, 4)
spot draws.shape  (1000, 4)
Spot path.shape  (1000, 4)


Array([ 0.        ,  0.        , 31.21334734, 42.90625302], dtype=float64)

In [21]:
# Price calculation of heston model on european call option
def price(key, heston: HestonSoA):
    n_multiplier = 1
    n_simulations = 10000 * n_multiplier
    n_outside_iter = 10 * n_multiplier
    n_sims_per_iter = n_simulations // n_outside_iter
    payoff_sum: float = 0.0
    # key, subkey = jrandom.split(key)

    def pathwise_payoff_fn(payoff_sum, key):
        # print("Payoff sum shape", payoff_sum.shape)
        keys = jrandom.split(key, num=n_sims_per_iter)
        pathwise_payoff = vmap(heston.payoff)(keys)
        print("pathwise payoff", pathwise_payoff.shape)
        jax.debug.print("pathwise payoff: {}", pathwise_payoff)
        
        payoff_sum_iter = jnp.sum(pathwise_payoff, axis=0)
        print("payoff sum iter", payoff_sum_iter.shape)
        payoff_sum += payoff_sum_iter
        return payoff_sum, payoff_sum_iter

    keys = jrandom.split(key, num=n_outside_iter)
    print("heston.n_states: :", heston.n_states)
    payoff_sum, _ = jax.lax.scan(pathwise_payoff_fn, jnp.zeros(heston.n_states), keys, length=n_outside_iter)

    price = (payoff_sum / n_simulations) * jnp.exp(-heston.r * heston.T)
    return price

In [22]:
key = jrandom.PRNGKey(0)
key, subkey = jrandom.split(key)
spot_range = Range(50.0, 150.0)
vol_range = Range(0.01, 0.1)
hset = HestonSoA(key=subkey, n_states=3, spot_range=spot_range, vol_range=vol_range)
print(hset)
print(hset.initial_vols.shape)

key, subkey = jrandom.split(key)
prices = price(subkey, hset)
jax.debug.print("prices {}", prices)

HestonSoA(
  n_states=3,
  initial_spot_prices=f64[3],
  initial_vols=f64[3],
  K=100.0,
  r=0.05,
  T=5.0,
  rho=-0.3,
  kappa=2.0,
  theta=0.09,
  xi=1.0
)
(3,)
heston.n_states: : 3
(1000, 3, 2)
(1000, 3)
spot_draws  (1000, 3)
Traced<ShapedArray(float64[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float64[1000])>with<DynamicJaxprTrace(level=1/0)>
  batch_dim = 0
vol_path:  (1000, 3)
spot draws.shape  (1000, 3)
Spot path.shape  (1000, 3)
pathwise payoff (1000, 3)
payoff sum iter (3,)
pathwise payoff: [[  0.           0.           0.        ]
 [118.28149295  55.26424119  83.17233871]
 [  3.30730845   0.           0.        ]
 ...
 [  4.5936793    0.          30.4587072 ]
 [ 16.48845507   0.          47.48411766]
 [ 80.71687884 115.41183646  65.14656448]]
pathwise payoff: [[  0.5022662  103.09776662  21.41585488]
 [ 42.73275349  36.99347792 237.40515421]
 [  0.           0.           1.64502354]
 ...
 [ 83.80805129 277.67362124  26.23683167]
 [  0.           0.        