In [98]:
import jax
import jax.numpy as jnp
import numpyro
from numpyro import infer
import numpy as np
from chainconsumer import ChainConsumer
import matplotlib.pylab as plt
import warnings
from numpyro.contrib.nested_sampling import NestedSampler #Need to have jaxns v 1.1.2 or earlier for this to work
from numpyro.infer.util import transform_fn

warnings.filterwarnings("ignore", category=FutureWarning)
numpyro.set_host_device_count(1)
jax.config.update('jax_platform_name', 'cpu')

In [118]:
nchains  = 20
nburn    = 0
nsamples = 20

xmin, xmax = 0, 0.0001
vmin, vmax = -20,20

In [119]:
def model_single():
    x  = numpyro.sample('x', numpyro.distributions.Uniform(xmin,xmax))
    v_0= numpyro.sample('v_0', numpyro.distributions.Uniform(vmin,vmax))
    v_1= numpyro.sample('v_1', numpyro.distributions.Uniform(vmin,vmax))
    v_2= numpyro.sample('v_2', numpyro.distributions.Uniform(vmin,vmax))

def model_listed():
    x = numpyro.sample('x', numpyro.distributions.Uniform(xmin,xmax))
    v = numpyro.sample('v', numpyro.distributions.Uniform(vmin,vmax), sample_shape=(3,))


In [120]:
ns_single = NestedSampler(model_single, constructor_kwargs={"num_live_points":50*4*(2+1)/10, "max_samples":50*4*(2+1)/10})
ns_single.run(jax.random.PRNGKey(0))
ns_results_single = ns_single.get_samples(jax.random.PRNGKey(0),nchains)

ns_listed = NestedSampler(model_listed, constructor_kwargs={"num_live_points":50*4*(2+1)/10, "max_samples":50*4*(2+1)/10})
ns_listed.run(jax.random.PRNGKey(0))
ns_results_listed = ns_listed.get_samples(jax.random.PRNGKey(0),nchains)

print("ns done")

ns done


In [121]:
NUTS_single = numpyro.infer.MCMC(
    infer.NUTS(model = model_single),
    num_warmup  =0,
    num_samples =nchains,
    num_chains  =1,
    progress_bar=False,
    chain_method = 'sequential',
)
NUTS_single.run(jax.random.PRNGKey(1))
NUTS_results_single = NUTS_single.get_samples()

NUTS_listed= numpyro.infer.MCMC(
    infer.NUTS(model = model_listed),
    num_warmup  =0,
    num_samples =nchains,
    num_chains  =1,
    progress_bar=False,
    chain_method = 'sequential',
)
NUTS_listed.run(jax.random.PRNGKey(1))
NUTS_results_listed = NUTS_listed.get_samples()

print("NUTS done")

NUTS done


In [122]:
for source in [ns_results_single,ns_results_listed,NUTS_results_single,NUTS_results_listed]:
    print("======================")
    for key in source.keys():
        print(key,"\n",source[key])

v_0 
 [  6.982317  -18.08641     6.340165   18.607283    5.7649803  18.702698
  -7.6079607 -15.00371    14.190512   16.208754  -18.08641     7.24391
  12.427044   -5.3977966 -19.286732    6.982317    9.254284  -18.22224
   5.7649803   6.340165 ]
v_1 
 [  3.8061714  -19.636059   -17.91304      4.0143156   11.584687
 -17.627974     7.864876   -12.330785    -8.887463    -5.7244253
 -19.636059     7.525935    15.541844    18.51957    -16.621342
   3.8061714    0.84996223   0.13863564  11.584687   -17.91304   ]
v_2 
 [  4.7610474  -3.0691671  -3.6883307  16.267357  -15.212798   -3.761592
 -11.000161   12.396107   10.970707  -10.237918   -3.0691671   5.407796
 -11.954021    4.4276285  13.9311695   4.7610474   0.6894016  -3.1046867
 -15.212798   -3.6883307]
