In [98]:
import jax
import jax.numpy as jnp
import numpyro
from numpyro import infer
import numpy as np
from chainconsumer import ChainConsumer
import matplotlib.pylab as plt
import warnings
from numpyro.contrib.nested_sampling import NestedSampler #Need to have jaxns v 1.1.2 or earlier for this to work
from numpyro.infer.util import transform_fn

warnings.filterwarnings("ignore", category=FutureWarning)
numpyro.set_host_device_count(1)
jax.config.update('jax_platform_name', 'cpu')

In [118]:
nchains  = 20
nburn    = 0
nsamples = 20

xmin, xmax = 0, 0.0001
vmin, vmax = -20,20

In [119]:
def model_single():
    x  = numpyro.sample('x', numpyro.distributions.Uniform(xmin,xmax))
    v_0= numpyro.sample('v_0', numpyro.distributions.Uniform(vmin,vmax))
    v_1= numpyro.sample('v_1', numpyro.distributions.Uniform(vmin,vmax))
    v_2= numpyro.sample('v_2', numpyro.distributions.Uniform(vmin,vmax))

def model_listed():
    x = numpyro.sample('x', numpyro.distributions.Uniform(xmin,xmax))
    v = numpyro.sample('v', numpyro.distributions.Uniform(vmin,vmax), sample_shape=(3,))


In [120]:
ns_single = NestedSampler(model_single, constructor_kwargs={"num_live_points":50*4*(2+1)/10, "max_samples":50*4*(2+1)/10})
ns_single.run(jax.random.PRNGKey(0))
ns_results_single = ns_single.get_samples(jax.random.PRNGKey(0),nchains)

ns_listed = NestedSampler(model_listed, constructor_kwargs={"num_live_points":50*4*(2+1)/10, "max_samples":50*4*(2+1)/10})
ns_listed.run(jax.random.PRNGKey(0))
ns_results_listed = ns_listed.get_samples(jax.random.PRNGKey(0),nchains)

print("ns done")

ns done


In [121]:
NUTS_single = numpyro.infer.MCMC(
    infer.NUTS(model = model_single),
    num_warmup  =0,
    num_samples =nchains,
    num_chains  =1,
    progress_bar=False,
    chain_method = 'sequential',
)
NUTS_single.run(jax.random.PRNGKey(1))
NUTS_results_single = NUTS_single.get_samples()

NUTS_listed= numpyro.infer.MCMC(
    infer.NUTS(model = model_listed),
    num_warmup  =0,
    num_samples =nchains,
    num_chains  =1,
    progress_bar=False,
    chain_method = 'sequential',
)
NUTS_listed.run(jax.random.PRNGKey(1))
NUTS_results_listed = NUTS_listed.get_samples()

print("NUTS done")

NUTS done


In [122]:
for source in [ns_results_single,ns_results_listed,NUTS_results_single,NUTS_results_listed]:
    print("======================")
    for key in source.keys():
        print(key,"\n",source[key])

v_0 
 [  6.982317  -18.08641     6.340165   18.607283    5.7649803  18.702698
  -7.6079607 -15.00371    14.190512   16.208754  -18.08641     7.24391
  12.427044   -5.3977966 -19.286732    6.982317    9.254284  -18.22224
   5.7649803   6.340165 ]
v_1 
 [  3.8061714  -19.636059   -17.91304      4.0143156   11.584687
 -17.627974     7.864876   -12.330785    -8.887463    -5.7244253
 -19.636059     7.525935    15.541844    18.51957    -16.621342
   3.8061714    0.84996223   0.13863564  11.584687   -17.91304   ]
v_2 
 [  4.7610474  -3.0691671  -3.6883307  16.267357  -15.212798   -3.761592
 -11.000161   12.396107   10.970707  -10.237918   -3.0691671   5.407796
 -11.954021    4.4276285  13.9311695   4.7610474   0.6894016  -3.1046867
 -15.212798   -3.6883307]
x 
 [6.46858825e-05 5.15633583e-05 7.56085647e-05 2.81645771e-05
 3.34266297e-05 8.42745503e-05 9.30176247e-05 2.51805773e-06
 7.86664605e-05 1.40842430e-05 5.15633583e-05 2.02704655e-06
 7.45608804e-06 6.30281211e-05 8.23689188e-05 6.4685

Try to transform the listed samples:

In [173]:
transforms_listed = {"x": numpyro.distributions.biject_to(numpyro.distributions.Uniform(xmin, xmax).support),
              "v": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support)}

transforms_single = {"x": numpyro.distributions.biject_to(numpyro.distributions.Uniform(xmin, xmax).support),
                    "v_0": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support),
                    "v_1": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support),
                    "v_2": numpyro.distributions.biject_to(numpyro.distributions.Uniform(vmin, vmax).support),}


starts = [ns_results_single,ns_results_listed,NUTS_results_single,NUTS_results_listed]
transforms = [transforms_single,transforms_listed,transforms_single,transforms_listed]

for source, transform in zip(starts,transforms):
    print("============================")
    transformed_start= tformed_params_listed = transform_fn(transform, source, invert=True)
    for key in transformed_start.keys():
        print(key,"\n",transformed_start[key])



IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.

In [124]:
print("Beginning sampling...")
sampler = numpyro.infer.MCMC(
    infer.NUTS(model = model),
    num_warmup  =nburn,
    num_samples =nsamples,
    num_chains  =nchains,
    progress_bar=False,
    chain_method = 'sequential',
)

sampler.run(jax.random.PRNGKey(1))
output=sampler.get_samples()
print("Done")

Beginning sampling...
Done


In [171]:
import re
def unflatten_dict(samples):
    #------
    # Get names
    names = []
    keys = samples.keys()
    for key in keys:
        if bool(re.search(".*_[0-9]", key)): names.append(key[:-2])

    counts = [names.count(name) for name in np.unique(names)]
    names = np.unique(names)
    
    out = {key: samples[key] for key in keys if not bool(re.search(".*_[0-9]", key))}
    
    print(out)
    print(names,counts)
    
    #------
    # Assemble
    for name, count in zip(names,counts):
        N = len(samples[name+"_0"])
        print(name,count,N)
        
        to_add = {name: []}        
        for j in range(N): #For each row
            to_append = [0]*count
            for i in range(count): #Get the values from each name
                to_append[i]=samples[name+"_"+str(i)][j]
            to_append = jnp.array(to_append)
            to_add[name].append(to_append)
            
        out = out | to_add
        
    return(out)
    

test = {"c_0":[0,1,2], "a_0":[0,1,2], "a_1":[0,1,2], "x":[0,1,2], "y":[0,1,2], "b_0":[0,1,2], "b_1":[0,1,2], "b_2":[0,1,2]}
unflatten_dict(test)

{'x': [0, 1, 2], 'y': [0, 1, 2]}
['a' 'b' 'c'] [2, 3, 1]
a 2 3
b 3 3
c 1 3


{'x': [0, 1, 2],
 'y': [0, 1, 2],
 'a': [Array([0, 0], dtype=int32),
  Array([1, 1], dtype=int32),
  Array([2, 2], dtype=int32)],
 'b': [Array([0, 0, 0], dtype=int32),
  Array([1, 1, 1], dtype=int32),
  Array([2, 2, 2], dtype=int32)],
 'c': [Array([0], dtype=int32),
  Array([1], dtype=int32),
  Array([2], dtype=int32)]}

In [164]:
test2 = {key: test[key] for key in test.keys() if len(key)<3 or key[-2]!="_"}
print(test2)

{'x': [0, 1, 2], 'y': [0, 1, 2]}
