Notebook to find out if we can replace `gammaln` of scipy with a numba implementation

In [2]:
from jupyter_black import load

load(line_length=100)

In [331]:
from scipy.special import gammaln, gamma
from numba import njit, float64
from timeit import timeit
import numpy as np

In [5]:
@njit()
def gammaln_pos_int(x: float64):
    """Logarithm of the Gamma function for positive integer values (as float64)."""
    result = 1.0
    while x - 2.0 > 1e-6:
        x -= 1.0
        result *= x
    return np.log(result)

In [6]:
%%time
gammaln_pos_int(5)

CPU times: user 118 ms, sys: 6.84 ms, total: 124 ms
Wall time: 124 ms


3.1780538303479458

In [7]:
%%time
gammaln_pos_int(5)

CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 7.15 µs


3.1780538303479458

In [8]:
%%timeit -r 100 -n 10000
gammaln_pos_int(25)

139 ns ± 34.4 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [9]:
%%timeit -r 100 -n 10000
gammaln(25)

672 ns ± 67.6 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [45]:
%%timeit -r 100 -n 10000
gammaln_pos_int(1000)

934 ns ± 90.6 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [46]:
%%timeit -r 100 -n 10000
gammaln(1000)

664 ns ± 44.6 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [10]:
abs(gammaln_pos_int(25) - gammaln(25))

0.0

In [267]:
@njit()
def gammaln_pos_float(x: float64):
    """Logarithm of the Gamma function for positive integer values (as float64)."""

    f = np.log(np.sqrt(2 * np.pi))

    return f + np.log(x) * (x - 0.5) - x

In [268]:
%time
gammaln_pos_float(125)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.72 µs


477.0439988273413

In [269]:
%%timeit -r 100 -n 10000
gammaln_pos_float(125)

130 ns ± 34.9 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [270]:
%%timeit -r 100 -n 10000
gammaln(125)

668 ns ± 67.5 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [305]:
for x in [0.1, 0.2, 0.5, 1.0, 1.5, 2.0, 5.0, 7.5, 10, 15, 20, 30]:
    print(
        x,
        abs(gammaln_pos_float(x) - gammaln(x)),
        abs(np.exp(gammaln_pos_float_2(x)) - np.exp(gammaln(x))),
    )

0.1 0.5127400813319152 9.956790299542256
0.2 0.3222939154958817 5.033155735657984
0.5 0.1534264097200273 2.208587837438079
1.0 0.08106146679532733 1.4060204363255755
1.5 0.054814121051917686 1.2115990758104906
2.0 0.04134069595540946 1.1353525765580148
5.0 0.016644691189821703 0.3961664084819958
7.5 0.01110455975820912 20.664507784007355
10 0.008330563433359472 3010.43812589522
15 0.0055547335519641194 482909720.0357361
20 0.004166319691996989 505758070775344.0
30 0.002777674929745899 2.452546297416522e+28


In [378]:
@njit()
def gammaln_pos_float_2(z: float64):
    """Logarithm of the Gamma function for positive integer values (as float64)."""

    if z > 2.0:
        # Has error
        f = np.log(np.sqrt(2 * np.pi))
        return f + np.log(z) * (z - 0.5) - z
    else:
        x = 10.0
        result = 0.0
        for n in range(30):
            denominator = 1.0
            for i in range(n + 1):
                denominator *= z + i
            result += x**n / denominator
        return np.log(x**z * np.exp(-x) * result)

In [380]:
%%timeit -r 100 -n 10000
gammaln_pos_float_2(1.5)

350 ns ± 21.2 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [381]:
%%timeit -r 100 -n 10000
gammaln(1.5)

463 ns ± 2.5 ns per loop (mean ± std. dev. of 100 runs, 10,000 loops each)


In [382]:
for x in [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]:
    print(
        x,
        abs(gammaln_pos_float_2(x) - gammaln(x)),
        abs(np.exp(gammaln_pos_float_2(x)) - np.exp(gammaln(x))),
        abs(np.exp(gammaln_pos_float_2(x)) - gamma(x)),
    )

1.0 4.548080194518237e-05 4.5479767709188934e-05 4.5479767709188934e-05
1.1 6.0705584301476434e-05 5.775055145218477e-05 5.775055145240682e-05
1.2 7.989320621294704e-05 7.335251445206303e-05 7.335251445217406e-05
1.3 0.0001038399154587033 9.318844280736815e-05 9.318844280736815e-05
1.4 0.00013346079695723823 0.00011840703467091984 0.00011840703467091984
1.5 0.00016980135989839995 0.00015046976177379712 0.00015046976177379712
1.6 0.00021404948813072422 0.00019123603543402812 0.00019123603543413914
1.7 0.0002675476772296681 0.00024307166433623628 0.00024307166433645833
1.8 0.0003318054730065828 0.0003089869680623636 0.00030898696806225256
1.9 0.0004085120182111962 0.0003928126612298133 0.00039281266122970226
2.0 0.0004995486069731168 0.000499423853342118 0.000499423853342118


In [383]:
for x in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    print(
        x,
        abs(gammaln_pos_float_2(x) - gammaln(x)),
        abs(np.exp(gammaln_pos_float_2(x)) - np.exp(gammaln(x))),
        abs(np.exp(gammaln_pos_float_2(x)) - gamma(x)),
    )

0.1 7.789454952167318e-07 7.410501080329368e-06 7.410501080329368e-06
0.2 1.659965198630431e-06 7.620634469240883e-06 7.620634469240883e-06
0.3 3.023438297722336e-06 9.044810574287965e-06 9.044810574287965e-06
0.4 5.031803163246806e-06 1.1161314128038669e-05 1.1161314128038669e-05
0.5 7.886359934095033e-06 1.3978153916349711e-05 1.3978153916349711e-05
0.6 1.183369738017026e-05 1.762254614323666e-05 1.762254614323666e-05
0.7 1.717275130724838e-05 2.2290990011164524e-05 2.2290990011608613e-05
0.8 2.4262499677685234e-05 2.8246780383822667e-05 2.8246780383600623e-05
0.9 3.353029196677404e-05 3.5830831673688834e-05 3.583083167346679e-05


In [376]:
z = 1.5
x = 10.00002
result = 0.0
for n in range(30):
    denominator = 1.0
    for i in range(n):
        denominator *= z + i
    result += x**n / denominator

result, (x + 1) * np.exp(-x)

(6172.975970230312, 0.0004993901474831005)