# Fast exponential and softmax functions

Based on the bit manupulation of the IEEE 754 floating point numbers.
See N. N. Schraudolph, “A Fast, Compact Approximation of the Exponential Function,” Neural Computation, vol. 11, no. 4, pp. 853–862, May 1999, doi: 10.1162/089976699300016467.

A 32-bit floating point number is given by:
$$ x = (-1)^s (1+m)2^{e-127} $$
where $s$ - sign bit; $m$ - 23-bit mantissa; $e$ - 8-bit exponent;

For bit layout see [this diagram](https://upload.wikimedia.org/wikipedia/commons/d/d2/Float_example.svg).

Writing some integer value $i$ in bit field of the exponent and read back as a float number will give value of$2^{i-127}$. 

So dividing $x$ by dividing by $\ln(2)$, shifting the value $x/\ln(2)+127$ by 23 bit, and reading back the bits of resulting integer number as a flot will give value of the $e^x$. Correction factor $C$ allows adjusting the approximation for some parameters.

$$f(x) = 2^{23} \left(\frac{x}{\ln(2)} + 127 - C\right)$$
In some example following approximation is used with specific values in code $C = 0.057985$:
```C
uint32_t value = (1<<23)*(1.4426950409*x+126.94201519f);
```
Which is then bitcast to floating point number via union type punning.


## Softmax

Softmax is a function over a vector $\mathbf{X}$:

$$ \sigma(\mathbf{X})_i = \frac{e^{x_i}}{\sum\limits_{j=0}^{N}e^{x_j}} $$


When calculating softmax the maximum is substracted from initial data for numerical stability - to avoid overflows.

## Further reading

F. Perini and R. D. Reitz, “Fast approximations of exponential and logarithm functions combined with efficient storage/retrieval for combustion kinetics calculations,” Combustion and Flame, vol. 194, pp. 37–51, 2018, doi: 10.1016/j.combustflame.2018.04.013.

J.-M. Muller, “Elementary Functions and Approximate Computing,” Proc. IEEE, vol. 108, no. 12, pp. 2136–2149, 2020, doi: 10.1109/JPROC.2020.2991885.





In [None]:
import numpy as np
from matplotlib import pyplot as plt
from functools import partial

SIZE = 10000

C = 127 - 126.94201519
LOG2_e = np.float32(1.4426950409)

print("C = {:f}".format(C))
print("1/log2(e)", 1 / np.log2(np.e))

print("A: {:0.1f}".format(LOG2_e * np.float32(1 << 23)))
print("C: {:0.1f}".format(np.float32(126.94201519) * np.float32(1 << 23)))

print("f64 A: {:0.1f}/ln(2) = {:0.1f}".format(np.float64(1 << 20), LOG2_e * np.float64(1 << 20)))
print("f64 C: {:0.1f}-{:0.1f}={:0.1f}".format(1023 * np.float32(1 << 20), 60801, 1023 * np.float32(1 << 20) - 60801))


def fast_exp(x, c=C):
    A = LOG2_e * np.ones((SIZE), np.float32)
    K = np.float32(1 << 23) * np.ones((SIZE), np.float32)
    y = np.multiply(K, (A * x + 127 - c), dtype="f")
    z = np.int32(y)
    v = z.view(np.float32)
    return v

def delta_yf_3(yf):
    s0 = np.float32(0)
    s1 = np.float32(0.30758037765820823 )
    s2 = np.float32(-0.23141283591588344 )
    s3 = np.float32(-7.6167541742324804e-2)
    return s1 * yf + s2 * yf * yf + s3 * yf * yf * yf

def delta_yf_5(yf):
    s0 = np.float32(0)
    s1 = np.float32( 3.06852819440055e-1)
    s2 = np.float32(-2.40226506959101e-1)
    s3 = np.float32(-5.57129652016652e-2)
    s4 = np.float32(-9.01146535969578e-3)
    s5 = np.float32(-1.90188191959304e-3)
    p1 = yf
    p2 = p1 * yf
    p3 = p2 * yf
    p4 = p3 * yf
    p5 = p4 * yf
    return s1 * p1 + s2 * p2 + s3 * p3 + s4 * p4 + s5 * p5

def perini_taylor(yf):
    return 1 + yf - 2**yf

def fast_exp_delta(x, corr):
    y = np.float32(x * LOG2_e)
    yf = y - np.int32(y) + 1 # y < 0
    yf[y>0] = yf[y>0] - 1
    y2 = np.int32((1 << 23) * (y - corr(yf) + 127))
    return y2.view(np.float32)

def fast_log(x):
    sigma = 0.0430357
    v = x.view(np.int32)
    y = 0.6931471805599453 * np.float32(v / (2**23) - (127 - sigma))
    return np.float32(y)


def fast_exp_newtons(x, exp_func, log_func, q=1):
    y = exp_func(x)
    for _ in range(q):
        y = y - y * log_func(y) + y * x
    return y


def bounds():
    print("Upper bound: {:f}".format(fast_exp([0], -1)[0]))
    print(
        "Lower bound: {:f}".format(2**23 * ((1 - (np.log(np.log(2)) + 1)) / np.log(2)))
    )
    gamma = np.log(np.log(2) + 2 / np.e) - np.log(2) - np.log(np.log(2))
    c = gamma * 2**23 / np.log(2)
    print("Lowest maximum relative error: gamma = {:f}".format(gamma))
    print("Lowest maximum relative error: c = {:f} {:f}".format(c, c / 2**23))
    c = np.log(3 / (8 * np.log(2)) + 0.5) / np.log(2)
    print("Lowest RMS error: c = {:f} {:f}".format(2**23 * c, c))
    gamma = 0.045111411
    c = gamma / np.log(2)
    print("Lowest mean error: c = {:f} {:f}".format(2**23 * c, c))


def softmax(x, func):
    exps = func(x)
    sums = np.sum(exps)
    return func(x) / sums


def get_err(x, tr, app):
    return 100 * (tr(x) - app(x)) / tr(x)


def plot(x, funcs, labels):
    fig, (ax1, ax2) = plt.subplots(2, 1)
    ax1.grid(True)
    ax2.grid(True)
    i = 0
    for func in funcs:
        ax1.plot(x, func(x))
        err = get_err(x, funcs[0], func)
        ax2.plot(x, err)
        print(
            "{:25s} | max err = {:f} | rms err = {:f} ".format(
                labels[i], max(np.abs(err)), np.sqrt(np.mean(err**2))
            )
        )
        i = i + 1
    ax1.legend(labels)
    ax2.legend(labels)
    ax1.set_xlabel("x")
    ax1.set_ylabel("f(x)")
    ax2.set_xlabel("x")
    ax2.set_ylabel("error, %")


bounds()

x = np.linspace(-2, 2, SIZE, dtype=np.float32)

fast_exp_3 = partial(fast_exp_delta, corr=delta_yf_3)
fast_exp_5 = partial(fast_exp_delta, corr=delta_yf_5)

plot(
    x,
    [
        np.exp,
        fast_exp,
        fast_exp_3,
        fast_exp_5,
    ],
    [
        "True",
        "fast_exp",
        "fast_exp_3",
        "fast_exp_5",
    ],
)
print("softmax")
plot(
    x,
    [
        partial(softmax, func=np.exp),
        partial(softmax, func=fast_exp),
        partial(softmax, func=fast_exp_3),
        partial(softmax, func=fast_exp_5),
    ],
    [
        "True",
        "fast_exp",
        "fast_exp_3",
        "fast_exp_5",
    ],
)
# x2 = np.linspace(0.5, 10.0, SIZE, dtype=np.float32)
# plot(x2, [np.log, fast_log], ["True", "Fast log"])

In [None]:
import numpy as np
SIZE = 133
x = np.float32(np.random.rand(SIZE))
def softmax(x, func):
    exps = func(x)
    sums = np.sum(exps)
    return func(x) / sums

with open("../bin/data", "wb") as file:
    file.write(x)

with open("../bin/golden", "wb") as file:
    file.write(softmax(x, np.exp))