# Efficient Economic Scenario Generation - Vectorised Numpy Approach

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import time
import jax
import jax.numpy as jnp
from cProfile import Profile


## The Numpy Approach

### Defining Our Model

In [4]:
def correlate_weiner_process(X, rho, dt):
    corr = np.array([[1,rho],[rho,1]])
    L = np.linalg.cholesky(corr)
    dZ = np.sqrt(dt)*(np.einsum('ij,ikl -> ikl',L,X))
    return dZ

def get_drift(x, alpha, mu, dt):
    return np.einsum('ij,ij -> ij', alpha, (mu-x)) * dt

def get_diffusion(x, sigma, dZ):
    return np.einsum('ij,ij->ij', np.einsum('ij,ij->ij',sigma,np.sqrt(x)),dZ)

def update(x, alpha, mu, sigma, dt, dZ):
    return x + get_drift(x, alpha, mu, dt) + get_diffusion(x, sigma, dZ)

def generate_processes(x0, dt, alpha, mu, sigma, dZ):
    x = np.zeros(dZ.shape)
    x[:, 0] = x0

    for i in range(1, 1200):
        x[:, i] = update(x[:, i-1], alpha, mu, sigma, dt, dZ[:, i-1])

        
    xt1 = np.sum(x,axis=0)
    phi = 0.045 - np.expand_dims(np.sum(x0,axis=0), axis=0)
    
    n = phi + xt1
    return n


### Setting Parameters

In [5]:
n_trials = 100000

n_years = 100
dt = 1/12
n_factors = 2
rho = 0.739

x0 = np.array([np.repeat(0.0228, n_trials),np.repeat(0.0809, n_trials)])
alpha = np.array([np.repeat(1.0682, n_trials), np.repeat(0.0469,n_trials)])
mu = np.array([np.repeat(0.0546, n_trials), np.repeat(0.0778, n_trials)])
sigma = np.array([np.repeat(0.0412, n_trials), np.repeat(0.0287, n_trials)])

### Generating Scenarios

In [6]:
key = jax.random.PRNGKey(57)
X = jax.random.normal(key=key, shape=(n_factors, int(n_years / dt), n_trials))


start = time.time()
dZ = correlate_weiner_process(X, rho, dt)
end1 = time.time()
print("Time taken for calculating Weiner process correlations: ", end1 - start)

2024-04-25 10:55:12.970408: W external/xla/xla/service/hlo_rematerialization.cc:2946] Can't reduce memory use below 3.28GiB (3519065080 bytes) by rematerialization; only reduced to 3.58GiB (3840000016 bytes), down from 3.58GiB (3840000016 bytes) originally


Time taken for calculating Weiner process correlations:  3.117891311645508


In [7]:
with Profile() as prof:
    n = generate_processes(x0, dt, alpha, mu, sigma, dZ)
    prof.print_stats(sort='tottime')

         29645 function calls (29620 primitive calls) in 3.860 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     3597    1.706    0.000    1.706    0.000 {built-in method numpy.core._multiarray_umath.c_einsum}
     1199    0.673    0.001    1.373    0.001 3917983083.py:7(get_drift)
        2    0.366    0.183    0.366    0.183 {method 'reduce' of 'numpy.ufunc' objects}
        6    0.353    0.059    2.574    0.429 base_events.py:1910(_run_once)
      2/1    0.270    0.135    1.281    1.281 3917983083.py:16(generate_processes)
     1199    0.265    0.000    2.866    0.002 3917983083.py:13(update)
     1199    0.210    0.000    1.222    0.001 3917983083.py:10(get_diffusion)
      9/4    0.006    0.001    0.001    0.000 {method 'run' of '_contextvars.Context' objects}
    17985    0.004    0.000    0.004    0.000 einsumfunc.py:1001(_einsum_dispatcher)
     3597    0.003    0.000    1.709    0.000 einsumfunc.py:1009(einsum)


In [8]:
# plt.plot(n);

## The Jax Approach

In [9]:
# Check Jax uses GPU

print(jax.devices())

[cuda(id=0)]


### Defining Our Model

