# Pricing a European Call Option Version 2

----

#### John Stachurski (August 2024)

----

In this notebook we will accelerate our code for option pricing using different
libraries

In [1]:
import numpy as np
import matplotlib.pyplot as plt

## Why is Pure Python Slow?

We saw that our Python code for option pricing was pretty slow.

In essence, this is because our loops were written in pure Python

Pure Python loops are not fast.

This has led some people to claim that Python is too slow for computational
economics.

These people are ~idiots~ misinformed -- so please ignore them.

Evidence: AI teams are solving optimization problems in $\mathbb R^d$ with $d >
$ 1 trillion using Pytorch / JAX

So I'm pretty sure we can use Python for computational economics.

But first let's try to understand the issues.

### Issue 1: Type Checking

Consider the following Python code

In [2]:
x, y = 1, 2
x + y

3

This is integer addition, which is different from floating point addition

In [3]:
x, y = 1.0, 2.0
x + y

3.0

Now consider this code

In [4]:
x, y = 'foo', 'bar'
x + y

'foobar'

Notice that we use the same symbol `+` on each occasion.

The Python interpreter figures out the correct action by type checking:

In [5]:
a, b = 'foo', 10
type(a)

str

In [6]:
type(b)

int

But think of all the type checking in our option pricing function --- the
overhead is huge!!

In [7]:
n, β, K = 10, 0.99, 100
μ, ρ, ν, S_0, h_0 = 0.0001, 0.01, 0.001, 10.0, 0.0
def compute_call_price_py(β=β,
                           μ=μ,
                           S_0=S_0,
                           h_0=h_0,
                           K=K,
                           n=n,
                           ρ=ρ,
                           ν=ν,
                           M=1_000_000,
                           seed=1234):
    np.random.seed(seed)

    s_0 = np.log(S_0)
    s_n = np.empty(M)

    for m in range(M):
        s, h = s_0, h_0
        for t in range(n):
            U, V = np.random.randn(2)
            s = s + μ + np.exp(h) * U
            h = ρ * h + ν * V
        s_n[m] = s

    S_n = np.exp(s_n)

    expectation = np.mean(np.maximum(S_n - K, 0))

    return β**n * expectation

### Issue 2:  Memory Management

Pure Python emphasizes flexibility and hence cannot attain maximal efficiency
vis-a-vis memory management.

For example,

In [8]:
import sys
x = [1.0, 2.0]  
sys.getsizeof(x) * 8   # number of bits

576

### Issue 3:  Parallelization

There are opportunities to parallelize our code above -- divide it across
multiple workers.

This can't be done efficiently with pure Python but certainly can with the right
Python libraries.

## Vectorization

As a first pass at improving efficiency, here's a vectorized version where all paths are updated together.

We use NumPy to store and update each vector of share prices.

When we use NumPy, type-checking is done per-array, not per element!

In [9]:
def compute_call_price_np(β=β,
                          μ=μ,
                          S_0=S_0,
                          h_0=h_0,
                          K=K,
                          n=n,
                          ρ=ρ,
                          ν=ν,
                          M=10_000_000,
                          seed=1234):
    np.random.seed(seed)
    s = np.full(M, np.log(S_0))
    h = np.full(M, h_0)
    for t in range(n):
        Z = np.random.randn(2, M)
        s = s + μ + np.exp(h) * Z[0, :]
        h = ρ * h + ν * Z[1, :]
    expectation = np.mean(np.maximum(np.exp(s) - K, 0))
        
    return β**n * expectation

Now computation of the option price is reasonably fast:

In [10]:
%time compute_call_price_np()

CPU times: user 3.93 s, sys: 360 ms, total: 4.29 s
Wall time: 4.29 s


1286.3220531978766

But we can still do better...

## Numba Version

Let's try a Numba version.

This version uses a just-in-time (JIT) compiler to eliminate type-checking.

In [11]:
import numba

@numba.jit()
def compute_call_price_numba(β=β,
                               μ=μ,
                               S_0=S_0,
                               h_0=h_0,
                               K=K,
                               n=n,
                               ρ=ρ,
                               ν=ν,
                               M=10_000_000,
                               seed=1234):
    np.random.seed(seed)
    s_0 = np.log(S_0)
    s_n = np.empty(M)
    for m in range(M):
        s, h = s_0, h_0
        for t in range(n):
            s = s + μ + np.exp(h) * np.random.randn()
            h = ρ * h + ν * np.random.randn()
        s_n[m] = s
    S_n = np.exp(s_n)
    expectation = np.mean(np.maximum(S_n - K, 0))

    return β**n * expectation

