In [1]:
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax.scipy.stats import norm
from jax import grad, jit, vmap
from jax import random as jrandom
import jax

from jax import config
config.update("jax_enable_x64", True)

from typing import Sequence, Tuple, Union, Optional
from jaxtyping import Array, ArrayLike, Float, Int, PyTree, PRNGKeyArray, ScalarLike

import equinox as eqx
import optax
import chex

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

from dataclasses import dataclass, astuple
from functools import partial

import numpy as np

import pytest
import pytest_benchmark

jax.enable_checks = True

jax.devices()

2023-08-28 15:23:39.358005: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


[gpu(id=0)]

# Model Definition

In [2]:
class EuropeanPayoff:
    @staticmethod
    def call(maturity_prices: Float[ArrayLike, " n"], strike_prices: Float[ScalarLike, ""]) -> Float[Array, " n"]:
        return jnp.maximum(jnp.subtract(maturity_prices, strike_prices), 0.0)

    @staticmethod
    def put(maturity_prices: Float[ArrayLike, " n"], strike_prices: Float[ScalarLike, ""]) -> Float[Array, " n"]:
        return jnp.maximum(jnp.subtract(strike_prices, maturity_prices), 0.0)

In [3]:

# class HestonState
# class MarkovState

@dataclass
class Heston:
        
    # S0: float = 100.0       # Initial spot price
    # K: float = 100.0        # Strike price
    # r: float = 0.0319       # Risk-free rate
    # v0: float = 0.010201    # Initial volatility
    # T: float = 1.00         # One year until expiry
    # rho: float = -0.7       # Correlation of asset and volatility
    # kappa: float = 6.21     # Mean-reversion rate
    # theta: float = 0.019    # Long run average volatility
    # xi: float = 0.61        # Volatility of volatility

    S0: float = 100.0       # Initial spot price
    K: float = 100.0        # Strike price
    r: float = 0.05         # Risk-free rate
    v0: float = 0.09        # Initial volatility
    T: float = 5.00         # One year until expiry
    rho: float = -0.3       # Correlation of asset and volatility
    kappa: float = 2.00     # Mean-reversion rate
    theta: float = 0.09     # Long run average volatility
    xi: float = 1.00        # Volatility of volatility

    
    # TODO: store vol_path inside class?
    def spot_path_for_loop(self, spot_draws: Array, vol_path: Array) -> Array:
        vec_size = len(spot_draws)
        dt = self.T / vec_size
        spot_paths = jnp.zeros_like(vol_path)
        spot_paths = spot_paths.at[0].set(self.S0)

        
        for i in range(1, vec_size):
            v_truncated = jnp.maximum(0.0, vol_path[i-1])
            path_new_spot = spot_paths[i-1] * jnp.exp((self.r - 0.5 * v_truncated) * dt + jnp.sqrt(v_truncated * dt) * spot_draws[i-1])
            spot_paths = spot_paths.at[i].set(path_new_spot)
        return spot_paths

    def spot_path(self, spot_draws: Array, vol_path: Array) -> Array:
        n_iter = len(spot_draws)
        dt = self.T / n_iter

        def spot_path_iter(prev_spot_path, iter_pair):
            spot_draw = iter_pair[0]
            vol = iter_pair[1]
            v_truncated = jnp.maximum(0.0, vol)
            path_new_spot = prev_spot_path * jnp.exp((self.r - 0.5 * v_truncated) * dt + jnp.sqrt(v_truncated * dt) * spot_draw)
            return path_new_spot, path_new_spot

        iter_values = jnp.column_stack((spot_draws, vol_path))
        carry, spot_paths = jax.lax.scan(spot_path_iter, jnp.array(self.S0), iter_values, length=n_iter)

        spot_paths = jnp.roll(spot_paths.at[-1].set(self.S0), 1)
        return spot_paths
    
    def volatility_path_for_loop(self, vol_draws: Float[Array, "n ..."]) -> Float[Array, "n ..."]:
        vec_size = len(vol_draws)
        dt = self.T / vec_size

        vol_path = jnp.zeros_like(vol_draws)
        vol_path = vol_path.at[0].set(self.v0)

        for i in range(1, vec_size):
            v_truncated = jnp.maximum(0.0, vol_path[i-1])
            prev_path_contribution = vol_path[i-1] + self.kappa * dt * (self.theta - v_truncated)
            randomness = self.xi * jnp.sqrt(v_truncated * dt) * vol_draws[i-1]
            vol_path = vol_path.at[i].set(prev_path_contribution + randomness)
        
        return vol_path

    def volatility_path(self, vol_draws):
        n_iter = len(vol_draws)
        dt = self.T / n_iter

        def vol_path_iter(prev_vol_path, vol_draw):
            v_truncated = jnp.maximum(0.0, prev_vol_path)
            prev_path_contribution = prev_vol_path + self.kappa * dt * (self.theta - v_truncated)
            randomness = self.xi * jnp.sqrt(v_truncated * dt) * vol_draw
            current = prev_path_contribution + randomness
            return current, current # use current both for carry and for y

        carry, vol_path = jax.lax.scan(vol_path_iter, jnp.array(self.v0), vol_draws, length=n_iter)

        # we now have iteration 1 at position 0. Place initial v0 at the end and rotate into initial slot
        vol_path = jnp.roll(vol_path.at[-1].set(self.v0), 1)
        return vol_path

    def path(self, key: PRNGKeyArray, n_intervals: int = 1000): #, payoff_fn = EuropeanPayoff.call):
        mean = jnp.zeros(2)
        cov = jnp.array([[1.0, heston.rho], [heston.rho, 1.0]])
        
        correlated_samples = jrandom.multivariate_normal(key, mean, cov, shape=(n_intervals,))
        
        vol_draws = correlated_samples[:, 0]
        spot_draws = correlated_samples[:, 1]
        
        vol_path = heston.volatility_path(vol_draws)
        spot_path = heston.spot_path(spot_draws, vol_path)
        
        #spot_maturity = spot_path[-1]
        #payoff = payoff_fn(spot_maturity, self.K)
        return spot_path

    def data(self, key: PRNGKeyArray, n_intervals: int = 1000, payoff_fn = EuropeanPayoff.call):
        pass
        


