Skip to content

Commit

Permalink
Add dot product, cached AD vars, stronger concepts
Browse files Browse the repository at this point in the history
- More type safety has been added in traits
- Model ad_log_pdfs and variable expression to_ad member functions
  allow users to pass in cache in case these expressions need them.
- Dot product finally implemented (performance is still great)
- NUTS creates long cache vector (the extra memory for adj and values
  may not be necessary)
- more unittests with dot product
  • Loading branch information
JamesYang007 committed Jul 14, 2020
1 parent 79cc3c1 commit a79fee1
Show file tree
Hide file tree
Showing 46 changed files with 1,225 additions and 218 deletions.
99 changes: 34 additions & 65 deletions benchmark/regression_autoppl.cpp
Original file line number Diff line number Diff line change
@@ -1,75 +1,43 @@
#include <chrono>
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <array>
#include <sstream>
#include <unordered_map>

#include <autoppl/expression/variable/data.hpp>
#include <autoppl/expression/variable/param.hpp>
#include <autoppl/expression/expr_builder.hpp>
#include <autoppl/mcmc/hmc/nuts/nuts.hpp>

#include "benchmark_utils.hpp"

#include <benchmark/benchmark.h>

namespace ppl {

template <class ArrayType>
inline double stddev(const ArrayType& v)
{
double mean = std::accumulate(v.begin(), v.end(), 0.)/v.size();
double var = 0.;
for (auto x : v) {
auto diff = (x - mean);
var += diff * diff;
}
return std::sqrt(var/(v.size()));
}

static void BM_Regression(benchmark::State& state) {
size_t num_samples = state.range(0);

std::array<std::string, 4> headers = {"Life expectancy", "Alcohol", "HIV/AIDS", "GDP"};

std::unordered_map<std::string, ppl::Data<double, ppl::vec>> data;
std::unordered_map<std::string, ppl::Param<double>> params;
std::array<std::vector<double>, 4> storage;

// Read in data
std::fstream fin;
fin.open("life-clean.csv", std::ios::in);
std::string line;
double value;
while (std::getline(fin, line, '\n')) {
auto it = headers.begin();
std::stringstream s(line);
while (s >> value) {
data[*it].push_back(value);
++it;
}
}

// resize each storage and bind with param
int i = 0;
for (auto it = headers.begin(); it != headers.end(); ++it, ++i) {
storage[i].resize(num_samples);
params[*it].storage() = storage[i].data();
// load data
std::string datapath = "life-clean.csv";
arma::mat data;
data.load(datapath);
arma::mat X_data = data.tail_cols(data.n_cols-1);
arma::vec y_data = data.col(0); // life expectancy

// create data and param tags
auto X = ppl::make_data_view<ppl::mat>(X_data);
auto y = ppl::make_data_view<ppl::vec>(y_data);
ppl::Param<double, ppl::vec> w(3);
ppl::Param<double> b;

// create and bind sample storage
arma::mat storage(num_samples, 4);

for (size_t i = 0; i < w.size(); ++i) {
w.storage(i) = storage.colptr(i);
}

auto model = (params["Alcohol"] |= ppl::normal(0., 5.),
params["HIV/AIDS"] |= ppl::normal(0., 5.),
params["GDP"] |= ppl::normal(0., 5.),
params["Life expectancy"] |= ppl::normal(0., 5.),

data["Life expectancy"] |= ppl::normal(
params["Alcohol"] * data["Alcohol"] +
params["HIV/AIDS"] * data["HIV/AIDS"] +
params["GDP"] * data["GDP"] + params["Life expectancy"], 5.0));
b.storage() = storage.colptr(w.size());

// define model
auto model = (b |= ppl::normal(0., 5.),
w |= ppl::normal(0., 5.),
y |= ppl::normal(ppl::dot(X, w) + b, 5.));

// perform NUTS sampling
NUTSConfig<> config = {
.warmup = num_samples,
.n_samples = num_samples
Expand All @@ -78,15 +46,16 @@ static void BM_Regression(benchmark::State& state) {
ppl::nuts(model, config);
}

std::cout << "Bias: " << sample_average(storage[0]) << std::endl;
std::cout << "Alcohol w: " << sample_average(storage[1]) << std::endl;
std::cout << "HIV/AIDS w: " << sample_average(storage[2]) << std::endl;
std::cout << "GDP: " << sample_average(storage[3]) << std::endl;
// print mean and stddev results
std::cout << "Bias: " << arma::mean(storage.col(3)) << std::endl;
std::cout << "Alcohol w: " << arma::mean(storage.col(0)) << std::endl;
std::cout << "HIV/AIDS w: " << arma::mean(storage.col(1)) << std::endl;
std::cout << "GDP: " << arma::mean(storage.col(2)) << std::endl;

std::cout << "Bias: " << stddev(storage[0]) << std::endl;
std::cout << "Alcohol w: " << stddev(storage[1]) << std::endl;
std::cout << "HIV/AIDS w: " << stddev(storage[2]) << std::endl;
std::cout << "GDP: " << stddev(storage[3]) << std::endl;
std::cout << "Bias: " << arma::stddev(storage.col(3)) << std::endl;
std::cout << "Alcohol w: " << arma::stddev(storage.col(0)) << std::endl;
std::cout << "HIV/AIDS w: " << arma::stddev(storage.col(1)) << std::endl;
std::cout << "GDP: " << arma::stddev(storage.col(2)) << std::endl;
}

BENCHMARK(BM_Regression)->Arg(100)->Arg(500)->Arg(1000)->Arg(5000)->Arg(10000)->Arg(50000)->Arg(100000);
Expand Down
18 changes: 18 additions & 0 deletions include/autoppl/expression/distribution/bernoulli.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct Bernoulli : util::DistExprBase<Bernoulli<PType>>
using value_t = util::disc_param_t;
using param_value_t = typename util::var_expr_traits<PType>::value_t;
using base_t = util::DistExprBase<Bernoulli<PType>>;
using index_t = uint32_t;
using typename base_t::dist_value_t;

Bernoulli(const PType& p)
Expand Down Expand Up @@ -106,6 +107,23 @@ struct Bernoulli : util::DistExprBase<Bernoulli<PType>>
}, x.size());
}


