Skip to content

Commit

Permalink
Add config for mh, unify mcmc result, add bayes net example with mh
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Aug 4, 2020
1 parent 5e59e69 commit c96953f
Show file tree
Hide file tree
Showing 19 changed files with 240 additions and 122 deletions.
7 changes: 7 additions & 0 deletions benchmark/CMakeLists.txt
@@ -1,3 +1,10 @@
add_executable(mh_bayes_net ${CMAKE_CURRENT_SOURCE_DIR}/mh_bayes_net.cpp)
target_include_directories(mh_bayes_net PRIVATE ${GBENCH_DIR}/include ${AUTOPPL_INCLUDE_DIRS})
target_link_libraries(mh_bayes_net benchmark benchmark_main ${AUTOPPL_LIBS})
if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
target_link_libraries(mh_bayes_net pthread)
endif()

add_executable(normal_two_prior_distribution ${CMAKE_CURRENT_SOURCE_DIR}/normal_two_prior_distribution.cpp)
target_include_directories(normal_two_prior_distribution PRIVATE ${GBENCH_DIR}/include ${AUTOPPL_INCLUDE_DIRS})
target_link_libraries(normal_two_prior_distribution benchmark benchmark_main ${AUTOPPL_LIBS})
Expand Down
60 changes: 60 additions & 0 deletions benchmark/mh_bayes_net.cpp
@@ -0,0 +1,60 @@
#include <random>
#include <benchmark/benchmark.h>
#include "benchmark_utils.hpp"
#include <autoppl/expression/expr_builder.hpp>
#include <autoppl/expression/variable/data.hpp>
#include <autoppl/expression/variable/param.hpp>
#include <autoppl/mcmc/mh/mh.hpp>

namespace ppl {

static void BM_MHBayesNet(benchmark::State& state) {
size_t n_samples = state.range(0);
constexpr size_t n_data = 1000;

std::bernoulli_distribution b(0.341);
std::mt19937 gen(0);
ppl::Data<util::disc_param_t, ppl::vec> y(n_data);

ppl::Param<double> p1, p2, m1, m2, M1, M2;
ppl::Param<int> w;
auto model = (
m1 |= ppl::uniform(0., 1.),
m2 |= ppl::uniform(0., 1.),
M1 |= ppl::uniform(0., 1.),
M2 |= ppl::uniform(0., 1.),
p1 |= ppl::uniform(m1, M1),
p2 |= ppl::uniform(m2, M2),
w |= ppl::bernoulli(0.3 * p1),
y |= ppl::bernoulli(w * p1 + (1-w) * p2)
);

for (size_t i = 0; i < n_data; ++i) {
y.get()(i) = b(gen);
}

ppl::MCMCResult res;

ppl::MHConfig config;
config.warmup = n_samples;
config.samples = n_samples;

for (auto _ : state) {
res = ppl::mh(model, config);
}

ppl::summary("m1, m2, M1, M2, p1, p2",
res.cont_samples,
res.warmup_time,
res.sampling_time);
ppl::summary("w",
res.disc_samples.cast<double>(),
res.warmup_time,
res.sampling_time);
}

BENCHMARK(BM_MHBayesNet)->Arg(100000)
->Arg(200000)
->Arg(300000)
;
} // namespace ppl
4 changes: 2 additions & 2 deletions benchmark/normal_two_prior_distribution.cpp
Expand Up @@ -28,10 +28,10 @@ static void BM_NormalTwoPrior(benchmark::State& state) {
}

ppl::NUTSConfig<> config;
config.n_samples = n_samples;
config.samples = n_samples;
config.warmup = n_samples;

ppl::NUTSResult res(0,0);
ppl::MCMCResult res;

