# Speed up with `jax`

In [13]:
import jax
import jax.numpy as np
import numpy as onp
import timeit

## Basic comparison

In [29]:
ar = np.arange(1000)

In [44]:
@jax.jit
def fjax(ar):
    return np.sum(ar)

%timeit fjax(ar)

140 µs ± 496 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [45]:
def fnp(ar):
    return onp.sum(ar)

%timeit fnp(ar)

162 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Benchmark nesting

### 1. single nesting (trivial operation)

In [57]:
@jax.jit
def gjax(ar, f=fjax):
    return f(ar)

%timeit gjax(ar)

140 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [58]:
def gnp(ar, f=fnp):
    return fnp(ar)

%timeit gnp(ar)

162 µs ± 640 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### 2. single nesting (non-trivial)

In [53]:
@jax.jit
def g1jax(ar, f=fjax):
    i = 0
    for i in range(ar.size):
        i += np.size(f(ar))
    return i

%timeit g1jax(ar)

73.1 µs ± 414 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [56]:
def g1np(ar, f=fnp):
    i = 0
    for i in range(ar.size):
        i += onp.size(f(ar))
    return i

%timeit g1np(ar)

162 ms ± 1.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### 3. multiple nesting

In [180]:
@jax.jit
def njax(ar):    
    def nest(f, i):
        return lambda x, f=f, i=i: x + f(i)
    
    f = lambda x: x
    for i in ar:
        f = nest(f, i)

    return f(ar.size)

%timeit njax(np.arange(1000))

259 µs ± 106 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [181]:
def nnp(ar):    
    def nest(f, i):
        return lambda x, f=f, i=i: x + f(i)
    
    f = lambda x: x
    for i in ar:
        f = nest(f, i)

    return f(ar.size)

%timeit nnp(np.arange(1000))

502 µs ± 4.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
