Skip to content

Commit

Permalink
Merge 04fdb31 into 8e514fa
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobaustin123 committed May 10, 2020
2 parents 8e514fa + 04fdb31 commit 2ab7dad
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
Binary file added docs/design/design-overview.pdf
Binary file not shown.
11 changes: 10 additions & 1 deletion include/autoppl/algorithm/mh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <chrono>
#include <random>
#include <algorithm>
#include <iostream>
#include <vector>
#include <array>
#include <variant>
Expand Down Expand Up @@ -64,7 +65,13 @@ inline void mh__(ModelType& model,
std::uniform_real_distribution unif_sampler(0., 1.);

for (size_t iter = 0; iter < n_sample + warmup; ++iter) {

if (iter % ((n_sample + warmup) / 100) == 0) {
int percent = static_cast<int>(static_cast<double>(iter) / (static_cast<double>(n_sample + warmup) / 100));
std::cout << '\r' << "MetropolisHastings: [" << std::string(percent, '=') <<
std::string(100 - percent, ' ') << "] (" <<
std::setw(2) << percent << "%)" << std::flush;
}

size_t n_swaps = 0; // during candidate sampling, if sample out-of-bounds,
// traversal will prematurely return and n_swaps < n_params
bool early_reject = false; // indicate early sample reject
Expand Down Expand Up @@ -170,6 +177,8 @@ inline void mh__(ModelType& model,
// update current log pdf for next iteration
if (accept) curr_log_pdf = cand_log_pdf;
}

std::cout << std::endl;
}

} // namespace alg
Expand Down
9 changes: 9 additions & 0 deletions include/autoppl/algorithm/nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <stack>
#include <optional>
#include <type_traits>
#include <iostream>
#include <armadillo>
#include <fastad>
#include <autoppl/util/var_traits.hpp>
Expand Down Expand Up @@ -424,6 +425,12 @@ void nuts(ModelType& model,
using tree_output_t = alg::TreeOutput<subview_t>;

for (size_t i = 0; i < n_samples + warmup; ++i) {
if (i % ((n_samples + warmup) / 100) == 0) {
int percent = static_cast<int>(static_cast<double>(i) / (static_cast<double>(n_samples + warmup) / 100));
std::cout << '\r' << "NUTS: [" << std::string(percent, '=') <<
std::string(100 - percent, ' ') << "] (" <<
std::setw(2) << percent << "%)" << std::flush;
}

// re-initialize vectors
theta_plus = theta_minus = theta_curr;
Expand Down Expand Up @@ -507,6 +514,8 @@ void nuts(ModelType& model,
}

} // end for

std::cout << std::endl;
}

} // namespace ppl

0 comments on commit 2ab7dad

Please sign in to comment.