Skip to content

Commit

Permalink
Add support for multiple samples for NUTS with AD expr
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Apr 28, 2020
1 parent 3c22ade commit f79505c
Show file tree
Hide file tree
Showing 14 changed files with 259 additions and 67 deletions.
23 changes: 10 additions & 13 deletions include/autoppl/algorithm/nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,7 @@ template <class ADExprType
, class MatType>
double find_reasonable_epsilon(ADExprType& ad_expr,
MatType& theta,
MatType& theta_adj,
size_t max_iter)
MatType& theta_adj)
{
double eps = 1.;
const double diff_bound = -std::log(2);
Expand Down Expand Up @@ -296,8 +295,7 @@ double find_reasonable_epsilon(ADExprType& ad_expr,

int a = 2*(ham_curr - ham_orig > diff_bound) - 1;

while ((a * (ham_curr - ham_orig) > a * diff_bound) &&
max_iter--) {
while ((a * (ham_curr - ham_orig) > a * diff_bound)) {

eps *= std::pow(2, a);

Expand Down Expand Up @@ -333,8 +331,8 @@ void nuts(ModelType& model,
size_t n_adapt,
size_t seed = 0,
size_t max_depth = 10,
double delta = 0.6,
size_t max_init_iter = 10
double delta = 0.6
//size_t max_init_iter = 10
)
{

Expand Down Expand Up @@ -397,26 +395,25 @@ void nuts(ModelType& model,

// initialize model tags using model specs
// copies the initialized values into theta_curr
// initialize potential energy
alg::init_params(model, gen);
auto theta_curr_it = theta_curr.begin();
double potential_prev = 0.;
auto copy_params_potential = [=, &potential_prev](const auto& eq_node) mutable {
auto copy_params_potential = [=](const auto& eq_node) mutable {
const auto& var = eq_node.get_variable();
const auto& dist = eq_node.get_distribution();
using var_t = std::decay_t<decltype(var)>;
if constexpr (util::is_param_v<var_t>) {
*theta_curr_it = var.get_value();
++theta_curr_it;
potential_prev += dist.log_pdf_no_constant(var.get_value());
}
};
model.traverse(copy_params_potential);

// initialize current potential (will be "previous" starting in for-loop)
double potential_prev = ad::evaluate(theta_curr_ad_expr);

// initialize rest of the metavariables
double log_eps = std::log(alg::find_reasonable_epsilon(
theta_curr_ad_expr, theta_curr, theta_curr_adj, max_init_iter));
const double mu = std::log(10) + log_eps;
theta_curr_ad_expr, theta_curr, theta_curr_adj));
const double mu = std::log(10.) + log_eps;

// tree output struct type
using subview_t = std::decay_t<decltype(rho_minus)>;
Expand Down
17 changes: 6 additions & 11 deletions include/autoppl/expression/distribution/normal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,17 @@ struct Normal : util::DistExpr<Normal<mean_type, stddev_type>>
return -0.5 * ((z_score * z_score) + std::log(stddev(index) * stddev(index) * 2 * M_PI));
}

dist_value_t log_pdf_no_constant(value_t x) const
{
dist_value_t z_score = (x - mean()) / stddev();
return -0.5 * (z_score * z_score) - std::log(stddev());
}

/*
* Up to constant addition, returns ad expression of log pdf
*/
template <class T, class VecRefType, class VecADVarType>
auto ad_log_pdf(const ad::Var<T>& x,
template <class ADVarType, class VecRefType, class VecADVarType>
auto ad_log_pdf(const ADVarType& x,
const VecRefType& keys,
const VecADVarType& vars) const
const VecADVarType& vars,
size_t idx = 0) const
{
auto ad_mean_expr = mean_.get_ad(keys, vars);
auto ad_stddev_expr = stddev_.get_ad(keys, vars);
auto ad_mean_expr = mean_.get_ad(keys, vars, idx);
auto ad_stddev_expr = stddev_.get_ad(keys, vars, idx);
return ((ad::constant(-0.5) *
((x - ad_mean_expr) * (x - ad_mean_expr) /
(ad_stddev_expr * ad_stddev_expr)))
Expand Down
32 changes: 27 additions & 5 deletions include/autoppl/expression/model/eq_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <algorithm>
#include <type_traits>
#include <functional>
#include <fastad>
#include <autoppl/util/var_traits.hpp>
#include <autoppl/util/model_expr_traits.hpp>
#include <autoppl/util/dist_expr_traits.hpp>
Expand Down Expand Up @@ -68,11 +69,32 @@ struct EqNode : util::ModelExpr<EqNode<VarType, DistType>>
auto ad_log_pdf(const VecRefType& keys,
const VecADVarType& vars) const
{
const void* addr = &orig_var_ref_.get();
auto it = std::find(keys.begin(), keys.end(), addr);
assert(it != keys.end());
size_t idx = std::distance(keys.begin(), it);
return dist_.ad_log_pdf(vars[idx], keys, vars);
// if parameter, find the corresponding variable
// in vars and return the AD log-pdf with this variable.
if constexpr (util::is_param_v<var_t>) {
const void* addr = &orig_var_ref_.get();
auto it = std::find(keys.begin(), keys.end(), addr);
assert(it != keys.end());
size_t idx = std::distance(keys.begin(), it);
return dist_.ad_log_pdf(vars[idx], keys, vars);
}

// if data, return sum of log_pdf where each element
// is a constant AD node containing each value of data.
// note: data is not copied at any point.
else if constexpr (util::is_data_v<var_t>) {
const auto& var = this->get_variable();
size_t idx = 0;
const size_t size = var.size();
return ad::sum(var.begin(), var.end(),
[&, idx, size](auto value) mutable {
idx = idx % size; // may be important since mutable
auto&& expr = dist_.ad_log_pdf(
ad::constant(value), keys, vars, idx);
++idx;
return expr;
});
}
}

auto& get_variable() { return orig_var_ref_.get(); }
Expand Down
7 changes: 4 additions & 3 deletions include/autoppl/expression/variable/binop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ struct BinaryOpNode :
*/
template <class VecRefType, class VecADVarType>
auto get_ad(const VecRefType& keys,
const VecADVarType& vars) const
const VecADVarType& vars,
size_t idx = 0) const
{
return BinaryOp::evaluate(lhs_.get_ad(keys, vars),
rhs_.get_ad(keys, vars));
return BinaryOp::evaluate(lhs_.get_ad(keys, vars, idx),
rhs_.get_ad(keys, vars, idx));
}

private:
Expand Down
3 changes: 2 additions & 1 deletion include/autoppl/expression/variable/constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ struct Constant : util::VarExpr<Constant<ValueType>>
*/
template <class VecRefType, class VecADVarType>
auto get_ad(const VecRefType&,
const VecADVarType&) const
const VecADVarType&,
size_t = 0) const
{ return ad::constant(c_); }

private:
Expand Down
24 changes: 16 additions & 8 deletions include/autoppl/expression/variable/variable_viewer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <algorithm>
#include <fastad>
#include <autoppl/util/var_traits.hpp>
#include <autoppl/util/var_expr_traits.hpp>

Expand All @@ -26,18 +27,25 @@ struct VariableViewer : util::VarExpr<VariableViewer<VariableType>>
size_t size() const { return var_ref_.get().size(); }

/*
* Returns ad expression of the constant.
* Assumes that AD variables in vars will always be parameters.
* Returns ad expression of the variable.
* If variable is parameter, find from vars and return.
* Otherwise if data, return idx'th ad::constant of that value.
*/
template <class VecRefType, class VecADVarType>
auto get_ad(const VecRefType& keys,
const VecADVarType& vars) const
const VecADVarType& vars,
size_t idx = 0) const
{
const void* addr = &var_ref_.get();
auto it = std::find(keys.begin(), keys.end(), addr);
assert(it != keys.end());
size_t idx = std::distance(keys.begin(), it);
return vars[idx];
if constexpr (util::is_param_v<var_t>) {
static_cast<void>(idx);
const void* addr = &var_ref_.get();
auto it = std::find(keys.begin(), keys.end(), addr);
assert(it != keys.end());
size_t i = std::distance(keys.begin(), it);
return vars[i];
} else if constexpr (util::is_data_v<var_t>) {
return ad::constant(this->get_value(idx));
}
}

private:
Expand Down
4 changes: 2 additions & 2 deletions include/autoppl/util/dist_expr_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct DistExpr : BaseCRTP<T>
std::enable_if_t<is_var_v<std::decay_t<VarType>>, dist_value_t>
log_pdf(const VarType& var) const {
dist_value_t value = 0.0;
for (size_t i = 0; i < var.size(); i++) {
for (size_t i = 0; i < var.size(); ++i) {
value += self().log_pdf(var.get_value(i), i);
}

Expand All @@ -33,7 +33,7 @@ struct DistExpr : BaseCRTP<T>
std::enable_if_t<is_var_v<std::decay_t<VarType>>, dist_value_t>
pdf(const VarType& var) const {
dist_value_t value = 1.0;
for (size_t i = 0; i < var.size(); i++) {
for (size_t i = 0; i < var.size(); ++i) {
value *= self().pdf(var.get_value(i), i);
}

Expand Down
3 changes: 3 additions & 0 deletions include/autoppl/variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ struct Data : util::DataLike<Data<ValueType>>
void observe(value_t value) { values_.push_back(value); }
void clear() { values_.clear(); }

auto begin() const { return values_.begin(); }
auto end() const { return values_.end(); }

private:
std::vector<value_t> values_; // store value associated with var
};
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ add_test(algorithm_unittest algorithm_unittest)

add_executable(expr_builder_unittest
${CMAKE_CURRENT_SOURCE_DIR}/expr_builder_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ad_integration_unittest.cpp
)
target_compile_options(expr_builder_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(expr_builder_unittest PRIVATE
Expand Down
95 changes: 95 additions & 0 deletions test/ad_integration_unittest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "gtest/gtest.h"
#include <array>
#include <autoppl/expr_builder.hpp>

namespace ppl {

struct ad_integration_fixture : ::testing::Test
{
protected:
Data<double> x{1., 2., 3.}, y{0., -1., 1.};
Param<double> theta;
std::array<const void*, 1> keys = {&theta};
std::vector<ad::Var<double>> vars;

ad_integration_fixture()
: theta{}
, vars(1)
{
vars[0].set_value(1.);
}
};

TEST_F(ad_integration_fixture, ad_log_pdf_data_constant_param)
{
auto model = (x |= normal(0., 1.));
auto ad_expr = model.ad_log_pdf(keys, vars);
double value = ad::evaluate(ad_expr);
EXPECT_DOUBLE_EQ(value, -0.5 * 14);
value = ad::autodiff(ad_expr); // should not affect the result
EXPECT_DOUBLE_EQ(value, -0.5 * 14);
}

TEST_F(ad_integration_fixture, ad_log_pdf_data_mean_param)
{
auto model = (
theta |= normal(0., 2.),
x |= normal(theta, 1.)
);
auto ad_expr = model.ad_log_pdf(keys, vars);

double value = ad::autodiff(ad_expr);
EXPECT_DOUBLE_EQ(value, -0.5 * 5 - 1./8 - std::log(2));
EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 2.75);

// after resetting adjoint, differentiating should not change anything
vars[0].reset_adjoint();

value = ad::autodiff(ad_expr);
EXPECT_DOUBLE_EQ(value, -0.5 * 5 - 1./8 - std::log(2));
EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 2.75);
}

TEST_F(ad_integration_fixture, ad_log_pdf_data_stddev_param)
{
auto model = (
theta |= normal(0., 2.),
x |= normal(0., theta)
);

auto ad_expr = model.ad_log_pdf(keys, vars);

double value = ad::autodiff(ad_expr);
EXPECT_DOUBLE_EQ(value, -0.5 * 14 - 1./8 - std::log(2));
EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 10.75);

// after resetting adjoint, differentiating should not change anything
vars[0].reset_adjoint();

value = ad::autodiff(ad_expr);
EXPECT_DOUBLE_EQ(value, -0.5 * 14 - 1./8 - std::log(2));
EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 10.75);
}

TEST_F(ad_integration_fixture, ad_log_pdf_data_param_with_data)
{
auto model = (
theta |= normal(0., 1.),
y |= normal(theta * x, 1.)
);

auto ad_expr = model.ad_log_pdf(keys, vars);

double value = ad::autodiff(ad_expr);
EXPECT_DOUBLE_EQ(value, -7.5);
EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), -14.);

// after resetting adjoint, differentiating should not change anything
vars[0].reset_adjoint();

value = ad::autodiff(ad_expr);
EXPECT_DOUBLE_EQ(value, -7.5);
EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), -14.);
}

} // namespace ppl
13 changes: 12 additions & 1 deletion test/algorithm/mh_regression_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,15 @@ TEST_F(mh_regression_fixture, sample_regression_fuzzy_dist) {
EXPECT_NEAR(sample_average(b_storage), 0.95, 0.1);
}

} // ppl
TEST_F(mh_regression_fixture, sample_regression_normal_weight) {
auto model = (w |= ppl::normal(0., 2.),
y |= ppl::normal(x * w + 1., 0.5));

ppl::mh_posterior(model, sample_size);

plot_hist(w_storage, 0.2, 0., 1.);

EXPECT_NEAR(sample_average(w_storage), 1.0, 0.1);
}

} // ppl

0 comments on commit f79505c

Please sign in to comment.