Skip to content

Commit

Permalink
Optimize to use fixed matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Apr 28, 2020
1 parent f79505c commit 070d9a4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
25 changes: 12 additions & 13 deletions include/autoppl/algorithm/nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ bool check_entropy(const MatType& theta_plus,
*
* Note that the caller MUST have input theta_adj already pre-computed.
*/
template <class InputType
template <size_t n_params
, class InputType
, class OutputType
, class UniformDistType = std::uniform_real_distribution<double>
, class GenType = std::mt19937
Expand Down Expand Up @@ -173,7 +174,6 @@ void build_tree(InputType& input,
}

// recursion
const size_t n_params = arma::size(input.theta_ref.get())[0];
arma::mat pm(0, 2);
OutputType first_output = output; // first recursive output

Expand All @@ -192,7 +192,7 @@ void build_tree(InputType& input,
first_output.opt_rho_ref = rho;
}

build_tree(input, first_output, depth - 1);
build_tree<n_params>(input, first_output, depth - 1);

// ham way below threshold of delta_max: early finish
// simply copy first output's values into caller's output.
Expand All @@ -204,14 +204,14 @@ void build_tree(InputType& input,
// second recursion with same input from original caller.
// This time, we don't have any other pm to update.
// Need a new theta_prime storage though.
arma::mat theta_tmp(n_params, 1);
arma::mat::fixed<n_params, 1> theta_tmp;
auto theta_double_prime = theta_tmp.unsafe_col(0);
OutputType second_output = output;
second_output.opt_theta_ref.reset();
second_output.opt_rho_ref.reset();
second_output.theta_prime_ref = theta_double_prime;

build_tree(input, second_output, depth - 1);
build_tree<n_params>(input, second_output, depth - 1);

// accept with n''/(n' + n'') probability
// if accepting, also copy over potential from second output
Expand Down Expand Up @@ -259,21 +259,21 @@ void build_tree(InputType& input,
* Finds a reasonable epsilon for NUTS algorithm.
* @param ad_expr AD expression bound to theta and theta_adj
*/
template <class ADExprType
template <size_t n_params
, class ADExprType
, class MatType>
double find_reasonable_epsilon(ADExprType& ad_expr,
MatType& theta,
MatType& theta_adj)
{
double eps = 1.;
const double diff_bound = -std::log(2);
const size_t n_params = arma::size(theta)[0];

arma::mat r_mat(n_params, 2);
arma::mat::fixed<n_params, 2> r_mat;
auto r = r_mat.unsafe_col(0);
auto r_orig = r_mat.unsafe_col(1);

arma::mat theta_mat(n_params, 2);
arma::mat::fixed<n_params, 2> theta_mat;
auto theta_orig = theta_mat.unsafe_col(0);
auto theta_adj_orig = theta_mat.unsafe_col(1);

Expand Down Expand Up @@ -332,7 +332,6 @@ void nuts(ModelType& model,
size_t seed = 0,
size_t max_depth = 10,
double delta = 0.6
//size_t max_init_iter = 10
)
{

Expand Down Expand Up @@ -411,7 +410,7 @@ void nuts(ModelType& model,
double potential_prev = ad::evaluate(theta_curr_ad_expr);

// initialize rest of the metavariables
double log_eps = std::log(alg::find_reasonable_epsilon(
double log_eps = std::log(alg::find_reasonable_epsilon<n_params>(
theta_curr_ad_expr, theta_curr, theta_curr_adj));
const double mu = std::log(10.) + log_eps;

Expand Down Expand Up @@ -450,15 +449,15 @@ void nuts(ModelType& model,
);
output.opt_theta_ref.reset();
output.opt_rho_ref.reset();
alg::build_tree(input, output, j, metrop_sampler, gen);
alg::build_tree<n_params>(input, output, j, metrop_sampler, gen);
} else {
auto input = alg::TreeInput(
theta_plus_ad_expr, theta_plus, theta_plus_adj, rho_plus,
log_u, v, std::exp(log_eps), ham_prev
);
output.opt_theta_ref.reset();
output.opt_rho_ref.reset();
alg::build_tree(input, output, j, metrop_sampler, gen);
alg::build_tree<n_params>(input, output, j, metrop_sampler, gen);
}

if (output.s) {
Expand Down
10 changes: 5 additions & 5 deletions test/algorithm/nuts_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ TEST_F(nuts_build_tree_fixture, build_tree_base_plus_no_opt_output)
epsilon, ham
);

build_tree(input, output, 0);
build_tree<3>(input, output, 0);

// output optional theta/rho still unset
EXPECT_FALSE(output.opt_theta_ref.has_value());
Expand Down Expand Up @@ -220,7 +220,7 @@ TEST_F(nuts_build_tree_fixture, build_tree_base_plus_opt_output)
output.opt_theta_ref = opt_theta;
output.opt_rho_ref = opt_rho;

build_tree(input, output, 0);
build_tree<3>(input, output, 0);

// optional theta and rho are the same as input ones
EXPECT_DOUBLE_EQ(opt_theta[0], theta[0]);
Expand Down Expand Up @@ -253,7 +253,7 @@ TEST_F(nuts_build_tree_fixture, build_tree_base_plus_no_opt_output_2)
epsilon, ham
);

build_tree(input, output, 0);
build_tree<3>(input, output, 0);

// input theta properly updated
EXPECT_DOUBLE_EQ(theta[0], 4.);
Expand Down Expand Up @@ -302,7 +302,7 @@ TEST_F(nuts_build_tree_fixture, build_tree_recursion_plus_no_opt_output)

// custom uniform distribution will always accept candidate
// except when optimized for n'' == 0 in the recursion
build_tree(input, output, 1, [](const auto&) {return 0;});
build_tree<3>(input, output, 1, [](const auto&) {return 0;});

// input theta properly updated
EXPECT_DOUBLE_EQ(theta[0], 4.);
Expand Down Expand Up @@ -341,7 +341,7 @@ TEST_F(nuts_build_tree_fixture, find_reasonable_log_epsilon)
ad_vars[1] * ad_vars[1] +
ad_vars[2] * ad_vars[2]
) ;
double eps = alg::find_reasonable_epsilon(ad_expr, theta, theta_adj);
double eps = alg::find_reasonable_epsilon<3>(ad_expr, theta, theta_adj);
static_cast<void>(eps);
}

Expand Down

0 comments on commit 070d9a4

Please sign in to comment.