# Exploration

In [23]:
n_sims : int = 1000000
n_intervals: int = 1000
heston = Heston()

In [24]:
key = jrandom.PRNGKey(0)

key, subkey = jrandom.split(key)
vol_draws = jrandom.normal(key=subkey, shape=(100, ))
# print(vol_draws)

vol_path_for = heston.volatility_path_for_loop(vol_draws)
vol_path = heston.volatility_path(vol_draws)
assert(jnp.allclose(vol_path_for, vol_path))

In [25]:
key, subkey = jrandom.split(key)
spot_draws = jrandom.normal(key=subkey, shape=(100, ))

spot_path_for = heston.spot_path_for_loop(spot_draws, vol_path)
spot_path = heston.spot_path(spot_draws, vol_path)
assert(jnp.allclose(spot_path_for, spot_path))

In [26]:
key, subkey = jrandom.split(key)

mean = jnp.zeros(2)
cov = jnp.array([[1, heston.rho], [heston.rho, 1]])

print(cov)
correlated_samples = jrandom.multivariate_normal(key, mean, cov, shape=(100000,))
correlated_samples.shape

[[ 1.  -0.3]
 [-0.3  1. ]]


(100000, 2)

In [27]:
sample_mean = jnp.mean(correlated_samples, axis=0)
assert(jnp.allclose(sample_mean, mean, atol=1e-2))

In [28]:
sample_cov = jnp.cov(correlated_samples.T)
assert(jnp.allclose(sample_cov, cov, atol=1e-2))

In [29]:
jnp.corrcoef(correlated_samples.T)[0, 1]

Array(-0.30700636, dtype=float64)

In [30]:
jnp.corrcoef(correlated_samples[:, 0], correlated_samples[:, 1])

Array([[ 1.        , -0.30700636],
       [-0.30700636,  1.        ]], dtype=float64)

In [31]:
mean = jnp.zeros(2)
cov = jnp.array([[1, heston.rho], [heston.rho, 1]])

print(cov)

[[ 1.  -0.3]
 [-0.3  1. ]]


# Price

