In [22]:
import jax.numpy as np
from jax import jit
from jax import random
from jax.scipy import stats as sps

import numpy as onp
from scipy import stats as osps

# Original Numpy

In [23]:
def normal2marginal(x, margin):
    return margin[0].ppf(osps.norm.cdf(x), *margin[1:])


def rvec(n, rho, margins):
    d = rho.shape[0]
    z = onp.random.multivariate_normal(mean=onp.zeros(d),
                                       cov=rho,
                                       size=n)
    u = onp.empty_like(z)
    for i, c in enumerate(z.T):
        u[:,i] = normal2marginal(c, my_margins[i])
    
    return u

In [34]:
my_margins = [(osps.gamma, 8, 2),
              (osps.gamma, 12, 3),
              (osps.gamma, 4, 4),
              (osps.gamma, 16, 1)]
x = rvec(10, onp.eye(4), my_margins)

In [35]:
%timeit rvec(2000, onp.eye(4), my_margins)

9.52 ms ± 229 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [26]:
x

array([[ 2., 14.,  9., 49.],
       [ 1., 18., 25., 48.],
       [ 3., 16.,  3., 29.],
       [ 5., 18., 21., 37.],
       [ 8., 24.,  8., 34.],
       [11., 20.,  6., 41.],
       [ 3., 16.,  7., 48.],
       [ 0., 19.,  8., 35.],
       [ 3., 12., 11., 40.],
       [ 4., 25.,  2., 23.]])

# Jax Numpy

In [40]:
@jit
def normal2marginal2(x, margin):
    return margin[0].ppf(stats.norm.cdf(x), *margin[1:])


def rvec2(key, n, rho, margins):
    d = rho.shape[0]
    z = random.multivariate_normal(key,
                                   mean=np.zeros(d),
                                   cov=rho,
                                   shape=(n,),
                                   dtype=np.float32)
    return np.stack([normal2marginal(v, margins[i]) for i, v in enumerate(z.T)])

In [41]:
key = random.PRNGKey(0)
my_margins2  = [(stats.gamma, 8, 2),
                (stats.gamma, 12, 3),
                (stats.gamma, 4, 4),
                (stats.gamma, 16, 1)]
x = rvec2(key, 10, np.eye(4), my_margins2)

In [42]:
x

DeviceArray([[ 9.833338 , 10.337629 , 11.287834 ,  9.608943 ,  8.870486 ,
              10.056011 ,  7.523633 ,  8.918815 , 16.096806 ,  6.259414 ],
             [ 9.586226 , 17.54643  , 10.3423395, 21.2349   , 12.142886 ,
              15.855058 , 13.0485325, 15.689952 , 23.330044 , 12.770042 ],
             [10.530828 ,  5.928883 ,  7.152106 ,  6.982293 ,  6.4090323,
               6.9920473,  6.5529485,  9.491214 , 10.033285 ,  8.352106 ],
             [18.552755 , 19.784525 , 12.112805 , 27.52418  , 17.31361  ,
              17.112753 , 17.965212 , 21.758514 , 14.025983 , 16.05636  ]],            dtype=float32)

In [43]:
%timeit rvec2(key, 2000, np.eye(4), my_margins2).block_until_ready()

9.82 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [44]:
rvec2_jitted = jit(rvec2)

In [46]:
rvec2_jitted(key, 2000, np.eye(4), my_margins2).block_until_ready()

TypeError: Argument '<scipy.stats._continuous_distns.gamma_gen object at 0x7f4aea669e50>' of type <class 'scipy.stats._continuous_distns.gamma_gen'> is not a valid JAX type