// Bernoulli doesn't need to support this function,
// but for concepts, we put a dummy body.
template <class VarType, class VecADVarType>
auto ad_log_pdf(const VarType&,
const VecADVarType&,
const VecADVarType&) const
{
return ad::constant(math::neg_inf<dist_value_t>);
}

index_t set_cache_offset(index_t idx)
{
idx = p_.set_cache_offset(idx);
return idx;
}

template <class PVecType
, class F = util::identity>
value_t min(const PVecType&,
Expand Down
33 changes: 21 additions & 12 deletions include/autoppl/expression/distribution/normal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct Normal:

using value_t = util::cont_param_t;
using base_t = util::DistExprBase<Normal<MeanType, SDType>>;
using index_t = uint32_t;
using typename base_t::dist_value_t;

Normal(const MeanType& mean,
Expand Down Expand Up @@ -148,7 +149,8 @@ struct Normal:
*/
template <class VarType, class VecADVarType>
auto ad_log_pdf(const VarType& x,
const VecADVarType& ad_vars) const
const VecADVarType& ad_vars,
const VecADVarType& cache) const
{
static_assert(util::is_var_v<VarType>);
static_assert(details::normal_valid_dim_v<VarType, MeanType, SDType>,
Expand All @@ -159,9 +161,9 @@ struct Normal:
util::is_scl_v<MeanType> &&
util::is_scl_v<SDType>)
{
auto&& ad_x = x.to_ad(ad_vars);
auto&& ad_mean = mean_.to_ad(ad_vars);
auto&& ad_sd = sd_.to_ad(ad_vars);
auto&& ad_x = x.to_ad(ad_vars, cache);
auto&& ad_mean = mean_.to_ad(ad_vars, cache);
auto&& ad_sd = sd_.to_ad(ad_vars, cache);

// Subcase 1: sd -> has no param
if constexpr (!SDType::has_param) {
Expand Down Expand Up @@ -203,8 +205,8 @@ struct Normal:
util::is_scl_v<SDType>)
{
size_t x_size = x.size();
auto&& ad_mean = mean_.to_ad(ad_vars);
auto&& ad_sd = sd_.to_ad(ad_vars);
auto&& ad_mean = mean_.to_ad(ad_vars, cache);
auto&& ad_sd = sd_.to_ad(ad_vars, cache);

// Subcase 1: x -> has param
if constexpr (VarType::has_param) {
Expand All @@ -214,7 +216,7 @@ struct Normal:
* ad::sum(util::counting_iterator<size_t>(0),
util::counting_iterator<size_t>(x_size),
[&](size_t i) {
return ad::pow<2>(x.to_ad(ad_vars, i) - ad_mean);
return ad::pow<2>(x.to_ad(ad_vars, cache, i) - ad_mean);
})
- (ad::constant<dist_value_t>(x_size) * ad::log(ad_sd)),
ad::constant(math::neg_inf<dist_value_t>)
Expand All @@ -227,12 +229,12 @@ struct Normal:
auto sample_mean = ad::sum(util::counting_iterator<size_t>(0),
util::counting_iterator<size_t>(x_size),
[&](size_t i) {
return x.to_ad(ad_vars, i);
return x.to_ad(ad_vars, cache, i);
}) / ad::constant<dist_value_t>(x_size);
auto sample_variance = ad::sum(util::counting_iterator<size_t>(0),
util::counting_iterator<size_t>(x_size),
[&](size_t i) {
return ad::pow<2>(x.to_ad(ad_vars, i) - sample_mean);
return ad::pow<2>(x.to_ad(ad_vars, cache, i) - sample_mean);
}) / ad::constant<dist_value_t>(x_size);
return ad::if_else(
ad_sd > ad::constant(0.),
Expand All @@ -251,15 +253,15 @@ struct Normal:
{
assert(x.size() == mean_.size());
size_t x_size = x.size();
auto&& ad_sd = sd_.to_ad(ad_vars);
auto&& ad_sd = sd_.to_ad(ad_vars, cache);
return ad::if_else(
ad_sd > ad::constant(0.),
(ad::constant(-0.5) / ad::pow<2>(ad_sd))
* ad::sum(util::counting_iterator<size_t>(0),
util::counting_iterator<size_t>(x_size),
[&](size_t i) {
return ad::pow<2>(x.to_ad(ad_vars, i)
- mean_.to_ad(ad_vars, i));
return ad::pow<2>(x.to_ad(ad_vars, cache, i)
- mean_.to_ad(ad_vars, cache, i));
})
- (ad::constant<dist_value_t>(x_size) * ad::log(ad_sd)),
ad::constant(math::neg_inf<dist_value_t>)
Expand All @@ -282,6 +284,13 @@ struct Normal:
F = F()) const
{ return math::inf<value_t>; }

index_t set_cache_offset(index_t idx)
{
idx = mean_.set_cache_offset(idx);
idx = sd_.set_cache_offset(idx);
return idx;
}

private:
MeanType mean_;
SDType sd_;
Expand Down
25 changes: 17 additions & 8 deletions include/autoppl/expression/distribution/uniform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ struct Uniform: util::DistExprBase<Uniform<MinType, MaxType>>

using value_t = util::cont_param_t;
using base_t = util::DistExprBase<Uniform<MinType, MaxType>>;
using index_t = uint32_t;
using typename base_t::dist_value_t;

Uniform(const MinType& min,
Expand Down Expand Up @@ -125,15 +126,16 @@ struct Uniform: util::DistExprBase<Uniform<MinType, MaxType>>
*/
template <class VarType, class VecADVarType>
auto ad_log_pdf(const VarType& x,
const VecADVarType& vars) const
const VecADVarType& vars,
const VecADVarType& cache) const
{
// Case 1: x -> vec, min -> scl, max -> scl
if constexpr (util::is_vec_v<VarType> &&
util::is_scl_v<MinType> &&
util::is_scl_v<MaxType>)
{
auto&& ad_min = min_.to_ad(vars);
auto&& ad_max = max_.to_ad(vars);
auto&& ad_min = min_.to_ad(vars, cache);
auto&& ad_max = max_.to_ad(vars, cache);

// Subcase 1: x -> has no param
if constexpr (!VarType::has_param) {
Expand Down Expand Up @@ -165,8 +167,8 @@ struct Uniform: util::DistExprBase<Uniform<MinType, MaxType>>
util::counting_iterator<>(x.size()),
[&](auto i) {
return ad::if_else(
( (ad_min < x.to_ad(vars, i)) &&
(x.to_ad(vars, i) < ad_max) ),
( (ad_min < x.to_ad(vars, cache, i)) &&
(x.to_ad(vars, cache, i) < ad_max) ),
ad::constant<dist_value_t>(0),
ad::constant(math::neg_inf<dist_value_t>)
);
Expand All @@ -180,9 +182,9 @@ struct Uniform: util::DistExprBase<Uniform<MinType, MaxType>>
return ad::sum(util::counting_iterator<>(0),
util::counting_iterator<>(x.size()),
[&](auto i) {
auto&& ad_x = x.to_ad(vars, i);
auto&& ad_min = min_.to_ad(vars, i);
auto&& ad_max = max_.to_ad(vars, i);
auto&& ad_x = x.to_ad(vars, cache, i);
auto&& ad_min = min_.to_ad(vars, cache, i);
auto&& ad_max = max_.to_ad(vars, cache, i);
return ad::if_else(
(ad_min < ad_x) && (ad_x < ad_max),
-ad::log(ad_max - ad_min),
Expand All @@ -206,6 +208,13 @@ struct Uniform: util::DistExprBase<Uniform<MinType, MaxType>>
F f = F()) const
{ return max_.value(pvalues, i, f); }

index_t set_cache_offset(index_t idx)
{
idx = min_.set_cache_offset(idx);
idx = max_.set_cache_offset(idx);
return idx;
}

private:
MinType min_;
MaxType max_;
Expand Down
13 changes: 13 additions & 0 deletions include/autoppl/expression/expr_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <autoppl/expression/variable/data.hpp>
#include <autoppl/expression/variable/param.hpp>
#include <autoppl/expression/variable/binop.hpp>
#include <autoppl/expression/variable/dot.hpp>
#include <autoppl/expression/distribution/uniform.hpp>
#include <autoppl/expression/distribution/normal.hpp>
#include <autoppl/expression/distribution/bernoulli.hpp>
Expand Down Expand Up @@ -333,4 +334,16 @@ inline constexpr auto operator/(LHSType&& lhs, RHSType&& rhs)
std::forward<RHSType>(rhs));
}

/**
* Builds a dot product expression for two expressions.
*/
template <class LHSVarExprType
, class RHSVarExprType>
inline constexpr auto dot(const LHSVarExprType& lhs,
const RHSVarExprType& rhs)
{
return expr::DotNode<LHSVarExprType,
RHSVarExprType>(lhs, rhs);
}

} // namespace ppl
9 changes: 7 additions & 2 deletions include/autoppl/expression/model/eq_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct EqNode: util::ModelExprBase<EqNode<VarType, DistType>>
util::dist_expr_traits<dist_t>::is_disc_v),
PPL_VAR_DIST_CONT_DISC_MATCH);

using dist_value_t = typename
util::dist_expr_traits<dist_t>::dist_value_t;

EqNode(const var_t& var,
const dist_t& dist) noexcept
: var_{var}
Expand Down Expand Up @@ -86,11 +89,13 @@ struct EqNode: util::ModelExprBase<EqNode<VarType, DistType>>
* @param ad_vars container of AD variables that correspond to parameters.
*/
template <class VecADVarType>
auto ad_log_pdf(const VecADVarType& ad_vars) const
{ return dist_.ad_log_pdf(get_variable(), ad_vars); }
auto ad_log_pdf(const VecADVarType& ad_vars,
const VecADVarType& cache) const
{ return dist_.ad_log_pdf(get_variable(), ad_vars, cache); }

var_t& get_variable() { return var_; }
const var_t& get_variable() const { return var_; }
dist_t& get_distribution() { return dist_; }
const dist_t& get_distribution() const { return dist_; }

private:
Expand Down
Loading

0 comments on commit a79fee1

Please sign in to comment.