# Introduction to faster Python: Prime checking

#### Author: Andrius Burnelis

We are often tasked with handling many calculations using Python. The standard approach is to write up a Jupyter notebook, or write some script that can loop through the calculations you need to perform. Standard Python _can_ be fast, however when compared to variations of C, Fortran, or others Python can feel a bit sluggish. The goal of this notebook is to equip you with a few tools that can help you make your Python runtimes competitive against codes written with some flavor of C or Fortran.

The example problem that we are tasked with is to use Python to define a function that will take an array of integers and perform a check to see which numbers are prime. The function will then return an array containing the indices of the prime numbers. This is a simple problem to solve, so we are using this as an illustrative example. In this notebook I will show you various approaches to speeding up your Python code.

In order to have a fair timing comparison between the different methods, it is important we use the same array throughout the notebook. I define the test array here as an array of 100 randomly generated integers between 0 and 5000. You can feel free to modify the test array to whatever size you may like. The only stipulation here is that you make sure that the array only contains integers as some functions may not be coded to handle different data types. Each function just assumes that the input array is formatted properly.

NOTE: The timing of these methods can be dependent on hardware. All timings mentioned in this notebook are from my personal laptop, so timings or improvements may be different on your computer.

In [None]:
# Useful imports
import numpy as np
from numba import njit

: 

In [2]:
# Initialize the test array of potential primes
test_array = np.array([3907, 4809, 1118, 1391, 3941, 1246, 2827, 2035, 1000, 1804, 
                       4775, 3979, 4693, 1790, 4472, 3417, 4976, 386, 772, 4125, 
                       138, 4677, 4400, 1806, 398, 4263, 1124, 971, 1472, 2798, 
                       2176, 4819, 2965, 193, 3388, 2896, 907, 4140, 828, 1343, 
                       899, 1711, 640, 4228, 4364, 4829, 3609, 1361, 558, 3902, 
                       2329, 1178, 3414, 2300, 986, 3514, 2011, 467, 2254, 439, 
                       2624, 1329, 1951, 4188, 1505, 3756, 2928, 4084, 1780, 316, 
                       1746, 2933, 125, 1801, 4818, 654, 4458, 2274, 2121, 3859, 
                       882, 4700, 2041, 1676, 4865, 3371, 1525, 3716, 3584, 3427, 
                       4483, 3677, 62, 3449, 493, 1125, 3931, 3586, 3597, 3307])

# # If you would like to generate a larger array of random numbers, you can modify the following lines:
# np.random.seed(0)  # For reproducibility
# int_min = 1
# int_max = 5000
# size = 10000
# test_array = np.random.randint(int_min, int_max, size=size)

------------------
------------------
## Algorithmic speed-up

The first approach to speeding up your code is to make sure that you are using the correct algorithm. Things you do in your code may be working and give you the right answer but the way you do it may not be the smartest. The standard brute force would be to check the divisibility of **every single number** up to $n - 1$. If any of those numbers divide $n$, then $n$ is not prime. This is incredibly inefficient however it will be the starting point of this journey.

In [3]:
def brute_force(ns):
    """
    Brute force algorithm to check if a number is prime.
    This method is incredibly inefficient since it checks ALL numbers up to n.
    """
    # Initialize a list to store the indices of prime numbers
    prime_indices = []

    # Iterate through each number contained in array ns to check if it is prime
    for index in range(0, ns.shape[0]):
        # Check if n is less than 2
        if ns[index] < 2:
            continue

        # Check if n is prime
        for i in range(2, ns[index] - 1):
            if ns[index] % i == 0:
                break
        else:
            # If n is prime, append its index to the list
            prime_indices.append(index)

    return np.array(prime_indices)

Aside: Jupyter notebooks have some cool little commands you can run in each code cell. The one we will be using in this demo is "%%timeit". This command will loop over the segment of code in the cell a few times and then provide you with runtime details. In our case, the segment of code is a function call to check what numbers in our test array are prime.

In [4]:
%%timeit
brute_force(test_array)

3.19 ms ± 104 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


My laptop ran this in a few milliseconds (~3 ms). This isn't bad as it would take me a much longer time to check these by hand, but we can do better. This algorithm is not optimal for many reasons, however the worst of all is that we are checking every single number. We can restrict the search space by only checking up to $\sqrt{n}$.

