In [1]:
%matplotlib notebook

In [2]:
from jax import config
config.update("jax_enable_x64", True)

In [3]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

In [4]:
FIL_BASE = 2_000_000_000.0
STORAGE_MINING = 0.55 * FIL_BASE
SIMPLE_ALLOC = 0.3 * STORAGE_MINING  # total simple minting allocation

In [5]:
days_np = np.arange(152, 1698+1)
days_jax = jnp.arange(152, 1698+1)

In [6]:
def ssse(x,y):
    return np.sqrt(np.sum((x-y)**2))

In [7]:
ssse(days_np, days_jax)

0.0

In [8]:
# NP
def np_cum_simple_minting(day: int, scale=True) -> float:
    LAMBDA = np.log(2) / (
        6.0 * 365
    )  # minting exponential reward decay rate (6yrs half-life)
    

    y = (1 - np.exp(-LAMBDA * day))
    if scale:
        y *= SIMPLE_ALLOC
    return y

In [9]:
def jax1_cum_simple_minting(day: int, scale=True) -> float:
    LAMBDA = jnp.log(2) / (
        6.0 * 365
    )  # minting exponential reward decay rate (6yrs half-life)
    y = (1 - jnp.exp(-LAMBDA * day))
    if scale:
        y *= SIMPLE_ALLOC
    return y
    
def jax2_cum_simple_minting(day: int, scale=True) -> float:
    LAMBDA = jnp.log(2) / (
        6.0 * 365
    )  # minting exponential reward decay rate (6yrs half-life)
    y = -1 * jnp.expm1(-LAMBDA*day)
    if scale:
        y *= SIMPLE_ALLOC
    return y

In [10]:
print(jax.make_jaxpr(jax1_cum_simple_minting)(4))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i64[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f64[][39m = log 2.0
    c[35m:f64[][39m = div b 2190.0
    d[35m:f64[][39m = neg c
    e[35m:f64[][39m = convert_element_type[new_dtype=float64 weak_type=True] a
    f[35m:f64[][39m = mul d e
    g[35m:f64[][39m = exp f
    h[35m:f64[][39m = sub 1.0 g
    i[35m:f64[][39m = mul h 330000000.0
  [34m[22m[1min [39m[22m[22m(i,) }


In [11]:
print(jax.make_jaxpr(jax2_cum_simple_minting)(4))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i64[][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f64[][39m = log 2.0
    c[35m:f64[][39m = div b 2190.0
    d[35m:f64[][39m = neg c
    e[35m:f64[][39m = convert_element_type[new_dtype=float64 weak_type=True] a
    f[35m:f64[][39m = mul d e
    g[35m:f64[][39m = expm1 f
    h[35m:f64[][39m = mul -1.0 g
    i[35m:f64[][39m = mul h 330000000.0
  [34m[22m[1min [39m[22m[22m(i,) }


In [12]:
unscaled_np_res = np_cum_simple_minting(days_np, scale=False)
unscaled_jax1_res = jax1_cum_simple_minting(days_jax, scale=False)
unscaled_jax2_res = jax2_cum_simple_minting(days_jax, scale=False)

scaled_np_res = np_cum_simple_minting(days_np)
scaled_jax1_res = jax1_cum_simple_minting(days_jax)
scaled_jax2_res = jax2_cum_simple_minting(days_jax)

In [13]:
plt.figure(figsize=(8,6))

plt.subplot(2,2,1)
plt.plot(unscaled_np_res-unscaled_jax1_res)
plt.title('Unscaled [1-exp(.)]Error=%0.06e' % (ssse(unscaled_np_res, unscaled_jax1_res),))

plt.subplot(2,2,2)
plt.plot(unscaled_np_res-unscaled_jax2_res)
plt.title('Unscaled [-1*exp1m(.)]-> Error=%0.06e' % (ssse(unscaled_np_res, unscaled_jax2_res),))

plt.subplot(2,2,3)
plt.plot(scaled_np_res-scaled_jax1_res)
plt.title('Scaled [1-exp(.)]Error=%0.02f' % (ssse(scaled_np_res, scaled_jax1_res),))

plt.subplot(2,2,4)
plt.plot(scaled_np_res-scaled_jax2_res)
plt.title('Scaled [-1*exp1m(.)] -> Error=%0.06e' % (ssse(scaled_np_res, scaled_jax2_res),))

plt.tight_layout()

<IPython.core.display.Javascript object>