In [1]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import equinox as eqx

import optax

import timeit

from scipy.stats import qmc

In [75]:
model_key = jax.random.PRNGKey(1)

input_dim = 4
output_dim = 2

layers_size = 4
n_layers = 4

model = eqx.nn.MLP(in_size=input_dim, out_size=output_dim, width_size=layers_size, depth=n_layers,
                    activation=jax.nn.relu, use_final_bias=False, key=model_key)

In [76]:
model(jnp.array([0,1,1,2]))

Array([-0.045782  , -0.10597677], dtype=float32)

In [77]:
def func(x, a): #That would be f(\rho)
    return a*x + a**2 # Random model for test

In [78]:
def sample_fn(mu, sigma=0.1, dt=0.01):
    key = jax.random.PRNGKey(0)
    keys = jax.random.split(key, mu.size)
    return mu*dt + jax.vmap(lambda k, m: jax.random.normal(k)*jnp.sqrt(dt)*sigma)(keys, mu)

In [79]:
arr = jnp.array([1,1,1])

In [80]:
sample(func(jnp.array([1]), 0.2))

TypeError: 'list' object is not callable

In [96]:
def loss(input, output):
    mu_0, sigma_0, x, y = input
    mu_1, sigma_1 = output

    prior_entropy = jnp.log(sigma_0**2)     #jnp.linalg.slogdt()
    posterior_entropy = jnp.log(sigma_1**2)

    log_gain = jnp.log(prior_entropy/posterior_entropy)

    log_likeli = -0.5*((y-func(x,mu_1))/sigma_1)**2 - jnp.log(sigma_1*jnp.sqrt(2*jnp.pi))

    return log_gain + log_likeli

In [97]:
vec_a = jnp.linspace(1,2,10)

a = vec_a[0]
func(jnp.array([1]), 1.)

Array([2.], dtype=float32, weak_type=True)

In [98]:
vec_a = jnp.linspace(1,2,10)
vec_x = jnp.linspace(0,1,10)

mu_0 = 1.5
sigma_0 = 1

vec_samples = []

for a in vec_a:

    for x in vec_x:

        for i in range(10):

            sample = [mu_0, sigma_0, x, sample_fn(func(jnp.array([x]), a))[0]]
            vec_samples += jnp.array(sample),


In [99]:
vec_samples

[Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.        , -0.00400884], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.11111111, -0.00289773], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.11111111, -0.00289773], dtype=float32),
 Array([ 1.5       ,  1.        ,  0.11111111, -0.00289773], dtype=float32),

In [100]:
outputs = jax.vmap(model)(jnp.array(vec_samples))

In [101]:
outputs

Array([[-0.04236208, -0.10637414],
       [-0.04236208, -0.10637414],
       [-0.04236208, -0.10637414],
       ...,
       [-0.05058081, -0.08290085],
       [-0.05058081, -0.08290085],
       [-0.05058081, -0.08290085]], dtype=float32)

In [102]:
loss([1.5, 1, 0, 0.00400884], [-0.48346964, -0.33878967])

Array(nan, dtype=float32, weak_type=True)

In [103]:
jax.vmap(loss)(jnp.array(vec_samples), outputs)

Array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, na