In [4]:
# Price calculation of heston model on european call option
def price(key):
    n_multiplier = 10
    n_simulations = 1000000 * n_multiplier
    n_outside_iter = 10 * n_multiplier
    n_sims_per_iter = n_simulations // n_outside_iter
    payoff_sum: float = 0.0
    key, subkey = jrandom.split(key)
    
    def pathwise_payoff_fn(payoff_sum, key):
        keys = jrandom.split(key, num=n_sims_per_iter)
        pathwise_payoff = vmap(heston.payoff)(keys)
        payoff_sum_iter = jnp.sum(pathwise_payoff)
        payoff_sum += payoff_sum_iter
        return payoff_sum, payoff_sum_iter  
    
    keys = jrandom.split(subkey, num=n_outside_iter)
    payoff_sum, _ = jax.lax.scan(pathwise_payoff_fn, 0.0, keys, length=n_outside_iter)
    
    price = (payoff_sum / n_simulations) * jnp.exp(-heston.r * heston.T)
    jax.debug.print("price {}", price)


In [9]:
key = jrandom.PRNGKey(0)
heston = Heston()

In [10]:
price_fn = jit(price)

In [96]:
%timeit price_fn(key)

price 35.0712355975864


2023-08-28 15:22:50.764056: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(1963): _wrapped_callback
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1221): __call__
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/profiler.py(314): wrapper
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1148): _pjit_call_impl_python
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1192): call_impl_cache_miss
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1209): _pjit_call_impl
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(821): process_primitive
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/neil/miniconda3/lib/python3.10/site-pa

XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(1963): _wrapped_callback
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1221): __call__
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/profiler.py(314): wrapper
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1148): _pjit_call_impl_python
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1192): call_impl_cache_miss
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(1209): _pjit_call_impl
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(821): process_primitive
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(389): bind_with_trace
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/core.py(2596): bind
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(166): _python_pjit_helper
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/pjit.py(253): cache_miss
  /home/neil/miniconda3/lib/python3.10/site-packages/jax/_src/traceback_util.py(166): reraise_with_filtered_traceback
  <magic-timeit>(1): inner
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/magics/execution.py(158): timeit
  /home/neil/miniconda3/lib/python3.10/timeit.py(206): repeat
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/magics/execution.py(1168): timeit
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2369): run_line_magic
  /tmp/ipykernel_1152/2421794830.py(1): <module>
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3460): run_code
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3400): run_ast_nodes
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3221): run_cell_async
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3016): _run_cell
  /home/neil/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py(2961): run_cell
  /home/neil/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py(540): run_cell
  /home/neil/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py(422): do_execute
  /home/neil/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(729): execute_request
  /home/neil/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(409): dispatch_shell
  /home/neil/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(502): process_one
  /home/neil/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py(513): dispatch_queue
  /home/neil/miniconda3/lib/python3.10/asyncio/events.py(80): _run
  /home/neil/miniconda3/lib/python3.10/asyncio/base_events.py(1909): _run_once
  /home/neil/miniconda3/lib/python3.10/asyncio/base_events.py(603): run_forever
  /home/neil/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py(215): start
  /home/neil/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py(725): start
  /home/neil/.local/lib/python3.10/site-packages/traitlets/config/application.py(1043): launch_instance
  /home/neil/.local/lib/python3.10/site-packages/ipykernel_launcher.py(17): <module>
  /home/neil/miniconda3/lib/python3.10/runpy.py(86): _run_code
  /home/neil/miniconda3/lib/python3.10/runpy.py(196): _run_module_as_main
; current tracing scope: custom-call.597; current profiling annotation: XlaModule:#hlo_module=jit_price,program_id=464#.

In [11]:
price_fn(key)

price 34.96811035689572


# Visualize Data

In [80]:
n_sims : int = 100000
n_intervals: int = 1000
heston = Heston()

In [81]:
heston_pathwise_payoffs_fn = jit(vmap(heston.payoff))

In [85]:
key, subkey = jrandom.split(key)
keys = jrandom.split(key, num=n_sims)
pathwise_payoff = heston_pathwise_payoffs_fn(keys)

print(pathwise_payoff)
print(pathwise_payoff.shape)

#heston.payoff(key=subkey, n_intervals=n_intervals)

[  0.           8.59688653  81.09357113 ... 118.6389061   51.95726374
   7.62013775]
(100000,)


In [34]:
vol_draws.shape

(100,)