x 
 [6.46858825e-05 5.15633583e-05 7.56085647e-05 2.81645771e-05
 3.34266297e-05 8.42745503e-05 9.30176247e-05 2.51805773e-06
 7.86664605e-05 1.40842430e-05 5.15633583e-05 2.02704655e-06
 7.45608804e-06 6.30281211e-05 8.23689188e-05 6.4685

Try to transform the listed samples:

In [174]:
transforms_listed = {"x": numpyro.distributions.biject_to(numpyro.distributions.Uniform(xmin, xmax).support),
              "v": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support)}

transforms_single = {"x": numpyro.distributions.biject_to(numpyro.distributions.Uniform(xmin, xmax).support),
                    "v_0": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support),
                    "v_1": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support),
                    "v_2": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support),}


starts = [ns_results_single,ns_results_listed,NUTS_results_single,NUTS_results_listed]
transforms = [transforms_single,transforms_listed,transforms_single,transforms_listed]

for source, transform in zip(starts,transforms):
    print("============================")
    transformed_start= tformed_params_listed = transform_fn(transform, source, invert=True)
    for key in transformed_start.keys():
        print(key,"\n",transformed_start[key])

v_0 
 [ 0.7288731  -2.9908767   0.6566352   3.322186    0.59331113  3.3956225
 -0.8010499  -1.9467584   1.7724562   2.2566068  -2.9908767   0.7588211
  1.4544092  -0.55349    -4.0087843   0.7288731   1.0015188  -3.0680637
  0.59331113  0.6566352 ]
v_1 
 [ 0.38531464 -4.690502   -2.8995867   0.40695658  1.3226194  -2.764003
  0.8312629  -1.4388056  -0.9553337  -0.5888909  -4.690502    0.79147696
  2.0759757   3.2588327  -2.3831525   0.38531464  0.08504745  0.01386381
  1.3226194  -2.8995867 ]
v_2 
 [ 0.48541692 -0.30936062 -0.37310183  2.2738013  -1.9954635  -0.3806913
 -1.2367858   1.4493779   1.2325677  -1.130591   -0.30936062  0.55456865
 -1.3791256   0.45021653  1.7211676   0.48541692  0.06896748 -0.31299922
 -1.9954635  -0.37310183]
x 
 [ 0.6052604   0.06255467  1.1313378  -0.9363126  -0.68895173  1.6787999
  2.5894     -3.6561792   1.3049364  -1.8083105   0.06255467 -3.8781116
 -2.5186524   0.5334234   1.541545    0.6052604  -2.014918   -4.4738665
 -0.68895173  1.1313378 ]