In [5]:
def smarter_search(ns):
    """
    Smarter search algorithm to check if a number is prime.
    This method checks only numbers up to the square root of n.
    """
    # Initialize a list to store the indices of prime numbers
    prime_indices = []

    # Iterate through each number contained in array ns to check if it is prime
    for index in range(0, ns.shape[0]):
        # Check if n is less than 2
        if ns[index] < 2:
            continue

        # Check if n is prime
        for i in range(2, int(np.sqrt(ns[index])) + 1):
            if ns[index] % i == 0:
                break
        else:
            # If n is prime, append its index to the list
            prime_indices.append(index)

    return np.array(prime_indices)

In [6]:
%%timeit
smarter_search(test_array)

137 μs ± 1.94 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Just the simple change of reducing the search from $n - 1$ down to $\sqrt{n}$ my computer now takes about a hundred microseconds (~130 µs)! This is getting better, but if we needed to do this many many times it is still slow. We could also reduce the search space and skip checking over all the even numbers but why stop there? We can skip over all multiples of prime numbers. This method is known as the sieve of Eratosthenes and can be very efficient.

In [7]:
def even_smarter_search(ns):
    """
    This uses the sieve of Eratosthenes to find all primes up to the maximum number in ns.
    """
    # Get the maximum number in ns
    max_n = np.max(ns)

    # Create a boolean array to mark non-prime numbers
    sieve = np.ones(max_n + 1, dtype = bool)
    sieve[0:2] = False  # 0 and 1 are not prime numbers
    for i in range(2, int(np.sqrt(max_n)) + 1):
        if sieve[i]:
            sieve[i*i:max_n + 1:i] = False  # Mark multiples of i as non-prime
    primes = np.nonzero(sieve)[0]
    primes_in_ns = np.intersect1d(primes, ns) # Get the prime numbers that are in ns
    return np.nonzero(np.isin(ns, primes_in_ns))[0]  # Get the indices of the prime numbers in ns

In [8]:
%%timeit
even_smarter_search(test_array)

51.5 μs ± 1.09 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Changing the algorithm in this way now gives me a runtime of 10s of microseconds. Utilizing a smarter algorithm has dramatically improved the runtime from milliseconds down to 10s of microseconds. That's an improvement of 2 orders of magnitude! This is good but not the end of the story. We now have a smarter algorithm that can run faster with just "vanilla" Python but what if this is still not fast enough?

------------------
------------------
## Numba: Njit

