Skip to content

Commit

Permalink
moved logging code to util::logging.hpp
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobaustin123 committed May 13, 2020
1 parent 04fdb31 commit 4884112
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 19 deletions.
22 changes: 10 additions & 12 deletions include/autoppl/algorithm/mh.hpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#pragma once
#include <chrono>
#include <random>
#include <algorithm>
#include <iostream>
#include <vector>
#include <array>
#include <variant>
#include <autoppl/algorithm/sampler_tools.hpp>
#include <autoppl/util/logging.hpp>
#include <autoppl/util/traits.hpp>
#include <autoppl/variable.hpp>
#include <autoppl/algorithm/sampler_tools.hpp>
#include <chrono>
#include <iostream>
#include <random>
#include <variant>
#include <vector>

/*
* Assumptions:
Expand Down Expand Up @@ -64,13 +65,10 @@ inline void mh__(ModelType& model,
{
std::uniform_real_distribution unif_sampler(0., 1.);

auto logger = util::ProgressLogger(n_sample + warmup, "MetropolisHastings");

for (size_t iter = 0; iter < n_sample + warmup; ++iter) {
if (iter % ((n_sample + warmup) / 100) == 0) {
int percent = static_cast<int>(static_cast<double>(iter) / (static_cast<double>(n_sample + warmup) / 100));
std::cout << '\r' << "MetropolisHastings: [" << std::string(percent, '=') <<
std::string(100 - percent, ' ') << "] (" <<
std::setw(2) << percent << "%)" << std::flush;
}
logger.printProgress(iter);

size_t n_swaps = 0; // during candidate sampling, if sample out-of-bounds,
// traversal will prematurely return and n_swaps < n_params
Expand Down
12 changes: 5 additions & 7 deletions include/autoppl/algorithm/nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <armadillo>
#include <fastad>
#include <autoppl/util/var_traits.hpp>
#include <autoppl/util/logging.hpp>
#include <autoppl/expression/model/glue_node.hpp>
#include <autoppl/expression/model/eq_node.hpp>
#include <autoppl/algorithm/sampler_tools.hpp>
Expand Down Expand Up @@ -424,13 +425,10 @@ void nuts(ModelType& model,
using subview_t = std::decay_t<decltype(rho_minus)>;
using tree_output_t = alg::TreeOutput<subview_t>;

auto logger = util::ProgressLogger(n_samples + warmup, "NUTS");

for (size_t i = 0; i < n_samples + warmup; ++i) {
if (i % ((n_samples + warmup) / 100) == 0) {
int percent = static_cast<int>(static_cast<double>(i) / (static_cast<double>(n_samples + warmup) / 100));
std::cout << '\r' << "NUTS: [" << std::string(percent, '=') <<
std::string(100 - percent, ' ') << "] (" <<
std::setw(2) << percent << "%)" << std::flush;
}
logger.printProgress(i);

// re-initialize vectors
theta_plus = theta_minus = theta_curr;
Expand Down Expand Up @@ -513,7 +511,7 @@ void nuts(ModelType& model,
model.traverse(store_sample);
}

} // end for
} // end for

std::cout << std::endl;
}
Expand Down
28 changes: 28 additions & 0 deletions include/autoppl/util/logging.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <string>
#include <iostream>

namespace ppl {

namespace util {

struct ProgressLogger {
ProgressLogger(size_t max, const std::string & name) : _max(max), _name(name) {};

void printProgress(size_t step) {
if (step % (_max / 100) == 0) {
int percent = static_cast<int>(static_cast<double>(step) / (static_cast<double>(_max) / 100));
std::cout << '\r' << _name << " [" << std::string(percent, '=') << std::string(100 - percent, ' ') << "] (" << std::setw(2) << percent << "%)" << std::flush;
}
}

private:
size_t _max;
std::string _name;
};


} // namespace util

} // namespace ppl

0 comments on commit 4884112

Please sign in to comment.