<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_0_NumericsCheatSheet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# That's a nice calculation you got there. Be a real shame if something happened to it, or: Numerics Cheat Sheet
This document is meant to illustrate how the numerics of common calculations can fail, and how "built-in" functions that replace the direct operations can improve numerical stability.

As usual for this course, we will be operating in [JAX](https://jax.readthedocs.io/en/latest/) using its `numpy` and `scipy` submodules.

In [None]:
# let's import our necessary modules before getting started

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.scipy.special as jspec

## log(1 + p) and exp(p) - 1
These functions typically arise in the calculation of [cross-entropy](https://en.wikipedia.org/wiki/Cross-entropy#Cross-entropy_loss_function_and_logistic_regression) for a [binary classifier](https://en.wikipedia.org/wiki/Binary_classification) or the [log-likelihood](https://en.wikipedia.org/wiki/Likelihood_function) for [logistic regression](https://en.wikipedia.org/wiki/Logistic_regression) under a [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution) model.

[log1p](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.log1p.html) calculates `jnp.log(1 + p)` in a numerically stable manner. Calulating this directly can result in catastrophic roundoff when `p` is very small and produce `0.0` when a positive value should exist.

Similarly, [expm1](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.expm1.html) calculates `jnp.exp(x) - 1` in a numerically stable manner and direct calculation can produce `0.0` when `p` is very small.

In [None]:
for pow in jnp.arange(10, 0, -1):
  p = 10. ** (-pow)
  safe_res = jnp.log1p(p)
  unsafe_res = jnp.log(1. + p)
  print(f"log(1 + {p:.3e}) | Safe = {safe_res:.5e} | Unsafe = {unsafe_res:.5e}")

print("-"*62)
for pow in jnp.arange(10, 0, -1):
  p = 10. ** (-pow)
  safe_res = jnp.expm1(p)
  unsafe_res = jnp.exp(p) - 1.
  print(f"exp({p:.3e}) - 1 | Safe = {safe_res:.5e} | Unsafe = {unsafe_res:.5e}")

log(1 + 1.000e-10) | Safe = 1.00000e-10 | Unsafe = 0.00000e+00
log(1 + 1.000e-09) | Safe = 1.00000e-09 | Unsafe = 0.00000e+00
log(1 + 1.000e-08) | Safe = 1.00000e-08 | Unsafe = 0.00000e+00
log(1 + 1.000e-07) | Safe = 1.00000e-07 | Unsafe = 1.19209e-07
log(1 + 1.000e-06) | Safe = 1.00000e-06 | Unsafe = 9.53674e-07
log(1 + 1.000e-05) | Safe = 9.99995e-06 | Unsafe = 1.00135e-05
log(1 + 1.000e-04) | Safe = 9.99950e-05 | Unsafe = 1.00012e-04
log(1 + 1.000e-03) | Safe = 9.99500e-04 | Unsafe = 9.99547e-04
log(1 + 1.000e-02) | Safe = 9.95033e-03 | Unsafe = 9.95032e-03
log(1 + 1.000e-01) | Safe = 9.53102e-02 | Unsafe = 9.53102e-02
--------------------------------------------------------------
exp(1.000e-10) - 1 | Safe = 1.00000e-10 | Unsafe = 0.00000e+00
exp(1.000e-09) - 1 | Safe = 1.00000e-09 | Unsafe = 0.00000e+00
exp(1.000e-08) - 1 | Safe = 1.00000e-08 | Unsafe = 0.00000e+00
exp(1.000e-07) - 1 | Safe = 1.00000e-07 | Unsafe = 1.19209e-07
exp(1.000e-06) - 1 | Safe = 1.00000e-06 | Unsafe = 9.53

In [None]:
for pow in jnp.arange(10, 0, -1):
  p = 10. ** (-pow)
  safe_res = jnp.log1p(p)
  unsafe_res = jnp.log(1. + p)
  print(f"log(1 + {p:.3e}) | Safe = {safe_res:.5e} | Unsafe = {unsafe_res:.5e}")

xlogy and xlog1py

In [None]:
jspec.xlogy
jspec.xlog1py

logsumexp, logaddexp, and softmax

In [None]:
jspec.logsumexp
jnp.logaddexp