Example notebook showing how to use the nested sampler

In [1]:
import os
import sys
import argparse
import torch
from getdist import plots, MCSamples
import getdist
import numpy as np
from scipy.stats import multivariate_normal

In [2]:
path = os.path.realpath(os.path.join(os.getcwd(), '../..'))
sys.path.insert(0, path)

In [3]:
from nnest import NestedSampler
from nnest.likelihoods import *

In [4]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [5]:
# Likelihood + prior
#like = Himmelblau(2)
#transform = lambda x: 5*x
like = Rosenbrock(10)
transform = lambda x: 5*x
#like = Gaussian(2, 0.9)
#transform = lambda x: 3*x
#like = Eggbox(2)
#transform = lambda x: 5*np.pi*x
#like = GaussianShell(2)
#transform = lambda x: 5*x
#like = GaussianMix(2)
#transform = lambda x: 5*x

In [6]:
sampler = NestedSampler(like.x_dim, like, transform=transform, num_live_points=2000, hidden_dim=16, num_blocks=3, flow='spline')

Creating directory for new run logs/test/run50
[nnest.trainer] [INFO] SingleSpeedSpline(
  (flow): NormalizingFlow(
    (flows): ModuleList(
      (0): ActNorm()
      (1): Invertible1x1Conv()
      (2): NSF_CL(
        (f1): MLP(
          (net): Sequential(
            (0): Linear(in_features=5, out_features=16, bias=True)
            (1): LeakyReLU(negative_slope=0.2)
            (2): Linear(in_features=16, out_features=16, bias=True)
            (3): LeakyReLU(negative_slope=0.2)
            (4): Linear(in_features=16, out_features=16, bias=True)
            (5): LeakyReLU(negative_slope=0.2)
            (6): Linear(in_features=16, out_features=115, bias=True)
          )
        )
        (f2): MLP(
          (net): Sequential(
            (0): Linear(in_features=5, out_features=16, bias=True)
            (1): LeakyReLU(negative_slope=0.2)
            (2): Linear(in_features=16, out_features=16, bias=True)
            (3): LeakyReLU(negative_slope=0.2)
            (4): Linear(in_f

In [None]:
sampler.run(strategy=['rejection_prior', 'density_flow'], volume_switch=0.1)

[nnest.sampler] [INFO] MCMC steps [50]
[nnest.sampler] [INFO] Initial scale [0.6325]
[nnest.sampler] [INFO] Volume switch [0.1000]
[nnest.sampler] [INFO] Step [0] loglstar [-3.2505e+05] max logl [-6.0433e+03] logz [-3.2506e+05] vol [1.00000e+00] ncalls [2001] mean calls [0.0000]
[nnest.sampler] [INFO] Step [400] loglstar [-1.6843e+05] max logl [-6.0433e+03] logz [-1.6844e+05] vol [8.18731e-01] ncalls [2454] mean calls [1.2000]
[nnest.sampler] [INFO] Step [800] loglstar [-1.4042e+05] max logl [-6.0433e+03] logz [-1.4043e+05] vol [6.70320e-01] ncalls [2974] mean calls [1.5000]
[nnest.sampler] [INFO] Step [1200] loglstar [-1.2028e+05] max logl [-6.0433e+03] logz [-1.2029e+05] vol [5.48812e-01] ncalls [3609] mean calls [1.9000]
[nnest.sampler] [INFO] Step [1600] loglstar [-1.0717e+05] max logl [-6.0433e+03] logz [-1.0718e+05] vol [4.49329e-01] ncalls [4431] mean calls [2.5000]
[nnest.sampler] [INFO] Step [2000] loglstar [-9.6120e+04] max logl [-6.0433e+03] logz [-9.6128e+04] vol [3.67879e-

[nnest.trainer] [INFO] Epoch [50] train loss [-0.0221] validation loss [-0.0118]
[nnest.trainer] [INFO] Epoch [74] ran out of patience
[nnest.trainer] [INFO] Best epoch [24] validation loss [-0.0119]
[nnest.sampler] [INFO] Step [14000] loglstar [-2.8812e+03] max logl [-2.1628e+02] logz [-2.8953e+03] vol [9.11882e-04] ncalls [38130] mean calls [3.8000]
[nnest.sampler] [INFO] Step [14400] loglstar [-2.5736e+03] max logl [-1.9390e+02] logz [-2.5875e+03] vol [7.46586e-04] ncalls [38763] mean calls [1.8000]
[nnest.sampler] [INFO] Step [14800] loglstar [-2.3106e+03] max logl [-1.9390e+02] logz [-2.3252e+03] vol [6.11253e-04] ncalls [39523] mean calls [2.0000]
[nnest.sampler] [INFO] Step [15200] loglstar [-2.0970e+03] max logl [-1.9390e+02] logz [-2.1106e+03] vol [5.00451e-04] ncalls [40416] mean calls [1.9000]
[nnest.sampler] [INFO] Step [15600] loglstar [-1.8795e+03] max logl [-1.9390e+02] logz [-1.8937e+03] vol [4.09735e-04] ncalls [41450] mean calls [3.4000]
[nnest.trainer] [INFO] Number 

[nnest.sampler] [INFO] Step [26800] loglstar [-1.9011e+02] max logl [-3.0757e+01] logz [-2.0774e+02] vol [1.51514e-06] ncalls [65886] mean calls [1.4000]
[nnest.sampler] [INFO] Step [27200] loglstar [-1.7756e+02] max logl [-3.0757e+01] logz [-1.9534e+02] vol [1.24050e-06] ncalls [66757] mean calls [3.1000]
[nnest.sampler] [INFO] Step [27600] loglstar [-1.6619e+02] max logl [-3.0757e+01] logz [-1.8396e+02] vol [1.01563e-06] ncalls [67796] mean calls [2.6000]
[nnest.trainer] [INFO] Number of training samples [2000]
[nnest.trainer] [INFO] Training jitter [0.0124]
[nnest.trainer] [INFO] Epoch [1] train loss [-0.1168] validation loss [-0.0597]
[nnest.trainer] [INFO] Epoch [50] train loss [-0.1215] validation loss [-0.0607]
[nnest.trainer] [INFO] Epoch [63] ran out of patience
[nnest.trainer] [INFO] Best epoch [13] validation loss [-0.0608]
[nnest.sampler] [INFO] Step [28000] loglstar [-1.5536e+02] max logl [-3.0757e+01] logz [-1.7322e+02] vol [8.31529e-07] ncalls [68916] mean calls [5.4000]

[nnest.trainer] [INFO] Training jitter [0.0053]
[nnest.trainer] [INFO] Epoch [1] train loss [-0.2031] validation loss [-0.1036]
[nnest.trainer] [INFO] Epoch [50] train loss [-0.2068] validation loss [-0.1041]
[nnest.trainer] [INFO] Epoch [58] ran out of patience
[nnest.trainer] [INFO] Best epoch [8] validation loss [-0.1042]
[nnest.sampler] [INFO] Step [40000] loglstar [-3.0773e+01] max logl [-9.1441e+00] logz [-5.2480e+01] vol [2.06115e-09] ncalls [94964] mean calls [3.9000]
[nnest.sampler] [INFO] Step [40400] loglstar [-2.9390e+01] max logl [-9.1441e+00] logz [-5.1360e+01] vol [1.68753e-09] ncalls [95701] mean calls [2.0000]
[nnest.sampler] [INFO] Step [40800] loglstar [-2.8203e+01] max logl [-9.1441e+00] logz [-5.0257e+01] vol [1.38163e-09] ncalls [96556] mean calls [2.0000]
[nnest.sampler] [INFO] Step [41200] loglstar [-2.7166e+01] max logl [-9.1441e+00] logz [-4.9286e+01] vol [1.13119e-09] ncalls [97441] mean calls [2.8000]
[nnest.sampler] [INFO] Step [41600] loglstar [-2.6115e+01

In [None]:
print(sampler.logz)

In [None]:
mc = MCSamples(samples=sampler.samples, weights=sampler.weights, loglikes=-sampler.loglikes)

In [None]:
print(mc.getEffectiveSamples())
print(mc.getMargeStats())
print(mc.likeStats)

In [None]:
g = plots.getSubplotPlotter(width_inch=8)
g.triangle_plot(mc, filled=True)