In [1]:
import numpy as np
import pylab as pl
from jax import grad
from jax import numpy as jnp
from jax import random
from scipy import stats

import line_profiler
%load_ext line_profiler

In [67]:
cov = np.random.normal(0, 1, (3, 3))
cov = cov@cov.T
x = [1]*3

In [50]:
%lprun -f compute_error compute_error(f,x,cov)

Timer unit: 1e-06 s

Total time: 1.82342 s
File: <ipython-input-49-1c9a3e85c56e>
Function: compute_error at line 7

Line #      Hits         Time  Per Hit   % Time  Line Contents
     7                                           def compute_error(f, x, pcov, n_samples=10000000):
     8         1       1442.0   1442.0      0.1      x, pcov = jnp.array(x, dtype=jnp.float64), jnp.array(pcov, dtype=jnp.float64)
     9                                           
    10         1          4.0      4.0      0.0      seed = int(time.time())
    11         1       2850.0   2850.0      0.2      key = random.PRNGKey(seed)
    12                                               
    13         1        115.0    115.0      0.0      samples = random.multivariate_normal(key, mean=x, cov=pcov, shape=(int(n_samples),))
    14         1    1803234.0 1803234.0     98.9      y = f(samples.T)
    15         1      11236.0  11236.0      0.6      mean = f(x)
    16                                           
    1

In [20]:
# Monte carlo based error propagation
def f(x):
    # x, y, z = x.T
    # return np.array([x**2, np.sin(x*y), np.cos(x*y/z)])
    return np.array([x[0]**2, np.sin(x[0]*x[0]), np.cos(x[0]*x[1]*x[2])])

def compute_error(f, x, pcov, n_samples=1e6):
    x, pcov = jnp.array(x, dtype=jnp.float64), jnp.array(pcov, dtype=jnp.float64)

    seed = int(time.time())
    key = random.PRNGKey(seed)
    
    samples = random.multivariate_normal(key, mean=x, cov=pcov, shape=(int(n_samples),))
    y = f(samples.T)
    mean = f(x)

    return jnp.cov(y, ddof=1)

compute_error(f, x, cov)

  lax._check_user_dtype_supported(dtype, "array")


NameError: name 'time' is not defined

In [None]:
# Creates a multivariate uniform distrubution, mainly for fun
def multivariate_uniform(mean, cov, n_samples):
    mean = np.array(mean)
    xs = np.empty((n_samples, cov.shape[0]))
    chol = np.linalg.cholesky(cov)
    for i in range(cov.shape[0]):
        xs[:, i] = np.random.uniform(-np.sqrt(12)/2, np.sqrt(12)/2, size=n_samples)
    
    return xs@chol.T + mean

cov = np.random.normal(0, 1, (2, 2))
cov = cov@cov.T
mean = [0]*2
x, y = multivariate_uniform(mean, cov, 1000).T
xn, yn = np.random.multivariate_normal(mean, cov, 1000).T

pl.plot(x, y, '.')
pl.plot(xn, yn, '.')

In [None]:
# Gradient based tests for error propagation


def h(x):
    return jnp.sin(x)

def derivate(f, x, h=1e-4):
    return (f(x+h) - f(x-h)) / (2*h)
    
g = grad(h)

g(0.)

g(np.pi*0.5)

In [7]:
from jax.config import config
config.update("jax_enable_x64", True)

import time
import jax

jax.numpy.array([12,2,3,45], dtype=float)

DeviceArray([12.,  2.,  3., 45.], dtype=float64)

In [25]:
# Faster number generation using jax


key = random.PRNGKey(0)
mean = jnp.array([0]*3)
cov = jnp.array(cov)

%time x=np.random.multivariate_normal(mean, cov, size=1000000)
%time x=random.multivariate_normal(key, mean=mean, cov=cov, shape=(int(1e8),))

CPU times: user 134 ms, sys: 111 µs, total: 134 ms
Wall time: 108 ms
CPU times: user 38.4 ms, sys: 11.6 ms, total: 50 ms
Wall time: 14.4 ms


In [29]:
cov

DeviceArray([[2.6600852, 0.7940799, 2.9283135],
             [0.7940799, 2.3275447, 2.533829 ],
             [2.9283135, 2.533829 , 5.1483064]], dtype=float32)

In [28]:
jnp.cov(x.T)

DeviceArray([[2.660535 , 0.7941171, 2.9286234],
             [0.7941171, 2.3282158, 2.5344994],
             [2.9286234, 2.5344994, 5.149097 ]], dtype=float32)