Notebook to find out if we can replace `gammaln` of scipy with a numba implementation

In [2]:
from jupyter_black import load

load(line_length=100)

In [4]:
from scipy.special import gammaln
from numba import njit, float64
from timeit import timeit
import numpy as np

In [5]:
@njit()
def gammaln_pos_int(x: float64):
    """Logarithm of the Gamma function for positive integer values (as float64)."""
    result = 1.0
    while x - 2.0 > 1e-6:
        x -= 1.0
        result *= x
    return np.log(result)

In [6]:
%%time
gammaln_pos_int(5)

CPU times: user 118 ms, sys: 6.84 ms, total: 124 ms
Wall time: 124 ms


3.1780538303479458

In [7]:
%%time
gammaln_pos_int(5)

CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 7.15 µs


3.1780538303479458

In [8]:
%%timeit -r 100 -n 10000
gammaln_pos_int(25)

139 ns ± 34.4 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [9]:
%%timeit -r 100 -n 10000
gammaln(25)

672 ns ± 67.6 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [45]:
%%timeit -r 100 -n 10000
gammaln_pos_int(1000)

934 ns ± 90.6 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [46]:
%%timeit -r 100 -n 10000
gammaln(1000)

664 ns ± 44.6 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [10]:
abs(gammaln_pos_int(25) - gammaln(25))

0.0

In [267]:
@njit()
def gammaln_pos_float(x: float64):
    """Logarithm of the Gamma function for positive integer values (as float64)."""

    f = np.log(np.sqrt(2 * np.pi))

    return f + np.log(x) * (x - 0.5) - x

In [268]:
%time
gammaln_pos_float(125)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.72 µs


477.0439988273413

In [269]:
%%timeit -r 100 -n 10000
gammaln_pos_float(125)

130 ns ± 34.9 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [270]:
%%timeit -r 100 -n 10000
gammaln(125)

668 ns ± 67.5 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [271]:
for x in [0.1, 0.2, 0.5, 1.0, 1.5, 2.0, 5.0, 7.5, 10, 15, 20]:
    print(x, abs(gammaln_pos_float(x) - gammaln(x)))

0.1 0.5127400813319152
0.2 0.3222939154958817
0.5 0.1534264097200273
1.0 0.08106146679532733
1.5 0.054814121051917686
2.0 0.04134069595540946
5.0 0.016644691189821703
7.5 0.01110455975820912
10 0.008330563433359472
15 0.0055547335519641194
20 0.004166319691996989


In [261]:
@njit()
def gammaln_pos_float_2(z: float64):
    """Logarithm of the Gamma function for positive integer values (as float64)."""

    if z > 2.0:
        # Has error
        f = np.log(np.sqrt(2 * np.pi))
        return f + np.log(z) * (z - 0.5) - z
    else:
        x = 2.0
        result = 0.0
        for n in range(500):
            denominator = 1.0
            for i in range(n):
                denominator *= z + i
            result += x**n / denominator
        return np.log(x**z * np.exp(-x) * result + (x + 1) * np.exp(-x))

In [262]:
%%timeit -r 100 -n 100
gammaln_pos_float_2(1.5)

The slowest run took 11.07 times longer than the fastest. This could mean that an intermediate result is being cached.
93.8 µs ± 85.1 µs per loop (mean ± std. dev. of 100 runs, 100 loops each)


In [263]:
%%timeit -r 100 -n 10000
gammaln(1.5)

475 ns ± 10 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [266]:
for x in [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]:
    print(x, abs(gammaln_pos_float_2(x) - gammaln(x)))

1.0 0.8779680488503786
1.1 0.8818732369301827
1.2 0.8822033901534023
1.3 0.8789364739705774
1.4 0.8720828222898078
1.5 0.8616774121867616
1.6 0.8477735545627689
1.7 0.8304378017827387
1.8 0.8097458794788825
1.9 0.7857794673122874
2.0 0.7586236756795133
