diff --git a/include/autoppl/algorithm/mh.hpp b/include/autoppl/algorithm/mh.hpp index aa552434..94af11e1 100644 --- a/include/autoppl/algorithm/mh.hpp +++ b/include/autoppl/algorithm/mh.hpp @@ -1,14 +1,15 @@ #pragma once -#include -#include #include -#include -#include #include -#include +#include +#include #include #include -#include +#include +#include +#include +#include +#include /* * Assumptions: @@ -64,13 +65,10 @@ inline void mh__(ModelType& model, { std::uniform_real_distribution unif_sampler(0., 1.); + auto logger = util::ProgressLogger(n_sample + warmup, "MetropolisHastings"); + for (size_t iter = 0; iter < n_sample + warmup; ++iter) { - if (iter % ((n_sample + warmup) / 100) == 0) { - int percent = static_cast(static_cast(iter) / (static_cast(n_sample + warmup) / 100)); - std::cout << '\r' << "MetropolisHastings: [" << std::string(percent, '=') << - std::string(100 - percent, ' ') << "] (" << - std::setw(2) << percent << "%)" << std::flush; - } + logger.printProgress(iter); size_t n_swaps = 0; // during candidate sampling, if sample out-of-bounds, // traversal will prematurely return and n_swaps < n_params diff --git a/include/autoppl/algorithm/nuts.hpp b/include/autoppl/algorithm/nuts.hpp index 92d57f61..b90ebe93 100644 --- a/include/autoppl/algorithm/nuts.hpp +++ b/include/autoppl/algorithm/nuts.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -424,13 +425,10 @@ void nuts(ModelType& model, using subview_t = std::decay_t; using tree_output_t = alg::TreeOutput; + auto logger = util::ProgressLogger(n_samples + warmup, "NUTS"); + for (size_t i = 0; i < n_samples + warmup; ++i) { - if (i % ((n_samples + warmup) / 100) == 0) { - int percent = static_cast(static_cast(i) / (static_cast(n_samples + warmup) / 100)); - std::cout << '\r' << "NUTS: [" << std::string(percent, '=') << - std::string(100 - percent, ' ') << "] (" << - std::setw(2) << percent << "%)" << std::flush; - } + logger.printProgress(i); // re-initialize vectors theta_plus = theta_minus = theta_curr; @@ -513,7 +511,7 @@ void nuts(ModelType& model, model.traverse(store_sample); } - } // end for + } // end for std::cout << std::endl; } diff --git a/include/autoppl/util/logging.hpp b/include/autoppl/util/logging.hpp new file mode 100644 index 00000000..9e03e4d9 --- /dev/null +++ b/include/autoppl/util/logging.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +namespace ppl { + +namespace util { + +struct ProgressLogger { + ProgressLogger(size_t max, const std::string & name) : _max(max), _name(name) {}; + + void printProgress(size_t step) { + if (step % (_max / 100) == 0) { + int percent = static_cast(static_cast(step) / (static_cast(_max) / 100)); + std::cout << '\r' << _name << " [" << std::string(percent, '=') << std::string(100 - percent, ' ') << "] (" << std::setw(2) << percent << "%)" << std::flush; + } + } + +private: + size_t _max; + std::string _name; +}; + + +} // namespace util + +} // namespace ppl \ No newline at end of file