# Efficient Economic Scenario Generation - Vectorised Numpy Approach

In [18]:
import matplotlib.pyplot as plt
import numpy as np
import time
import jax
import jax.numpy as jnp
from cProfile import Profile


## The Numpy Approach

### Defining Our Model

In [19]:
def correlate_weiner_process(X, rho, dt):
    corr = np.array([[1,rho],[rho,1]])
    L = np.linalg.cholesky(corr)
    # dZ = np.sqrt(dt)*(np.einsum('ij,ikl -> ikl',L,X))
    dZ = np.sqrt(dt)* np.tensordot(L,X, axes=([1],[0]))
    return dZ

def get_drift(x, alpha, mu, dt):
    # return np.einsum('ij,ij -> ij', alpha, (mu-x)) * dt
    return alpha * (mu-x) * dt

def get_diffusion(x, sigma, dZ):
    # return np.einsum('ij,ij->ij', np.einsum('ij,ij->ij',sigma,np.sqrt(x)),dZ)
    return sigma * np.sqrt(x) * dZ

def update(x, alpha, mu, sigma, dt, dZ):
    return x + get_drift(x, alpha, mu, dt) + get_diffusion(x, sigma, dZ)

def generate_processes(x0, dt, alpha, mu, sigma, dZ):
    x = np.zeros(dZ.shape)
    x[:, 0] = x0

    for i in range(1, 1200):
        x[:, i] = update(x[:, i-1], alpha, mu, sigma, dt, dZ[:, i-1])

        
    xt1 = np.sum(x,axis=0)
    phi = 0.045 - np.expand_dims(np.sum(x0,axis=0), axis=0)
    
    n = phi + xt1
    return n


### Setting Parameters

In [20]:
n_trials = 100

n_years = 100
dt = 1/12
n_factors = 2
rho = 0.739

x0 = np.array([np.repeat(0.0228, n_trials),np.repeat(0.0809, n_trials)])
alpha = np.array([np.repeat(1.0682, n_trials), np.repeat(0.0469,n_trials)])
mu = np.array([np.repeat(0.0546, n_trials), np.repeat(0.0778, n_trials)])
sigma = np.array([np.repeat(0.0412, n_trials), np.repeat(0.0287, n_trials)])

### Generating Scenarios

In [21]:
key = jax.random.PRNGKey(57)
X = jax.random.normal(key=key, shape=(n_factors, int(n_years / dt), n_trials))


start = time.time()
dZ = correlate_weiner_process(X, rho, dt)
end1 = time.time()
print("Time taken for calculating Weiner process correlations: ", end1 - start)

Time taken for calculating Weiner process correlations:  0.004537820816040039


In [22]:
with Profile() as prof:
    n_npy = generate_processes(x0, dt, alpha, mu, sigma, dZ)
    prof.print_stats(sort='tottime')

         4152 function calls (4145 primitive calls) in 0.045 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1199    0.015    0.000    0.032    0.000 708260610.py:16(update)
     1199    0.010    0.000    0.010    0.000 708260610.py:8(get_drift)
        2    0.007    0.003    0.007    0.003 {method 'reduce' of 'numpy.ufunc' objects}
     1199    0.007    0.000    0.007    0.000 708260610.py:12(get_diffusion)
        1    0.003    0.003    0.033    0.033 708260610.py:19(generate_processes)
       13    0.002    0.000    0.003    0.000 socket.py:621(send)
        2    0.001    0.001    0.008    0.004 {method '__exit__' of 'sqlite3.Connection' objects}
        1    0.000    0.000    0.000    0.000 {method 'execute' of 'sqlite3.Connection' objects}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       54    0.000    0.000    0.000    0.000 enum.py:1544(_get_value)
       13  

In [23]:
plt.plot(n_npy);

## The Jax Approach

In [24]:
# Check Jax uses GPU

print(jax.devices())

[cuda(id=0)]


### Defining Our Model

In [25]:
def correlate_weiner_process_jax(X, rho, dt):
    corr = jnp.array([[1,rho],[rho,1]])
    L = jnp.linalg.cholesky(corr)
    dZ = jnp.sqrt(dt)*jnp.tensordot(L, X, axes=([1],[0]))
    # dZ = jnp.sqrt(dt)*(jnp.einsum('ij,ikl -> jkl',L,X))
    return dZ

def get_drift_jax(x, alpha, mu, dt):
    # return jnp.einsum('ij,ij -> ij', alpha, (mu-x)) * dt
    return alpha * (mu - x) * dt

