In [30]:
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, value_and_grad
import numpy as np

In [9]:
N = jsp.stats.norm.cdf

def BlackScholesCall(S, sigma, T, r, K):
    d1 = (jnp.log(S/K) + (r + sigma**2/2)*T) / (sigma*jnp.sqrt(T))
    d2 = d1 - sigma * jnp.sqrt(T)
    return S * N(d1) - K * jnp.exp(-r*T)* N(d2)

In [10]:
grad_BS = grad(BlackScholesCall, argnums=1)
grad_BS(100.0, 0.1, 1, 0, 100.0)

DeviceArray(39.844395, dtype=float32, weak_type=True)

In [20]:
def simulate_BS(sigma):
    T = 1
    it = 10
    weeks = 52
    dt = T/weeks
    sum = 0
    S = 100

    for i in range(it):
        Sold = S
        for t in range(weeks):
            Z = np.random.normal()
            Snew = Sold*jnp.exp((-0.5*sigma**2)*dt+sigma*jnp.sqrt(dt)*Z)
            Sold = Snew
        if Snew - 100 > 0:
            sum = sum + (Snew - 100)
    
    return sum/float(it)

In [21]:
grad_simulateBS = grad(simulate_BS)
grad_simulateBS(0.1)

DeviceArray(25.481146, dtype=float32, weak_type=True)

In [28]:
def simulate_BS(sigma):
    S0 = 100
    it = 100000
    T = 1
    weeks = 52
    dt = T/weeks

    S = jnp.empty(shape=(it,weeks+1))
    S = S.at[:,0].set(S0)

    for i in range(weeks):
        Z = np.random.normal(size=it)
        S = S.at[:,i+1].set(S[:,i]*jnp.exp((-0.5*sigma**2)*dt+sigma*jnp.sqrt(dt)*Z))

    return jnp.sum(jnp.where(S[:,i+1]-100<0,0,S[:,i+1]-100))/it

In [38]:
C, delta = value_and_grad(simulate_BS, argnums=0)(0.1)

In [40]:
print(C, delta)

4.003037 39.988308
