In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


# Numba Examble

JIT-compiled (thread) parallel python code: https://gist.github.com/safijari/fa4eba922cea19b3bc6a693fe2a97af7

We want to solve a silly version of the (under) damped spring-mass problem

![spring-mass](https://upload.wikimedia.org/wikipedia/commons/f/fa/Spring-mass_under-damped.gif)

In [2]:
def friction_fn(v, vt):
    if v > vt:
        return - v * 3
    else:
        return - vt * 3 * np.sign(v)


def simulate_spring_mass_funky_damper(x0, T=10, dt=0.0001, vt=1.0):
    times = np.arange(0, T, dt)
    positions = np.zeros_like(times)

    v = 0
    a = 0
    x = x0
    positions[0] = x0/x0

    for ii in range(len(times)):
        if ii == 0:
            continue
        t = times[ii]
        a = friction_fn(v, vt) - 100*x
        v = v + a*dt
        x = x + v*dt
        positions[ii] = x/x0

    return times, positions

In [3]:
plot(*simulate_spring_mass_funky_damper(0.1))
plot(*simulate_spring_mass_funky_damper(1))
plot(*simulate_spring_mass_funky_damper(10))
legend(['0.1', '1', '10'])

savefig("ts_python.png")
close()

This code generates a time series of the (normalized) position given its different initial position:

![ts](ts_python.png)

In [4]:
%time _ = simulate_spring_mass_funky_damper(1)

CPU times: user 280 ms, sys: 4.82 ms, total: 285 ms
Wall time: 289 ms


## Compile with Numba

Remember to use the `njit` decorator, to disable (slow) python support

In [5]:
from numba import njit

In [6]:
@njit
def numba_friction_fn(v, vt):
    if v > vt:
        return - v * 3
    else:
        return - vt * 3 * np.sign(v)

@njit
def numba_simulate_spring_mass_funky_damper(x0, T=10, dt=0.0001, vt=1.0):
    times = np.arange(0, T, dt)
    positions = np.zeros_like(times)
    
    v = 0
    a = 0
    x = x0
    positions[0] = x0/x0
    
    for ii in range(len(times)):
        if ii == 0:
            continue
        t = times[ii]
        a = numba_friction_fn(v, vt) - 100*x
        v = v + a*dt
        x = x + v*dt
        positions[ii] = x/x0
    return times, positions

In [7]:
_ = numba_simulate_spring_mass_funky_damper(0.1)

In [8]:
_, ax = subplots(nrows=1, ncols=2, sharey=True, figsize=(12,5))

ax[0].plot(*numba_simulate_spring_mass_funky_damper(0.1))
ax[0].plot(*numba_simulate_spring_mass_funky_damper(1))
ax[0].plot(*numba_simulate_spring_mass_funky_damper(10))
ax[0].legend(['0.1', '1', '10'])

ax[1].plot(*simulate_spring_mass_funky_damper(0.1))
ax[1].plot(*simulate_spring_mass_funky_damper(1))
ax[1].plot(*simulate_spring_mass_funky_damper(10))
ax[1].legend(['0.1', '1', '10'])

savefig("ts_numba.png")
close()

Numba-generated a time series (left) vs original python version (right):

![ts](ts_numba.png)

In [9]:
%time _ = simulate_spring_mass_funky_damper(0.1)

CPU times: user 308 ms, sys: 5.5 ms, total: 313 ms
Wall time: 319 ms


In [10]:
%time _ = numba_simulate_spring_mass_funky_damper(1)

CPU times: user 1.49 ms, sys: 44 µs, total: 1.53 ms
Wall time: 1.66 ms