[Numba](https://numba.pydata.org) is a Python library that utilizes just-in-time (JIT) compilation to dramatically boost the speed of some codes. Importing this package grants you the ability to use the decorator before a function you want to compile. Just-in-time compilation is a clean way to easily speed up some code. It is able to do this by exploiting the way you will likely use your code. As with many things, you will likely be computing things multiple times in loops. JIT activates the first time you encounter a piece of code. As the computer is executing the piece of code for the first time, it is also keeping track of everything that is happening and will then compile it after this first run. This first run is often slower than a standard Python call, however every subsequent call will be much faster. 

The way we implement this practically in the code is super easy. We simply need to add the decorator "@njit" right before our definition of the function. This tells the computer that we would like to JIT compile this function. While this decorator is versatile and simple to use, it also has its limits. Later down I show you an example designed to fail to show you how finicky this can get.

In [9]:
@njit
def jitted_brute_force(ns):
    """
    Brute force algorithm to check if a number is prime.
    This method is incredibly inefficient since it checks ALL numbers up to n.
    """
    # Initialize a list to store the indices of prime numbers
    prime_indices = []

    # Iterate through each number contained in array ns to check if it is prime
    for index in range(0, ns.shape[0]):
        # Check if n is less than 2
        if ns[index] < 2:
            continue

        # Check if n is prime
        for i in range(2, ns[index] - 1):
            if ns[index] % i == 0:
                break
        else:
            # If n is prime, append its index to the list
            prime_indices.append(index)

    return np.array(prime_indices)

In [10]:
%%timeit
jitted_brute_force(test_array)

20.5 μs ± 2.71 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


This is amazing! Simply adding the decorator to the worst method has made it even twice as fast as our best algorithm method. We are finally starting to see some significant speed up. Now what if we combine this with the first step of optimization where we consider divisibility checks up to $\sqrt{n}$? Let us find out!

In [11]:
@njit
def jitted_smarter_search(ns):
    """
    Smarter search algorithm to check if a number is prime.
    This method checks only numbers up to the square root of n.
    """
    # Initialize a list to store the indices of prime numbers
    prime_indices = []

    # Iterate through each number contained in array ns to check if it is prime
    for index in range(0, ns.shape[0]):
        # Check if n is less than 2
        if ns[index] < 2:
            continue

        # Check if n is prime
        for i in range(2, int(np.sqrt(ns[index])) + 1):
            if ns[index] % i == 0:
                break
        else:
            # If n is prime, append its index to the list
            prime_indices.append(index)

    return np.array(prime_indices)

In [12]:
%%timeit
jitted_smarter_search(test_array)

1.11 μs ± 15.7 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


This is even better!! Now my laptop went from the worst method ~2 ms down to ~1 µs. This is a factor of 200x faster just by being a little bit clever with the algorithm and then using the Numba decorator. Surely we can do better with the even smarter sieve algorithm right?

In [13]:
#########
# THE FOLLOWING FUNCTION IS BROKEN INTENTIONALLY
# TO SHOW THE LIMITATIONS OF NUMBA AND HOW CAREFUL
# YOU HAVE TO BE WITH YOUR DATA TYPES
#########

@njit
def BROKEN_jitted_even_smarter_search(ns):
    """
    This function in principle uses the sieve of Eratosthenes to find all primes up to the maximum number in ns.
    However, it is broken because Numba cannot compile the boolean array sieve with dtype = bool.
    We can run this cell because we are only defining the function, not executing it.
    When we try to run this, it will throw an error because Numba cannot compile the function.
    """
    # Get the maximum number in ns
    max_n = np.max(ns)

    # Create a boolean array to mark non-prime numbers
    sieve = np.ones(max_n + 1, dtype = bool) ### Numba cannot compile dtype = bool here, so we use uint8
    sieve[0:2] = False  # 0 and 1 are not prime numbers
    for i in range(2, int(np.sqrt(max_n)) + 1):
        if sieve[i]:
            sieve[i*i:max_n + 1:i] = False  # Mark multiples of i as non-prime
    primes = np.nonzero(sieve)[0]
    primes_in_ns = np.intersect1d(primes, ns) # Get the prime numbers that are in ns
    return np.nonzero(np.isin(ns, primes_in_ns))[0]  # Get the indices of the prime numbers in ns

In [14]:
%%timeit
### THIS CELL WILL THROW AN ERROR
BROKEN_jitted_even_smarter_search(test_array)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function ones at 0x10a765580>) found for signature:
 
 >>> ones(int64, dtype=Function(<class 'bool'>))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'ol_np_ones': File: numba/np/arrayobj.py: Line 4586.
    With argument(s): '(int64, dtype=Function(<class 'bool'>))':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   No implementation of function Function(<built-in function empty>) found for signature:
    
    >>> empty(int64, dtype=Function(<class 'bool'>))
    
   There are 2 candidate implementations:
         - Of which 2 did not match due to:
         Overload in function 'ol_np_empty': File: numba/np/arrayobj.py: Line 4440.
           With argument(s): '(int64, dtype=Function(<class 'bool'>))':
          Rejected as the implementation raised a specific error:
            TypingError: Cannot parse input types to function np.empty(int64, Function(<class 'bool'>))
     raised from /Users/andrius/code/gotta_go_fast/venv/lib/python3.12/site-packages/numba/np/arrayobj.py:4459
   
   During: resolving callee type: Function(<built-in function empty>)
   During: typing of call at /Users/andrius/code/gotta_go_fast/venv/lib/python3.12/site-packages/numba/np/arrayobj.py (4593)
   
   
   File "venv/lib/python3.12/site-packages/numba/np/arrayobj.py", line 4593:
       def impl(shape, dtype=None):
           arr = np.empty(shape, dtype=dtype)
           ^
   
   During: Pass nopython_type_inference
  raised from /Users/andrius/code/gotta_go_fast/venv/lib/python3.12/site-packages/numba/core/typeinfer.py:1074

