In [1]:
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax.scipy.stats import norm
from jax import grad, jit, vmap
from jax import random as jrandom
import jax

from jax import config
config.update("jax_enable_x64", True)

from typing import Sequence, Tuple, Union
from jaxtyping import Array, ArrayLike, Float, Int, PyTree, ScalarLike

import equinox as eqx
import optax
import chex

import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

from dataclasses import dataclass, astuple
from functools import partial

import numpy as np

import pytest
import pytest_benchmark

jax.enable_checks = True

jax.devices()

2023-08-22 17:35:53.241665: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


[gpu(id=0)]

In [2]:
class EuropeanPayoff:
    @staticmethod
    def call(maturity_prices: Float[ArrayLike, " n"], strike_prices: Float[ScalarLike, ""]) -> Float[Array, " n"]:
        return jnp.maximum(jnp.subtract(maturity_prices, strike_prices), 0.0)

    @staticmethod
    def put(maturity_prices: Float[ArrayLike, " n"], strike_prices: Float[ScalarLike, ""]) -> Float[Array, " n"]:
        return jnp.maximum(jnp.subtract(strike_prices, maturity_prices), 0.0)

In [18]:
@dataclass
class Heston:
    

    S0 : float = 105.0;    # Initial spot price
    K : float = 100.0;     # Strike price
    T : float = 1.0;       # One year until expiry
    rho : float = 0.0;     # Correlation of asset and volatility
    xi : float = 0.1;      # Vol of vol
    v0 : float = 0.01;     # Initial volatility
    
    r : float = 0.00;      # Risk-free rate
    kappa : float = 2.0;   # Mean-reversion rate
    theta : float = 0.01;  # Long run average volatility

    

    # TODO: store vol_path inside class?
    def spot_path_for_loop(self, spot_draws: Array, vol_path: Array) -> Array:
        vec_size = len(spot_draws)
        dt = self.T / vec_size
        spot_paths = jnp.zeros_like(vol_path)
        spot_paths = spot_paths.at[0].set(self.S0)
        
        for i in range(1, vec_size):
            v_truncated = jnp.maximum(0.0, vol_path[i-1])
            prev_path_contribution = spot_paths[i-1] * jnp.exp((self.r - 0.5 * v_truncated)) * dt
            randomness = jnp.sqrt(v_truncated * dt) * spot_draws[i-1]
            spot_paths = spot_paths.at[i].set(prev_path_contribution + randomness) # change this in optimized loop to build up the end array?

        return spot_paths

    def spot_path(self, spot_draws: Array, vol_path: Array) -> Array:
        n_iter = len(spot_draws)
        dt = self.T / n_iter

        def spot_path_iter(prev_spot_path, iter_pair):
            spot_draw = iter_pair[0]
            vol = iter_pair[1]
            v_truncated = jnp.maximum(0.0, vol)
            prev_path_contribution = prev_spot_path * jnp.exp((self.r - 0.5 * v_truncated)) * dt
            randomness = jnp.sqrt(v_truncated * dt) * spot_draw
            current_spot_path = prev_path_contribution + randomness
            return current_spot_path, current_spot_path

        iter_values = jnp.column_stack((spot_draws, vol_path))
        carry, spot_paths = jax.lax.scan(spot_path_iter, jnp.array(self.S0), iter_values, length=n_iter)

        spot_paths = jnp.roll(spot_paths.at[-1].set(self.S0), 1)
        return spot_paths
    
    def volatility_path_for_loop(self, vol_draws: Float[Array, "n ..."]) -> Float[Array, "n ..."]:
        vec_size = len(vol_draws)
        dt = self.T / vec_size

        vol_path = jnp.zeros_like(vol_draws)
        vol_path = vol_path.at[0].set(self.v0)

        for i in range(1, vec_size):
            v_truncated = jnp.maximum(0.0, vol_path[i-1])
            prev_path_contribution = vol_path[i-1] + self.kappa * dt * (self.theta - v_truncated)
            randomness = self.xi * jnp.sqrt(v_truncated * dt) * vol_draws[i-1]
            vol_path = vol_path.at[i].set(prev_path_contribution + randomness)
        
        return vol_path

    def volatility_path(self, vol_draws):
        n_iter = len(vol_draws)
        dt = self.T / n_iter

        def vol_path_iter(prev_vol_path, vol_draw):
            v_truncated = jnp.maximum(0.0, prev_vol_path)
            prev_path_contribution = prev_vol_path + self.kappa * dt * (self.theta - v_truncated)
            randomness = self.xi * jnp.sqrt(v_truncated * dt) * vol_draw
            current = prev_path_contribution + randomness
            return current, current # use current both for carry and for y

        carry, vol_path = jax.lax.scan(vol_path_iter, jnp.array(self.v0), vol_draws, length=n_iter)

        # we now have iteration 1 at position 0. Place initial v0 at the end and rotate into initial slot
        vol_path = jnp.roll(vol_path.at[-1].set(self.v0), 1)
        return vol_path


In [19]:
n_sims : int = 1000000
n_intervals: int = 1000

payoff_sum: float = 0.0



In [20]:
heston = Heston()

In [21]:
key = jrandom.PRNGKey(0)

key, subkey = jrandom.split(key)
vol_draws = jrandom.normal(key=subkey, shape=(100, ))
#print(vol_draws)

vol_path_for = heston.volatility_path_for_loop(vol_draws)
vol_path = heston.volatility_path(vol_draws)
assert(jnp.allclose(vol_path_for, vol_path))

In [25]:
key, subkey = jrandom.split(key)
spot_draws = jrandom.normal(key=subkey, shape=(100, ))

spot_path_for = heston.spot_path_for_loop(spot_draws, vol_path)
spot_path = heston.spot_path(spot_draws, vol_path)
assert(jnp.allclose(spot_path_for, spot_path))