In [117]:
from functools import partial
from typing import Optional
from jax import Array
import math
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import lax
import equinox as eqx

from jax_src.sst import load_sst_net, SSTNet
from jax_src.discriminator import marginal_wass2_error


jax.config.update("jax_enable_x64", True)
dtype = jnp.float64


@eqx.filter_jit
def compute_ll(w_01: Array, hh_01: Array, c_01: Array, dt):
    return ((dt**2) / 2.0) * (c_01 - w_01 * hh_01 - 1 / 3 * w_01**2)


@eqx.filter_jit
def generate_w_hh_ll(key, net: Optional[SSTNet], dt, num_samples: int, dtype):
    key_w, key_hh, key_c = jr.split(key, 3)
    w_01 = jr.normal(key_w, (num_samples,), dtype=dtype)
    hh_01 = math.sqrt(1 / 12) * jr.normal(key_hh, (num_samples,), dtype=dtype)
    if net is None:
        ll = dt**2 * (1 / 30 + (3 / 5) * (hh_01**2))
    else:
        c_01 = jnp.squeeze(net.generate_c(key_c, w_01, hh_01))
        ll = compute_ll(w_01, hh_01, c_01, dt)
    w = jnp.sqrt(dt) * w_01
    hh = jnp.sqrt(dt) * hh_01
    return w, hh, ll


@eqx.filter_jit
def step(key, dt, args, y_n: Array, net: Optional[SSTNet]):
    a, b, sigma = args
    num_samples = y_n.shape[0]
    dtype = y_n.dtype
    assert y_n.ndim == 1
    w, hh, ll = generate_w_hh_ll(key, net, dt, num_samples, dtype)
    c1 = sigma * w - a * dt
    c2 = a * b * (dt - dt * sigma * hh + (sigma**2) * ll)
    exp_c1 = jnp.exp(c1)
    exp_c1_by_c1 = (
        1.0
        + 1 / 2 * c1
        + 1 / 6 * c1**2
        + 1 / 24 * c1**3
        + 1 / 120 * c1**4
        + 1 / 720 * c1**5
    )
    y_next = y_n * exp_c1 + c2 * exp_c1_by_c1
    return y_next


@eqx.filter_jit
def simulate(key, args, y0: Array, net: Optional[SSTNet], t1, num_steps: int):
    dt0 = jnp.asarray(t1 / num_steps, dtype=dtype)
    carry = y0, jnp.asarray(0.0, dtype=dtype)
    keys = jnp.array(jr.split(key, num_steps))

    def scan_step(_carry, _key):
        y, t = _carry
        dt = jnp.minimum(t1 - t, dt0)
        t = t + dt
        return (step(_key, dt, args, y, net), t), None

    (y_t1, _), _ = lax.scan(scan_step, carry, keys, length=num_steps)
    return y_t1


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)

In [130]:
from jax_src.sst import eval_net

noise_size = 3
hidden_dim = 16
num_layers = 3
leaky_slope = 0.01
use_multlayer = True
net = load_sst_net(
    "/home/andy/PycharmProjects/Levy_CFGAN/numpy_nets/",
    noise_size,
    hidden_dim,
    num_layers,
    leaky_slope,
    use_multlayer,
    jnp.float64,
    use_batch_norm=False,
    use_activation=True,
)


def ito_to_stratonovich(a: float, b: float, sigma: float):
    tilde_a = a + sigma**2 / 2
    tilde_b = a * b / tilde_a
    return tilde_a, tilde_b, sigma


def get_args(a: float, b: float, sigma: float):
    return (
        jnp.array(a, dtype=dtype),
        jnp.array(b, dtype=dtype),
        jnp.array(sigma, dtype=dtype),
    )


ito_args = (1.0, 0.1, 1.4)
strat_args = get_args(*ito_to_stratonovich(*ito_args))
# ito_args = (1.0, 1.0, 1.0)
# strat_args = get_args(*ito_args)

_ = eval_net(net, jr.key(6), 100, 2**20, -1, True, False)