During: resolving callee type: Function(<function ones at 0x10a765580>)
During: typing of call at /var/folders/mb/trbrgjlx0bd1gzxf6cd_k2h40000gn/T/ipykernel_88116/1990369097.py (19)


File "../../../../var/folders/mb/trbrgjlx0bd1gzxf6cd_k2h40000gn/T/ipykernel_88116/1990369097.py", line 19:
<source missing, REPL/exec in use?>

During: Pass nopython_type_inference

Sorry for those of you who just clicked "Run All" at the top... 

This code seems to not have been able to compile. Depending on your needs, Numba may not be able to help you. In this case, the issue comes from the assignment of the data type "bool" to the array of ones that make up the sieve. Numba has limitations, and it cannot compile arrays with this data type. In this case we can get around it by converting the data type to something else, but there may not always be some sort of work around.

In [3]:
@njit
def jitted_even_smarter_search(ns):
    """
    This uses the sieve of Eratosthenes to find all primes up to the maximum number in ns.
    """
    # Get the maximum number in ns
    max_n = np.max(ns)

    # Create a boolean array to mark non-prime numbers
    sieve = np.ones(max_n + 1, dtype = np.uint8) ### Numba cannot compile dtype = bool here, so we use uint8
    sieve[0:2] = False  # 0 and 1 are not prime numbers
    for i in range(2, int(np.sqrt(max_n)) + 1):
        if sieve[i]:
            sieve[i*i:max_n + 1:i] = False  # Mark multiples of i as non-prime
    primes = np.nonzero(sieve)[0]
    primes_in_ns = np.intersect1d(primes, ns) # Get the prime numbers that are in ns
    return np.nonzero(np.isin(ns, primes_in_ns))[0]  # Get the indices of the prime numbers in ns

In [4]:
%%timeit
jitted_even_smarter_search(test_array)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Use of unsupported NumPy function 'numpy.isin' or unsupported use of the function.

File "../../../../var/folders/mb/trbrgjlx0bd1gzxf6cd_k2h40000gn/T/ipykernel_5671/2931390825.py", line 17:
<source missing, REPL/exec in use?>

During: typing of get attribute at /var/folders/mb/trbrgjlx0bd1gzxf6cd_k2h40000gn/T/ipykernel_5671/2931390825.py (17)

File "../../../../var/folders/mb/trbrgjlx0bd1gzxf6cd_k2h40000gn/T/ipykernel_5671/2931390825.py", line 17:
<source missing, REPL/exec in use?>


Wait what? We just took the fastest algorithm and slapped our Numba decorator on there, but we got worse results than the previous method? 

This is most likely due to the fact that there is a lot of array index manipulation happening here. Computational operations like this require a lot of memory overhead; the computer needs to read and write a lot of data multiple times. We need to extract the primes from the sieve, compare to the original test array, and then spit out the indices that match the primes in the test array. This is still a bit faster than the non-JIT counter-part from above, but is significantly slower than our previous method. This is an example where the optimal solution is some combination of multiple methods.

------------------
------------------
## Cython: If Python and C had a child

