Skip to content

Commit

Permalink
Add overload of ess for 1 chain case
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Jul 17, 2020
1 parent 6874b33 commit 0851b29
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
6 changes: 1 addition & 5 deletions benchmark/regression_autoppl.cpp
Expand Up @@ -50,11 +50,7 @@ static void BM_Regression(benchmark::State& state) {
ppl::nuts(model, config);
}

arma::cube out(storage.n_rows,
storage.n_cols,
1);
out.slice(0) = storage;
arma::vec ess_res = math::ess(out);
arma::vec ess_res = math::ess(storage);
ess_res.print("ESS");

// print mean and stddev results
Expand Down
36 changes: 34 additions & 2 deletions include/autoppl/math/ess.hpp
Expand Up @@ -12,12 +12,14 @@ namespace math {
* every row is a sample of an n-dimensional vector, where n
* is the number of columns of the matrix.
*
* If number of samples is 0
*
* @tparam T underlying data type
* @param samples sample cube
*
* @return a vector of ESS for each component
* If number of samples is 1 or less, or there are 0 components,
* or number of chains is 0, return a vector of zeros.
* In either case, the dimension of the return vector is same
* as the number of components.
*/
template <class T>
inline arma::Col<T> ess(const arma::Cube<T>& samples)
Expand Down Expand Up @@ -104,5 +106,35 @@ inline arma::Col<T> ess(const arma::Cube<T>& samples)
return N*M*arma::clamp(n_eff, n_eff.min(), std::log10(N));
}

/**
* Computes the effective sample size (ESS) for a given sample matrix.
* This is an overload for when there is only chain and can supply
* a single matrix instead.
* See above overload for more details.
* Note that we take in by non-const lvalue reference to
* fit the API for armadillo in ensuring there is no copy of data.
* However, this function does not modify samples.
* It simply makes a cube viewing this matrix and delegates the call
* to the overload above, which takes in a const reference.
*
* @tparam T underlying data type
* @param samples sample matrix
*
* @return a vector of ESS for each component
*/
template <class T>
inline arma::Col<T> ess(arma::Mat<T>& samples)
{
size_t n_rows = samples.n_rows;
size_t n_cols = samples.n_cols;
size_t n_slices = 1;
arma::Cube<T> cubed(samples.memptr(), n_rows,
n_cols, n_slices,
false,
true);
return ess(cubed);
}


} // namespace math
} // namespace ppl

0 comments on commit 0851b29

Please sign in to comment.