num_samples = 2**16
# y0_flt = 0.06
y0_flt = 0.16
y0 = jnp.broadcast_to(jnp.array(y0_flt, dtype=dtype), (num_samples,))

t1 = jnp.array(5.0, dtype=dtype)

Mean error: 0.006662, variance error: 0.0009496, avg var: 0.05002, wass error: 3.271e-05, score: 0.02005
Inital score: 0.02005


In [128]:
# ito_args_str = "_".join([str(arg) for arg in ito_args])
# filename = f"igbm_{ito_args_str}_SlowRK_2^21.npy"
# # filename = f"igbm_y{y0_flt}_args_{ito_args_str}_SlowRK_2^21.npy"
# print(filename)
# with open(filename, "rb") as f:
#     y_true = np.load(f)
y_true = simulate(jr.key(601), strat_args, y0, net, t1, 2**8 + 7)
y_true2 = simulate(jr.key(91), strat_args, y0, net, t1, 2**7 + 1)
# y_true = jnp.sqrt(jnp.abs(jr.normal(jr.key(9), (num_samples+6, 3), dtype=dtype)))
# y_true2 = jnp.sqrt(jnp.abs(jr.normal(jr.key(10), (num_samples+6, 3), dtype=dtype)))
# wasserstein_bias = 0.
wasserstein_bias = marginal_wass2_error(y_true, y_true2)

energy_bias = energy_distance(y_true, y_true2, max_len=2**16)
del y_true2
print(
    f"Y_true shape: {y_true.shape}, mean: {jnp.mean(y_true):.4}, std: {jnp.std(y_true):.4}, wasserstein_bias: {wasserstein_bias:.4}, enn_bias: {energy_bias:.4}"
)

Y_true shape: (65536,), mean: 0.1003, std: 0.2873, wasserstein_bias: 0.02386, enn_bias: 5.999e-06


In [129]:
steps_net = []
errors_net = []
errors_mean_only = []
for n in range(1, 20):
    num_steps = n
    steps_net.append(num_steps)
    y1_net = simulate(jr.key(1), strat_args, y0, net, t1, num_steps)
    energy_err_net = energy_distance(y1_net, y_true, max_len=2**16)
    errors_net.append(energy_err_net)

    y1_m_o = simulate(jr.key(1), strat_args, y0, None, t1, num_steps)
    energy_err_m_o = energy_distance(y1_m_o, y_true, max_len=2**16)
    errors_mean_only.append(energy_err_m_o)

    print(
        f"p: {n}, energy_err_net: {energy_err_net:.5}, energy_err_m_o: {energy_err_m_o:.5}"
    )

p: 1, energy_err_net: 101.74, energy_err_m_o: 105.99
p: 2, energy_err_net: 1.4216, energy_err_m_o: 1.4434
p: 3, energy_err_net: 0.12183, energy_err_m_o: 0.12307
p: 4, energy_err_net: 0.019207, energy_err_m_o: 0.019225
p: 5, energy_err_net: 0.0041282, energy_err_m_o: 0.0041345
p: 6, energy_err_net: 0.0011118, energy_err_m_o: 0.0011144
p: 7, energy_err_net: 0.00033826, energy_err_m_o: 0.00033833
p: 8, energy_err_net: 0.00011376, energy_err_m_o: 0.00011433
p: 9, energy_err_net: 4.2253e-05, energy_err_m_o: 4.2925e-05
p: 10, energy_err_net: 1.3267e-05, energy_err_m_o: 1.3024e-05
p: 11, energy_err_net: 1.0568e-05, energy_err_m_o: 1.0866e-05
p: 12, energy_err_net: 5.4887e-06, energy_err_m_o: 5.3018e-06
p: 13, energy_err_net: 3.0385e-06, energy_err_m_o: 3.2076e-06
p: 14, energy_err_net: 2.3121e-06, energy_err_m_o: 2.4498e-06
p: 15, energy_err_net: 3.2824e-06, energy_err_m_o: 3.5947e-06
p: 16, energy_err_net: 4.1792e-06, energy_err_m_o: 3.9896e-06
p: 17, energy_err_net: 2.3887e-06, energy_err_m