for (auto _ : state) {
res = ppl::nuts(model, config);
Expand Down
4 changes: 2 additions & 2 deletions benchmark/regression_autoppl.cpp
Expand Up @@ -59,9 +59,9 @@ static void BM_Regression(benchmark::State& state) {
// perform NUTS sampling
NUTSConfig<> config;
config.warmup = num_samples;
config.n_samples = num_samples;
config.samples = num_samples;

NUTSResult res(0,0);
MCMCResult res;

for (auto _ : state) {
res = ppl::nuts(model, config);
Expand Down
2 changes: 1 addition & 1 deletion benchmark/regression_autoppl_2.cpp
Expand Up @@ -41,7 +41,7 @@ static void BM_Regression(benchmark::State& state) {
w |= ppl::normal(0., 5.),
y |= ppl::normal(ppl::dot(X, w) + b, 1.0));

NUTSResult res(0,0);
MCMCResult res;

for (auto _ : state) {
res = ppl::nuts(model);
Expand Down
2 changes: 1 addition & 1 deletion include/autoppl/autoppl.hpp
@@ -1,6 +1,6 @@
#pragma once
#include "util/traits/traits.hpp"
#include "expression/expr_builder.hpp"
#include "mcmc/mh.hpp"
#include "mcmc/mh/mh.hpp"
#include "mcmc/hmc/nuts/nuts.hpp"
#include "math/ess.hpp"
17 changes: 2 additions & 15 deletions include/autoppl/math/welford.hpp
Expand Up @@ -25,24 +25,11 @@ struct WelfordVar
{
++n_;
auto delta = x - mean_;
mean_ += delta/static_cast<double>(n_);
mean_ += (1./static_cast<double>(n_)) * delta;
m2n_ += (delta.array() * (x - mean_).array()).matrix();
}

/**
* Populate v with sample variance vector.
* If sample size is not greater than 1, v is zeroed out.
*/
template <class MatType>
void get_variance(MatType& v)
{
if (n_ > 1) {
v = m2n_/static_cast<double>(n_ - 1);
} else {
v.setZero();
}
}

const auto& get_variance() const { return m2n_; }
size_t get_n_samples() const { return n_; }

/**
Expand Down
14 changes: 14 additions & 0 deletions include/autoppl/mcmc/config_base.hpp
@@ -0,0 +1,14 @@
#pragma once
#include <cstddef>
#include <autoppl/mcmc/sampler_tools.hpp>

namespace ppl {

struct ConfigBase
{
size_t warmup = 1000;
size_t samples = 1000;
size_t seed = mcmc::random_seed();
};

} // namespace ppl
1 change: 1 addition & 0 deletions include/autoppl/mcmc/hmc/momentum_handler.hpp
Expand Up @@ -93,6 +93,7 @@ struct MomentumHandler<diag_var>
{ return (m_inverse_.array() * rho.array()).matrix(); }

variance_t& get_m_inverse() { return m_inverse_; }
const variance_t& get_m_inverse() const { return m_inverse_; }

private:
std::normal_distribution<> dist;
Expand Down
20 changes: 2 additions & 18 deletions include/autoppl/mcmc/hmc/nuts/configs.hpp
Expand Up @@ -3,21 +3,19 @@
#include <autoppl/mcmc/hmc/step_adapter.hpp>
#include <autoppl/mcmc/hmc/momentum_handler.hpp>
#include <autoppl/mcmc/sampler_tools.hpp>
#include <autoppl/mcmc/config_base.hpp>

namespace ppl {

/**
* User configuration for NUTS algorithm.
*/
template <class VarAdapterPolicy=diag_var>
struct NUTSConfig
struct NUTSConfig: ConfigBase
{
using var_adapter_policy_t = VarAdapterPolicy;

// configuration for sampling
size_t warmup = 1000;
size_t n_samples = 1000;
size_t seed = mcmc::random_seed();
size_t max_depth = 10;

// configuration for step-size adaptation
Expand All @@ -37,18 +35,4 @@ struct nuts_config_traits
typename NUTSConfigType::var_adapter_policy_t;
};

/**
* Output result for NUTS
*/
struct NUTSResult
{
NUTSResult(size_t n_samples, size_t n_params)
: cont_samples(n_samples, n_params)
{}

Eigen::MatrixXd cont_samples;
double warmup_time;
double sampling_time;
};

} // namespace ppl
43 changes: 27 additions & 16 deletions include/autoppl/mcmc/hmc/nuts/nuts.hpp
@@ -1,6 +1,5 @@
#pragma once
#include <type_traits>
#include <iostream>
#include <Eigen/Dense>
#include <fastad_bits/reverse/core/var_view.hpp>
#include <fastad_bits/reverse/core/eval.hpp>
Expand All @@ -14,6 +13,7 @@
#include <autoppl/mcmc/hmc/leapfrog.hpp>
#include <autoppl/mcmc/hmc/hamiltonian.hpp>
#include <autoppl/mcmc/hmc/nuts/configs.hpp>
#include <autoppl/mcmc/result.hpp>

namespace ppl {
namespace mcmc {
Expand Down Expand Up @@ -41,10 +41,20 @@ bool check_entropy(const MatType1& rho,
* Building binary tree for sampling candidates.
* Helper function to obtain the forward/backward-most position and momentum.
* Accept/reject policy is based on UniformDistType parameter and GenType
* By default, uses standard library uniform_real_distribution on (0.,1.).
* By default, GenType usees standard library mt19937.
*
* Note that the caller MUST have input theta_adj already pre-computed.
* Note that the caller, i.e. nuts(), MUST have theta_adj already pre-computed
* (theta_adj is a member of input and input will be an instance of TreeInput).
*
* @param n_params number of (continuous) parameters
* @param input TreeInput-like input object
* @param depth current depth of building tree
* @param unif_sampler an object like std::uniform_distribution(0,1)
* used for metropolis acceptance
* @param gen rng device
* @param momentum_handler MomentumHandler-like object to compute
* kinetic energy and momentum
* @param tree_cache pointer to cache memory that will be used by build_tree.
* The array of doubles must be of size n_params * 7 * max_depth.
*/
template <class InputType
, class UniformDistType
Expand Down Expand Up @@ -210,7 +220,13 @@ TreeOutput build_tree(size_t n_params,

/**
* Finds a reasonable epsilon for NUTS algorithm.
* @param ad_expr AD expression bound to theta and theta_adj
*
* @param eps initial epsilon (see Gelman's paper)
* @param ad_expr AD expression bound to theta and theta_adj
* @param theta vector of theta values
* @param theta_adj vector of theta adjoints
* @param gen rng device
* @param momentum_handler MomentumHandler-like object
*/
template <class ADExprType
, class MatType
Expand All @@ -231,7 +247,6 @@ double find_reasonable_epsilon(double eps,
size_t n_params = theta.rows(); // theta is expected to be vector-like

Eigen::MatrixXd mat(n_params, 3);
mat.setZero();
Eigen::Map<Eigen::VectorXd> r(mat.col(0).data(), n_params);
Eigen::Map<Eigen::VectorXd> theta_orig(mat.col(1).data(), n_params);
Eigen::Map<Eigen::VectorXd> theta_adj_orig(mat.col(2).data(), n_params);
Expand Down Expand Up @@ -300,7 +315,7 @@ double find_reasonable_epsilon(double eps,
*/
template <class ModelType
, class NUTSConfigType = NUTSConfig<>>
NUTSResult nuts(ModelType& model,
MCMCResult nuts(ModelType& model,
NUTSConfigType config = NUTSConfigType())
{
// activate model
Expand Down Expand Up @@ -389,18 +404,17 @@ NUTSResult nuts(ModelType& model,
config.var_config.term_buffer, config.var_config.window_base
);

// construct miscellaneous objects
auto logger = util::ProgressLogger(config.n_samples + config.warmup, "NUTS");
// construct miscellaneous objects
MCMCResult res(config.samples, n_params, 0);
res.name = "nuts";
auto logger = util::ProgressLogger(config.samples + config.warmup, "NUTS");
util::StopWatch<> stopwatch_warmup;
util::StopWatch<> stopwatch_sampling;

// create output object
NUTSResult res(config.n_samples, n_params);

// start timing warmup
stopwatch_warmup.start();

for (size_t i = 0; i < config.n_samples + config.warmup; ++i) {
for (size_t i = 0; i < config.samples + config.warmup; ++i) {

// if warmup is finished, stop timing warmup and start timing sampling
if (i == config.warmup) {
Expand Down Expand Up @@ -569,9 +583,6 @@ NUTSResult nuts(ModelType& model,
// stop timing sampling
stopwatch_sampling.stop();

// end progress bar with a newline
std::cout << std::endl;

// save output results
res.warmup_time = stopwatch_warmup.elapsed();
res.sampling_time = stopwatch_sampling.elapsed();
Expand Down
5 changes: 3 additions & 2 deletions include/autoppl/mcmc/hmc/var_adapter.hpp
Expand Up @@ -113,10 +113,11 @@ struct VarAdapter<diag_var>
// if currently at the end of the window,
// get updated variance and reset estimator
if (counter_ == window_end_ - 1) {
var_estimator_.get_variance(var);
auto&& v = var_estimator_.get_variance();
double n = var_estimator_.get_n_samples();
// regularized sample variance (see STAN)
var = ((n / (n + 5.0)) * var.array() + 1e-3 * (5.0 / (n + 5.0))).matrix();
var.array() = ( (n / ((n + 5.0) * (n - 1.))) * v.array() +
1e-3 * (5.0 / (n + 5.0)) );
var_estimator_.reset();
shift_window();
++counter_;
Expand Down
12 changes: 12 additions & 0 deletions include/autoppl/mcmc/mh/config.hpp
@@ -0,0 +1,12 @@
#pragma once
#include <autoppl/mcmc/config_base.hpp>

namespace ppl {

struct MHConfig : ConfigBase
{
double sigma = 1.0;
double alpha = 0.25;
};

} // namespace ppl

0 comments on commit c96953f

Please sign in to comment.