## Pricing a European Call Option (Numba vs JAX)

#### Written for the CBC QuantEcon Workshop (September 2022)

#### Author: [John Stachurski](http://johnstachurski.net/)

In [8]:
import numpy as np
import matplotlib.pyplot as plt

Recall that we want to compute


$$ P = \beta^n \mathbb E \max\{ S_n - K, 0 \} $$

We suppose that

In [9]:
n  = 20
β = 0.99
K = 100

The dynamics are

$$ \ln \frac{S_{t+1}}{S_t} = \mu + \sigma_t \xi_{t+1} $$

where 

$$ 
    \sigma_t = \exp(h_t), 
    \quad
        h_{t+1} = \rho h_t + \nu \eta_{t+1}
$$

Here $\{\xi_t\}$ and $\{\eta_t\}$ are IID and standard normal.

With $s_t := \ln S_t$, the price dynamics become

$$ s_{t+1} = s_t + \mu + \exp(h_t) \xi_{t+1} $$

We use the following defaults.

In [10]:
μ  = 0.0001
ρ  = 0.1
ν  = 0.001
S0 = 10
h0 = 0

(Here `S0` is $S_0$ and `h0` is $h_0$.)

We used the following estimate of the price, computed via Monte Carlo and applying Numba and parallelization.

In [15]:
from numba import njit, prange
from numpy.random import randn

In [16]:
M = 10_000_000

In [17]:
@njit(parallel=True)
def compute_call_price_parallel(β=β,
                                μ=μ,
                                S0=S0,
                                h0=h0,
                                K=K,
                                n=n,
                                ρ=ρ,
                                ν=ν,
                                M=M):
    current_sum = 0.0
    # For each sample path
    for m in prange(M):
        s = np.log(S0)
        h = h0
        # Simulate forward in time
        for t in range(n):
            s = s + μ + np.exp(h) * randn()
            h = ρ * h + ν * randn()
        # And add the value max{S_n - K, 0} to current_sum
        current_sum += np.maximum(np.exp(s) - K, 0)
        
    return β**n * current_sum / M

In [18]:
%%time
compute_call_price_parallel()

CPU times: user 14.3 s, sys: 19.5 ms, total: 14.4 s
Wall time: 2.33 s


167647.93969568016

In [19]:
%%time
compute_call_price_parallel()

CPU times: user 13.8 s, sys: 11.7 ms, total: 13.8 s
Wall time: 1.8 s


157306.5698800889

### Exercise

Try to shift the whole operation to the GPU using JAX and test your speed gain.

### Solution

In [20]:
!nvidia-smi

Sat Sep 10 14:42:50 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:3B:00.0 Off |                  N/A |
| 30%   27C    P8    25W / 320W |      1MiB / 10240MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [21]:
import jax
import jax.numpy as jnp

In [26]:
@jax.jit
def compute_call_price_jax(β=β,
                           μ=μ,
                           S0=S0,
                           h0=h0,
                           K=K,
                           n=n,
                           ρ=ρ,
                           ν=ν,
                           M=M,
                           key=jax.random.PRNGKey(1)):

    s = jnp.full(M, np.log(S0))
    h = jnp.full(M, h0)
    for t in range(n):
        key, subkey = jax.random.split(key)
        Z = jax.random.normal(subkey, (2, M))
        s = s + μ + jnp.exp(h) * Z[0, :]
        h = ρ * h + ν * Z[1, :]
    expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0))
        
    return β**n * expectation

In [27]:
%%time 
compute_call_price_jax()

CPU times: user 2.64 s, sys: 108 ms, total: 2.75 s
Wall time: 4.42 s


DeviceArray(152444.25, dtype=float32)

In [30]:
%%time 
compute_call_price_jax()

CPU times: user 587 µs, sys: 10 µs, total: 597 µs
Wall time: 398 µs


DeviceArray(152444.25, dtype=float32)