v 
 [ 1

In [199]:
import re
def unflatten_dict(samples):
    #------
    # Get names
    names = []
    keys = samples.keys()
    for key in keys:
        if bool(re.search(".*_[0-9]", key)): names.append(key[:-2])

    counts = [names.count(name) for name in np.unique(names)]
    names = np.unique(names)
    
    out = {key: samples[key] for key in keys if not bool(re.search(".*_[0-9]", key))}
    
    print(out)
    print(names,counts)
    
    #------
    # Assemble
    for name, count in zip(names,counts):
        N = len(samples[name+"_0"])
        print(name,count,N)
        
        to_add = {name: []}        
        for j in range(N): #For each row
            to_append = [0]*count
            for i in range(count): #Get the values from each name
                to_append[i]=samples[name+"_"+str(i)][j]
            to_append = jnp.array(to_append)
            to_add[name].append(to_append)
            
        out = out | to_add
        
        out_sorted = {}
        for key in sorted(out.keys()):
            out_sorted = out_sorted | {key:jnp.array(out[key])}
    return(out_sorted)

test = {"c_0":[0,1,2], "a_0":[0,1,2], "a_1":[0.1,1.1,2.1], "x":[0,1,2], "y":[0,1,2], "b_0":[0,1,2], "b_1":[0.1,1.1,2.1], "b_2":[0.2,1.2,2.2]}
unflatten_dict(test)

{'x': [0, 1, 2], 'y': [0, 1, 2]}
['a' 'b' 'c'] [2, 3, 1]
a 2 3
b 3 3
c 1 3


{'a': Array([[0. , 0.1],
        [1. , 1.1],
        [2. , 2.1]], dtype=float32),
 'b': Array([[0. , 0.1, 0.2],
        [1. , 1.1, 1.2],
        [2. , 2.1, 2.2]], dtype=float32),
 'c': Array([[0],
        [1],
        [2]], dtype=int32),
 'x': Array([0, 1, 2], dtype=int32),
 'y': Array([0, 1, 2], dtype=int32)}

In [198]:
test2 = {key: test[key] for key in test.keys() if len(key)<3 or key[-2]!="_"}
print(test2)

{'x': [0, 1, 2], 'y': [0, 1, 2]}


In [200]:
print(transforms_single)
tformed_starts = transform_fn(transforms_single, ns_results_single, invert=True)
print(tformed_starts)
tformed_starts = unflatten_dict(tformed_starts)
print(tformed_starts)


print("Beginning sampling...")
sampler = numpyro.infer.MCMC(
    infer.NUTS(model = model_listed, step_size = 0),
    num_warmup  =nburn,
    num_samples =nsamples,
    num_chains  =nchains,
    progress_bar=False,
    chain_method = 'sequential',
)

sampler.run(jax.random.PRNGKey(1), init_params=tformed_starts)
output=sampler.get_samples()
print("Done")

{'x': <numpyro.distributions.transforms.ComposeTransform object at 0x7f4c744e7490>, 'v_0': <numpyro.distributions.transforms.ComposeTransform object at 0x7f4c744e5f00>, 'v_1': <numpyro.distributions.transforms.ComposeTransform object at 0x7f4c744e5fc0>, 'v_2': <numpyro.distributions.transforms.ComposeTransform object at 0x7f4c744e60b0>}
{'v_0': Array([ 0.7288731 , -2.9908767 ,  0.6566352 ,  3.322186  ,  0.59331113,
        3.3956225 , -0.8010499 , -1.9467584 ,  1.7724562 ,  2.2566068 ,
       -2.9908767 ,  0.7588211 ,  1.4544092 , -0.55349   , -4.0087843 ,
        0.7288731 ,  1.0015188 , -3.0680637 ,  0.59331113,  0.6566352 ],      dtype=float32), 'v_1': Array([ 0.38531464, -4.690502  , -2.8995867 ,  0.40695658,  1.3226194 ,
       -2.764003  ,  0.8312629 , -1.4388056 , -0.9553337 , -0.5888909 ,
       -4.690502  ,  0.79147696,  2.0759757 ,  3.2588327 , -2.3831525 ,
        0.38531464,  0.08504745,  0.01386381,  1.3226194 , -2.8995867 ],      dtype=float32), 'v_2': Array([ 0.48541692,

In [201]:
print(output)
print("================")
print(tformed_starts)

{'v': Array([[  6.982317 ,   3.8061714,   4.7610474],
       [  6.982317 ,   3.8061714,   4.7610474],
       [  6.982317 ,   3.8061714,   4.7610474],
       ...,
       [  6.340165 , -17.91304  ,  -3.6883307],
       [  6.340165 , -17.91304  ,  -3.6883307],
       [  6.340165 , -17.91304  ,  -3.6883307]], dtype=float32), 'x': Array([6.46858825e-05, 6.46858825e-05, 6.46858825e-05, 6.46858825e-05,
       6.46858825e-05, 6.46858825e-05, 6.46858825e-05, 6.46858825e-05,
       6.46858825e-05, 6.46858825e-05, 6.46858825e-05, 6.46858825e-05,
       6.46858825e-05, 6.46858825e-05, 6.46858825e-05, 6.46858825e-05,
       6.46858825e-05, 6.46858825e-05, 6.46858825e-05, 6.46858825e-05,
       5.15633583e-05, 5.15633583e-05, 5.15633583e-05, 5.15633583e-05,
       5.15633583e-05, 5.15633583e-05, 5.15633583e-05, 5.15633583e-05,
       5.15633583e-05, 5.15633583e-05, 5.15633583e-05, 5.15633583e-05,
       5.15633583e-05, 5.15633583e-05, 5.15633583e-05, 5.15633583e-05,
       5.15633583e-05, 5.15633583