This notebook assumes that you are running an environment where cython is available, for example, a Docker container running the jupyter/scipy-notebook image (https://github.com/jupyter/docker-stacks/tree/master/scipy-notebook)

We are going to use a simple example to show the benefit of using Numba to optimize code.


In [1]:
def is_prime(n):
    if n < 2: 
        return False
    for div in range (2, int (n**0.5) + 1):
        if n % div == 0:
            return False
    return True


In [2]:
%time is_prime(2147483647)
%time is_prime(2147483647)
%time is_prime(2147483647)
%timeit is_prime(2147483647)

Wall time: 9.65 ms
Wall time: 11.1 ms
Wall time: 6.92 ms
8.49 ms ± 474 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [3]:
import numba

@numba.jit
def is_prime(n):
    if n < 2: 
        return False
    for div in range (2, int (n**0.5) + 1):
        if n % div == 0:
            return False
    return True

In [4]:
%time is_prime(2147483647)
%time is_prime(2147483647)
%time is_prime(2147483647)
%timeit is_prime(2147483647)

Wall time: 485 ms
Wall time: 0 ns
Wall time: 0 ns
193 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [5]:
def is_prime(n):
    if n < 2: 
        return False
    for div in range (2, int (n**0.5) + 1):
        if n % div == 0:
            return False
    return True

def count_primes(N):
    count = 0
    for i in range(1, N):
        if is_prime(i):
            count += 1
    print (count)

In [6]:
%time count_primes(1000_000)

78498
Wall time: 5.58 s


In [7]:
import numba
@numba.jit
def is_prime(n):
    if n < 2: 
        return False
    for div in range (2, int (n**0.5) + 1):
        if n % div == 0:
            return False
    return True

def count_primes(N):
    count = 0
    for i in range(1, N):
        if is_prime(i):
            count += 1
    print (count)

In [8]:
%time count_primes(1000_000)

78498
Wall time: 658 ms


In [9]:
@numba.njit(numba.int32(numba.int32))
def is_prime(n):
    if n < 2: 
        return False
    for div in range (2, int (n**0.5) + 1):
        if n % div == 0:
            return False
    return True

@numba.njit(numba.void(numba.int32))
def count_primes(N):
    count = 0
    for i in range(1, N):
        if is_prime(i):
            count += 1
    print (count)

In [10]:
%time count_primes(1000_000)

78498
Wall time: 465 ms


Hint (*Thanks, Gerry!*): restart the kernel before you run the next steps. 

(

In [1]:
%%file numba_helper.py
import numba

@numba.njit(numba.int32(numba.int32))
def is_prime_jit(n):
    if n < 2:
        return False
    for div in range(2, int(n**0.5)+1):
        if n % div == 0: 
            return False
    return True


def is_prime_py(n):
    if n < 2:
        return False
    for div in range(2, int(n**0.5) +1):
        if n % div == 0: 
            return False
    return True

Overwriting numba_helper.py


In [2]:
from multiprocessing import Pool
from numba_helper import *

def run(func):
    with Pool(8) as p:
        s = sum(p.imap_unordered(func, range(1, 10_000_000), chunksize=40_0000))
    print (s)
    
%time sum(map(is_prime_py, range(10_000_000)));
%time sum(map(is_prime_jit, range(10_000_000)));
%time run(is_prime_py)
%time run(is_prime_jit)


Wall time: 3min 20s
Wall time: 12.1 s


KeyboardInterrupt: 

664579
Wall time: 20.6 s
