From f79505cfa5107784a58469a14319744592e4f76b Mon Sep 17 00:00:00 2001 From: James Yang Date: Tue, 28 Apr 2020 13:00:11 -0400 Subject: [PATCH] Add support for multiple samples for NUTS with AD expr --- include/autoppl/algorithm/nuts.hpp | 23 ++-- .../expression/distribution/normal.hpp | 17 ++- include/autoppl/expression/model/eq_node.hpp | 32 +++++- include/autoppl/expression/variable/binop.hpp | 7 +- .../autoppl/expression/variable/constant.hpp | 3 +- .../expression/variable/variable_viewer.hpp | 24 +++-- include/autoppl/util/dist_expr_traits.hpp | 4 +- include/autoppl/variable.hpp | 3 + test/CMakeLists.txt | 1 + test/ad_integration_unittest.cpp | 95 ++++++++++++++++ test/algorithm/mh_regression_unittest.cpp | 13 ++- test/algorithm/nuts_unittest.cpp | 101 ++++++++++++++---- test/algorithm/sampler_tools_unittest.cpp | 1 + test/expression/variable/data_unittest.cpp | 2 + 14 files changed, 259 insertions(+), 67 deletions(-) create mode 100644 test/ad_integration_unittest.cpp diff --git a/include/autoppl/algorithm/nuts.hpp b/include/autoppl/algorithm/nuts.hpp index 5b402d14..c3d0c6c7 100644 --- a/include/autoppl/algorithm/nuts.hpp +++ b/include/autoppl/algorithm/nuts.hpp @@ -263,8 +263,7 @@ template double find_reasonable_epsilon(ADExprType& ad_expr, MatType& theta, - MatType& theta_adj, - size_t max_iter) + MatType& theta_adj) { double eps = 1.; const double diff_bound = -std::log(2); @@ -296,8 +295,7 @@ double find_reasonable_epsilon(ADExprType& ad_expr, int a = 2*(ham_curr - ham_orig > diff_bound) - 1; - while ((a * (ham_curr - ham_orig) > a * diff_bound) && - max_iter--) { + while ((a * (ham_curr - ham_orig) > a * diff_bound)) { eps *= std::pow(2, a); @@ -333,8 +331,8 @@ void nuts(ModelType& model, size_t n_adapt, size_t seed = 0, size_t max_depth = 10, - double delta = 0.6, - size_t max_init_iter = 10 + double delta = 0.6 + //size_t max_init_iter = 10 ) { @@ -397,26 +395,25 @@ void nuts(ModelType& model, // initialize model tags using model specs // copies the initialized values into theta_curr - // initialize potential energy alg::init_params(model, gen); auto theta_curr_it = theta_curr.begin(); - double potential_prev = 0.; - auto copy_params_potential = [=, &potential_prev](const auto& eq_node) mutable { + auto copy_params_potential = [=](const auto& eq_node) mutable { const auto& var = eq_node.get_variable(); - const auto& dist = eq_node.get_distribution(); using var_t = std::decay_t; if constexpr (util::is_param_v) { *theta_curr_it = var.get_value(); ++theta_curr_it; - potential_prev += dist.log_pdf_no_constant(var.get_value()); } }; model.traverse(copy_params_potential); + // initialize current potential (will be "previous" starting in for-loop) + double potential_prev = ad::evaluate(theta_curr_ad_expr); + // initialize rest of the metavariables double log_eps = std::log(alg::find_reasonable_epsilon( - theta_curr_ad_expr, theta_curr, theta_curr_adj, max_init_iter)); - const double mu = std::log(10) + log_eps; + theta_curr_ad_expr, theta_curr, theta_curr_adj)); + const double mu = std::log(10.) + log_eps; // tree output struct type using subview_t = std::decay_t; diff --git a/include/autoppl/expression/distribution/normal.hpp b/include/autoppl/expression/distribution/normal.hpp index a37e3592..92c41565 100644 --- a/include/autoppl/expression/distribution/normal.hpp +++ b/include/autoppl/expression/distribution/normal.hpp @@ -42,22 +42,17 @@ struct Normal : util::DistExpr> return -0.5 * ((z_score * z_score) + std::log(stddev(index) * stddev(index) * 2 * M_PI)); } - dist_value_t log_pdf_no_constant(value_t x) const - { - dist_value_t z_score = (x - mean()) / stddev(); - return -0.5 * (z_score * z_score) - std::log(stddev()); - } - /* * Up to constant addition, returns ad expression of log pdf */ - template - auto ad_log_pdf(const ad::Var& x, + template + auto ad_log_pdf(const ADVarType& x, const VecRefType& keys, - const VecADVarType& vars) const + const VecADVarType& vars, + size_t idx = 0) const { - auto ad_mean_expr = mean_.get_ad(keys, vars); - auto ad_stddev_expr = stddev_.get_ad(keys, vars); + auto ad_mean_expr = mean_.get_ad(keys, vars, idx); + auto ad_stddev_expr = stddev_.get_ad(keys, vars, idx); return ((ad::constant(-0.5) * ((x - ad_mean_expr) * (x - ad_mean_expr) / (ad_stddev_expr * ad_stddev_expr))) diff --git a/include/autoppl/expression/model/eq_node.hpp b/include/autoppl/expression/model/eq_node.hpp index dd353219..9b14569e 100644 --- a/include/autoppl/expression/model/eq_node.hpp +++ b/include/autoppl/expression/model/eq_node.hpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -68,11 +69,32 @@ struct EqNode : util::ModelExpr> auto ad_log_pdf(const VecRefType& keys, const VecADVarType& vars) const { - const void* addr = &orig_var_ref_.get(); - auto it = std::find(keys.begin(), keys.end(), addr); - assert(it != keys.end()); - size_t idx = std::distance(keys.begin(), it); - return dist_.ad_log_pdf(vars[idx], keys, vars); + // if parameter, find the corresponding variable + // in vars and return the AD log-pdf with this variable. + if constexpr (util::is_param_v) { + const void* addr = &orig_var_ref_.get(); + auto it = std::find(keys.begin(), keys.end(), addr); + assert(it != keys.end()); + size_t idx = std::distance(keys.begin(), it); + return dist_.ad_log_pdf(vars[idx], keys, vars); + } + + // if data, return sum of log_pdf where each element + // is a constant AD node containing each value of data. + // note: data is not copied at any point. + else if constexpr (util::is_data_v) { + const auto& var = this->get_variable(); + size_t idx = 0; + const size_t size = var.size(); + return ad::sum(var.begin(), var.end(), + [&, idx, size](auto value) mutable { + idx = idx % size; // may be important since mutable + auto&& expr = dist_.ad_log_pdf( + ad::constant(value), keys, vars, idx); + ++idx; + return expr; + }); + } } auto& get_variable() { return orig_var_ref_.get(); } diff --git a/include/autoppl/expression/variable/binop.hpp b/include/autoppl/expression/variable/binop.hpp index 226f9fe4..1a54311d 100644 --- a/include/autoppl/expression/variable/binop.hpp +++ b/include/autoppl/expression/variable/binop.hpp @@ -36,10 +36,11 @@ struct BinaryOpNode : */ template auto get_ad(const VecRefType& keys, - const VecADVarType& vars) const + const VecADVarType& vars, + size_t idx = 0) const { - return BinaryOp::evaluate(lhs_.get_ad(keys, vars), - rhs_.get_ad(keys, vars)); + return BinaryOp::evaluate(lhs_.get_ad(keys, vars, idx), + rhs_.get_ad(keys, vars, idx)); } private: diff --git a/include/autoppl/expression/variable/constant.hpp b/include/autoppl/expression/variable/constant.hpp index 0484bb2d..18f1ae96 100644 --- a/include/autoppl/expression/variable/constant.hpp +++ b/include/autoppl/expression/variable/constant.hpp @@ -23,7 +23,8 @@ struct Constant : util::VarExpr> */ template auto get_ad(const VecRefType&, - const VecADVarType&) const + const VecADVarType&, + size_t = 0) const { return ad::constant(c_); } private: diff --git a/include/autoppl/expression/variable/variable_viewer.hpp b/include/autoppl/expression/variable/variable_viewer.hpp index 35e328d1..5b7f9c83 100644 --- a/include/autoppl/expression/variable/variable_viewer.hpp +++ b/include/autoppl/expression/variable/variable_viewer.hpp @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include @@ -26,18 +27,25 @@ struct VariableViewer : util::VarExpr> size_t size() const { return var_ref_.get().size(); } /* - * Returns ad expression of the constant. - * Assumes that AD variables in vars will always be parameters. + * Returns ad expression of the variable. + * If variable is parameter, find from vars and return. + * Otherwise if data, return idx'th ad::constant of that value. */ template auto get_ad(const VecRefType& keys, - const VecADVarType& vars) const + const VecADVarType& vars, + size_t idx = 0) const { - const void* addr = &var_ref_.get(); - auto it = std::find(keys.begin(), keys.end(), addr); - assert(it != keys.end()); - size_t idx = std::distance(keys.begin(), it); - return vars[idx]; + if constexpr (util::is_param_v) { + static_cast(idx); + const void* addr = &var_ref_.get(); + auto it = std::find(keys.begin(), keys.end(), addr); + assert(it != keys.end()); + size_t i = std::distance(keys.begin(), it); + return vars[i]; + } else if constexpr (util::is_data_v) { + return ad::constant(this->get_value(idx)); + } } private: diff --git a/include/autoppl/util/dist_expr_traits.hpp b/include/autoppl/util/dist_expr_traits.hpp index 11195c46..fbc604db 100644 --- a/include/autoppl/util/dist_expr_traits.hpp +++ b/include/autoppl/util/dist_expr_traits.hpp @@ -22,7 +22,7 @@ struct DistExpr : BaseCRTP std::enable_if_t>, dist_value_t> log_pdf(const VarType& var) const { dist_value_t value = 0.0; - for (size_t i = 0; i < var.size(); i++) { + for (size_t i = 0; i < var.size(); ++i) { value += self().log_pdf(var.get_value(i), i); } @@ -33,7 +33,7 @@ struct DistExpr : BaseCRTP std::enable_if_t>, dist_value_t> pdf(const VarType& var) const { dist_value_t value = 1.0; - for (size_t i = 0; i < var.size(); i++) { + for (size_t i = 0; i < var.size(); ++i) { value *= self().pdf(var.get_value(i), i); } diff --git a/include/autoppl/variable.hpp b/include/autoppl/variable.hpp index 499c7cca..01d05276 100644 --- a/include/autoppl/variable.hpp +++ b/include/autoppl/variable.hpp @@ -88,6 +88,9 @@ struct Data : util::DataLike> void observe(value_t value) { values_.push_back(value); } void clear() { values_.clear(); } + auto begin() const { return values_.begin(); } + auto end() const { return values_.end(); } + private: std::vector values_; // store value associated with var }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 239fbec3..ee814ef1 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -149,6 +149,7 @@ add_test(algorithm_unittest algorithm_unittest) add_executable(expr_builder_unittest ${CMAKE_CURRENT_SOURCE_DIR}/expr_builder_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ad_integration_unittest.cpp ) target_compile_options(expr_builder_unittest PRIVATE -g -Wall -Werror -Wextra) target_include_directories(expr_builder_unittest PRIVATE diff --git a/test/ad_integration_unittest.cpp b/test/ad_integration_unittest.cpp new file mode 100644 index 00000000..7bef3505 --- /dev/null +++ b/test/ad_integration_unittest.cpp @@ -0,0 +1,95 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { + +struct ad_integration_fixture : ::testing::Test +{ +protected: + Data x{1., 2., 3.}, y{0., -1., 1.}; + Param theta; + std::array keys = {&theta}; + std::vector> vars; + + ad_integration_fixture() + : theta{} + , vars(1) + { + vars[0].set_value(1.); + } +}; + +TEST_F(ad_integration_fixture, ad_log_pdf_data_constant_param) +{ + auto model = (x |= normal(0., 1.)); + auto ad_expr = model.ad_log_pdf(keys, vars); + double value = ad::evaluate(ad_expr); + EXPECT_DOUBLE_EQ(value, -0.5 * 14); + value = ad::autodiff(ad_expr); // should not affect the result + EXPECT_DOUBLE_EQ(value, -0.5 * 14); +} + +TEST_F(ad_integration_fixture, ad_log_pdf_data_mean_param) +{ + auto model = ( + theta |= normal(0., 2.), + x |= normal(theta, 1.) + ); + auto ad_expr = model.ad_log_pdf(keys, vars); + + double value = ad::autodiff(ad_expr); + EXPECT_DOUBLE_EQ(value, -0.5 * 5 - 1./8 - std::log(2)); + EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 2.75); + + // after resetting adjoint, differentiating should not change anything + vars[0].reset_adjoint(); + + value = ad::autodiff(ad_expr); + EXPECT_DOUBLE_EQ(value, -0.5 * 5 - 1./8 - std::log(2)); + EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 2.75); +} + +TEST_F(ad_integration_fixture, ad_log_pdf_data_stddev_param) +{ + auto model = ( + theta |= normal(0., 2.), + x |= normal(0., theta) + ); + + auto ad_expr = model.ad_log_pdf(keys, vars); + + double value = ad::autodiff(ad_expr); + EXPECT_DOUBLE_EQ(value, -0.5 * 14 - 1./8 - std::log(2)); + EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 10.75); + + // after resetting adjoint, differentiating should not change anything + vars[0].reset_adjoint(); + + value = ad::autodiff(ad_expr); + EXPECT_DOUBLE_EQ(value, -0.5 * 14 - 1./8 - std::log(2)); + EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 10.75); +} + +TEST_F(ad_integration_fixture, ad_log_pdf_data_param_with_data) +{ + auto model = ( + theta |= normal(0., 1.), + y |= normal(theta * x, 1.) + ); + + auto ad_expr = model.ad_log_pdf(keys, vars); + + double value = ad::autodiff(ad_expr); + EXPECT_DOUBLE_EQ(value, -7.5); + EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), -14.); + + // after resetting adjoint, differentiating should not change anything + vars[0].reset_adjoint(); + + value = ad::autodiff(ad_expr); + EXPECT_DOUBLE_EQ(value, -7.5); + EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), -14.); +} + +} // namespace ppl diff --git a/test/algorithm/mh_regression_unittest.cpp b/test/algorithm/mh_regression_unittest.cpp index f7c596dd..b7fa38da 100644 --- a/test/algorithm/mh_regression_unittest.cpp +++ b/test/algorithm/mh_regression_unittest.cpp @@ -75,4 +75,15 @@ TEST_F(mh_regression_fixture, sample_regression_fuzzy_dist) { EXPECT_NEAR(sample_average(b_storage), 0.95, 0.1); } -} // ppl \ No newline at end of file +TEST_F(mh_regression_fixture, sample_regression_normal_weight) { + auto model = (w |= ppl::normal(0., 2.), + y |= ppl::normal(x * w + 1., 0.5)); + + ppl::mh_posterior(model, sample_size); + + plot_hist(w_storage, 0.2, 0., 1.); + + EXPECT_NEAR(sample_average(w_storage), 1.0, 0.1); +} + +} // ppl diff --git a/test/algorithm/nuts_unittest.cpp b/test/algorithm/nuts_unittest.cpp index 4c429981..6e225a65 100644 --- a/test/algorithm/nuts_unittest.cpp +++ b/test/algorithm/nuts_unittest.cpp @@ -7,12 +7,17 @@ namespace ppl { -struct nuts_fixture : ::testing::Test +struct nuts_tools_fixture : ::testing::Test { protected: + template + double sample_average(const VecType& v) + { + return std::accumulate(v.begin(), v.end(), 0.)/v.size(); + } }; -TEST_F(nuts_fixture, check_entropy_1d) +TEST_F(nuts_tools_fixture, check_entropy_1d) { using namespace alg; constexpr size_t n_params = 1; @@ -51,7 +56,7 @@ TEST_F(nuts_fixture, check_entropy_1d) } -TEST_F(nuts_fixture, check_entropy_3d) +TEST_F(nuts_tools_fixture, check_entropy_3d) { using namespace alg; constexpr size_t n_params = 3; @@ -99,7 +104,7 @@ TEST_F(nuts_fixture, check_entropy_3d) /* * Fixture just to test for this build_tree function */ -struct nuts_build_tree_fixture : ::testing::Test +struct nuts_build_tree_fixture : nuts_tools_fixture { protected: using ad_vec_t = std::vector>; @@ -144,12 +149,6 @@ struct nuts_build_tree_fixture : ::testing::Test // theta adjoint MUST be set theta_adj[0] = 0.; theta_adj[1] = 0.; theta_adj[2] = 0.; } - - template - double sample_average(const VecType& v) - { - return std::accumulate(v.begin(), v.end(), 0.)/v.size(); - } }; TEST_F(nuts_build_tree_fixture, build_tree_base_plus_no_opt_output) @@ -342,33 +341,89 @@ TEST_F(nuts_build_tree_fixture, find_reasonable_log_epsilon) ad_vars[1] * ad_vars[1] + ad_vars[2] * ad_vars[2] ) ; - double eps = alg::find_reasonable_epsilon(ad_expr, theta, theta_adj, 10000); + double eps = alg::find_reasonable_epsilon(ad_expr, theta, theta_adj); static_cast(eps); } -TEST_F(nuts_build_tree_fixture, nuts) +struct nuts_fixture : nuts_tools_fixture { - constexpr size_t n_samples = 10000; - constexpr size_t warmup = 10000; - constexpr size_t n_adapt = 1000; +protected: + size_t n_adapt = 1000; + size_t n_samples = 10000; + size_t warmup = 10000; double delta = 0.6; + size_t max_depth = 10; + size_t seed = 4821; + + std::vector w_storage, b_storage; + Param w, b; + ppl::Data x {2.5, 3, 3.5, 4, 4.5, 5.}; + ppl::Data y {3.5, 4, 4.5, 5, 5.5, 6.}; - std::vector> thetas(2); + nuts_fixture() + : w_storage(n_samples) + , b_storage(n_samples) + , w{w_storage.data()} + , b{b_storage.data()} + {} - std::vector samples_0(n_samples); - thetas[0].set_storage(samples_0.data()); + void reconfigure(size_t n) + { + w_storage.resize(n); + b_storage.resize(n); + w.set_storage(w_storage.data()); + b.set_storage(b_storage.data()); + } +}; +TEST_F(nuts_fixture, nuts_std_normal) +{ auto model = ( - thetas[0] |= normal(0., 1.) + w |= normal(0., 1.) + ); + + nuts(model, warmup, n_samples, n_adapt, seed, + max_depth, delta); + + plot_hist(w_storage); + EXPECT_NEAR(sample_average(w_storage), 0., 0.1); +} + +TEST_F(nuts_fixture, nuts_sample_regression_dist_weight) +{ + reconfigure(n_samples); + + auto model = (w |= normal(0., 2.), + y |= normal(x * w + 1., 0.5) + ); + + nuts(model, warmup, n_samples, n_adapt, seed, + max_depth, delta); + + plot_hist(w_storage, 0.1); + EXPECT_NEAR(sample_average(w_storage), 1.0, 0.1); +} + +TEST_F(nuts_fixture, nuts_sample_regression_dist_weight_bias) +{ + n_adapt = 1000; + n_samples = 1000; + warmup = 1000; + + reconfigure(n_samples); + + auto model = (b |= normal(0., 2.), + w |= normal(0., 2.), + y |= normal(x * w + b, 0.5) ); - size_t max_depth = 10; - size_t seed = 4821; nuts(model, warmup, n_samples, n_adapt, seed, max_depth, delta); - plot_hist(samples_0); - EXPECT_NEAR(sample_average(samples_0), 0., 0.1); + plot_hist(w_storage, 0.1); + plot_hist(b_storage); + EXPECT_NEAR(sample_average(w_storage), 1.0, 0.1); + EXPECT_NEAR(sample_average(b_storage), 1.0, 0.3); } } // namespace ppl diff --git a/test/algorithm/sampler_tools_unittest.cpp b/test/algorithm/sampler_tools_unittest.cpp index 82e3fd95..bde80a57 100644 --- a/test/algorithm/sampler_tools_unittest.cpp +++ b/test/algorithm/sampler_tools_unittest.cpp @@ -1,4 +1,5 @@ #include "gtest/gtest.h" +#include #include #include #include diff --git a/test/expression/variable/data_unittest.cpp b/test/expression/variable/data_unittest.cpp index 051029f4..bab1828b 100644 --- a/test/expression/variable/data_unittest.cpp +++ b/test/expression/variable/data_unittest.cpp @@ -29,6 +29,7 @@ TEST_F(data_fixture, test_multiple_value) { EXPECT_EQ(var1.get_value(1), 2.0); EXPECT_EQ(var1.get_value(2), 3.0); +#ifndef NDEBUG EXPECT_DEATH({ var2.get_value(1); }, ""); @@ -40,6 +41,7 @@ TEST_F(data_fixture, test_multiple_value) { EXPECT_DEATH({ var1.get_value(3); }, ""); +#endif var1.clear(); expected_size = 0;