def get_diffusion_jax(x, sigma, dZ):
    # return jnp.einsum('ij,ij->ij', jnp.einsum('ij,ij->ij',sigma, jnp.sqrt(x)),dZ)
    return sigma * jnp.sqrt(x) * dZ

def update_jax(x, alpha, mu, sigma, dt, dZ):
    return x + get_drift_jax(x, alpha, mu, dt) + get_diffusion_jax(x, sigma, dZ)

def generate_processes_jax(x0, dt, alpha, mu, sigma, dZ):
    x = jnp.zeros(dZ.shape)
    x = x.at[:, 0].set(x0)

    update_jax_jit = jax.jit(update_jax)
    for i in range(1, dZ.shape[1]):
        x = x.at[:, i].set(update_jax_jit(x[:, i-1], alpha, mu, sigma, dt, dZ[:, i-1]))
        
    xt1 = jnp.sum(x,axis=0)
    phi = 0.045 - jnp.expand_dims(jnp.sum(x0,axis=0), axis=0)
    
    n = phi + xt1
    return n


### Setting Parameters

In [26]:
n_trials = 100

n_years = 100
dt = 1/12
n_factors = 2
rho = 0.739

x0 = jnp.array([jnp.repeat(0.0228, n_trials),jnp.repeat(0.0809, n_trials)])
alpha = jnp.array([jnp.repeat(1.0682, n_trials), jnp.repeat(0.0469,n_trials)])
mu = jnp.array([jnp.repeat(0.0546, n_trials), jnp.repeat(0.0778, n_trials)])
sigma = jnp.array([jnp.repeat(0.0412, n_trials), jnp.repeat(0.0287, n_trials)])

### Generating Scenarios

In [27]:
key = jax.random.PRNGKey(57)
X = jax.random.normal(key=key, shape=(n_factors, int(n_years / dt), n_trials))

correlate_weiner_process_jax_jit = jax.jit(correlate_weiner_process_jax)

start = time.time()
dZ = correlate_weiner_process_jax_jit(X, rho, dt).block_until_ready()
end1 = time.time()
print("Time taken for calculating Weiner process correlations: ", end1 - start)

Time taken for calculating Weiner process correlations:  0.12123584747314453


In [28]:
# Profile the jax code with jax profiling
generate_processes_jax_jit = jax.jit(generate_processes_jax)

with Profile() as prof:
    n_jax = generate_processes_jax_jit(x0, dt, alpha, mu, sigma, dZ)
    prof.print_stats(sort='tottime')

         7123858 function calls (7002356 primitive calls) in 139.860 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     30/1  127.385    4.246    0.000    0.000 {method 'poll' of 'select.epoll' objects}
        1    7.318    7.318    7.318    7.318 compiler.py:215(backend_compile)
117799/61340    0.351    0.000    0.570    0.000 {jaxlib.utils.safe_map}
1492125/1492121    0.222    0.000    0.234    0.000 {built-in method builtins.isinstance}
        1    0.198    0.198    0.245    0.245 mlir.py:488(make_ir_context)
    74441    0.129    0.000    0.239    0.000 dtypes.py:329(issubdtype)
    37198    0.121    0.000    0.252    0.000 dtypes.py:644(dtype)
    12024    0.117    0.000    0.140    0.000 source_info_util.py:219(current)
   381294    0.103    0.000    0.144    0.000 {method 'get' of 'dict' objects}
        1    0.080    0.080    0.080    0.080 pxla.py:1196(__call__)
   2411/0    0.076    0.000    0.000          pji

In [29]:
plt.plot(n_jax[:, :1000]);

## Benchmarking!
To examine how efficient our approaches are at scenario generation we'll be comparing 6 options over an exponentially increasing number of runs. The options we'll be including for benchmarking are:
- PYESG
- Our Numpy approach.
- Our JAX approach with Just in Time (JIT) compilation on a CPU.
- Our JAX approach with Just in Time (JIT) compilation on a GPU.
- Our Cython approach.

Benchmarking is being performed on a Dell G15 5510 laptop running Arch Linux with kernel version 6.8.5-arch1-1 with the following computational resources:
- 16 × Intel® Core™ i7-10870H CPU @ 2.20GHz
- 31.1 GiB of RAM 
- NVIDIA GeForce RTX 3060

In [5]:
n_trials = [1*10**(i) for i in range(7)]


[1, 10, 100, 1000, 10000, 100000, 1000000]