In [1]:
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax.scipy.stats import norm
from jax import grad, jit, vmap
from jax import random
import jax

from typing import Sequence
from jaxtyping import Array, Float, Int, PyTree

import equinox as eqx
import optax
import chex

from dataclasses import dataclass
from functools import partial

jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

In [9]:
@jit
def bsPrice(spot, strike, vol, T):
    d1 = (jnp.log(spot/strike) + 0.5 * vol * vol * T) / vol / jnp.sqrt(T)
    d2 = d1 - vol * jnp.sqrt(T)
    return spot * norm.cdf(d1) - strike * norm.cdf(d2)

@jit
def bsDelta(spot, strike, vol, T):
    d1 = (jnp.log(spot/strike) + 0.5 * vol * vol * T) / vol / jnp.sqrt(T)
    return norm.cdf(d1)

@jit
def bsVega(spot, strike, vol, T):
    d1 = (jnp.log(spot/strike) + 0.5 * vol * vol * T) / vol / jnp.sqrt(T)
    return spot * jnp.sqrt(T) * norm.pdf(d1)

@dataclass
class BlackScholes:
    
    spot: int = 1
    vol: float = 0.2
    T1: int = 1
    T2: int = 2
    K: float = 1.10
    volMult: float = 1.5
                        
    # training set: returns S1 (mx1), C2 (mx1) and dC2/dS1 (mx1)
    # @partial(jit, static_argnums=(0,))
    def trainingSet(self, m, anti=True, seed=42):
    
        key = random.PRNGKey(seed)
        # 2 sets of normal returns
        returns = random.normal(key, shape=(m, 2))

        # SDE
        vol0 = self.vol * self.volMult
        R1 = jnp.exp(-0.5*vol0*vol0*self.T1 + vol0*jnp.sqrt(self.T1)*returns[:,0])
        R2 = jnp.exp(-0.5*self.vol*self.vol*(self.T2-self.T1) \
                    + self.vol*jnp.sqrt(self.T2-self.T1)*returns[:,1])
        S1 = self.spot * R1
        S2 = S1 * R2 

        # payoff
        pay = jnp.maximum(0, S2 - self.K)
        
        # two antithetic paths
        if anti:
            
            R2a = jnp.exp(-0.5*self.vol*self.vol*(self.T2-self.T1) \
                    - self.vol*jnp.sqrt(self.T2-self.T1)*returns[:,1])
            S2a = S1 * R2a             
            paya = jnp.maximum(0, S2a - self.K)
            
            X = S1
            Y = 0.5 * (pay + paya)
    
            # differentials
            Z1 =  jnp.where(S2 > self.K, R2, 0.0).reshape((-1,1)) 
            Z2 =  jnp.where(S2a > self.K, R2a, 0.0).reshape((-1,1)) 
            Z = 0.5 * (Z1 + Z2)
                    
        # standard
        else:
        
            X = S1
            Y = pay
            
            # differentials
            Z =  jnp.where(S2 > self.K, R2, 0.0).reshape((-1,1)) 
        
        return X.reshape([-1,1]), Y.reshape([-1,1]), Z.reshape([-1,1])
    
    # test set: returns a grid of uniform spots 
    # with corresponding ground true prices, deltas and vegas
    def testSet(self, lower=0.35, upper=1.65, num=100, seed=42):
        
        spots = jnp.linspace(lower, upper, num).reshape((-1, 1))
        # compute prices, deltas and vegas
        prices = bsPrice(spots, self.K, self.vol, self.T2 - self.T1).reshape((-1, 1))
        deltas = bsDelta(spots, self.K, self.vol, self.T2 - self.T1).reshape((-1, 1))
        vegas = bsVega(spots, self.K, self.vol, self.T2 - self.T1).reshape((-1, 1))
        return spots, spots, prices, deltas, vegas   

In [12]:
bs = BlackScholes()
x_train, y_train, dydx_train = bs.trainingSet(8192)
x_test, x_axis, y_test, dydx_test, vegas = bs.testSet(num=1000)

In [59]:
# @partial(jit, static_argnums=(1,2,3,4,5,6,7,8))

# @jit
def bs_trainingSet(spot, vol, T1, T2, K, volMult, m, anti=True, seed=42):

    key = random.PRNGKey(seed)
    # 2 sets of normal returns
    returns = random.normal(key, shape=(m, 2))

    # SDE
    vol0 = vol * volMult
    R1 = jnp.exp(-0.5*vol0*vol0*T1 + vol0*jnp.sqrt(T1)*returns[:,0])
    R2 = jnp.exp(-0.5*vol*vol*(T2-T1) \
                + vol*jnp.sqrt(T2-T1)*returns[:,1])
    S1 = spot * R1
    S2 = S1 * R2 

    # payoff
    pay = jnp.maximum(0, S2 - K)
    
    # two antithetic paths
    if anti:
        
        R2a = jnp.exp(-0.5*vol*vol*(T2-T1) \
                - vol*jnp.sqrt(T2-T1)*returns[:,1])
        S2a = S1 * R2a             
        paya = jnp.maximum(0, S2a - K)
        
        X = S1
        Y = 0.5 * (pay + paya)

        # differentials
        Z1 =  jnp.where(S2 > K, R2, 0.0).reshape((-1,1)) 
        Z2 =  jnp.where(S2a > K, R2a, 0.0).reshape((-1,1)) 
        Z = 0.5 * (Z1 + Z2)
                
    # standard
    else:
    
        X = S1
        Y = pay
        
        # differentials
        Z =  jnp.where(S2 > K, R2, 0.0).reshape((-1,1)) 
    
    # return X.reshape([-1,1]), Y.reshape([-1,1]), Z.reshape([-1,1])
    return Y.reshape([-1,1])

bs_trainingSet(1.0, 0.2, 1.0, 2.0, 1.1, 1.5, 8192)

grad(bs_trainingSet)(1., 0.2, 1, 2, 1.1, 1.5, 8192)

TypeError: Gradient only defined for scalar-output functions. Output had shape: (8192, 1).