# Calculating pi by counting collisions with numba
Based on a 3Blue1Brown video and G Galperin's work. 150x speedup with numba.<br>
Original paper https://www.maths.tcd.ie/~lebed/Galperin.%20Playing%20pool%20with%20pi.pdf<br>
Collision formula from https://en.wikipedia.org/wiki/Elastic_collision#One-dimensional_Newtonian 
## Original Video
https://www.youtube.com/watch?v=HEfHFsfGXjs <br>
https://www.youtube.com/watch?v=jsYwFizhncE
## Awesome demo

https://prajwalsouza.github.io/Experiments/Colliding-Blocks.html

In [1]:
def collisions(m1, m2):
    v1, v2 = 0, -1
    mtotal, mdiff = m1 + m2, m1 - m2
    collisions = 0
    while v1 > v2:
        collisions += 1 
        v1, v2 = (v1*mdiff+2*m2*v2)/mtotal, (2*m1*v1-mdiff*v2)/mtotal
        if v1 < 0:
            collisions += 1 
            v1 = -v1
    return collisions

In [2]:
%%timeit -n 1 -r 1
print(collisions(1,int(1e12)))

3141592
922 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [3]:
#import numba
from numba import jit
#Changing local type does not help, fastmath = True makes it 2x faster
@jit(["int32(float32,float32)"], fastmath=True)#, locals={'collisions':numba.int32})
def collisions_jit(m1, m2):
    v1, v2 = 0, -1
    mtotal, mdiff = m1 + m2, m1 - m2
    collisions = 0
    while v1 > v2:
        collisions += 1
        v1, v2 = (v1*mdiff+2*m2*v2)/mtotal, (2*m1*v1-mdiff*v2)/mtotal
        if v1 < 0:
            collisions += 1 
            v1 = -v1
    return collisions

In [4]:
%%timeit -n 1 -r 1
print(collisions_jit(1,int(1e12))) #First run to compile the function

3141592
6.17 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [5]:
%%timeit -n 10 -r 10
collisions_jit(1,int(1e12))

5.32 ms ± 64.4 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
