In [1]:
import numpyro
import numpyro.distributions as dist
from numpyro import handlers
import jax.numpy as jnp
import jax
import pylab as plt
%matplotlib inline

In [2]:
def model(params_mu,params_sig,y,c):
    n_src,n_param=params_mu.shape
    L_Rho = numpyro.sample('L_Rho', dist.LKJCholesky(3, concentration=c))
    scale_tril = params_sig[0,:, jnp.newaxis] * L_Rho
    with numpyro.plate('src',n_src):
        params=numpyro.sample('params',dist.Normal(jnp.zeros((n_param,1)), jnp.ones((n_param,1)))).T
    return L_Rho,params

def transform(params,params_sig,params_mu):
    scale_tril = params_sig[..., jnp.newaxis] * L_Rho
    v = params_mu+jnp.dot(scale_tril, params).T
    return v

In [3]:
params_mu=jnp.full((100,3),10)
params_sig=jnp.ones((100,3))
for c in [0.05]:
    L_Rho,params=handlers.seed(model, rng_seed=0)(params_mu,params_sig,None,c)
    new_params=jax.vmap(transform)(params,params_sig,params_mu)
    print(c,jnp.corrcoef(new_params.T))

0.05 [[ 1.          0.11668614  0.851299  ]
 [ 0.11668614  1.         -0.3125015 ]
 [ 0.8512989  -0.3125015   1.        ]]


In [4]:
from numpyro.infer.reparam import TransformReparam
with handlers.seed(rng_seed=0): 
    with numpyro.plate('n_src', 10):
        with numpyro.handlers.reparam(config={"params":TransformReparam()}):
            params = numpyro.sample('params', dist.TransformedDistribution(dist.Normal(0.0,1.0),dist.transforms.AffineTransform(jnp.ones((10,2)).T,jnp.ones((10,2)).T)))


In [5]:
params

DeviceArray([[ 1.5199829 ,  0.7265288 ,  1.3115745 ,  1.7302203 ,
               3.0017805 ,  1.4927709 ,  1.4523036 ,  1.884373  ,
               1.0006878 ,  0.2674694 ],
             [ 1.8331263 ,  0.47367555,  0.6601328 ,  0.5734397 ,
               1.6678789 , -0.13601577,  0.69445187, -0.9193239 ,
              -1.1068289 ,  1.579773  ]], dtype=float32)

In [6]:
jnp.stack((jnp.arange(0,10.0),jnp.full(10,0.5))).shape

(2, 10)

In [13]:
with handlers.seed(rng_seed=0): 
    with numpyro.plate('n_src', 10):
        numpyro.sample('trunc_z',dist.TruncatedNormal(jnp.array([0.0]),jnp.full(10,4.0),jnp.full(10,4.0)))

In [10]:
jnp.full(4,10.0)

DeviceArray([10., 10., 10., 10.], dtype=float32)

(10,)