Now we will switch gears and use a tool called [Cython](https://cython.readthedocs.io/en/latest/index.html). Cython allows you to write Python-like syntax which is then compiled into a C script that you can import and call from Python. I have personally used this method in my research and took code that ran for about a week to run in only a few hours. This method can speed up your code by quite a lot, but there are some downsides such as compatability. It is more readable like Python, but there are a few minor changes to the syntax that I will show you here.

We need to tell the Jupyter notebook to load Cython, and we need to import things using Cython. To load Cython in a Jupyter notebook, we use the following: "%load_ext Cython". To declare that a cell should be interpreted as Cython code, we begin the cell with "%%cython". 

Another thing is that the way Cython interacts with the Jupyter notebook is that it treats the cell as an isolated script. Any dependencies of the cell need to be re-defined in the cell in order for it to compile properly. This is why each cell has the repeated imports. If you were to write a Cython script you can assume variables and dependencies follow the standard scope of their definitions.

In [None]:
%load_ext Cython

In [None]:
import warnings
warnings.filterwarnings("ignore", category = DeprecationWarning)

In [None]:
%%cython
cimport cython
cimport numpy as cnp


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef cython_brute_force(long[:] ns):
    """
    Brute force algorithm to check if a number is prime.
    This method is incredibly inefficient since it checks ALL numbers up to n.
    """
    # Initialize a list to store the indices of prime numbers
    cdef list prime_indices = []

    # Iterate through each number contained in array ns to check if it is prime
    for index in range(0, ns.shape[0]):
        # Check if n is less than 2
        if ns[index] < 2:
            continue

        # Check if n is prime
        for i in range(2, ns[index] - 1):
            if ns[index] % i == 0:
                break
        else:
            # If n is prime, append its index to the list
            prime_indices.append(index)

    return prime_indices

In [None]:
%%timeit
cython_brute_force(test_array)

852 μs ± 39.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Using the worst method, but implementing it in Cython we see that it now takes many hundreds of microseconds. This is still faster than the pure Python implementation, but this isn't great either. We can do better again by doing the minor optimization as we have before.

In [None]:
%%cython
from libc.math cimport sqrt
cimport cython
cimport numpy as cnp

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef cython_smarter_search(long[:] ns):
    """
    Brute force algorithm to check if a number is prime.
    This method is incredibly inefficient since it checks ALL numbers up to n.
    """
    # Initialize a list to store the indices of prime numbers
    cdef list prime_indices = []

    # Iterate through each number contained in array ns to check if it is prime
    for index in range(0, ns.shape[0]):
        # Check if n is less than 2
        if ns[index] < 2:
            continue

        # Check if n is prime
        for i in range(2, int(sqrt(ns[index])) + 1):
            if ns[index] % i == 0:
                break
        else:
            # If n is prime, append its index to the list
            prime_indices.append(index)

    return prime_indices

In [None]:
%%timeit
cython_smarter_search(test_array)

19.8 μs ± 181 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Now we have gotten the timing down to 10s of microseconds. This is much better than the few milliseconds that we started with but still not as good as the corresponding Numba method. We can see what happens if we try Cythonizing the sieve method.

In [None]:
%%cython

from libc.math cimport sqrt
cimport cython
import numpy as np
cimport numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef list cython_even_smarter_search(np.ndarray[long, ndim = 1] arr):
    cdef int i, j, val, max_val = 0
    cdef list result = []

    # Find max value in array
    for i in range(arr.shape[0]):
        if arr[i] > max_val:
            max_val = arr[i]

    # Build sieve
    cdef np.ndarray[np.uint8_t, ndim=1] sieve = np.ones(max_val + 1, dtype = np.uint8)
    sieve[0] = 0
    if max_val >= 1:
        sieve[1] = 0
    for i in range(2, int(sqrt(max_val)) + 1):
        if sieve[i]:
            for j in range(i * i, max_val + 1, i):
                sieve[j] = 0

    # Check array elements against sieve
    for i in range(arr.shape[0]):
        val = arr[i]
        if sieve[val]:
            result.append(i)

    return result


In [None]:
%%timeit
cython_even_smarter_search(test_array)
# This cell might not look like it is running for some reason

77.9 μs ± 4.25 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Again, we are seeing that the array manipulation costs additional overhead and that this is less optimal than the previous method.

------------
------------
## Outlook:

Every code you are tasked to write will be different. Some code may take almost no time, and you won't even need additional speed up, other codes may take forever no matter what methods you try and use. The intention of this notebook is to show you a few simple ways to modify your approach or code to achieve a faster runtime. You may have to experiment a bit to uncover the method that runs the fastest as we did in the cells above. 

A useful tool to assist you in identifying bottlenecks in your code is the Python [profiler](https://docs.python.org/3/library/profile.html#). This tool will allow you to time your code (like our %%timeit cells above) and will provide you with a ton of info. It will keep track of every function call, how many times, how long each evaluation takes and will help you see where the computer is spending its time. Knowing which code is the most expensive will let you optimally optimize your code!

------------
------------
## Choose your own adventure!

Now that you have working examples of the basic function, you should be ready to start implementing these methods on your own! At this point you can either move on to the next notebook, or you can stay here and take on the challenge at the end of this notebook. The choice is yours!

------------
------------

## Challenge: 
See if you can write a function that executes faster than any of the above methods! 

Feel free to copy and paste any of the above code that may be relevant for you.