# Summary

This notebook gives an introduction to two HPC-libraries in python: Numba and Jax. Both have their advantages, not one 
dominates the other and it is up to hardware, personal taste and familiarity which is more efficient to use. What they have in common is [just in time compilation](https://en.wikipedia.org/wiki/Just-in-time_compilation), which is a huge advantage over pure python.


We will look at the same value function iteraiton problem as last week:

In [1]:
import numpy as np

In [2]:
beta_factor  = 0.95
grid_x = np.linspace(0.1, 1, 100)
num_iterations = 100

In [3]:
def value_function_iteration(grid, num_iter, beta):
    value = np.sqrt(grid)
    for i in range(num_iter):
        value = single_iteration(grid, beta, value)
    return value


def single_iteration(grid, beta, v_old):
    v_new = np.zeros_like(v_old)
    for id_x, x in enumerate(grid):
        mask = grid <= x
        v_new[id_x] = np.max((np.sqrt(x - grid[mask]) + beta * v_old[mask]))
    return v_new

In [4]:
%timeit value_function_iteration(grid_x, num_iterations, beta_factor)

107 ms ± 24.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


This code version was already a more efficent version, than just writing everything in loops. An complete loop version would be even slower.

In [5]:
def value_function_iteration_all_loops(grid, num_iter, beta):
    value = np.sqrt(grid)
    for i in range(num_iter):
        value = single_iteration_all_loops(grid, beta, value)
    return value

def single_iteration_all_loops(grid, beta, v_old):
    v_new = np.empty_like(v_old)
    for id_x, x in enumerate(grid):
        for id_x_prime, x_prime in enumerate(grid):
            array_to_max = np.zeros_like(grid)
            if x_prime <= x:
                array_to_max[id_x_prime] = np.sqrt(x - x_prime) + beta * v_old[id_x_prime]
            v_new[id_x] = np.max(array_to_max)
    return v_new


In [6]:
%timeit value_function_iteration_all_loops(grid_x, num_iterations, beta_factor)

4.64 s ± 22.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


So always try to use numpy functions, when possible. They are usually faster than loops. If you can't avoid loops, then you can use numba:

In [7]:
import numba as nb

In [8]:
@nb.jit(nopython=True)
def value_function_iteration_numba(grid, num_iter, beta):
    value = np.sqrt(grid)
    for i in range(num_iter):
        value = single_iteration_numba(grid, beta, value)
    return value

@nb.jit(nopython=True)
def single_iteration_numba(grid, beta, v_old):
    v_new = np.zeros_like(v_old)
    for id_x, x in enumerate(grid):
        mask = grid <= x
        v_new[id_x] = np.max((np.sqrt(x - grid[mask]) + beta * v_old[mask]))
    return v_new

In [9]:
value_function_iteration_numba(grid_x, num_iterations, beta_factor)
%timeit value_function_iteration_numba(grid_x, num_iterations, beta_factor)

3.38 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Even though numba has worked right of the box here, it does not always do that. Error messages can be hard to debug and it takes some experience. With numba you can always start small and then escalate your code by adding more components. As soon as it fails you will see what the issue is.

Another library you can use is jax by google. It is made for machine and deep learning and provides a powerfull framework. There are great resources, youtube tutorials and blog articles about the use of jax. As it is the main workhourse for google companies like DeepMind, there is a lot of support and development going on.

In [10]:
import jax
import jax.numpy as jnp

In [11]:
def value_function_iteration_jax(grid, beta):
    value = jnp.sqrt(grid)
    iter_func = jax.vmap(single_iteration_jax, in_axes=(0, None, None, None))

    for i in range(100):
        value = iter_func(grid, grid, beta, value)
    return value


def single_iteration_jax(x, grid, beta, v_old):
    v_new = jnp.max(jnp.sqrt(jnp.maximum(x - grid, 0)) + beta * v_old)
    return v_new

In [12]:
jax_jit_value_func = jax.jit(value_function_iteration_jax)

In [13]:
jax.block_until_ready(jax_jit_value_func(grid_x, beta_factor))
%timeit jax.block_until_ready(jax_jit_value_func(grid_x, beta_factor))

1.09 ms ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Not only is jax slightly faster than numba in this case, it has a lot of other advantages:

- Code can be executed directly on GPUs and TPUs (optimization might be different - RAM vs. computation time)
- Error messages are much easier to read
- Jax supports without performance loss all python containers
- Jax supports automatic differentiation
- Jax rewrites your code and optimizes it for you