In [10]:
def correlate_weiner_process_jax(X, rho, dt):
    corr = jnp.array([[1,rho],[rho,1]])
    L = jnp.linalg.cholesky(corr)
    dZ = jnp.sqrt(dt)*jnp.tensordot(L, X, axes=([1],[0]))
    # dZ = jnp.sqrt(dt)*(jnp.einsum('ij,ikl -> jkl',L,X))
    return dZ

def get_drift_jax(x, alpha, mu, dt):
    # return jnp.einsum('ij,ij -> ij', alpha, (mu-x)) * dt
    return alpha * (mu - x) * dt

def get_diffusion_jax(x, sigma, dZ):
    # return jnp.einsum('ij,ij->ij', jnp.einsum('ij,ij->ij',sigma, jnp.sqrt(x)),dZ)
    return sigma * jnp.sqrt(x) * dZ

def update_jax(x, alpha, mu, sigma, dt, dZ):
    return x + get_drift_jax(x, alpha, mu, dt) + get_diffusion_jax(x, sigma, dZ)

def generate_processes_jax(x0, dt, alpha, mu, sigma, dZ):
    x = jnp.zeros(dZ.shape)
    x = x.at[:, 0].set(x0)

    update_jax_jit = jax.jit(update_jax)
    for i in range(1, 1200):
        x = x.at[:, i].set(update_jax_jit(x[:, i-1], alpha, mu, sigma, dt, dZ[:, i-1]))

        
    xt1 = jnp.sum(x,axis=0)
    phi = 0.045 - jnp.expand_dims(jnp.sum(x0,axis=0), axis=0)
    
    n = phi + xt1
    return n


### Setting Parameters

In [11]:
n_trials = 100000

n_years = 100
dt = 1/12
n_factors = 2
rho = 0.739

x0 = jnp.array([jnp.repeat(0.0228, n_trials),jnp.repeat(0.0809, n_trials)])
alpha = jnp.array([jnp.repeat(1.0682, n_trials), jnp.repeat(0.0469,n_trials)])
mu = jnp.array([jnp.repeat(0.0546, n_trials), jnp.repeat(0.0778, n_trials)])
sigma = jnp.array([jnp.repeat(0.0412, n_trials), jnp.repeat(0.0287, n_trials)])

### Generating Scenarios

In [12]:
key = jax.random.PRNGKey(57)
X = jax.random.normal(key=key, shape=(n_factors, int(n_years / dt), n_trials))

correlate_weiner_process_jax_jit = jax.jit(correlate_weiner_process_jax)

start = time.time()
dZ = correlate_weiner_process_jax_jit(X, rho, dt).block_until_ready()
end1 = time.time()
print("Time taken for calculating Weiner process correlations: ", end1 - start)

: 

In [None]:
# Profile the jax code with jax profiling
generate_processes_jax_jit = jax.jit(generate_processes_jax)
with Profile() as prof:
    n = generate_processes_jax(x0, dt, alpha, mu, sigma, dZ).block_until_ready()
    prof.print_stats(sort='tottime')

         2451051 function calls (2395393 primitive calls) in 5.364 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      9/5    2.607    0.290    0.011    0.002 mlir.py:1524(jaxpr_subcomp)
   8395/1    0.704    0.000    0.000    0.000 dispatch.py:79(apply_primitive)
       10    0.156    0.016    0.156    0.016 compiler.py:215(backend_compile)
    37287    0.135    0.000    0.267    0.000 dtypes.py:644(dtype)
    65148    0.118    0.000    0.222    0.000 dtypes.py:329(issubdtype)
457365/457361    0.095    0.000    0.130    0.000 {built-in method builtins.isinstance}
   1199/0    0.078    0.000    0.000          lax_numpy.py:887(squeeze)
   2397/0    0.076    0.000    0.000          lax_numpy.py:4503(_attempt_rewriting_take_via_slice)
   172360    0.062    0.000    0.091    0.000 config.py:275(value)
     8422    0.046    0.000    0.501    0.000 lax.py:518(_convert_element_type)
     1200    0.042    0.000    0.580    0.000 

In [None]:
# plt.plot(n);

NameError: name 'plt' is not defined