In [12]:
%time compute_call_price_numba()

CPU times: user 6.14 s, sys: 83 ms, total: 6.22 s
Wall time: 6.22 s


1389.1371007387315

In [13]:
%time compute_call_price_numba()

CPU times: user 5.6 s, sys: 86.7 ms, total: 5.68 s
Wall time: 5.66 s


1389.1371007387315

## Numba Plus Parallelization

The last version was only running on one core.

Next let's try a Numba version with parallelization.

In [14]:
from numba import prange

@numba.jit(parallel=True)
def compute_call_price_numba_parallel(β=β,
                                      μ=μ,
                                      S_0=S_0,
                                      h_0=h_0,
                                      K=K,
                                      n=n,
                                      ρ=ρ,
                                      ν=ν,
                                      M=10_000_000,
                                      seed=1234):
    np.random.seed(seed)
    s_0 = np.log(S_0)
    s_n = np.empty(M)
    for m in prange(M):
        s, h = s_0, h_0
        for t in range(n):
            s = s + μ + np.exp(h) * np.random.randn()
            h = ρ * h + ν * np.random.randn()
        s_n[m] = s
    S_n = np.exp(s_n)
    expectation = np.mean(np.maximum(S_n - K, 0))

    return β**n * expectation

In [15]:
%time compute_call_price_numba_parallel()

CPU times: user 8.4 s, sys: 40 ms, total: 8.44 s
Wall time: 1.15 s


1296.4341728772752

In [16]:
%time compute_call_price_numba_parallel()

CPU times: user 7.68 s, sys: 19.7 ms, total: 7.7 s
Wall time: 492 ms


1285.8905806670182

## JAX Version

We can do even better if we exploit a hardware accelerator such as a GPU.

Let's see if we have a GPU:

In [17]:
!nvidia-smi

Wed Aug 14 15:08:59 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3080        Off |   00000000:3E:00.0 Off |                  N/A |
| 30%   30C    P8             25W /  320W |       2MiB /  10240MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

The following import is standard, replacing `import numpy as np`:

In [18]:
import jax
import jax.numpy as jnp

### Simple JAX version

Let's start with a simple version that looks like the NumPy version.

In [19]:
def compute_call_price_jax(β=β,
                           μ=μ,
                           S_0=S_0,
                           h_0=h_0,
                           K=K,
                           n=n,
                           ρ=ρ,
                           ν=ν,
                           M=10_000_000,
                           seed=1234):

    key=jax.random.PRNGKey(seed)
    s = jnp.full(M, np.log(S_0))
    h = jnp.full(M, h_0)
    for t in range(n):
        key, subkey = jax.random.split(key)
        Z = jax.random.normal(subkey, (2, M))
        s = s + μ + jnp.exp(h) * Z[0, :]
        h = ρ * h + ν * Z[1, :]
    S = jnp.exp(s)
    expectation = jnp.mean(jnp.maximum(S - K, 0))
        
    return β**n * expectation

Let's run it once to compile it:

In [20]:
%%time 
price = compute_call_price_jax()
print(price)

1269.854
CPU times: user 427 ms, sys: 192 ms, total: 619 ms
Wall time: 717 ms


And now let's time it:

In [21]:
%%time 
price = compute_call_price_jax()
print(price)

1269.854
CPU times: user 3.22 ms, sys: 15.7 ms, total: 18.9 ms
Wall time: 28.4 ms


### Compiled JAX version

Let's take the simple JAX version above and compile the entire function.

In [22]:
compute_call_price_jax_compiled = jax.jit(compute_call_price_jax, static_argnums=(8, ))

We run once to compile.

In [23]:
%%time 
price = compute_call_price_jax_compiled()
print(price)

1269.8539
CPU times: user 950 ms, sys: 13 ms, total: 963 ms
Wall time: 1.82 s


And now let's time it.

In [24]:
%%time 
price = compute_call_price_jax_compiled()
print(price)

1269.8539
CPU times: user 1.73 ms, sys: 8 µs, total: 1.74 ms
Wall time: 10.7 ms


Now we have a really big speed gain relative to NumPy.