diff --git a/benchmark/normal_two_prior_distribution.cpp b/benchmark/normal_two_prior_distribution.cpp index bf121667..8588de17 100644 --- a/benchmark/normal_two_prior_distribution.cpp +++ b/benchmark/normal_two_prior_distribution.cpp @@ -1,7 +1,11 @@ #include +#include #include -#include #include "benchmark_utils.hpp" +#include +#include +#include +#include namespace ppl { @@ -11,7 +15,7 @@ static void BM_NormalTwoPrior(benchmark::State& state) { std::normal_distribution n(0.0, 1.0); std::mt19937 gen(0); - ppl::Data y; + ppl::Data y; ppl::Param lambda1, lambda2, sigma; auto model = ( @@ -22,7 +26,7 @@ static void BM_NormalTwoPrior(benchmark::State& state) { ); for (size_t i = 0; i < n_data; ++i) { - y.observe(n(gen)); + y.push_back(n(gen)); } std::array l1_storage, l2_storage, s_storage; @@ -30,8 +34,12 @@ static void BM_NormalTwoPrior(benchmark::State& state) { lambda2.set_storage(l2_storage.data()); sigma.set_storage(s_storage.data()); + ppl::NUTSConfig<> config; + config.n_samples = n_samples; + config.warmup = n_samples; + for (auto _ : state) { - ppl::nuts(model); + ppl::nuts(model, config); } std::cout << "l1: " << sample_average(l1_storage) << std::endl; diff --git a/benchmark/normal_two_prior_distribution_stan.py b/benchmark/normal_two_prior_distribution_stan.py index 99f7da46..129b96f8 100644 --- a/benchmark/normal_two_prior_distribution_stan.py +++ b/benchmark/normal_two_prior_distribution_stan.py @@ -7,7 +7,7 @@ stan_file = 'normal_two_prior_distribution_stan.stan' sm = CmdStanModel(stan_file=stan_file) -fit = sm.sample(data=cool_dat, chains=4, cores=1, +fit = sm.sample(data=cool_dat, chains=1, cores=1, iter_warmup=1000, iter_sampling=1000, thin=1, max_treedepth=10, metric='diag', adapt_engaged=True, output_dir='.') diff --git a/benchmark/regression_autoppl.cpp b/benchmark/regression_autoppl.cpp index 35892e58..512bfabc 100644 --- a/benchmark/regression_autoppl.cpp +++ b/benchmark/regression_autoppl.cpp @@ -7,8 +7,9 @@ #include #include -#include -#include +#include +#include +#include #include #include "benchmark_utils.hpp" @@ -34,7 +35,7 @@ static void BM_Regression(benchmark::State& state) { std::array headers = {"Life expectancy", "Alcohol", "HIV/AIDS", "GDP"}; - std::unordered_map> data; + std::unordered_map> data; std::unordered_map> params; std::array, 4> storage; @@ -47,7 +48,7 @@ static void BM_Regression(benchmark::State& state) { auto it = headers.begin(); std::stringstream s(line); while (s >> value) { - data[*it].observe(value); + data[*it].push_back(value); ++it; } } diff --git a/benchmark/regression_autoppl_2.cpp b/benchmark/regression_autoppl_2.cpp index d6f8158f..90e52437 100644 --- a/benchmark/regression_autoppl_2.cpp +++ b/benchmark/regression_autoppl_2.cpp @@ -7,8 +7,9 @@ #include #include -#include -#include +#include +#include +#include #include #include "benchmark_utils.hpp" @@ -23,7 +24,7 @@ static void BM_Regression(benchmark::State& state) { std::array headers = {"b", "x1", "x2", "x3"}; - std::unordered_map> data; + std::unordered_map> data; std::unordered_map> params; std::array, 4> storage; @@ -37,10 +38,10 @@ static void BM_Regression(benchmark::State& state) { double x1 = n1(gen); double x2 = n2(gen); double x3 = n3(gen); - data[headers[1]].observe(x1); - data[headers[2]].observe(x2); - data[headers[3]].observe(x3); - data["y"].observe(x1 * 1.4 + x2 * 2. + x3 * 0.32 + eps(gen)); + data[headers[1]].push_back(x1); + data[headers[2]].push_back(x2); + data[headers[3]].push_back(x3); + data["y"].push_back(x1 * 1.4 + x2 * 2. + x3 * 0.32 + eps(gen)); } // resize each storage and bind with param diff --git a/docs/design/README.md b/docs/design/README.md new file mode 100644 index 00000000..add4ddb3 --- /dev/null +++ b/docs/design/README.md @@ -0,0 +1,157 @@ +# Design Overview + +## Example + +```cpp +DataView, ppl::vec> x(raw_x); +// Data x({...}); // another option +Param l1; +ParamFixed l2; +// Param l2(3); // another option +auto model = ( + l1 |= normal(0., 1.), + l2 |= normal(l1, 2.), + x |= normal(l2[0] * l2[1] - l2[2], 1.) +); +l1.storage(ptr); +l2.storage(ptr, i); +ppl::nuts(model); +``` + +- `l1` is a scalar that is standard normally distributed +- `l2` is a vector of size 3 that is each independently ~ N(l1, 2) +- `x` is a vector of data ~ N(l2[0]*l2[1]-l2[2], 1.) + - `l2` is subscriptable + +## Variable + +A variable really is only satisfied by Param, ParamView, Data, DataView, or alike. +Every first variable has a unique ID or views a unique ID. +This is so that we have a way to know which variable that gets referenced +in the model is pointing to the "same" entity. +This can be useful when checking correct construction of model such as: +- no variable gets assigned a distribution more than once +- no variable gets assigned a distribution, which references the same variable +- no distribution uses variables that reference variables below it + +### Param + +A Param should be a variable expression and also a variable. +The model will only be built using ParamView since Param may own values +that the model should only view. + +If Param is multi-dimensional (vec, mat), size of the shape must be known +at construction and cannot change. +The model may reference old size values if changed. +Logically, a parameter denoted by a symbol was defined from fathoming a model. +If it is immediately used in a different model, it's most likely that the parameter +represents the same kind of quantity, but assigned to a different distribution. + +## Concepts + +### model_expr + +Implements: + +```cpp +template +void traverse(F&& elt_f); // + const version + +template +void traverse(F1&& elt_f, F2&& combine_f); // + const version + +/*...*/ pdf() const; +/*...*/ log_pdf() const; + +template +/*...*/ ad_log_pdf(const MapType& map, + const VecType& vars) const; +``` + +- map is expected to be a hashmap of: + ``` + addresses of unique parameters (const void*) -> + begin idx of corresponding vector of vars + ``` +- Ex. + ``` + (mu |= normal(0,1), s |= normal(0,1), x |= normal(mu, s)) + addr(mu) -> 0 + addr(s) -> 1 + AD Var vec: [v1, v2] + ``` + +## Expression Nodes + +The core of AutoPPL is how we construct expressions. +These expressions and their interaction define a language to express model construction. + +#### Glue Node + +``` +glue_node = (model_expr, model_expr); +``` + +##### Sketch of Interface + +```cpp +struct GlueNode +{ + traverse(elt_f) + traverse(elt_f, combine_f) + pdf() + log_pdf() + ad_log_pdf(map, vars) +}; +``` + +Example: + +```cpp +// apply log_pdf to get and add them all +double lgpdf = model.traverse(log_pdf, add); + +// apply ad_log_pdf to get AD expr and add them all +// if ad_log_pdf or add requires extra parameters, lambdafy them: +// [&](auto& elt) {return ad_log_pdf(elt, other_params...);} +auto ad_expr = model.traverse(ad_log_pdf, add); + +// get each "unique quantity" and add them to the mapping +model.traverse(update_map); +``` + +#### Eq Node + +``` +eq_node = (quantity_expr |= dist_expr); +``` + +An eq expression relates a quantity with a distribution. +While the arguments can be generalized further, +we're most motivated by the example when quantity is a parameter/data +of either variable/vector/mat (vvm) form and dist_expr is one such as normal distribution. + +##### Sketch of Interface + +```cpp +struct EqNode +{ + traverse(eq_f); + traverse(eq_f, combine_f); + pdf(); + log_pdf(); + ad_log_pdf(map, vars); + get_variable(); + get_distribution(); +}; +``` + +- map is the mapping of addresses of params/data to corresponding + index of a vector of AD vectors. + - Ex. + ``` + mu |= normal(0,1), x |= normal(mu, 1) + addr(mu) -> 0 + addr(x) -> 1 + AD Var vec: [v1, v2] + ``` diff --git a/docs/example/normal_posterior_mean_stddev.cpp b/docs/example/normal_posterior_mean_stddev.cpp index c351fb9e..edd8c22f 100644 --- a/docs/example/normal_posterior_mean_stddev.cpp +++ b/docs/example/normal_posterior_mean_stddev.cpp @@ -5,7 +5,7 @@ int main() { std::array mu_samples, sigma_samples; - ppl::Data x {1.0, 1.5, 1.7, 1.2, 1.5}; + ppl::Data x {1.0, 1.5, 1.7, 1.2, 1.5}; ppl::Param mu {mu_samples.data()}; ppl::Param sigma {sigma_samples.data()}; diff --git a/include/autoppl/autoppl.hpp b/include/autoppl/autoppl.hpp index 1a58b9ab..38a183f8 100644 --- a/include/autoppl/autoppl.hpp +++ b/include/autoppl/autoppl.hpp @@ -5,11 +5,10 @@ #include "expression/distribution/normal.hpp" #include "expression/model/eq_node.hpp" #include "expression/model/glue_node.hpp" -#include "expression/model/model_utils.hpp" #include "expression/variable/binop.hpp" +#include "expression/variable/data.hpp" +#include "expression/variable/param.hpp" #include "expression/variable/constant.hpp" -#include "expression/variable/variable_viewer.hpp" +#include "expression/expr_builder.hpp" #include "mcmc/mh.hpp" #include "mcmc/hmc/nuts/nuts.hpp" -#include "expr_builder.hpp" -#include "variable.hpp" diff --git a/include/autoppl/expression/distribution/bernoulli.hpp b/include/autoppl/expression/distribution/bernoulli.hpp index 3050ba9f..7e33e870 100644 --- a/include/autoppl/expression/distribution/bernoulli.hpp +++ b/include/autoppl/expression/distribution/bernoulli.hpp @@ -1,71 +1,125 @@ #pragma once #include #include -#include -#include +#include +#include +#include +#include +#include + +#define PPL_BERNOULLI_PARAM_DIM \ + "Bernoulli distribution probability must either be a scalar or vector. " \ namespace ppl { namespace expr { +namespace details { + +/** + * Checks whether prob has proper dimensions. + * Must be proper shape and cannot be matrix. + */ +template +struct bern_valid_param_dim +{ + static constexpr bool value = + util::is_shape_v && + !util::is_mat_v; +}; -#if __cplusplus <= 201703L -template -#else -template -#endif -struct Bernoulli : util::DistExpr> +/** + * Checks if var, prob have proper relative dimensions. + * Currently, we only allow up to vector dimension (no matrix). + */ +template +struct bern_valid_dim { + static constexpr bool value = + util::is_shape_v && + ( + (util::is_scl_v && + util::is_scl_v) || + (util::is_vec_v && + bern_valid_param_dim::value) + ); +}; + +template +inline constexpr bool bern_valid_param_dim_v = + bern_valid_param_dim::value; + +template +inline constexpr bool bern_valid_dim_v = + bern_valid_dim::value; -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_expr_v); -#endif +} // namespace details + +template +struct Bernoulli : util::DistExprBase> +{ + static_assert(util::is_var_expr_v); + static_assert(details::bern_valid_param_dim_v, + PPL_DIST_DIM_MISMATCH + PPL_BERNOULLI_PARAM_DIM + ); using value_t = util::disc_param_t; - using param_value_t = typename util::var_expr_traits::value_t; - using base_t = util::DistExpr>; - using dist_value_t = typename base_t::dist_value_t; - using base_t::pdf; - using base_t::log_pdf; + using param_value_t = typename util::var_expr_traits::value_t; + using base_t = util::DistExprBase>; + using typename base_t::dist_value_t; - Bernoulli(p_type p) + // TODO: const ref? + Bernoulli(PType p) : p_{p} {} - template - value_t sample(GeneratorType& gen) const + template + dist_value_t pdf(const VarType& x, + const PVecType& pvalues) const { - std::bernoulli_distribution dist(p()); - return dist(gen); + static_assert(util::is_var_v); + static_assert(details::bern_valid_dim_v, + PPL_DIST_DIM_MISMATCH); + return pdf_indep([&](size_t i) { + return math::bernoulli_pdf( + x.value(pvalues, i), + p_.value(pvalues, i)); + }, x.size()); } - dist_value_t pdf(value_t x, size_t index=0) const + template + dist_value_t log_pdf(const VarType& x, + const PVecType& pvalues, + F f = F()) const { - if (x == 1) return p(index); - else if (x == 0) return 1. - p(); - else return 0.0; + static_assert(util::is_var_v); + static_assert(details::bern_valid_dim_v, + PPL_DIST_DIM_MISMATCH); + return pdf_indep([&](size_t i) { + return math::bernoulli_log_pdf( + x.value(pvalues, i, f), + p_.value(pvalues, i, f)); + }, x.size()); } - dist_value_t log_pdf(value_t x, size_t index=0) const - { - if (x == 1) return std::log(p(index)); - else if (x == 0) return std::log(1. - p(index)); - else return std::numeric_limits::lowest(); - } - - param_value_t p(size_t index=0) const - { - return std::max( - std::min( - p_.get_value(index), - static_cast(max()) - ), - static_cast(min()) - ); - } + template + value_t min(const PVecType&, + size_t=0, + F = F()) const + { return 0; } - value_t min() const { return 0; } - value_t max() const { return 1; } + template + value_t max(const PVecType&, + size_t=0, + F = F()) const + { return 1; } private: - p_type p_; + PType p_; }; } // namespace expr diff --git a/include/autoppl/expression/distribution/dist_utils.hpp b/include/autoppl/expression/distribution/dist_utils.hpp new file mode 100644 index 00000000..4af62020 --- /dev/null +++ b/include/autoppl/expression/distribution/dist_utils.hpp @@ -0,0 +1,49 @@ +#pragma once +#include + +#define PPL_DIST_DIM_MISMATCH \ + "Unsupported variable and/or distribution parameter dimensions. " +#define PPL_PDF_INVOCABLE \ + "Log-pdf and pdf functors must be invocable with a single size_t argument. " + +namespace ppl { +namespace expr { + +/** + * Computes joint log pdf defined by size number of independent variables. + * log_pdf(i) computes the log pdf of ith variable. + */ +template +inline constexpr auto log_pdf_indep(LogPDFType&& log_pdf, + size_t size) +{ + static_assert(std::is_invocable_v, + PPL_PDF_INVOCABLE); + using dist_value_t = std::decay_t; + dist_value_t value = 0.0; + for (size_t i = 0ul; i < size; ++i) { + value += log_pdf(i); + } + return value; +} + +/** + * Computes joint pdf defined by size number of independent variables. + * pdf(i) computes the pdf of ith variable. + */ +template +inline constexpr auto pdf_indep(PDFType&& pdf, + size_t size) +{ + static_assert(std::is_invocable_v, + PPL_PDF_INVOCABLE); + using dist_value_t = std::decay_t; + dist_value_t value = 1.0; + for (size_t i = 0ul; i < size; ++i) { + value *= pdf(i); + } + return value; +} + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/expression/distribution/normal.hpp b/include/autoppl/expression/distribution/normal.hpp index 5f41050f..b1c60e8e 100644 --- a/include/autoppl/expression/distribution/normal.hpp +++ b/include/autoppl/expression/distribution/normal.hpp @@ -1,88 +1,289 @@ #pragma once #include #include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include +#include -// MSVC does not seem to support M_PI -#ifndef M_PI -#define M_PI 3.14159265358979323846 -#endif +#define PPL_NORMAL_PARAM_DIM \ + "Normal distribution mean must either be a scalar or vector " \ + "and standard deviation must be scalar. " namespace ppl { namespace expr { +namespace details { -#if __cplusplus <= 201703L -template -#else -template -#endif -struct Normal : util::DistExpr> +/** + * Checks case 1 of whether mean, and sd have proper relative dimensions. + * Case 1: mean, sd are all scalars. + */ +template +struct normal_valid_param_dim_case_1 { + static constexpr bool value = + util::is_shape_v && + util::is_shape_v && + util::is_scl_v && + util::is_scl_v; +}; + +/** + * Checks case 2 of whether mean, and sd have proper relative dimensions. + * Case 2: mean, sd are both non-matrices. + */ +template +struct normal_valid_param_dim_case_2 +{ + static constexpr bool value = + util::is_shape_v && + util::is_shape_v && + !util::is_mat_v && + util::is_scl_v; +}; + +/** + * Checks if var, mean, and sd have proper relative dimensions. + * Currently, we only allow up to vector dimension (no matrix). + */ +template +struct normal_valid_dim +{ + static constexpr bool value = + util::is_shape_v && + ( + (util::is_scl_v && + normal_valid_param_dim_case_1::value) || + (util::is_vec_v && + normal_valid_param_dim_case_2::value) + ); +}; + +template +inline constexpr bool normal_valid_param_dim_case_1_v = + normal_valid_param_dim_case_1::value; + +template +inline constexpr bool normal_valid_param_dim_case_2_v = + normal_valid_param_dim_case_2::value; -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_expr_v); - static_assert(util::assert_is_var_expr_v); -#endif +template +inline constexpr bool normal_valid_dim_v = + normal_valid_dim::value; + +} // namespace details + +template +struct Normal: + util::DistExprBase> +{ + static_assert(util::is_var_expr_v); + static_assert(util::is_var_expr_v); + static_assert(details::normal_valid_param_dim_case_2_v, + PPL_DIST_DIM_MISMATCH + PPL_NORMAL_PARAM_DIM + ); using value_t = util::cont_param_t; - using base_t = util::DistExpr>; - using dist_value_t = typename base_t::dist_value_t; - using base_t::pdf; - using base_t::log_pdf; + using base_t = util::DistExprBase>; + using typename base_t::dist_value_t; - Normal(mean_type mean, stddev_type stddev) - : mean_{mean}, stddev_{stddev} + // TODO: const ref? + Normal(MeanType mean, SDType sd) + : mean_{mean}, sd_{sd} {} - template - value_t sample(GeneratorType& gen) const { - std::normal_distribution dist(mean(), stddev()); - return dist(gen); + // TODO: size check on x, mean, sd? + template + dist_value_t pdf(const VarType& x, + const PVecType& pvalues) const + { + static_assert(util::is_var_v); + static_assert(details::normal_valid_dim_v, + PPL_DIST_DIM_MISMATCH); + return pdf_indep([&](size_t i) { + return math::normal_pdf( + x.value(pvalues, i), + mean_.value(pvalues, i), + sd_.value(pvalues, i)); + }, x.size()); } - dist_value_t pdf(value_t x, size_t index=0) const + // TODO: size check on x, mean, sd? + template + dist_value_t log_pdf(const VarType& x, + const PVecType& pvalues, + F f = F()) const { - dist_value_t z_score = (x - mean(index)) / stddev(index); - return std::exp(- 0.5 * z_score * z_score) / (stddev(index) * std::sqrt(2 * M_PI)); - } - - dist_value_t log_pdf(value_t x, size_t index=0) const - { - dist_value_t z_score = (x - mean(index)) / stddev(index); - return -0.5 * ((z_score * z_score) + std::log(stddev(index) * stddev(index) * 2 * M_PI)); + static_assert(util::is_var_v); + static_assert(details::normal_valid_dim_v, + PPL_DIST_DIM_MISMATCH); + return log_pdf_indep([&](size_t i) { + return math::normal_log_pdf( + x.value(pvalues, i, f), + mean_.value(pvalues, i, f), + sd_.value(pvalues, i, f)); + }, x.size()); } /** - * Up to constant addition, returns ad expression of log pdf + * Up to constant addition, returns AD expression of log pdf. + * TODO: save mean and sd in separate variable? */ - template - auto ad_log_pdf(const ADVarType& x, - const VecRefType& keys, - const VecADVarType& vars, - size_t idx = 0) const + template + auto ad_log_pdf(const VarType& x, + const VecADVarType& ad_vars) const { - auto&& ad_mean_expr = mean_.get_ad(keys, vars, idx); - auto&& ad_stddev_expr = stddev_.get_ad(keys, vars, idx); - return ad::if_else( - ad_stddev_expr > ad::constant(0.), - (ad::constant(-0.5) * - ad::pow<2>((x - ad_mean_expr) / ad_stddev_expr)) - - ad::log(ad_stddev_expr), - ad::constant(std::numeric_limits::lowest()) - ); + static_assert(util::is_var_v); + static_assert(details::normal_valid_dim_v, + PPL_DIST_DIM_MISMATCH); + + // Case 1: x -> scalar, mean -> scalar, sd -> scalar + if constexpr (util::is_scl_v && + util::is_scl_v && + util::is_scl_v) + { + auto&& ad_x = x.to_ad(ad_vars); + auto&& ad_mean = mean_.to_ad(ad_vars); + auto&& ad_sd = sd_.to_ad(ad_vars); + + // Subcase 1: sd -> has no param + if constexpr (!SDType::has_param) { + return ad::if_else( + ad_sd > ad::constant(0.), + ( (ad::constant(-0.5) / ad::pow<2>(ad_sd)) * + ad::pow<2>(ad_x - ad_mean) ) + - ad::log(ad_sd), + ad::constant(math::neg_inf) + ); + } + + // Subcase 2: x -> has param or mean -> has param, sd -> has param + else if constexpr (VarType::has_param || MeanType::has_param) { + return ad::if_else( + ad_sd > ad::constant(0.), + (ad::constant(-0.5) * + ad::pow<2>( (ad_x - ad_mean) / ad_sd )) + - ad::log(ad_sd), + ad::constant(math::neg_inf) + ); + } + + // Subcase 3: x-> has no param, mean -> has no param, sd -> has param + else { + return ad::if_else( + ad_sd > ad::constant(0.), + ( ad::constant(-0.5) * ad::pow<2>(ad_x - ad_mean) ) + / ad::pow<2>(ad_sd) + - ad::log(ad_sd), + ad::constant(math::neg_inf) + ); + } + } + + // Case 2: x -> vec, mean -> scalar, sd -> scalar + else if constexpr (util::is_vec_v && + util::is_scl_v && + util::is_scl_v) + { + size_t x_size = x.size(); + auto&& ad_mean = mean_.to_ad(ad_vars); + auto&& ad_sd = sd_.to_ad(ad_vars); + + // Subcase 1: x -> has param + if constexpr (VarType::has_param) { + return ad::if_else( + ad_sd > ad::constant(0.), + (ad::constant(-0.5) / ad::pow<2>(ad_sd)) + * ad::sum(util::counting_iterator(0), + util::counting_iterator(x_size), + [&](size_t i) { + return ad::pow<2>(x.to_ad(ad_vars, i) - ad_mean); + }) + - (ad::constant(x_size) * ad::log(ad_sd)), + ad::constant(math::neg_inf) + ); + } + + // Subcase 2: x -> has no param + // Note: this is HUGE optimization here + else { + auto sample_mean = ad::sum(util::counting_iterator(0), + util::counting_iterator(x_size), + [&](size_t i) { + return x.to_ad(ad_vars, i); + }) / ad::constant(x_size); + auto sample_variance = ad::sum(util::counting_iterator(0), + util::counting_iterator(x_size), + [&](size_t i) { + return ad::pow<2>(x.to_ad(ad_vars, i) - sample_mean); + }) / ad::constant(x_size); + return ad::if_else( + ad_sd > ad::constant(0.), + (ad::constant(-0.5 * x_size) / ad::pow<2>(ad_sd)) + * ( ad::pow<2>(ad_mean - sample_mean) + sample_variance ) + - ( ad::constant(x_size) * ad::log(ad_sd) ), + ad::constant(math::neg_inf) + ); + } + } + + // Case 3: x -> vector, mean -> vector, sd -> scalar + else if constexpr (util::is_vec_v && + util::is_vec_v && + util::is_scl_v) + { + assert(x.size() == mean_.size()); + size_t x_size = x.size(); + auto&& ad_sd = sd_.to_ad(ad_vars); + return ad::if_else( + ad_sd > ad::constant(0.), + (ad::constant(-0.5) / ad::pow<2>(ad_sd)) + * ad::sum(util::counting_iterator(0), + util::counting_iterator(x_size), + [&](size_t i) { + return ad::pow<2>(x.to_ad(ad_vars, i) + - mean_.to_ad(ad_vars, i)); + }) + - (ad::constant(x_size) * ad::log(ad_sd)), + ad::constant(math::neg_inf) + ); + } } + + template + value_t min(const PVecType&, + size_t=0, + F = F()) const + { return math::neg_inf; } - auto mean(size_t index=0) const { return mean_.get_value(index);} - auto stddev(size_t index=0) const { return stddev_.get_value(index);} - value_t min() const { return std::numeric_limits::lowest(); } - value_t max() const { return std::numeric_limits::max(); } + + template + value_t max(const PVecType&, + size_t=0, + F = F()) const + { return math::inf; } private: - mean_type mean_; // TODO enforce that these are at least descended from a Param class. - stddev_type stddev_; + MeanType mean_; // TODO enforce that these are at least descended from a Param class. + SDType sd_; }; } // namespace expr diff --git a/include/autoppl/expression/distribution/uniform.hpp b/include/autoppl/expression/distribution/uniform.hpp index 83db4ca2..05dee30d 100644 --- a/include/autoppl/expression/distribution/uniform.hpp +++ b/include/autoppl/expression/distribution/uniform.hpp @@ -1,80 +1,211 @@ #pragma once #include -#include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include + +#define PPL_UNIFORM_PARAM_DIM \ + "Uniform parameters min and max must be either scalar or vector. " namespace ppl { namespace expr { +namespace details { + +/** + * Checks whether min, max have proper relative dimensions. + * Must be proper shapes and cannot be matrices. + */ +template +struct uniform_valid_param_dim +{ + static constexpr bool value = + util::is_shape_v && + util::is_shape_v && + !util::is_mat_v && + !util::is_mat_v; +}; -#if __cplusplus <= 201703L -template -#else -template -#endif -struct Uniform : util::DistExpr> +/** + * Checks if var, min, max have proper relative dimensions. + * Currently, we only allow up to vector dimension (no matrix). + */ +template +struct uniform_valid_dim { + static constexpr bool value = + util::is_shape_v && + ( + (util::is_scl_v && + util::is_scl_v && + util::is_scl_v) || + (util::is_vec_v && + uniform_valid_param_dim::value) + ); +}; + +template +inline constexpr bool uniform_valid_param_dim_v = + uniform_valid_param_dim::value; + +template +inline constexpr bool uniform_valid_dim_v = + uniform_valid_dim::value; + +} // namespace details -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_expr_v); - static_assert(util::assert_is_var_expr_v); -#endif +template +struct Uniform: util::DistExprBase> +{ + static_assert(util::is_var_expr_v); + static_assert(util::is_var_expr_v); + static_assert(details::uniform_valid_param_dim_v, + PPL_DIST_DIM_MISMATCH + PPL_UNIFORM_PARAM_DIM + ); using value_t = util::cont_param_t; - using base_t = util::DistExpr>; - using dist_value_t = typename base_t::dist_value_t; - using base_t::pdf; - using base_t::log_pdf; + using base_t = util::DistExprBase>; + using typename base_t::dist_value_t; - Uniform(min_type min, max_type max) + // TODO: const ref? + Uniform(MinType min, MaxType max) : min_{min}, max_{max} {} - // TODO: tag this class as "TriviallySamplable"? - template - value_t sample(GeneratorType& gen) const - { - std::uniform_real_distribution dist(min(), max()); - return dist(gen); - } - - dist_value_t pdf(value_t x, size_t index=0) const + // TODO: size check on x, mean, sd? + template + dist_value_t pdf(const VarType& x, + const PVecType& pvalues) const { - return (min(index) < x && x < max(index)) ? 1. / (max(index) - min(index)) : 0; + static_assert(util::is_var_v); + static_assert(details::uniform_valid_dim_v, + PPL_DIST_DIM_MISMATCH); + return pdf_indep([&](size_t i) { + return math::uniform_pdf( + x.value(pvalues, i), + min_.value(pvalues, i), + max_.value(pvalues, i)); + }, x.size()); } - dist_value_t log_pdf(value_t x, size_t index=0) const + // TODO: size check on x, mean, sd? + template + dist_value_t log_pdf(const VarType& x, + const PVecType& pvalues, + F f = F()) const { - return (min(index) < x && x < max(index)) ? - -std::log(max(index) - min(index)) : - std::numeric_limits::lowest(); + static_assert(util::is_var_v); + static_assert(details::uniform_valid_dim_v, + PPL_DIST_DIM_MISMATCH); + return log_pdf_indep([&](size_t i) { + return math::uniform_log_pdf( + x.value(pvalues, i, f), + min_.value(pvalues, i, f), + max_.value(pvalues, i, f)); + }, x.size()); } /** * Up to constant addition, returns ad expression of log pdf */ - template - auto ad_log_pdf(const ADVarType& x, - const VecRefType& keys, - const VecADVarType& vars, - size_t idx = 0) const + template + auto ad_log_pdf(const VarType& x, + const VecADVarType& vars) const { - auto&& ad_min_expr = min_.get_ad(keys, vars, idx); - auto&& ad_max_expr = max_.get_ad(keys, vars, idx); - return ad::if_else( - ((ad_min_expr < x) && (x < ad_max_expr)), - -ad::log(ad_max_expr - ad_min_expr), - ad::constant(std::numeric_limits::lowest()) - ); + // Case 1: x -> vec, min -> scl, max -> scl + if constexpr (util::is_vec_v && + util::is_scl_v && + util::is_scl_v) + { + auto&& ad_min = min_.to_ad(vars); + auto&& ad_max = max_.to_ad(vars); + + // Subcase 1: x -> has no param + if constexpr (!VarType::has_param) { + // Note: value can be used instead of to_ad because + // vars will be ignored by anything that does not have param + // TODO: wait for support for ad::min for constants + auto x_min = math::min(util::counting_iterator<>(0), + util::counting_iterator<>(x.size()), + [&](auto i) { return x.value(vars, i); }); + auto x_max = math::max(util::counting_iterator<>(0), + util::counting_iterator<>(x.size()), + [&](auto i) { return x.value(vars, i); }); + return ad::if_else( + ((ad_min < ad::constant(x_min)) && + (ad::constant(x_max) < ad_max)), + -ad::constant(x.size()) * + ad::log(ad_max - ad_min), + ad::constant(math::neg_inf) + ); + } + + // Subcase 2: x -> has param + else { + return (-ad::constant(x.size()) * + ad::log(ad_max - ad_min)) + + ad::sum(util::counting_iterator<>(0), + util::counting_iterator<>(x.size()), + [&](auto i) { + return ad::if_else( + ( (ad_min < x.to_ad(vars, i)) && + (x.to_ad(vars, i) < ad_max) ), + ad::constant(0), + ad::constant(math::neg_inf) + ); + } + ); + } + } + + // Case 2: all other cases + else { + return ad::sum(util::counting_iterator<>(0), + util::counting_iterator<>(x.size()), + [&](auto i) { + auto&& ad_x = x.to_ad(vars, i); + auto&& ad_min = min_.to_ad(vars, i); + auto&& ad_max = max_.to_ad(vars, i); + return ad::if_else( + (ad_min < ad_x) && (ad_x < ad_max), + -ad::log(ad_max - ad_min), + ad::constant(math::neg_inf) + ); + }); + } } - value_t min(size_t index=0) const { return min_.get_value(index); } - value_t max(size_t index=0) const { return max_.get_value(index); } + template + value_t min(const PVecType& pvalues, + size_t i=0, + F f = F()) const + { return min_.value(pvalues, i, f); } + + template + value_t max(const PVecType& pvalues, + size_t i=0, + F f = F()) const + { return max_.value(pvalues, i, f); } private: - min_type min_; // TODO enforce that these are at least descended from a Param class. - max_type max_; + MinType min_; // TODO enforce that these are at least descended from a Param class. + MaxType max_; }; } // namespace expr diff --git a/include/autoppl/expr_builder.hpp b/include/autoppl/expression/expr_builder.hpp similarity index 72% rename from include/autoppl/expr_builder.hpp rename to include/autoppl/expression/expr_builder.hpp index 5f062f0f..3adf37dd 100644 --- a/include/autoppl/expr_builder.hpp +++ b/include/autoppl/expression/expr_builder.hpp @@ -2,10 +2,10 @@ #include #include #include -#include #include +#include +#include #include -#include #include #include #include @@ -32,70 +32,96 @@ namespace details { * Assumes each condition is non-overlapping. */ -#if __cplusplus <= 201703L - template struct convert_to_param {}; +// Convert from param to param viewer template struct convert_to_param> > - > + std::enable_if_t< + util::is_param_v> && + util::is_scl_v> + > > { - using type = expr::VariableViewer>; +private: + using raw_t = std::decay_t; + using pointer_t = typename + util::param_traits::pointer_t; +public: + using type = ppl::ParamView; }; template -struct convert_to_param> > - > +struct convert_to_param> && + util::is_vec_v> + > > { - using type = expr::Constant>; +private: + using raw_t = std::decay_t; + using vec_t = typename + util::param_traits::vec_t; +public: + using type = ppl::ParamView; }; +// Convert from data to data viewer template -struct convert_to_param> > - > +struct convert_to_param> && + util::is_scl_v> + > > { - using type = T; +private: + using raw_t = std::decay_t; + using value_t = typename + util::data_traits::value_t; +public: + using type = ppl::DataView; }; -#else - template -struct convert_to_param; - -template -requires util::var> -struct convert_to_param +struct convert_to_param> && + util::is_vec_v> + > > { - using type = expr::VariableViewer>; +private: + using raw_t = std::decay_t; + using vec_t = typename + util::data_traits::vec_t; +public: + using type = ppl::DataView; }; +// Convert arithmetic types to Constants template -requires std::is_arithmetic_v> -struct convert_to_param +struct convert_to_param> > + > { using type = expr::Constant>; }; +// Convert variable expressions (not variables) into itself (no change) template -requires util::var_expr> -struct convert_to_param +struct convert_to_param> && + !util::is_var_v> > + > { using type = T; }; -#endif - template using convert_to_param_t = typename convert_to_param::type; -#if __cplusplus <= 201703L - /** * Checks if valid distribution parameter: * - can be arithmetic @@ -121,25 +147,6 @@ inline constexpr bool is_not_both_arithmetic_v = std::is_arithmetic_v>) ; -#else - -template -concept valid_dist_param = - std::is_arithmetic_v> || - (util::var> && - !std::is_rvalue_reference_v && - !std::is_const_v>) || - (util::var_expr>) - ; - -template -concept not_both_arithmetic = - !(std::is_arithmetic_v> && - std::is_arithmetic_v>) - ; - -#endif - } // namespace details /** @@ -147,7 +154,6 @@ concept not_both_arithmetic = * are both valid continuous distribution parameter types. * See var_expr.hpp for more information. */ -#if __cplusplus <= 201703L template && @@ -155,12 +161,6 @@ template > inline constexpr auto uniform(MinType&& min_expr, MaxType&& max_expr) -#else -template -inline constexpr auto uniform(MinType&& min_expr, - MaxType&& max_expr) -#endif { using min_t = details::convert_to_param_t; using max_t = details::convert_to_param_t; @@ -176,7 +176,6 @@ inline constexpr auto uniform(MinType&& min_expr, * are both valid continuous distribution parameter types. * See var_expr.hpp for more information. */ -#if __cplusplus <= 201703L template && @@ -184,12 +183,6 @@ template > inline constexpr auto normal(MeanType&& mean_expr, StddevType&& stddev_expr) -#else -template -inline constexpr auto normal(MeanType&& mean_expr, - StddevType&& stddev_expr) -#endif { using mean_t = details::convert_to_param_t; using stddev_t = details::convert_to_param_t; @@ -205,16 +198,11 @@ inline constexpr auto normal(MeanType&& mean_expr, * is a valid discrete distribution parameter type. * See var_expr.hpp for more information. */ -#if __cplusplus <= 201703L template > > inline constexpr auto bernoulli(ProbType&& p_expr) -#else -template -inline constexpr auto bernoulli(ProbType&& p_expr) -#endif { using p_t = details::convert_to_param_t; p_t wrap_p_expr = std::forward(p_expr); @@ -230,11 +218,13 @@ inline constexpr auto bernoulli(ProbType&& p_expr) * only when var is a Variable and dist is a valid distribution expression. * Ex. x |= uniform(0,1) */ -template +template > > inline constexpr auto operator|=( - util::Var& var, - const util::DistExpr& dist) -{ return expr::EqNode(var.self(), dist.self()); } + const VarType& var, + const util::DistExprBase& dist) +{ return expr::EqNode(var, dist.self()); } /** * Builds a GlueNode to "glue" the left expression with the right @@ -242,8 +232,8 @@ inline constexpr auto operator|=( * Ex. (x |= uniform(0,1), y |= uniform(0, 2)) */ template -inline constexpr auto operator,(const util::ModelExpr& lhs, - const util::ModelExpr& rhs) +inline constexpr auto operator,(const util::ModelExprBase& lhs, + const util::ModelExprBase& rhs) { return expr::GlueNode(lhs.self(), rhs.self()); } //////////////////////////////////////////////////////// @@ -252,8 +242,6 @@ inline constexpr auto operator,(const util::ModelExpr& lhs, namespace details { -#if __cplusplus <= 201703L - /** * Concept of valid variable expression parameter * for the operator overloads: @@ -264,19 +252,8 @@ namespace details { template inline constexpr bool is_valid_op_param_v = std::is_arithmetic_v> || - util::is_var_expr_v> || - util::is_var_v> + util::is_var_expr_v> ; -#else - -template -concept valid_op_param = - std::is_arithmetic_v> || - util::var_expr> || - util::var> - ; - -#endif template inline constexpr auto operator_helper(LHSType&& lhs, RHSType&& rhs) @@ -303,18 +280,12 @@ inline constexpr auto operator_helper(LHSType&& lhs, RHSType&& rhs) * SFINAE to ensure concepts are placed. */ -#if __cplusplus <= 201703L template && details::is_valid_op_param_v && details::is_valid_op_param_v > > -#else -template -requires details::not_both_arithmetic -#endif inline constexpr auto operator+(LHSType&& lhs, RHSType&& rhs) { @@ -323,18 +294,12 @@ inline constexpr auto operator+(LHSType&& lhs, std::forward(rhs)); } -#if __cplusplus <= 201703L template && details::is_valid_op_param_v && details::is_valid_op_param_v > > -#else -template -requires details::not_both_arithmetic -#endif inline constexpr auto operator-(LHSType&& lhs, RHSType&& rhs) { return details::operator_helper( @@ -342,18 +307,12 @@ inline constexpr auto operator-(LHSType&& lhs, RHSType&& rhs) std::forward(rhs)); } -#if __cplusplus <= 201703L template && details::is_valid_op_param_v && details::is_valid_op_param_v > > -#else -template -requires details::not_both_arithmetic -#endif inline constexpr auto operator*(LHSType&& lhs, RHSType&& rhs) { return details::operator_helper( @@ -361,18 +320,12 @@ inline constexpr auto operator*(LHSType&& lhs, RHSType&& rhs) std::forward(rhs)); } -#if __cplusplus <= 201703L template && details::is_valid_op_param_v && details::is_valid_op_param_v > > -#else -template -requires details::not_both_arithmetic -#endif inline constexpr auto operator/(LHSType&& lhs, RHSType&& rhs) { return details::operator_helper( diff --git a/include/autoppl/expression/model/eq_node.hpp b/include/autoppl/expression/model/eq_node.hpp index ad2c8046..cb06f4f1 100644 --- a/include/autoppl/expression/model/eq_node.hpp +++ b/include/autoppl/expression/model/eq_node.hpp @@ -3,9 +3,14 @@ #include #include #include -#include -#include -#include +#include +#include +#include +#include + +#define PPL_VAR_DIST_CONT_DISC_MATCH \ + "A continuous variable can only be assigned to a continuous distribution. " \ + "A discrete variable can only be assigned to a discrete distribution. " namespace ppl { namespace expr { @@ -14,26 +19,25 @@ namespace expr { * This class represents a "node" in the model expression * that relates a var with a distribution. */ -#if __cplusplus <= 201703L -template -#else -template -#endif -struct EqNode : util::ModelExpr> +template +struct EqNode: util::ModelExprBase> { - -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_v); - static_assert(util::assert_is_dist_expr_v); -#endif - using var_t = VarType; using dist_t = DistType; - using dist_value_t = typename util::dist_expr_traits::dist_value_t; - EqNode(var_t& var, + static_assert(util::is_var_v); + static_assert(util::is_dist_expr_v); + + static_assert((util::var_traits::is_cont_v && + util::dist_expr_traits::is_cont_v) || + (util::var_traits::is_disc_v && + util::dist_expr_traits::is_disc_v), + PPL_VAR_DIST_CONT_DISC_MATCH); + + EqNode(const var_t& var, const dist_t& dist) noexcept - : orig_var_ref_{var} + : var_{var} , dist_{dist} {} @@ -60,67 +64,36 @@ struct EqNode : util::ModelExpr> * Compute pdf of underlying distribution with underlying value. * Assumes that underlying value has been assigned properly. */ - dist_value_t pdf() const { - return dist_.pdf(get_variable()); - } + template + auto pdf(const PVecType& pvalues) const + { return dist_.pdf(get_variable(), pvalues); } /** * Compute log-pdf of underlying distribution with underlying value. * Assumes that underlying value has been assigned properly. */ - dist_value_t log_pdf() const { - return dist_.log_pdf(get_variable()); - } - - template - auto ad_log_pdf(const VecRefType& keys, - const VecADVarType& vars) const - { - // if parameter, find the corresponding variable - // in vars and return the AD log-pdf with this variable. -#if __cplusplus <= 201703L - if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - const void* addr = &orig_var_ref_.get(); - auto it = std::find(keys.begin(), keys.end(), addr); - assert(it != keys.end()); - size_t idx = std::distance(keys.begin(), it); - return dist_.ad_log_pdf(vars[idx], keys, vars); - } + template + auto log_pdf(const PVecType& pvalues, + F f = F()) const + { return dist_.log_pdf(get_variable(), pvalues, f); } - // if data, return sum of log_pdf where each element - // is a constant AD node containing each value of data. - // note: data is not copied at any point. -#if __cplusplus <= 201703L - else if constexpr (util::is_data_v) { -#else - else if constexpr (util::data) { -#endif - const auto& var = this->get_variable(); - size_t idx = 0; - const size_t size = var.size(); - return ad::sum(var.begin(), var.end(), - [&, idx, size](auto value) mutable { - idx = idx % size; // may be important since mutable - auto&& expr = dist_.ad_log_pdf( - ad::constant(value), keys, vars, idx); - ++idx; - return expr; - }); - } - } + /** + * Generates AD expression for log pdf of underlying distribution. + * @param map mapping of variable IDs to offset in ad_vars + * @param ad_vars container of AD variables that correspond to parameters. + */ + template + auto ad_log_pdf(const VecADVarType& ad_vars) const + { return dist_.ad_log_pdf(get_variable(), ad_vars); } - auto& get_variable() { return orig_var_ref_.get(); } - const auto& get_variable() const { return orig_var_ref_.get(); } - const auto& get_distribution() const { return dist_; } + var_t& get_variable() { return var_; } + const var_t& get_variable() const { return var_; } + const dist_t& get_distribution() const { return dist_; } private: - using var_ref_t = std::reference_wrapper; - var_ref_t orig_var_ref_; // reference of the original var since - // any configuration may be changed until right before update - dist_t dist_; // distribution associated with var + var_t var_; + dist_t dist_; }; } // namespace expr diff --git a/include/autoppl/expression/model/glue_node.hpp b/include/autoppl/expression/model/glue_node.hpp index 4682c0d0..f50f3f9d 100644 --- a/include/autoppl/expression/model/glue_node.hpp +++ b/include/autoppl/expression/model/glue_node.hpp @@ -1,7 +1,8 @@ #pragma once #include #include -#include +#include +#include namespace ppl { namespace expr { @@ -10,25 +11,15 @@ namespace expr { * This class represents a "node" in a model expression that * "glues" two sub-model expressions. */ -#if __cplusplus <= 201703L -template -#else -template -#endif -struct GlueNode : util::ModelExpr> +template +struct GlueNode: util::ModelExprBase> { - -#if __cplusplus <= 201703L - static_assert(util::assert_is_model_expr_v); - static_assert(util::assert_is_model_expr_v); -#endif + static_assert(util::is_model_expr_v); + static_assert(util::is_model_expr_v); using left_node_t = LHSNodeType; using right_node_t = RHSNodeType; - using dist_value_t = std::common_type_t< - typename util::model_expr_traits::dist_value_t, - typename util::model_expr_traits::dist_value_t - >; GlueNode(const left_node_t& lhs, const right_node_t& rhs) noexcept @@ -58,26 +49,32 @@ struct GlueNode : util::ModelExpr> * Computes left node joint pdf then right node joint pdf * and returns the product of the two. */ - dist_value_t pdf() const - { return left_node_.pdf() * right_node_.pdf(); } + template + auto pdf(const PVecType& pvalues) const + { return left_node_.pdf(pvalues) * right_node_.pdf(pvalues); } /** * Computes left node joint log-pdf then right node joint log-pdf * and returns the sum of the two. */ - dist_value_t log_pdf() const - { return left_node_.log_pdf() + right_node_.log_pdf(); } + template + auto log_pdf(const PVecType& pvalues, + F f = F()) const + { + return left_node_.log_pdf(pvalues, f) + + right_node_.log_pdf(pvalues, f); + } /** * Up to constant addition, returns ad expression of log pdf * of both sides added together. */ - template - auto ad_log_pdf(const VecRefType& keys, - const VecADVarType& vars) const + template + auto ad_log_pdf(const VecADVarType& vars) const { - return (left_node_.ad_log_pdf(keys, vars) + - right_node_.ad_log_pdf(keys, vars)); + return (left_node_.ad_log_pdf(vars) + + right_node_.ad_log_pdf(vars)); } private: diff --git a/include/autoppl/expression/model/model_utils.hpp b/include/autoppl/expression/model/model_utils.hpp deleted file mode 100644 index ea7fb577..00000000 --- a/include/autoppl/expression/model/model_utils.hpp +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once -#include -#include - -namespace ppl { - -/** - * Returns number of parameters in model. - */ -namespace details { - -template -struct get_n_params {}; - -template -struct get_n_params> -{ - static constexpr size_t value = -#if __cplusplus <= 201703L - 1 * util::is_param_v; -#else - 1 * util::param; -#endif -}; - -template -struct get_n_params> -{ - static constexpr size_t value = - get_n_params::value + - get_n_params::value; -}; - -} // namespace details - -template -inline constexpr size_t get_n_params_v = - details::get_n_params::value; - -} // namespace ppl diff --git a/include/autoppl/expression/variable/binop.hpp b/include/autoppl/expression/variable/binop.hpp index a7a7b5c4..18bff340 100644 --- a/include/autoppl/expression/variable/binop.hpp +++ b/include/autoppl/expression/variable/binop.hpp @@ -1,35 +1,44 @@ #pragma once -#include -#include +#include +#include +#include namespace ppl { namespace expr { -#if __cplusplus <= 201703L -template -#else -template -#endif +template struct BinaryOpNode : - util::VarExpr> + util::VarExprBase> { -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_expr_v); - static_assert(util::assert_is_var_expr_v); -#endif + static_assert(util::is_var_expr_v); + static_assert(util::is_var_expr_v); using value_t = std::common_type_t< typename util::var_expr_traits::value_t, typename util::var_expr_traits::value_t >; - - BinaryOpNode(const LHSVarExprType& lhs, const RHSVarExprType& rhs) + using shape_t = util::max_shape_t< + typename util::shape_traits::shape_t, + typename util::shape_traits::shape_t + >; + static constexpr bool has_param = + LHSVarExprType::has_param || RHSVarExprType::has_param; + + BinaryOpNode(const LHSVarExprType& lhs, + const RHSVarExprType& rhs) : lhs_{lhs}, rhs_{rhs} - { assert(lhs.size() == rhs.size() || lhs.size() == 1 || rhs.size() == 1); } - - value_t get_value(size_t i = 0) const { - auto lhs_value = lhs_.get_value(i); - auto rhs_value = rhs_.get_value(i); + {} + + template + value_t value(const PVecType& pvalues, + size_t i, + F f = F()) const + { + auto lhs_value = lhs_.value(pvalues, i, f); + auto rhs_value = rhs_.value(pvalues, i, f); return BinaryOp::evaluate(lhs_value, rhs_value); } @@ -38,59 +47,41 @@ struct BinaryOpNode : /** * Returns ad expression of the binary operation. */ - template - auto get_ad(const VecRefType& keys, - const VecADVarType& vars, - size_t idx = 0) const + template + auto to_ad(const VecADVarType& vars, + size_t i=0) const { - return BinaryOp::evaluate(lhs_.get_ad(keys, vars, idx), - rhs_.get_ad(keys, vars, idx)); + return BinaryOp::evaluate(lhs_.to_ad(vars, i), + rhs_.to_ad(vars, i)); } private: LHSVarExprType lhs_; RHSVarExprType rhs_; - }; struct AddOp { - template static auto evaluate(LHSValueType x, RHSValueType y) - { - return x + y; - } - + { return x + y; } }; struct SubOp { - template static auto evaluate(LHSValueType x, RHSValueType y) - { - return x - y; - } - + { return x - y; } }; struct MultOp { - template static auto evaluate(LHSValueType x, RHSValueType y) - { - return x * y; - } - + { return x * y; } }; struct DivOp { - template static auto evaluate(LHSValueType x, RHSValueType y) - { - return x / y; - } - + { return x / y; } }; } // namespace expr diff --git a/include/autoppl/expression/variable/constant.hpp b/include/autoppl/expression/variable/constant.hpp index fe5955d6..6bcf213a 100644 --- a/include/autoppl/expression/variable/constant.hpp +++ b/include/autoppl/expression/variable/constant.hpp @@ -1,30 +1,34 @@ #pragma once -#include #include +#include +#include namespace ppl { namespace expr { -template -struct Constant : util::VarExpr> +template +struct Constant: + util::VarExprBase> { using value_t = ValueType; - Constant(value_t c) - : c_{c} - {} - value_t get_value(size_t = 0) const { - return c_; - } - - constexpr size_t size() const { return 1; } - - /** - * Returns ad expression of the constant. - */ - template - auto get_ad(const VecRefType&, - const VecADVarType&, - size_t = 0) const + using shape_t = ShapeType; + static constexpr bool has_param = false; + + Constant(value_t c) : c_{c} {} + + template + value_t value(const PVecType&, + size_t=0, + F = F()) const + { return c_; } + + constexpr size_t size() const { return 1ul; } + + template + auto to_ad(const VecADVarType&, + size_t = 0) const { return ad::constant(c_); } private: diff --git a/include/autoppl/expression/variable/data.hpp b/include/autoppl/expression/variable/data.hpp new file mode 100644 index 00000000..441b605d --- /dev/null +++ b/include/autoppl/expression/variable/data.hpp @@ -0,0 +1,160 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace ppl { + +/** + * DataView is a class that only views data values. + * It cannot modify the underlying value. + * If there are multiple values, i.e. shape is vec or mat, + * it views all of the elements. + */ +template +struct DataView: + util::VarExprBase>, + util::DataBase> +{ + using value_t = ValueType; + using const_pointer_t = const value_t*; + using id_t = const void*; + using shape_t = ShapeType; + static constexpr bool has_param = false; + + DataView(const value_t& v) noexcept + : value_ptr_{&v} + , id_{this} + {} + + template + value_t value(const VecType&, + size_t=0, + F = F()) const + { return *value_ptr_; } + + constexpr size_t size() const { return 1ul; } + id_t id() const { return id_; } + + template + auto to_ad(const VecADVarType&, + size_t=0) const + { return ad::constant(*value_ptr_); } + +private: + const_pointer_t value_ptr_; + id_t id_; +}; + +template +struct DataView : + util::VarExprBase>, + util::DataBase> +{ + using vec_t = VecType; + using vec_const_pointer_t = const vec_t*; + using value_t = typename vec_t::value_type; + using id_t = const void*; + using shape_t = ppl::vec; + static constexpr bool has_param = false; + + DataView(const vec_t& v) noexcept + : vec_ptr_{&v} + , id_{this} + {} + + template + value_t value(const PVecType&, + size_t i, + F = F()) const + { return (*vec_ptr_)[i]; } + + size_t size() const { return vec_ptr_->size(); } + + id_t id() const { return id_; } + + template + auto to_ad(const VecADVarType&, + size_t i) const + { return ad::constant((*vec_ptr_)[i]); } + +private: + vec_const_pointer_t vec_ptr_; + id_t id_; +}; + +// Primary: var-like +template +struct Data: + DataView, + util::VarExprBase>, + util::DataBase> +{ + using base_t = DataView; + using typename base_t::value_t; + using typename base_t::shape_t; + using typename base_t::id_t; + using base_t::value; + using base_t::size; + using base_t::id; + using base_t::to_ad; + + Data(value_t v) noexcept + : base_t(value_) + , value_(v) + {} + Data() noexcept : Data(0) {} + +private: + value_t value_; // store value associated with data +}; + +// Specialization: vec-like +template +struct Data: + DataView, ppl::vec>, + util::VarExprBase>, + util::DataBase> +{ + using base_t = DataView, ppl::vec>; + using typename base_t::value_t; + using typename base_t::shape_t; + using typename base_t::id_t; + using base_t::value; + using base_t::size; + using base_t::id; + using base_t::to_ad; + + Data(std::initializer_list l) noexcept + : base_t(vec_) + , vec_(l) + {} + + Data(size_t n) + : base_t(vec_) + , vec_(n) + {} + + Data() noexcept : Data(0) {} + + void push_back(value_t x) { vec_.push_back(x); } + +private: + std::vector vec_; +}; + +// TODO: Specialization: mat-like + +// Compiler should choose this when ShapeType is ppl::scl +template +inline constexpr auto make_data_viewer(const Container& x) +{ return DataView(x); } + +} // namespace ppl diff --git a/include/autoppl/expression/variable/param.hpp b/include/autoppl/expression/variable/param.hpp new file mode 100644 index 00000000..cd4b0d27 --- /dev/null +++ b/include/autoppl/expression/variable/param.hpp @@ -0,0 +1,249 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace ppl { + +/** + * ParamView is a class that views both data values and storage pointers. + * Note that it is viewing a storage pointer and not the storage. + * This is because user can externally choose to change the storage pointer. + * + * It can bind to view a different value but not storage pointer. + * It cannot modify the underlying value or storage pointer. + * It can modify storage values by dereferencing storage pointer. + * If there are multiple values, i.e. shape is vec or mat, + * it views all of the elements. + * If vec or mat, must know the size at construction, but the actual viewees. + */ + +template +struct ParamView: + util::VarExprBase>, + util::ParamBase> +{ + using pointer_t = PointerType; + using value_t = std::remove_const_t< + std::remove_pointer_t >; + using const_pointer_t = const value_t*; + using const_storage_pointer_t = const pointer_t*; + using id_t = const void*; + using shape_t = ShapeType; + using index_t = uint32_t; + static constexpr bool has_param = true; + + // Note: id may need to be provided when subscripting + ParamView(index_t& offset, + const pointer_t& storage_ptr, + id_t id, + index_t rel_offset = 0) noexcept + : offset_ptr_{&offset} + , rel_offset_{rel_offset} + , storage_ptr_ptr_{&storage_ptr} + , id_{id} + {} + + ParamView(index_t& offset, + const pointer_t& storage_ptr, + index_t rel_offset = 0) noexcept + : ParamView(offset, storage_ptr, this, rel_offset) + {} + + template + auto& value(VecType& vars, + size_t=0, + F f = F()) const + { + return f.template operator()( + vars[*offset_ptr_ + rel_offset_]); + } + + template + auto value(const VecType& vars, + size_t=0, + F f = F()) const + { + return f.template operator()( + vars[*offset_ptr_ + rel_offset_]); + } + + constexpr size_t size() const { return 1ul; } + + pointer_t storage(size_t=0) const + { return *storage_ptr_ptr_; } + + id_t id() const { return id_; } + + // TODO: type check that it's a vector of ad vars? + template + auto to_ad(const VecType& vars, + size_t=0) const + { return vars[*offset_ptr_ + rel_offset_]; } + + index_t& offset() { return *offset_ptr_; } + +private: + index_t* const offset_ptr_; + const index_t rel_offset_; + const_storage_pointer_t storage_ptr_ptr_; + id_t id_; +}; + +template +struct ParamView: + util::VarExprBase>, + util::ParamBase> +{ + using vec_t = VecType; + using pointer_t = typename VecType::value_type; + using value_t = std::remove_const_t< + std::remove_pointer_t >; + using const_pointer_t = const value_t*; + using shape_t = ppl::vec; + using index_t = uint32_t; + using id_t = const void*; + static constexpr bool has_param = true; + + ParamView(index_t& offset, + const vec_t& storages, + index_t size) noexcept + : offset_ptr_{&offset} + , storages_ptr_{&storages} + , id_{this} + , size_{size} + {} + + template + auto& value(PVecType& vars, + size_t i, + F f = F()) const + { + return f.template operator()( + vars[*offset_ptr_ + i]); + } + + template + auto value(const PVecType& vars, + size_t i, + F f = F()) const + { + return f.template operator()( + vars[*offset_ptr_ + i]); + } + + size_t size() const { return size_; } + + pointer_t storage(size_t i) const + { return (*storages_ptr_)[i]; } + + id_t id() const { return id_; } + + template + auto to_ad(const VecADVarType& vars, + size_t i) const + { return vars[*offset_ptr_ + i]; } + + index_t& offset() { return *offset_ptr_; } + + auto operator[](index_t i) { + return ParamView( + *offset_ptr_, + (*storages_ptr_)[i], + id_, + i); + } + +private: + index_t* offset_ptr_; + const vec_t* storages_ptr_; + id_t id_; + index_t size_; +}; + +template +struct Param: + ParamView, + util::VarExprBase>, + util::ParamBase> +{ + using base_t = ParamView; + using typename base_t::value_t; + using typename base_t::pointer_t; + using typename base_t::const_pointer_t; + using typename base_t::id_t; + using typename base_t::index_t; + using typename base_t::shape_t; + using base_t::value; + using base_t::size; + using base_t::storage; + using base_t::to_ad; + using base_t::id; + using base_t::offset; + + Param(pointer_t ptr=nullptr) noexcept + : base_t(offset_, storage_ptr_) + , offset_(0) + , storage_ptr_(ptr) + {} + + void set_storage(pointer_t ptr) + { storage_ptr_ = ptr; } + +private: + index_t offset_; + pointer_t storage_ptr_; +}; + +template +struct Param : + ParamView, ppl::vec>, + util::VarExprBase>, + util::ParamBase> +{ + using base_t = ParamView, ppl::vec>; + using typename base_t::value_t; + using typename base_t::pointer_t; + using typename base_t::const_pointer_t; + using typename base_t::id_t; + using typename base_t::index_t; + using typename base_t::shape_t; + using base_t::value; + using base_t::size; + using base_t::storage; + using base_t::to_ad; + using base_t::id; + using base_t::offset; + + Param(size_t n) + : base_t(offset_, storage_ptrs_, n) + , storage_ptrs_(n, nullptr) + {} + + Param(std::initializer_list ptrs) noexcept + : base_t(offset_, storage_ptrs_, ptrs.size()) + , offset_(0) + , storage_ptrs_(ptrs) + {} + + void set_storage(pointer_t ptr, size_t i) + { storage_ptrs_[i] = ptr; } + +private: + + index_t offset_; + std::vector storage_ptrs_; +}; + +// TODO: ParamFixed + +} // namespace ppl diff --git a/include/autoppl/expression/variable/variable_viewer.hpp b/include/autoppl/expression/variable/variable_viewer.hpp deleted file mode 100644 index 73ad9733..00000000 --- a/include/autoppl/expression/variable/variable_viewer.hpp +++ /dev/null @@ -1,71 +0,0 @@ -#pragma once -#include -#include -#include -#include - -namespace ppl { -namespace expr { - -/** - * VariableViewer is a viewer of some variable type. - * It will mainly be used to view Variable class defined in autoppl/variable.hpp. - */ -#if __cplusplus <= 201703L -template -#else -template -#endif -struct VariableViewer : util::VarExpr> -{ -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_v); -#endif - - using var_t = VariableType; - using value_t = typename util::var_traits::value_t; - - VariableViewer(var_t& var) - : var_ref_{var} - {} - - value_t get_value(size_t i = 0) const { return var_ref_.get().get_value(i); } - size_t size() const { return var_ref_.get().size(); } - - /** - * Returns ad expression of the variable. - * If variable is parameter, find from vars and return. - * Otherwise if data, return idx'th ad::constant of that value. - */ - template - auto get_ad(const VecRefType& keys, - const VecADVarType& vars, - size_t idx = 0) const - { -#if __cplusplus <= 201703L - if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - static_cast(idx); - const void* addr = &var_ref_.get(); - auto it = std::find(keys.begin(), keys.end(), addr); - assert(it != keys.end()); - size_t i = std::distance(keys.begin(), it); - return vars[i]; -#if __cplusplus <= 201703L - } else if constexpr (util::is_data_v) { -#else - } else if constexpr (util::data) { -#endif - return ad::constant(this->get_value(idx)); - } - } - -private: - using var_ref_t = std::reference_wrapper; - var_ref_t var_ref_; -}; - -} // namespace expr -} // namespace ppl diff --git a/include/autoppl/math/density.hpp b/include/autoppl/math/density.hpp new file mode 100644 index 00000000..4d722a3c --- /dev/null +++ b/include/autoppl/math/density.hpp @@ -0,0 +1,72 @@ +#pragma once +#include +#include + +// MSVC does not seem to support M_PI +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +namespace ppl { +namespace math { + +///////////////////////////////// +// Compile-time Constants +///////////////////////////////// + +inline constexpr double SQRT_TWO_PI = + 2.506628274631000502415765284811045; +inline constexpr double LOG_SQRT_TWO_PI = + 0.918938533204672741780329736405617; + +///////////////////////////////// +// Univariate densities +///////////////////////////////// + +template +inline constexpr auto normal_pdf(T x, T mean, T sd) +{ + T z_score = (x - mean) / sd; + return std::exp(-0.5 * z_score * z_score) / + (sd * SQRT_TWO_PI); +} + +template +inline constexpr auto normal_log_pdf(T x, T mean, T sd) +{ + T z_score = (x - mean) / sd; + return (-0.5 * z_score * z_score) - std::log(sd) - LOG_SQRT_TWO_PI; +} + +template +inline constexpr auto uniform_pdf(T x, T min, T max) +{ + return (min < x && x < max) ? 1. / (max - min) : 0; +} + +template +inline constexpr auto uniform_log_pdf(T x, T min, T max) +{ + return (min < x && x < max) ? + -std::log(max - min) : + neg_inf; +} + +template +inline constexpr auto bernoulli_pdf(IntType x, T p) +{ + if (x == 1) return p; + else if (x == 0) return 1. - p; + else return 0.0; +} + +template +inline constexpr auto bernoulli_log_pdf(IntType x, T p) +{ + if (x == 1) return std::log(p); + else if (x == 0) return std::log(1. - p); + else return neg_inf; +} + +} // namespace math +} // namespace ppl diff --git a/include/autoppl/math/math.hpp b/include/autoppl/math/math.hpp new file mode 100644 index 00000000..9e637df0 --- /dev/null +++ b/include/autoppl/math/math.hpp @@ -0,0 +1,61 @@ +#pragma once +#include +#include +#include +#include + +namespace ppl { +namespace math { + +template +inline constexpr T inf = + std::numeric_limits::is_iec559 ? + std::numeric_limits::infinity() : + std::numeric_limits::max(); + +template +inline constexpr T neg_inf = + std::numeric_limits::is_iec559 ? + -std::numeric_limits::infinity() : + std::numeric_limits::lowest(); + +template +inline constexpr auto min(Iter begin, Iter end, F f = F()) +{ + using value_t = typename std::iterator_traits::value_type; + static_assert(std::is_invocable_v); + using ret_value_t = std::decay_t< + decltype(f(std::declval())) >; + + if (std::distance(begin, end) <= 0) { + return inf; + } + + ret_value_t res = inf; + std::for_each(begin, end, + [&](value_t x) + { res = std::min(res, f(x)); }); + return res; +} + +template +inline constexpr auto max(Iter begin, Iter end, F f = F()) +{ + using value_t = typename std::iterator_traits::value_type; + static_assert(std::is_invocable_v); + using ret_value_t = std::decay_t< + decltype(f(std::declval())) >; + + if (std::distance(begin, end) <= 0) { + return neg_inf; + } + + ret_value_t res = neg_inf; + std::for_each(begin, end, + [&](value_t x) + { res = std::max(res, f(x)); }); + return res; +} + +} // namespace math +} // namespace ppl diff --git a/include/autoppl/mcmc/hmc/nuts/configs.hpp b/include/autoppl/mcmc/hmc/nuts/configs.hpp index 88bd8e9b..e8ce533f 100644 --- a/include/autoppl/mcmc/hmc/nuts/configs.hpp +++ b/include/autoppl/mcmc/hmc/nuts/configs.hpp @@ -1,6 +1,6 @@ #pragma once #include -#include +#include #include #include diff --git a/include/autoppl/mcmc/hmc/nuts/nuts.hpp b/include/autoppl/mcmc/hmc/nuts/nuts.hpp index b9fa2add..5947c3bc 100644 --- a/include/autoppl/mcmc/hmc/nuts/nuts.hpp +++ b/include/autoppl/mcmc/hmc/nuts/nuts.hpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include @@ -51,18 +51,17 @@ bool check_entropy(const MatType1& rho, * * Note that the caller MUST have input theta_adj already pre-computed. */ -template -TreeOutput build_tree(InputType& input, +TreeOutput build_tree(size_t n_params, + InputType& input, uint8_t depth, UniformDistType& unif_sampler, GenType& gen, - const MomentumHandlerType& momentum_handler - ) + const MomentumHandlerType& momentum_handler) { constexpr double delta_max = 1000; // suggested by Gelman @@ -114,11 +113,11 @@ TreeOutput build_tree(InputType& input, } // recursion - arma::mat::fixed mat_first(arma::fill::zeros); + arma::mat mat_first(n_params, 3, arma::fill::zeros); auto p_end_inner = mat_first.col(0); auto p_end_scaled_inner = mat_first.col(1); auto rho_first = mat_first.col(2); - double log_sum_weight_first = -std::numeric_limits::infinity(); + double log_sum_weight_first = math::neg_inf; // create a new input for first recursion // some references have to rebound @@ -130,8 +129,8 @@ TreeOutput build_tree(InputType& input, // build first subtree TreeOutput first_output = - build_tree(first_input, depth - 1, - unif_sampler, gen, momentum_handler); + build_tree(n_params, first_input, depth - 1, + unif_sampler, gen, momentum_handler); // if first subtree is already invalid, early exit // note that caller will break out of doubling process now, @@ -139,12 +138,12 @@ TreeOutput build_tree(InputType& input, if (!first_output.valid) { return first_output; } // second recursion - arma::mat::fixed mat_second(arma::fill::zeros); + arma::mat mat_second(n_params, 4, arma::fill::zeros); auto theta_double_prime = mat_second.col(0); auto p_beg_inner = mat_second.col(1); auto p_beg_scaled_inner = mat_second.col(2); auto rho_second = mat_second.col(3); - double log_sum_weight_second = -std::numeric_limits::infinity(); + double log_sum_weight_second = math::neg_inf; // create a new input for second recursion InputType second_input = input; @@ -156,8 +155,8 @@ TreeOutput build_tree(InputType& input, // build second subtree TreeOutput second_output = - build_tree(second_input, depth - 1, - unif_sampler, gen, momentum_handler); + build_tree(n_params, second_input, depth - 1, + unif_sampler, gen, momentum_handler); // if second subtree is invalid, early exit // note that we must return first output since it has the potential @@ -211,8 +210,7 @@ TreeOutput build_tree(InputType& input, * Finds a reasonable epsilon for NUTS algorithm. * @param ad_expr AD expression bound to theta and theta_adj */ -template double find_reasonable_epsilon(double eps, @@ -226,12 +224,12 @@ double find_reasonable_epsilon(double eps, const double diff_bound = std::log(0.8); - arma::mat::fixed r_mat(arma::fill::zeros); - auto r = r_mat.col(0); + size_t n_params = theta.n_elem; // theta is expected to be vector-like - arma::mat::fixed theta_mat(arma::fill::zeros); - auto theta_orig = theta_mat.col(0); - auto theta_adj_orig = theta_mat.col(1); + arma::mat r_theta_mat(n_params, 3, arma::fill::zeros); + auto r = r_theta_mat.col(0); + auto theta_orig = r_theta_mat.col(1); + auto theta_adj_orig = r_theta_mat.col(2); // sample momentum vector based on handler momentum_handler.sample(r); @@ -295,10 +293,14 @@ double find_reasonable_epsilon(double eps, */ template > -void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) +void nuts(ModelType& model, + NUTSConfigType config = NUTSConfigType()) { + // activate model + mcmc::activate(model); + // initialization of meta-variables - constexpr size_t n_params = get_n_params_v; + size_t n_params = mcmc::param_size(model); std::mt19937 gen(config.seed); std::uniform_int_distribution direction_sampler(0, 1); std::uniform_real_distribution unif_sampler(0., 1.); @@ -310,7 +312,7 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) // right-subtree forwardmost momentum => ff // scaled versions are based on hamiltonian adjusted covariance matrix constexpr uint8_t n_p_cached = 8; - arma::mat::fixed p_mat(arma::fill::zeros); + arma::mat p_mat(n_params, n_p_cached, arma::fill::zeros); auto p_bb = p_mat.col(0); auto p_bb_scaled = p_mat.col(1); auto p_bf = p_mat.col(2); @@ -322,7 +324,7 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) // position matrix for thetas and adjoints constexpr uint8_t n_thetas_cached = 7; - arma::mat::fixed theta_mat(arma::fill::zeros); + arma::mat theta_mat(n_params, n_thetas_cached, arma::fill::zeros); auto theta_bb = theta_mat.col(0); auto theta_bb_adj = theta_mat.col(1); auto theta_ff = theta_mat.col(2); @@ -336,7 +338,7 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) // backward-subtree => rho_b // combined subtrees => rho constexpr uint8_t n_rho_cached = 3; - arma::mat::fixed rho_mat(arma::fill::zeros); + arma::mat rho_mat(n_params, n_rho_cached, arma::fill::zeros); auto rho_f = rho_mat.col(0); auto rho_b = rho_mat.col(1); auto rho = rho_mat.col(2); @@ -353,18 +355,18 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) // keys needed to construct a correct AD expression from model // key: address of original variable tags - std::vector keys; - mcmc::get_keys(model, keys); + //std::vector keys; + //mcmc::get_keys(model, keys); // AD Expressions for L(theta) (log-pdf up to constant at theta) // Note that these expressions are the only ones used ever. - auto theta_bb_ad_expr = model.ad_log_pdf(keys, theta_bb_ad); - auto theta_ff_ad_expr = model.ad_log_pdf(keys, theta_ff_ad); - auto theta_curr_ad_expr = model.ad_log_pdf(keys, theta_curr_ad); + auto theta_bb_ad_expr = model.ad_log_pdf(theta_bb_ad); + auto theta_ff_ad_expr = model.ad_log_pdf(theta_ff_ad); + auto theta_curr_ad_expr = model.ad_log_pdf(theta_curr_ad); // initializes first sample into theta_curr // TODO: allow users to choose how to initialize first point? - mcmc::init_sample(model, theta_curr, gen); + mcmc::init_params(model, gen, theta_curr); // initialize current potential (will be "previous" starting in for-loop) double potential_prev = -ad::evaluate(theta_curr_ad_expr); @@ -376,7 +378,7 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) // initialize step adapter const double log_eps = std::log( - mcmc::find_reasonable_epsilon( + mcmc::find_reasonable_epsilon( 1., // initial epsilon theta_curr_ad_expr, theta_curr, theta_curr_adj, momentum_handler)); @@ -437,7 +439,7 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) rho_b.zeros(); rho_f.zeros(); - double log_sum_weight_subtree = std::numeric_limits::lowest(); + double log_sum_weight_subtree = math::neg_inf; int8_t v = 2 * direction_sampler(gen) - 1; // -1 or 1 if (v == -1) { @@ -457,8 +459,8 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) p_fb = p_bb; p_fb_scaled = p_bb_scaled; - output = mcmc::build_tree(input, depth, - unif_sampler, gen, momentum_handler); + output = mcmc::build_tree(n_params, input, depth, + unif_sampler, gen, momentum_handler); } else { auto input = mcmc::TreeInput( // correct position information to update @@ -475,8 +477,8 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) p_bf = p_ff; p_bf_scaled = p_ff_scaled; - output = mcmc::build_tree(input, depth, - unif_sampler, gen, momentum_handler); + output = mcmc::build_tree(n_params, input, depth, + unif_sampler, gen, momentum_handler); } // early break if starting to U-Turn @@ -528,10 +530,10 @@ void nuts(ModelType& model, NUTSConfigType config = NUTSConfigType()) std::is_same_v) { const bool update = var_adapter.adapt(theta_curr, momentum_handler.get_m_inverse()); if (update) { - double log_eps = std::log(mcmc::find_reasonable_epsilon( + double log_eps = std::log( mcmc::find_reasonable_epsilon( std::exp(step_adapter.log_eps), theta_curr_ad_expr, theta_curr, - theta_curr_adj, momentum_handler)); + theta_curr_adj, momentum_handler) ); step_adapter.reset(); step_adapter.init(log_eps); } diff --git a/include/autoppl/mcmc/hmc/nuts/step_adapter.hpp b/include/autoppl/mcmc/hmc/step_adapter.hpp similarity index 100% rename from include/autoppl/mcmc/hmc/nuts/step_adapter.hpp rename to include/autoppl/mcmc/hmc/step_adapter.hpp diff --git a/include/autoppl/mcmc/hmc/var_adapter.hpp b/include/autoppl/mcmc/hmc/var_adapter.hpp index 4eacac20..c4d1fe89 100644 --- a/include/autoppl/mcmc/hmc/var_adapter.hpp +++ b/include/autoppl/mcmc/hmc/var_adapter.hpp @@ -47,8 +47,11 @@ struct VarAdapter }; /** - * Diagonal variance matrix M is estimated for momentum covariance matrix. + * Diagonal precision matrix M is estimated for momentum covariance matrix. * M inverse is estimated as sample variance and is regularized towards identity. + * + * Follows STAN guide: https://mc-stan.org/docs/2_18/reference-manual/hmc-algorithm-parameters.html + * STAN implementation: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/windowed_adaptation.hpp */ template <> struct VarAdapter diff --git a/include/autoppl/mcmc/mh.hpp b/include/autoppl/mcmc/mh.hpp index 5addae90..b97eac91 100644 --- a/include/autoppl/mcmc/mh.hpp +++ b/include/autoppl/mcmc/mh.hpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -18,6 +17,7 @@ namespace ppl { namespace mcmc { +namespace details { /** * Convert ValueType to either util::cont_param_t if floating point @@ -27,17 +27,19 @@ namespace mcmc { template struct value_to_param { - static_assert(!(std::is_integral_v || - std::is_floating_point_v), - AUTOPPL_MH_UNKNOWN_VALUE_TYPE_ERROR); + static_assert(!(util::is_cont_v || + util::is_disc_v), + PPL_CONT_XOR_DISC); }; template -struct value_to_param>> +struct value_to_param> > { using type = util::disc_param_t; }; template -struct value_to_param>> +struct value_to_param> > { using type = util::cont_param_t; }; @@ -49,13 +51,26 @@ using value_to_param_t = typename value_to_param::type; */ struct MHData { + std::variant curr; std::variant next; // TODO: maybe keep an array for batch sampling? }; -template -inline void mh__(ModelType& model, - Iter params_it, +// Helper functor to get the correct variant value. +struct get_curr +{ + template + constexpr auto&& operator()(MHDataType&& d) noexcept + { return *std::get_if(&d.curr); } +}; + +} // namespace details + +template +inline void mh__(const ModelType& model, + PVecType& pvalues, RGenType& gen, size_t n_sample, size_t warmup, @@ -63,9 +78,11 @@ inline void mh__(ModelType& model, double alpha, double stddev) { - std::uniform_real_distribution unif_sampler(0., 1.); + std::uniform_real_distribution metrop_sampler(0., 1.); + std::discrete_distribution disc_sampler({alpha, 1-2*alpha, alpha}); + std::normal_distribution norm_sampler(0., stddev); - auto logger = util::ProgressLogger(n_sample + warmup, "MetropolisHastings"); + auto logger = util::ProgressLogger(n_sample + warmup, "Metropolis-Hastings"); for (size_t iter = 0; iter < n_sample + warmup; ++iter) { logger.printProgress(iter); @@ -77,50 +94,36 @@ inline void mh__(ModelType& model, // generate next candidates and place them in parameter // variables as next values; update log_alpha - // The old values are temporary stored in the params vector. - auto get_candidate = [=, &n_swaps, &early_reject, &gen](auto& eq_node) mutable { + auto get_candidate = [&](const auto& eq_node) mutable { if (early_reject) return; - auto& var = eq_node.get_variable(); + const auto& var = eq_node.get_variable(); + const auto& dist = eq_node.get_distribution(); using var_t = std::decay_t; using value_t = typename util::var_traits::value_t; + using converted_value_t = details::value_to_param_t; -#if __cplusplus <= 201703L if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - auto curr = var.get_value(0); - const auto& dist = eq_node.get_distribution(); - - // Choose either continuous or discrete sampler depending on value_t - if constexpr (std::is_integral_v) { - std::discrete_distribution disc_sampler({alpha, 1-2*alpha, alpha}); - auto cand = disc_sampler(gen) - 1 + curr; // new candidate in curr + [-1, 0, 1] - // TODO: refactor common logic - if (dist.min() <= cand && cand <= dist.max()) { // if within dist bound - var.set_value(cand); + // generate next candidates for each element of parameter + for (size_t i = 0; i < var.size(); ++i) { + auto& pstate = var.value(pvalues, i); // MHData object corresponding to ith param elt + converted_value_t& curr_val = *std::get_if(&pstate.curr); + converted_value_t& next_val = *std::get_if(&pstate.next); + + converted_value_t min = dist.min(pvalues, i, details::get_curr()); + converted_value_t max = dist.max(pvalues, i, details::get_curr()); + + // choose delta based on if discrete or continuous param + if constexpr (util::is_disc_v) + { next_val = curr_val + disc_sampler(gen) - 1; } + else { next_val = curr_val + norm_sampler(gen); } + + if (min <= next_val && next_val <= max) { // if within dist bound + std::swap(pstate.curr, pstate.next); ++n_swaps; - } - else { early_reject = true; return; } - } else if constexpr (std::is_floating_point_v) { - std::normal_distribution norm_sampler(static_cast(curr), stddev); - auto cand = norm_sampler(gen); - if (dist.min() <= cand && cand <= dist.max()) { // if within dist bound - var.set_value(cand); - ++n_swaps; - } - else { early_reject = true; return; } - } else { - static_assert(!(std::is_integral_v || - std::is_floating_point_v), - AUTOPPL_MH_UNKNOWN_VALUE_TYPE_ERROR); - } + } else { early_reject = true; return; } - // move old value into params - using converted_value_t = value_to_param_t; - params_it->next = static_cast(curr); - ++params_it; + } // end for } }; model.traverse(get_candidate); @@ -128,64 +131,33 @@ inline void mh__(ModelType& model, if (early_reject) { // swap back original params only up until when candidate was out of bounds. - auto add_to_storage = [=, &n_swaps](auto& eq_node) mutable { - auto& var = eq_node.get_variable(); - using var_t = std::decay_t; - using value_t = typename util::var_traits::value_t; -#if __cplusplus <= 201703L - if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - if (n_swaps) { - using converted_value_t = value_to_param_t; - var.set_value(*std::get_if(¶ms_it->next)); - ++params_it; - --n_swaps; - } - if (iter >= warmup) { - auto storage = var.get_storage(); - storage[iter - warmup] = var.get_value(0); - } - } - }; - model.traverse(add_to_storage); - continue; - } + for (size_t i = 0; i < n_swaps; ++i) { + std::swap(pvalues[i].curr, pvalues[i].next); + } - // compute next candidate log pdf and update log_alpha - double cand_log_pdf = model.log_pdf(); - log_alpha += cand_log_pdf; - bool accept = (std::log(unif_sampler(gen)) <= log_alpha); - - // If accept, "current" sample for next iteration is already in the variables - // so simply append to storage. - // Otherwise, "current" sample for next iteration must be moved back from - // params vector into variables. - auto add_to_storage = [=](auto& eq_node) mutable { - auto& var = eq_node.get_variable(); - using var_t = std::decay_t; - using value_t = typename util::var_traits::value_t; -#if __cplusplus <= 201703L - if constexpr(util::is_param_v) { -#else - if constexpr(util::param) { -#endif - if (!accept) { - using converted_value_t = value_to_param_t; - var.set_value(*std::get_if(¶ms_it->next)); - ++params_it; - } - if (iter >= warmup) { - auto storage = var.get_storage(); - storage[iter - warmup] = var.get_value(0); + } else { + + // compute next candidate log pdf and update log_alpha + double cand_log_pdf = model.log_pdf(pvalues, details::get_curr()); + log_alpha += cand_log_pdf; + bool accept = (std::log(metrop_sampler(gen)) <= log_alpha); + + // if not accept, "current" sample for next iteration is in next: swap the two! + if (!accept) { + for (auto& pvalue : pvalues) { + std::swap(pvalue.curr, pvalue.next); } - } - }; - model.traverse(add_to_storage); + } else { + // update current log pdf for next iteration + curr_log_pdf = cand_log_pdf; + } + + } - // update current log pdf for next iteration - if (accept) curr_log_pdf = cand_log_pdf; + if (iter >= warmup) { + store_sample(model, pvalues, + iter-warmup, details::get_curr()); + } } std::cout << std::endl; @@ -212,55 +184,27 @@ inline void mh(ModelType& model, size_t seed = mcmc::random_seed() ) { - using data_t = mcmc::MHData; + using data_t = mcmc::details::MHData; - // set-up auxiliary tools - constexpr double initial_radius = 5.; - std::mt19937 gen(seed); - size_t n_params = 0; - double curr_log_pdf = 0.; // current log pdf - - // 1. initialize parameters with values in valid range - // - discrete valued params sampled uniformly within the distribution range - // - continuous valued params sampled uniformly within the intersection range - // of distribution min and max and [-initial_radius, initial_radius] - // 2. update n_params with number of parameters - // 3. compute current log-pdf - auto init_params = [&](auto& eq_node) { - auto& var = eq_node.get_variable(); - const auto& dist = eq_node.get_distribution(); - - using var_t = std::decay_t; - using value_t = typename util::var_traits::value_t; - -#if __cplusplus <= 201703L - if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - if constexpr (std::is_integral_v) { - std::uniform_int_distribution init_sampler(dist.min(), dist.max()); - var.set_value(init_sampler(gen)); - } else if constexpr (std::is_floating_point_v) { - std::uniform_real_distribution init_sampler( - std::max(dist.min(), -initial_radius), - std::min(dist.max(), initial_radius) - ); - var.set_value(init_sampler(gen)); - } else { - static_assert(!(std::is_integral_v || - std::is_floating_point_v), - AUTOPPL_MH_UNKNOWN_VALUE_TYPE_ERROR); - } - ++n_params; - } - curr_log_pdf += dist.log_pdf(var); - }; - model.traverse(init_params); + // REALLY important + // TODO: should inference really do this? + mcmc::activate(model); + size_t n_params = mcmc::param_size(model); + + // data structure to keep track of param candidates std::vector params(n_params); // vector of parameter-related data with candidate + + // initialize sample 0 + std::mt19937 gen(seed); + mcmc::init_params(model, gen, params, mcmc::details::get_curr()); + + // compute log pdf with sample 0 + double curr_log_pdf = model.log_pdf(params, mcmc::details::get_curr()); + + // sample the rest mcmc::mh__(model, - params.begin(), + params, gen, n_sample, warmup, diff --git a/include/autoppl/mcmc/sampler_tools.hpp b/include/autoppl/mcmc/sampler_tools.hpp index a03568f4..426b432b 100644 --- a/include/autoppl/mcmc/sampler_tools.hpp +++ b/include/autoppl/mcmc/sampler_tools.hpp @@ -3,12 +3,10 @@ #include #include #include -#include -#include - -#define AUTOPPL_MH_UNKNOWN_VALUE_TYPE_ERROR \ - "Unknown value type: must be convertible to util::disc_param_t " \ - "such as uint64_t or util::cont_param_t such as double." +#include +#include +#include +#include namespace ppl { namespace mcmc { @@ -23,148 +21,147 @@ inline size_t random_seed() } /** - * Initializes parameters with the given priors and - * conditional distributions based on the model. - * Random numbers are generated with gen. + * Get number of parameters in a model. + * If a parameter is a vector, the size of the vector is accumulated. */ -template -void init_params(ModelType& model, GenType& gen) +template +inline size_t param_size(const ModelType& model) { - // arbitrarily chosen radius for initial sampling - constexpr double initial_radius = 2.; - - auto init_params__ = [&](auto& eq_node) { - auto& var = eq_node.get_variable(); - const auto& dist = eq_node.get_distribution(); - + size_t n_params = 0; + auto param_size__ = [&](const auto& eq_node) { + const auto& var = eq_node.get_variable(); using var_t = std::decay_t; - using value_t = typename util::var_traits::value_t; - -#if __cplusplus <= 201703L if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - - if constexpr (std::is_integral_v) { - std::uniform_int_distribution init_sampler(dist.min(), dist.max()); - var.set_value(init_sampler(gen)); - - } else if constexpr (std::is_floating_point_v) { - std::uniform_real_distribution init_sampler(-initial_radius, initial_radius); - - // if unbounded prior - if (dist.min() == std::numeric_limits::lowest() && - dist.max() == std::numeric_limits::max()) { - var.set_value(init_sampler(gen)); - } - - // TODO: uncomment once there exists distributions with these properties - //// if bounded above but not below - //else if (dist.min() == std::numeric_limits::lowest()) { - // var.set_value(dist.max() - std::exp(init_sampler(gen))); - //} - - //// if bounded below but not above - //else if (dist.max() == std::numeric_limits::max()) { - // var.set_value(std::exp(init_sampler(gen)) + dist.min()); - //} - - // bounded below and above - else { - value_t range = dist.max() - dist.min(); - value_t avg = dist.min() + range / 2.; - var.set_value(avg + range / (2 * initial_radius) * init_sampler(gen)); - } - - } else { - static_assert(!(std::is_integral_v || - std::is_floating_point_v), - AUTOPPL_MH_UNKNOWN_VALUE_TYPE_ERROR); - } + n_params += var.size(); } }; - model.traverse(init_params__); + model.traverse(param_size__); + return n_params; } /** - * Initializes first sample of parameters using mcmc::init_params. - * Helper function to copy the samples into theta_curr. + * Initializes parameters with the given priors and + * conditional distributions based on the model. + * Random numbers are generated with gen. + * Assumes that model was initialized before. */ template -void init_sample(ModelType& model, - MatType& theta_curr, - GenType& gen) + , class GenType + , class PVecType + , class F = util::identity> +inline void init_params(const ModelType& model, + GenType& gen, + PVecType& pvalues, + F f = F()) { - mcmc::init_params(model, gen); - auto theta_curr_it = theta_curr.begin(); - auto copy_params_potential = [&](const auto& eq_node) { + // arbitrarily chosen radius for initial sampling + constexpr double initial_radius = 2.; + + // initialize each parameter + auto init_params__ = [&](const auto& eq_node) { const auto& var = eq_node.get_variable(); + const auto& dist = eq_node.get_distribution(); + using var_t = std::decay_t; -#if __cplusplus <= 201703L + using value_t = typename util::var_traits::value_t; + if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - *theta_curr_it = var.get_value(); - ++theta_curr_it; - } + + // initialization routine for each element of that parameter + for (size_t i = 0; i < var.size(); ++i) { + + auto min = dist.min(pvalues, i, f); + auto max = dist.max(pvalues, i, f); + + if constexpr (util::var_traits::is_disc_v) { + std::uniform_int_distribution init_sampler(min, max); + auto new_val = init_sampler(gen); + var.value(pvalues, i, f) = new_val; + + } else { + std::uniform_real_distribution init_sampler(-initial_radius, initial_radius); + + // if unbounded prior + if (min == math::neg_inf && + max == math::inf) { + var.value(pvalues, i, f) = init_sampler(gen); + } + + // TODO: uncomment once there exists distributions with these properties + //// if bounded above but not below + //else if (dist.min() == std::numeric_limits::lowest()) { + // var.set_value(dist.max() - std::exp(init_sampler(gen))); + //} + + //// if bounded below but not above + //else if (dist.max() == std::numeric_limits::max()) { + // var.set_value(std::exp(init_sampler(gen)) + dist.min()); + //} + + // bounded below and above + else { + value_t range = max - min; + value_t avg = min + range / 2.; + var.value(pvalues, i, f) = + avg + range / (2 * initial_radius) * init_sampler(gen); + } + + } // end outer else + } // end for + } // end if + }; - model.traverse(copy_params_potential); + model.traverse(init_params__); } /** - * Get unique raw addresses of the referenced variables in the model. - * Can be used to bind algorithm specific storage associated with each variable. + * Activates model with the correct offset values for each parameter. + * Every inference algorithm must invoke this call. + * Otherwise, undefined behavior. */ template -void get_keys(const ModelType& model, - std::vector& keys) +inline ModelType&& activate(ModelType&& model) { - constexpr size_t n_params = get_n_params_v; - keys.resize(n_params); - auto keys_it = keys.begin(); - auto get_keys = [&](auto& eq_node) { + size_t offset = 0; + auto activate__ = [&](auto& eq_node) { auto& var = eq_node.get_variable(); using var_t = std::decay_t; -#if __cplusplus <= 201703L if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - *keys_it = &var; - ++keys_it; + var.offset() = offset; + offset += var.size(); } }; - model.traverse(get_keys); + model.traverse(activate__); + return std::forward(model); } /** * Store ith sample currently in theta_curr into * storage by traversing model. + * Assumes that theta_curr[i] is the value of the ith parameter in model. + * If the parameter is a vector and theta_curr[i] is the value for the first + * element of the parameter, theta_curr[i+j] is the jth value within the parameter. */ -template -void store_sample(ModelType& model, - MatType& theta_curr, - size_t i) +template +inline void store_sample(const ModelType& model, + const MatType& theta_curr, + size_t i, + F f = F()) { - auto theta_curr_it = theta_curr.begin(); - auto store_sample = [&, i](auto& eq_node) { - auto& var = eq_node.get_variable(); + auto store_sample__ = [&, i](const auto& eq_node) { + const auto& var = eq_node.get_variable(); using var_t = std::decay_t; -#if __cplusplus <= 201703L if constexpr (util::is_param_v) { -#else - if constexpr (util::param) { -#endif - auto storage_ptr = var.get_storage(); - storage_ptr[i] = *theta_curr_it; - ++theta_curr_it; + for (size_t j = 0; j < var.size(); ++j) { + auto var_val = var.value(theta_curr, j, f); + auto storage_ptr = var.storage(j); + storage_ptr[i] = var_val; + } } }; - model.traverse(store_sample); + model.traverse(store_sample__); } /** @@ -173,12 +170,11 @@ void store_sample(ModelType& model, * The uniform sampler must sample from [0,1]. */ template -bool accept_or_reject(double p, - UniformDistType&& unif_sampler, - GenType&& gen) +inline bool accept_or_reject(double p, + UniformDistType&& unif_sampler, + GenType&& gen) { - double u = unif_sampler(gen); - return (u <= p); + return (unif_sampler(gen) <= p); } } // namespace mcmc diff --git a/include/autoppl/util/dist_expr_traits.hpp b/include/autoppl/util/dist_expr_traits.hpp deleted file mode 100644 index 0ff14cf6..00000000 --- a/include/autoppl/util/dist_expr_traits.hpp +++ /dev/null @@ -1,147 +0,0 @@ -#pragma once -#if __cplusplus <= 201703L -#include -#endif -#include -#include -#include -#include - -namespace ppl { -namespace util { - -/** - * Base class for all distribution expressions. - * It is necessary for all distribution expressions to - * derive from this class. - */ -template -struct DistExpr : BaseCRTP -{ - using BaseCRTP::self; - using dist_value_t = double; - - template -#if __cplusplus <= 201703L - std::enable_if_t>, dist_value_t> - log_pdf(const VarType& v) const { -#else - dist_value_t log_pdf(const VarType& v) const - requires var> { -#endif - dist_value_t value = 0.0; - for (size_t i = 0; i < v.size(); ++i) { - value += self().log_pdf(v.get_value(i), i); - } - - return value; - } - - template -#if __cplusplus <= 201703L - std::enable_if_t>, dist_value_t> - pdf(const VarType& v) const { -#else - dist_value_t pdf(const VarType& v) const - requires var> { -#endif - dist_value_t value = 1.0; - for (size_t i = 0; i < v.size(); ++i) { - value *= self().pdf(v.get_value(i), i); - } - - return value; - } -}; - -/** - * Checks if DistExpr is base of type T - */ -template -inline constexpr bool dist_expr_is_base_of_v = - std::is_base_of_v, T>; - -#if __cplusplus <= 201703L -DEFINE_ASSERT_ONE_PARAM(dist_expr_is_base_of_v); -#endif - -/* - * TODO: Samplable distribution expression concept? - */ - -/* - * TODO: continuous/discrete distribution expression concept? - */ - -/** - * Continuous distribution expressions can be constructed with this type. - */ -using cont_param_t = double; - -/** - * Discrete distribution expressions can be constructed with this type. - */ -using disc_param_t = int64_t; - -/** - * Traits for Distribution Expression classes. - * value_t type of value Variable represents during computation - * dist_value_t type of pdf/log_pdf value - */ -template -struct dist_expr_traits -{ - using value_t = typename DistExprType::value_t; - using dist_value_t = typename DistExprType::dist_value_t; -}; - -#if __cplusplus <= 201703L - -/** - * A distribution expression is any class that satisfies the following concept: - */ -template -inline constexpr bool is_dist_expr_v = - dist_expr_is_base_of_v && - has_type_value_t_v && - has_type_dist_value_t_v && - // has_func_pdf_v && // removed to allow overloading - // has_func_log_pdf_v && - has_func_min_v && - has_func_max_v - ; - -template -inline constexpr bool assert_is_dist_expr_v = - assert_dist_expr_is_base_of_v && - assert_has_type_value_t_v && - assert_has_type_dist_value_t_v && - // assert_has_func_pdf_v && // removed to allow overloading - // assert_has_func_log_pdf_v && - assert_has_func_min_v && - assert_has_func_max_v - ; - -#else - -template -concept dist_expr = - dist_expr_is_base_of_v && - requires () { - typename dist_expr_traits::value_t; - typename dist_expr_traits::dist_value_t; - } && - requires (T x, const T cx, - typename dist_expr_traits::value_t val, - size_t i) { - {cx.pdf(val, i)} -> std::same_as::dist_value_t>; - {cx.log_pdf(val, i)} -> std::same_as::dist_value_t>; - {cx.min()} -> std::same_as::value_t>; - {cx.max()} -> std::same_as::value_t>; - } - ; - -#endif - -} // namespace util -} // namespace ppl diff --git a/include/autoppl/util/functional.hpp b/include/autoppl/util/functional.hpp new file mode 100644 index 00000000..1deca9af --- /dev/null +++ b/include/autoppl/util/functional.hpp @@ -0,0 +1,15 @@ +#pragma once +#include + +namespace ppl { +namespace util { + +struct identity +{ + template + constexpr T&& operator()(T&& x) const noexcept + { return std::forward(x); } +}; + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/iterator/counting_iterator.hpp b/include/autoppl/util/iterator/counting_iterator.hpp new file mode 100644 index 00000000..eac3f7c1 --- /dev/null +++ b/include/autoppl/util/iterator/counting_iterator.hpp @@ -0,0 +1,52 @@ +#pragma once +#include +#include + +namespace ppl { +namespace util { + +// forward declaration +template +struct counting_iterator; + +template +inline constexpr bool +operator==(const counting_iterator& it1, + const counting_iterator& it2) +{ return it1.curr_ == it2.curr_; } + +template +inline constexpr bool +operator!=(const counting_iterator& it1, + const counting_iterator& it2) +{ return it1.curr_ != it2.curr_; } + +template +struct counting_iterator +{ + using difference_type = int32_t; + using value_type = IntType; + using pointer = value_type*; + using reference = IntType&; + using iterator_category = std::bidirectional_iterator_tag; + + counting_iterator(value_type begin) + : curr_(begin) + {} + + counting_iterator& operator++() { ++curr_; return *this; } + counting_iterator& operator--() { --curr_; return *this; } + counting_iterator operator++(int) { auto tmp = *this; ++curr_; return tmp; } + counting_iterator operator--(int) { auto tmp = *this; --curr_; return tmp; } + reference operator*() { return curr_; } + + friend constexpr bool operator==<>(const counting_iterator&, + const counting_iterator&); + friend constexpr bool operator!=<>(const counting_iterator&, + const counting_iterator&); +private: + value_type curr_; +}; + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/iterator/range.hpp b/include/autoppl/util/iterator/range.hpp new file mode 100644 index 00000000..fc4a1355 --- /dev/null +++ b/include/autoppl/util/iterator/range.hpp @@ -0,0 +1,54 @@ +#pragma once +#include +#include + +namespace ppl { +namespace util { + +/** + * Small class to view a range of elements. + */ +template +struct range +{ + using iter_t = Iter; + + range(iter_t begin, iter_t end) + : begin_{begin} + , end_{end} + , size_{static_cast(std::distance(begin, end))} + {} + + auto& operator()(size_t i) { + assert(i < size_); + return *std::next(begin_, i); + } + + const auto& operator()(size_t i) const { + assert(i < size_); + return *std::next(begin_, i); + } + + iter_t begin() { return begin_; } + const iter_t begin() const { return begin_; } + + iter_t end() { return end_; } + const iter_t end() const { return end_; } + + size_t size() const { return size_; } + + void bind(iter_t begin, iter_t end) + { + begin_ = begin; + end_ = end; + size_ = static_cast(std::distance(begin, end)); + } + +private: + iter_t begin_; + iter_t end_; + size_t size_; +}; + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/model_expr_traits.hpp b/include/autoppl/util/model_expr_traits.hpp deleted file mode 100644 index cc4d514e..00000000 --- a/include/autoppl/util/model_expr_traits.hpp +++ /dev/null @@ -1,76 +0,0 @@ -#pragma once -#if __cplusplus <= 201703L -#include -#endif -#include - -namespace ppl { -namespace util { - -/** - * Base class for all model expressions. - * It is necessary for all model expressions to - * derive from this class. - */ -template -struct ModelExpr : BaseCRTP -{ using BaseCRTP::self; }; - -/** - * Checks if DistExpr is base of type T - */ -template -inline constexpr bool model_expr_is_base_of_v = - std::is_base_of_v, T>; - -#if __cplusplus <= 201703L -DEFINE_ASSERT_ONE_PARAM(model_expr_is_base_of_v); -#endif - -/** - * Traits for Model Expression classes. - * dist_value_t type of value Variable represents during computation - */ -template -struct model_expr_traits -{ - using dist_value_t = typename NodeType::dist_value_t; -}; - -#if __cplusplus <= 201703L - -// TODO: -// - pdf and log_pdf remove from interface? -// - how to check if template member function exists (for traverse)? -template -inline constexpr bool is_model_expr_v = - model_expr_is_base_of_v && - has_type_dist_value_t_v && - has_func_pdf_v && - has_func_log_pdf_v - ; - -template -inline constexpr bool assert_is_model_expr_v = - assert_model_expr_is_base_of_v && - assert_has_type_dist_value_t_v && - assert_has_func_pdf_v && - assert_has_func_log_pdf_v - ; - -#else - -template -concept model_expr = - model_expr_is_base_of_v && - requires (const T cx) { - typename model_expr_traits::dist_value_t; - {cx.pdf()} -> std::same_as::dist_value_t>; - {cx.log_pdf()} -> std::same_as::dist_value_t>; - } - ; - -#endif - -} // namespace util -} // namespace ppl diff --git a/include/autoppl/util/traits.hpp b/include/autoppl/util/traits.hpp index c0f9e4ab..785f3236 100644 --- a/include/autoppl/util/traits.hpp +++ b/include/autoppl/util/traits.hpp @@ -6,9 +6,9 @@ * Users should rely on these classes to grab member aliases. */ -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include diff --git a/include/autoppl/util/concept.hpp b/include/autoppl/util/traits/concept.hpp similarity index 98% rename from include/autoppl/util/concept.hpp rename to include/autoppl/util/traits/concept.hpp index 184e4ea3..f1d1ea0b 100644 --- a/include/autoppl/util/concept.hpp +++ b/include/autoppl/util/traits/concept.hpp @@ -208,6 +208,10 @@ struct invalid_tag DEFINE_HAS_TYPE(value_t); DEFINE_HAS_TYPE(pointer_t); DEFINE_HAS_TYPE(const_pointer_t); +DEFINE_HAS_TYPE(id_t); +DEFINE_HAS_TYPE(vec_t); + +DEFINE_HAS_TYPE(shape_t); DEFINE_HAS_TYPE(dist_value_t); @@ -216,6 +220,10 @@ DEFINE_HAS_FUNC(get_value); DEFINE_HAS_FUNC(set_storage); DEFINE_HAS_FUNC(get_storage); +DEFINE_HAS_FUNC(value); +DEFINE_HAS_FUNC(size); +DEFINE_HAS_FUNC(id); + DEFINE_HAS_FUNC(pdf); DEFINE_HAS_FUNC(log_pdf); DEFINE_HAS_FUNC(min); diff --git a/include/autoppl/util/traits/dist_expr_traits.hpp b/include/autoppl/util/traits/dist_expr_traits.hpp new file mode 100644 index 00000000..0e978672 --- /dev/null +++ b/include/autoppl/util/traits/dist_expr_traits.hpp @@ -0,0 +1,111 @@ +#pragma once +#if __cplusplus <= 201703L +#include +#endif +#include +#include +#include +#include + +namespace ppl { +namespace util { + +/** + * Base class for all distribution expressions. + * It is necessary for all distribution expressions to + * derive from this class. + */ +template +struct DistExprBase : BaseCRTP +{ + using BaseCRTP::self; + using dist_value_t = double; +}; + +template +inline constexpr bool dist_expr_is_base_of_v = + std::is_base_of_v, T>; + +/** + * Continuous distribution expressions can be constructed with this type. + */ +using cont_param_t = double; + +/** + * Discrete distribution expressions can be constructed with this type. + */ +using disc_param_t = int32_t; + +/** + * Traits for Distribution Expression classes. + * value_t type of value Variable represents during computation + * dist_value_t type of pdf/log_pdf value + */ +template +struct dist_expr_traits +{ + using value_t = typename DistExprType::value_t; + using dist_value_t = typename DistExprType::dist_value_t; + static constexpr bool is_cont_v = util::is_cont_v; + static constexpr bool is_disc_v = util::is_disc_v; + + static_assert(is_cont_v == !is_disc_v, + PPL_CONT_XOR_DISC); +}; + +#if __cplusplus <= 201703L + +DEFINE_ASSERT_ONE_PARAM(dist_expr_is_base_of_v); + +/** + * A distribution expression is any class that satisfies the following concept: + */ +template +inline constexpr bool is_dist_expr_v = + dist_expr_is_base_of_v && + has_type_value_t_v && + has_type_dist_value_t_v + //has_func_pdf_v && // removed to allow overloading + //has_func_log_pdf_v && + //has_func_min_v && + //has_func_max_v + ; + +template +inline constexpr bool assert_is_dist_expr_v = + assert_dist_expr_is_base_of_v && + assert_has_type_value_t_v && + assert_has_type_dist_value_t_v + // assert_has_func_pdf_v && // removed to allow overloading + // assert_has_func_log_pdf_v && + //assert_has_func_min_v && + //assert_has_func_max_v + ; + +#else + +template +concept dist_expr_c = + dist_expr_is_base_of_v && + requires () { + typename dist_expr_traits::value_t; + typename dist_expr_traits::dist_value_t; + } && + requires (T x, const T cx, + typename dist_expr_traits::value_t val, + size_t i) { + // TODO: pdf, log_pdf, ad_log_pdf? + //{ cx.pdf(val, i) } -> std::same_as::dist_value_t>; + //{ cx.log_pdf(val, i) } -> std::same_as::dist_value_t>; + //{ cx.min() } -> std::same_as::value_t>; + //{ cx.max() } -> std::same_as::value_t>; + } + ; + +template +concept is_dist_expr_v = dist_expr_c; + +#endif + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/traits/model_expr_traits.hpp b/include/autoppl/util/traits/model_expr_traits.hpp new file mode 100644 index 00000000..9144a230 --- /dev/null +++ b/include/autoppl/util/traits/model_expr_traits.hpp @@ -0,0 +1,64 @@ +#pragma once +#if __cplusplus <= 201703L +#include +#endif +#include + +namespace ppl { +namespace util { + +/** + * Base class for all model expressions. + * It is necessary for all model expressions to + * derive from this class. + */ +template +struct ModelExprBase : BaseCRTP +{ using BaseCRTP::self; }; + +/** + * Checks if DistExpr is base of type T + */ +template +inline constexpr bool model_expr_is_base_of_v = + std::is_base_of_v, T>; + +#if __cplusplus <= 201703L + +DEFINE_ASSERT_ONE_PARAM(model_expr_is_base_of_v); + +// TODO: +// - ad_log_pdf? +// - how to check if template member function exists (for traverse)? +template +inline constexpr bool is_model_expr_v = + model_expr_is_base_of_v + //has_func_pdf_v && + //has_func_log_pdf_v + ; + +template +inline constexpr bool assert_is_model_expr_v = + assert_model_expr_is_base_of_v + //assert_has_func_pdf_v && + //assert_has_func_log_pdf_v + ; + +#else + +template +concept model_expr_c = + model_expr_is_base_of_v && + requires (const T cx) { + //{cx.pdf()} -> std::same_as::dist_value_t>; + //{cx.log_pdf()} -> std::same_as::dist_value_t>; + } + ; + +template +concept is_model_expr_v = model_expr_c; + +#endif + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/traits/shape_traits.hpp b/include/autoppl/util/traits/shape_traits.hpp new file mode 100644 index 00000000..5fab468c --- /dev/null +++ b/include/autoppl/util/traits/shape_traits.hpp @@ -0,0 +1,187 @@ +#pragma once +#include +#if __cplusplus <= 201703L +#include +#else +#include +#endif +#include + +namespace ppl { + +inline constexpr size_t DIM_SCALAR = 0; +inline constexpr size_t DIM_VECTOR = 1; +inline constexpr size_t DIM_MATRIX = 2; + +/** + * Class tags to determine which shape a Data or Param is expected to be. + */ +struct scl { static constexpr size_t dim = DIM_SCALAR; }; +struct vec { static constexpr size_t dim = DIM_VECTOR; }; +struct mat { static constexpr size_t dim = DIM_MATRIX; }; + +namespace util { + +/** + * Base class for all variables. + * It is necessary for all variables to + * derive from this class. + */ +//template +//struct SclBase : BaseCRTP +//{ using BaseCRTP::self; }; +// +//template +//struct VecBase : BaseCRTP +//{ using BaseCRTP::self; }; +// +//template +//inline constexpr bool scl_is_base_of_v = +// std::is_base_of_v, T>; +// +//template +//inline constexpr bool vec_is_base_of_v = +// std::is_base_of_v, T>; +// + +template +struct shape_traits +{ + using shape_t = typename T::shape_t; +}; + +#if __cplusplus <= 201703L + +//DEFINE_ASSERT_ONE_PARAM(scl_is_base_of_v); +//DEFINE_ASSERT_ONE_PARAM(vec_is_base_of_v); + +/** + * C++17 version of concepts to check var properties. + * - var_traits must be well-defined under type T + * - T must be explicitly convertible to its value_t + * - not possible to get overloads + */ + +template +inline constexpr bool is_scl_v = + has_type_shape_t_v && + std::is_same_v, ppl::scl> && + has_func_size_v + ; +DEFINE_ASSERT_ONE_PARAM(is_scl_v); + +template +inline constexpr bool is_vec_v = + has_type_shape_t_v && + std::is_same_v, ppl::vec> && + has_func_size_v + ; +DEFINE_ASSERT_ONE_PARAM(is_vec_v); + +template +inline constexpr bool is_mat_v = + has_type_shape_t_v && + std::is_same_v, ppl::mat> && + has_func_size_v + ; +DEFINE_ASSERT_ONE_PARAM(is_mat_v); + +template +inline constexpr bool is_shape_v = + is_scl_v || + is_vec_v || + is_mat_v + ; +DEFINE_ASSERT_ONE_PARAM(is_shape_v); + +#else + +template +concept scl_c = + requires(const T cx) { + typename T::shape_t; + { cx.size() } -> std::same_as; + } && + std::same_as + ; + +template +concept vec_c = + requires(const T cx) { + typename T::shape_t; + { cx.size() } -> std::same_as; + } && + std::same_as + ; + +template +concept mat_c = + requires(const T cx) { + typename T::shape_t; + { cx.size() } -> std::same_as; // TODO: return type? + } && + std::same_as + ; + +template +concept shape_c = + scl_c || + vec_c || + mat_c + ; + +template +concept is_scl_v = scl_c; + +template +concept is_vec_v = vec_c; + +template +concept is_mat_v = mat_c; + +template +concept is_shape_v = shape_c; + +#endif + +////////////////////////////////////////////////// +// Useful tools to manage shapes +////////////////////////////////////////////////// + +/** + * Checks if T is a shape tag. + */ +template +inline constexpr bool is_shape_tag_v = + std::is_same_v || + std::is_same_v + //std::is_same_v + ; + +namespace details { + +template && is_shape_tag_v> +struct max_shape; + +template +struct max_shape +{ + using type = std::conditional_t< + S1::dim >= S2::dim, + S1, + S2>; +}; + +} // namespace details + +/** + * Returns the type whose shape has more dimension. + * Undefined behavior if S1 and S2 are not shape tags. + */ +template +using max_shape_t = typename details::max_shape::type; + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/type_traits.hpp b/include/autoppl/util/traits/type_traits.hpp similarity index 86% rename from include/autoppl/util/type_traits.hpp rename to include/autoppl/util/traits/type_traits.hpp index c33a335f..64f27624 100644 --- a/include/autoppl/util/type_traits.hpp +++ b/include/autoppl/util/traits/type_traits.hpp @@ -37,7 +37,13 @@ inline constexpr bool assert_##name = \ details::assert_##name>::value; \ +// Important type checking error messages +#define PPL_CONT_XOR_DISC \ + "Expression must be either continuous or discrete. " \ + "It cannot be both continuous and discrete. " + namespace ppl { +namespace util { /** * Checks if type From can be explicitly converted to type To. @@ -60,4 +66,12 @@ struct BaseCRTP const T& self() const { return static_cast(*this); } }; +template +inline constexpr bool is_cont_v = std::is_floating_point_v; + +template +inline constexpr bool is_disc_v = std::is_integral_v; + + +} // namespace util } // namespace ppl diff --git a/include/autoppl/util/var_expr_traits.hpp b/include/autoppl/util/traits/var_expr_traits.hpp similarity index 54% rename from include/autoppl/util/var_expr_traits.hpp rename to include/autoppl/util/traits/var_expr_traits.hpp index 85799368..e0ec537d 100644 --- a/include/autoppl/util/var_expr_traits.hpp +++ b/include/autoppl/util/traits/var_expr_traits.hpp @@ -1,11 +1,11 @@ #pragma once #if __cplusplus <= 201703L -#include +#include #else #include #endif -#include -#include +#include +#include namespace ppl { namespace util { @@ -16,24 +16,13 @@ namespace util { * derive from this class. */ template -struct VarExpr : BaseCRTP +struct VarExprBase : BaseCRTP { using BaseCRTP::self; }; -/** - * Checks if VarExpr is base of type T - */ template inline constexpr bool var_expr_is_base_of_v = - std::is_base_of_v, T>; - -#if __cplusplus <= 201703L -DEFINE_ASSERT_ONE_PARAM(var_expr_is_base_of_v); -#endif + std::is_base_of_v, T>; -/** - * Traits for Variable Expression classes. - * value_t type of value Variable represents during computation - */ template struct var_expr_traits { @@ -42,46 +31,44 @@ struct var_expr_traits #if __cplusplus <= 201703L +DEFINE_ASSERT_ONE_PARAM(var_expr_is_base_of_v); + /** * A variable expression is any class that satisfies the following concept. */ template inline constexpr bool is_var_expr_v = + is_shape_v && var_expr_is_base_of_v && - !is_var_v && - has_type_value_t_v && - has_func_get_value_v + has_type_value_t_v + //has_func_value_v ; -namespace details { - -// Tool needed to assert -template -inline constexpr bool is_not_var_v = !is_var_v; -DEFINE_ASSERT_ONE_PARAM(is_not_var_v); - -} // namespace details - template inline constexpr bool assert_is_var_expr_v = + assert_is_shape_v && assert_var_expr_is_base_of_v && - details::assert_is_not_var_v && - assert_has_type_value_t_v && - assert_has_func_get_value_v + assert_has_type_value_t_v + //assert_has_func_value_v ; #else template -concept var_expr = +concept var_expr_c = + shape_c && var_expr_is_base_of_v && - !var && requires (const T cx, size_t i) { + { T::has_param } -> std::same_as; typename var_expr_traits::value_t; - {cx.get_value(i)} -> std::same_as::value_t>; + {cx.value(i)} -> std::convertible_to< + typename var_expr_traits::value_t>; } ; +template +concept is_var_expr_v = var_expr_c; + #endif diff --git a/include/autoppl/util/traits/var_traits.hpp b/include/autoppl/util/traits/var_traits.hpp new file mode 100644 index 00000000..b19027dd --- /dev/null +++ b/include/autoppl/util/traits/var_traits.hpp @@ -0,0 +1,158 @@ +#pragma once +#include +#include +#if __cplusplus <= 201703L +#include +#endif + +/* + * We say Param or Data, etc. are vars. + */ + +namespace ppl { +namespace util { + +template +struct ParamBase : BaseCRTP +{ using BaseCRTP::self; }; + +template +struct DataBase : BaseCRTP +{ using BaseCRTP::self; }; + +template +inline constexpr bool param_is_base_of_v = + std::is_base_of_v, T>; + +template +inline constexpr bool data_is_base_of_v = + std::is_base_of_v, T>; + +template +struct var_traits : var_expr_traits +{ +private: + using base_t = var_expr_traits; +public: + using id_t = typename VarType::id_t; + using vec_t = get_type_vec_t_t; + static constexpr bool is_cont_v = util::is_cont_v; + static constexpr bool is_disc_v = util::is_disc_v; + + static_assert(is_cont_v == !is_disc_v, + PPL_CONT_XOR_DISC); +}; + +template +struct param_traits : var_traits +{ + using pointer_t = typename VarType::pointer_t; + using const_pointer_t = typename VarType::const_pointer_t; + using index_t = typename VarType::index_t; +}; + +template +struct data_traits : var_traits +{}; + +#if __cplusplus <= 201703L + +DEFINE_ASSERT_ONE_PARAM(param_is_base_of_v); +DEFINE_ASSERT_ONE_PARAM(data_is_base_of_v); + +template +inline constexpr bool is_param_v = + // T itself is a parameter-like variable + is_var_expr_v && + param_is_base_of_v && + has_type_id_t_v && + has_type_pointer_t_v && + has_type_const_pointer_t_v && + has_func_id_v + // TODO: set, get value may not be needed + //has_func_set_value_v && + //has_func_get_value_v && + //has_func_set_storage_v + ; + +template +inline constexpr bool is_data_v = + is_var_expr_v && + data_is_base_of_v && + has_type_id_t_v && + has_func_id_v + ; + +template +inline constexpr bool is_var_v = + is_param_v || + is_data_v + ; +DEFINE_ASSERT_ONE_PARAM(is_var_v); + +template +inline constexpr bool assert_is_param_v = + assert_is_var_expr_v && + assert_param_is_base_of_v && + assert_has_type_pointer_t_v && + assert_has_type_const_pointer_t_v && + assert_has_type_id_t_v && + assert_has_func_id_v + // TODO: may not be needed + //assert_has_func_set_value_v && + //assert_has_func_get_value_v && + //assert_has_func_set_storage_v + ; + +template +inline constexpr bool assert_is_data_v = + assert_is_var_expr_v && + assert_data_is_base_of_v && + assert_has_type_id_t_v && + assert_has_func_id_v + ; + +#else + +template +concept data_c = + var_expr_c && + data_is_base_of_v && + requires (const T cx, size_t i) { + typename var_traits::id_t; + { cx.id() } -> std::same_as::id_t>; + } + ; + +template +concept param_c = + var_expr_c && + param_is_base_of_v && + requires () { + typename var_traits::id_t; + typename param_traits::pointer_t; + typename param_traits::const_pointer_t; + } && + requires (T x, const T cx, size_t i) { + { cx.storage(i) } -> std::convertible_to::pointer_t>; + { cx.id() } -> std::same_as::id_t>; + } + ; + +template +concept var_c = + data_c || + param_c + ; + +template +concept is_data_v = data_c; +template +concept is_param_v = param_c; +template +concept is_var_v = var_c; + +#endif + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/var_traits.hpp b/include/autoppl/util/var_traits.hpp deleted file mode 100644 index 8b8dfceb..00000000 --- a/include/autoppl/util/var_traits.hpp +++ /dev/null @@ -1,170 +0,0 @@ -#pragma once -#include -#if __cplusplus <= 201703L -#include -#else -#include -#endif -#include - -namespace ppl { -namespace util { - -/** - * Base class for all variables. - * It is necessary for all variables to - * derive from this class. - */ -template -struct Var : BaseCRTP -{ using BaseCRTP::self; }; - -/** - * Base class for all Data-like variables. - * It is necessary for all Data-like variables to - * derive from this class. - */ -template -struct DataLike : Var -{ using Var::self; }; - -/** - * Base class for all Param-like variables. - * It is necessary for all Param-like variables to - * derive from this class. - */ -template -struct ParamLike : Var -{ using Var::self; }; - - -/** - * Checks if DataLike, ParamLike or Var - * is base of type T - */ - -template -inline constexpr bool data_is_base_of_v = - std::is_base_of_v, T>; - -template -inline constexpr bool param_is_base_of_v = - std::is_base_of_v, T>; - -template -inline constexpr bool var_is_base_of_v = - std::is_base_of_v, T>; - -#if __cplusplus <= 201703L -DEFINE_ASSERT_ONE_PARAM(var_is_base_of_v); -DEFINE_ASSERT_ONE_PARAM(param_is_base_of_v); -DEFINE_ASSERT_ONE_PARAM(data_is_base_of_v); -#endif - -/** - * Traits for Variable-like classes. - * value_t type of value Variable represents during computation - * pointer_t storage pointer type - */ -template -struct var_traits -{ - using value_t = typename VarType::value_t; - using pointer_t = typename VarType::pointer_t; - using const_pointer_t = typename VarType::const_pointer_t; -}; - -/** - * C++17 version of concepts to check var properties. - * - var_traits must be well-defined under type T - * - T must be explicitly convertible to its value_t - * - not possible to get overloads - */ - -#if __cplusplus <= 201703L - -template -inline constexpr bool is_data_v = - data_is_base_of_v && - has_type_value_t_v && - has_type_pointer_t_v && - has_type_const_pointer_t_v && - has_func_get_value_v - ; - -template -inline constexpr bool is_param_v = - param_is_base_of_v && - has_type_value_t_v && - has_type_pointer_t_v && - has_type_const_pointer_t_v && - has_func_set_value_v && - has_func_get_value_v && - has_func_set_storage_v - ; - -template -inline constexpr bool assert_is_data_v = - assert_data_is_base_of_v && - assert_has_type_value_t_v && - assert_has_type_pointer_t_v && - assert_has_type_const_pointer_t_v && - assert_has_func_get_value_v - ; - -template -inline constexpr bool assert_is_param_v = - assert_param_is_base_of_v && - assert_has_type_value_t_v && - assert_has_type_pointer_t_v && - assert_has_type_const_pointer_t_v && - assert_has_func_set_value_v && - assert_has_func_get_value_v && - assert_has_func_set_storage_v - ; - -template -inline constexpr bool is_var_v = - is_data_v || is_param_v - ; - -DEFINE_ASSERT_ONE_PARAM(is_var_v); - -#else - -template -concept data = - data_is_base_of_v && - requires (const T cx, size_t i) { - typename var_traits::value_t; - typename var_traits::pointer_t; - typename var_traits::const_pointer_t; - {cx.get_value(i)} -> std::same_as::value_t>; - } - ; - -template -concept param = - param_is_base_of_v && - requires () { - typename var_traits::value_t; - typename var_traits::pointer_t; - typename var_traits::const_pointer_t; - } && - requires (T x, const T cx, - typename var_traits::value_t val, - typename var_traits::pointer_t p, - size_t i) { - {x.set_value(val)}; - {x.set_storage(p)}; - {cx.get_value(i)} -> std::same_as::value_t>; - } - ; - -template -concept var = data || param; - -#endif - -} // namespace util -} // namespace ppl diff --git a/include/autoppl/variable.hpp b/include/autoppl/variable.hpp deleted file mode 100644 index 32d1cc5f..00000000 --- a/include/autoppl/variable.hpp +++ /dev/null @@ -1,101 +0,0 @@ -#pragma once -#include -#include -#include -#include -#include - -namespace ppl { - -/** - * Param is a light-weight structure that represents a univariate hidden random variable. - * That means the parameter does not hold samples, but it does contain a value that is used - * by model.pdf and get_value. Param requires user-provided external storage for samples and - * other algorithms. It is up to the user to ensure the storage pointer has enough capacity - * to support algorithms like metropolis-hastings which store data in this pointer. get_value - * supports an integer argument for compatibility with the get_value Data API, but this is never - * used. - */ -template -struct Param : util::ParamLike> { - using value_t = ValueType; - using pointer_t = value_t*; - using const_pointer_t = const value_t*; - - Param(value_t value, pointer_t storage_ptr) noexcept - : value_{value}, storage_ptr_{storage_ptr} {} - - Param(pointer_t storage_ptr) noexcept - : Param(0., storage_ptr) {} - - Param(value_t value) noexcept - : Param(value, nullptr) {} - - Param() noexcept - : Param(0., nullptr) {} - - void set_value(value_t value) { value_ = value; } - - constexpr size_t size() const { return 1; } - value_t get_value(size_t = 0) const { - return value_; - } - - void set_storage(pointer_t storage_ptr) { storage_ptr_ = storage_ptr; } - pointer_t get_storage() { return storage_ptr_; } - const_pointer_t get_storage() const { return storage_ptr_; } - -private: - value_t value_; // store value associated with var - pointer_t storage_ptr_; // points to beginning of storage - // storage is assumed to be contiguous -}; - -/** - * Data is a light-weight structure that represents a set of samples from an observed random variable. - * It acts as an intermediate layer of communication between a model expression and the users. - * A Data object is different from a Param object in that it can hold multiple values but cannot - * be sampled. To this end, the user does not provide external storage for samples. It does not - * support set_value, but you can instead var.observe() to add an extra observation internally. - */ -template -struct Data : util::DataLike> -{ - using value_t = ValueType; - using pointer_t = value_t*; - using const_pointer_t = const value_t*; - - template - Data(iterator begin, iterator end) noexcept - : values_{begin, end} {} - - Data(std::initializer_list values) noexcept - : Data(values.begin(), values.end()) {} - - Data(value_t value) noexcept - : values_{{value}} {} - - Data() noexcept : values_{} {} - - size_t size() const { return values_.size(); } - - value_t get_value(size_t i) const { - assert((i >= 0) && (i < size())); // TODO change this to exception - return values_[i]; - } - - void observe(value_t value) { values_.push_back(value); } - void clear() { values_.clear(); } - - auto begin() const { return values_.begin(); } - auto end() const { return values_.end(); } - -private: - std::vector values_; // store value associated with var -}; - -// Useful aliases -using cont_var = Data; // continuous RV var -using disc_var = Data; // discrete RV var - -} // namespace ppl diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 52d74739..49bbf87d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -10,16 +10,19 @@ endif() ###################################################### add_executable(util_unittest - ${CMAKE_CURRENT_SOURCE_DIR}/util/concept_unittest.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/util/dist_expr_traits_unittest.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/util/var_expr_traits_unittest.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/util/var_traits_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/traits/concept_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/traits/dist_expr_traits_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/traits/var_expr_traits_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/traits/var_traits_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/traits/shape_traits_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/iterator/counting_iterator_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/iterator/range_unittest.cpp ) if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") target_compile_options(util_unittest PRIVATE -g -Wall) else() - target_compile_options(util_unittest PRIVATE -g -Wall -Werror -Wextra) + target_compile_options(util_unittest PRIVATE -g -Wall -Werror -Wextra) endif() target_include_directories(util_unittest PRIVATE @@ -39,131 +42,45 @@ endif() add_test(util_unittest util_unittest) ###################################################### -# Sample Test +# Expression Test ###################################################### -add_executable(sample_unittest - ${CMAKE_CURRENT_SOURCE_DIR}/expression/samples/dist_sample_unittest.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/expression/samples/model_sample_unittest.cpp - ) - -if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_compile_options(sample_unittest PRIVATE -g -Wall) -else() - target_compile_options(sample_unittest PRIVATE -g -Wall -Werror -Wextra) -endif() - -target_include_directories(sample_unittest PRIVATE - ${GTEST_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR} - ${AUTOPPL_INCLUDE_DIRS} - ) -if (AUTOPPL_ENABLE_TEST_COVERAGE) - target_link_libraries(sample_unittest gcov) -endif() - -target_link_libraries(sample_unittest autoppl_gtest_main ${AUTOPPL_LIBS}) -if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_link_libraries(sample_unittest pthread) -endif() - -add_test(sample_unittest sample_unittest) - -###################################################### -# Variable Test -###################################################### - -add_executable(var_unittest - ${CMAKE_CURRENT_SOURCE_DIR}/expression/variable/variable_viewer_unittest.cpp +add_executable(expr_unittest ${CMAKE_CURRENT_SOURCE_DIR}/expression/variable/param_unittest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/expression/variable/data_unittest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/expression/variable/constant_unittest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/expression/variable/binop_unittest.cpp - ) - -if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_compile_options(var_unittest PRIVATE -g -Wall) -else() - target_compile_options(var_unittest PRIVATE -g -Wall -Werror -Wextra) -endif() - -target_include_directories(var_unittest PRIVATE - ${GTEST_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR} - ${AUTOPPL_INCLUDE_DIRS} - ) -if (AUTOPPL_ENABLE_TEST_COVERAGE) - target_link_libraries(var_unittest gcov) -endif() - -target_link_libraries(var_unittest autoppl_gtest_main ${AUTOPPL_LIBS}) -if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_link_libraries(var_unittest pthread) -endif() - -add_test(var_unittest var_unittest) - -###################################################### -# Distribution Expression Test -###################################################### - -add_executable(dist_expr_unittest ${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/bernoulli_unittest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/normal_unittest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/uniform_unittest.cpp - ) - -if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_compile_options(dist_expr_unittest PRIVATE -g -Wall) -else() - target_compile_options(dist_expr_unittest PRIVATE -g -Wall -Werror -Wextra) -endif() - -target_include_directories(dist_expr_unittest PRIVATE - ${GTEST_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR} - ${AUTOPPL_INCLUDE_DIRS} - ) -if (AUTOPPL_ENABLE_TEST_COVERAGE) - target_link_libraries(dist_expr_unittest gcov) -endif() - -target_link_libraries(dist_expr_unittest autoppl_gtest_main ${AUTOPPL_LIBS}) -if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_link_libraries(dist_expr_unittest pthread) -endif() - -add_test(dist_expr_unittest dist_expr_unittest) - -###################################################### -# Model Expression Test -###################################################### - -add_executable(model_expr_unittest ${CMAKE_CURRENT_SOURCE_DIR}/expression/model/model_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/expr_builder_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/integration/dist_inttest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/integration/model_inttest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/integration/ad_inttest.cpp ) if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_compile_options(model_expr_unittest PRIVATE -g -Wall) + target_compile_options(expr_unittest PRIVATE -g -Wall) else() - target_compile_options(model_expr_unittest PRIVATE -g -Wall -Werror -Wextra) + target_compile_options(expr_unittest PRIVATE -g -Wall -Werror -Wextra) endif() -target_include_directories(model_expr_unittest PRIVATE +target_include_directories(expr_unittest PRIVATE ${GTEST_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR} ${AUTOPPL_INCLUDE_DIRS} ) if (AUTOPPL_ENABLE_TEST_COVERAGE) - target_link_libraries(model_expr_unittest gcov) + target_link_libraries(expr_unittest gcov) endif() -target_link_libraries(model_expr_unittest autoppl_gtest_main ${AUTOPPL_LIBS}) +target_link_libraries(expr_unittest autoppl_gtest_main ${AUTOPPL_LIBS}) if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_link_libraries(model_expr_unittest pthread) + target_link_libraries(expr_unittest pthread) endif() -add_test(model_expr_unittest model_expr_unittest) +add_test(expr_unittest expr_unittest) ###################################################### # Math Test @@ -171,6 +88,8 @@ add_test(model_expr_unittest model_expr_unittest) add_executable(math_unittest ${CMAKE_CURRENT_SOURCE_DIR}/math/welford_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/math/density_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/math/math_unittest.cpp ) if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") @@ -212,7 +131,13 @@ add_executable(mcmc_unittest if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") target_compile_options(mcmc_unittest PRIVATE -g -Wall) else() + # -Wno-error=maybe-uninitialized: + # GCC8 throws weird compiler error about lambda possibly uninitialized before use. + # Strongly suspect it's a false positive. target_compile_options(mcmc_unittest PRIVATE -g -Wall -Werror -Wextra) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + target_compile_options(mcmc_unittest PRIVATE -Wno-error=maybe-uninitialized) + endif() endif() target_include_directories(mcmc_unittest PRIVATE @@ -235,34 +160,3 @@ if (UNIX AND NOT APPLE) openblas lapack) endif() add_test(mcmc_unittest mcmc_unittest) - -###################################################### -# Expression Builder Test -###################################################### - -add_executable(expr_builder_unittest - ${CMAKE_CURRENT_SOURCE_DIR}/expr_builder_unittest.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/ad_integration_unittest.cpp - ) - -if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_compile_options(expr_builder_unittest PRIVATE -g -Wall) -else() - target_compile_options(expr_builder_unittest PRIVATE -g -Wall -Werror -Wextra) -endif() - -target_include_directories(expr_builder_unittest PRIVATE - ${GTEST_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR} - ${AUTOPPL_INCLUDE_DIRS} - ) -if (AUTOPPL_ENABLE_TEST_COVERAGE) - target_link_libraries(expr_builder_unittest gcov) -endif() - -target_link_libraries(expr_builder_unittest autoppl_gtest_main ${AUTOPPL_LIBS}) -if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - target_link_libraries(expr_builder_unittest pthread) -endif() - -add_test(expr_builder_unittest expr_builder_unittest) diff --git a/test/expression/distribution/bernoulli_unittest.cpp b/test/expression/distribution/bernoulli_unittest.cpp index 63ab7c58..82263ea3 100644 --- a/test/expression/distribution/bernoulli_unittest.cpp +++ b/test/expression/distribution/bernoulli_unittest.cpp @@ -1,6 +1,5 @@ #include "gtest/gtest.h" -#include -#include +#include "dist_fixture_base.hpp" #include #include #include @@ -8,109 +7,68 @@ namespace ppl { namespace expr { -struct bernoulli_fixture : ::testing::Test +struct bernoulli_fixture : + dist_fixture_base, + dist_fixture_base, + ::testing::Test { protected: - using value_t = typename MockVarExpr::value_t; - static constexpr size_t sample_size = 1000; - double p = 0.6; - MockVarExpr x{p}; - Bernoulli bern = {x}; - std::array sample = {0.}; -}; - -TEST_F(bernoulli_fixture, ctor) -{ -#if __cplusplus <= 201703L - static_assert(util::assert_is_dist_expr_v>); -#else - static_assert(util::dist_expr>); -#endif -} - -TEST_F(bernoulli_fixture, bernoulli_check_params) { - EXPECT_DOUBLE_EQ(bern.p(), x.get_value(0)); -} - -TEST_F(bernoulli_fixture, bernoulli_pdf_in_range) -{ - EXPECT_DOUBLE_EQ(bern.pdf(0), 1-p); - EXPECT_DOUBLE_EQ(bern.pdf(1), p); -} + using disc_base_t = dist_fixture_base; + using cont_base_t = dist_fixture_base; -TEST_F(bernoulli_fixture, bernoulli_pdf_out_of_range) -{ - EXPECT_DOUBLE_EQ(bern.pdf(-100), 0.); - EXPECT_DOUBLE_EQ(bern.pdf(-3), 0.); - EXPECT_DOUBLE_EQ(bern.pdf(-2), 0.); - EXPECT_DOUBLE_EQ(bern.pdf(2), 0.); - EXPECT_DOUBLE_EQ(bern.pdf(3), 0.); - EXPECT_DOUBLE_EQ(bern.pdf(5), 0.); - EXPECT_DOUBLE_EQ(bern.pdf(100), 0.); -} + disc_base_t::value_t x_val_in = 0; + disc_base_t::value_t x_val_out = -1; -TEST_F(bernoulli_fixture, bernoulli_pdf_always_tail) -{ - double p = 0.; - MockVarExpr x{p}; - Bernoulli bern = {x}; - EXPECT_DOUBLE_EQ(bern.pdf(0), 1.); - EXPECT_DOUBLE_EQ(bern.pdf(1), 0.); -} + cont_base_t::value_t p_val = 0.6; +}; -TEST_F(bernoulli_fixture, bernoulli_pdf_always_head) +TEST_F(bernoulli_fixture, ctor) { - double p = 1.; - MockVarExpr x{p}; - Bernoulli bern = {x}; - EXPECT_DOUBLE_EQ(bern.pdf(0), 0.); - EXPECT_DOUBLE_EQ(bern.pdf(1), 1.); + static_assert(util::is_dist_expr_v>); } -TEST_F(bernoulli_fixture, bernoulli_log_pdf_in_range) +TEST_F(bernoulli_fixture, pdf_in) { - EXPECT_DOUBLE_EQ(bern.log_pdf(0), std::log(1-p)); - EXPECT_DOUBLE_EQ(bern.log_pdf(1), std::log(p)); + using bern_t = Bernoulli; + disc_base_t::dv_scl_t x(x_val_in); + cont_base_t::dv_scl_t p(p_val); + bern_t bern(p); + cont_base_t::vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(bern.pdf(x, pvalues), + 1-p_val); } -TEST_F(bernoulli_fixture, bernoulli_log_pdf_out_of_range) +TEST_F(bernoulli_fixture, pdf_out) { - EXPECT_DOUBLE_EQ(bern.log_pdf(-100), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(bern.log_pdf(-3), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(bern.log_pdf(-1), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(bern.log_pdf(2), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(bern.log_pdf(3), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(bern.log_pdf(5), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(bern.log_pdf(100), std::numeric_limits::lowest()); + using bern_t = Bernoulli; + disc_base_t::dv_scl_t x(x_val_out); + cont_base_t::dv_scl_t p(p_val); + bern_t bern(p); + cont_base_t::vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(bern.pdf(x, pvalues), + 0.); } -TEST_F(bernoulli_fixture, bernoulli_log_pdf_always_tail) +TEST_F(bernoulli_fixture, log_pdf_in) { - double p = 0.; - MockVarExpr x{p}; - Bernoulli bern = {x}; - EXPECT_DOUBLE_EQ(bern.log_pdf(0), 0.); - EXPECT_DOUBLE_EQ(bern.log_pdf(1), std::numeric_limits::lowest()); + using bern_t = Bernoulli; + disc_base_t::dv_scl_t x(x_val_in); + cont_base_t::dv_scl_t p(p_val); + bern_t bern(p); + cont_base_t::vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(bern.log_pdf(x, pvalues), + std::log(1-p_val)); } -TEST_F(bernoulli_fixture, bernoulli_log_pdf_always_head) +TEST_F(bernoulli_fixture, log_pdf_out) { - double p = 1.; - MockVarExpr x{p}; - Bernoulli bern = {x}; - EXPECT_DOUBLE_EQ(bern.log_pdf(0), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(bern.log_pdf(1), 0.); -} - -TEST_F(bernoulli_fixture, bernoulli_sample) { - std::random_device rd{}; - std::mt19937 gen{rd()}; - - for (size_t i = 0; i < sample_size; i++) { - sample[i] = bern.sample(gen); - } - - plot_hist(sample); + using bern_t = Bernoulli; + disc_base_t::dv_scl_t x(x_val_out); + cont_base_t::dv_scl_t p(p_val); + bern_t bern(p); + cont_base_t::vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(bern.log_pdf(x, pvalues), + math::neg_inf); } } // namespace expr diff --git a/test/expression/distribution/dist_fixture_base.hpp b/test/expression/distribution/dist_fixture_base.hpp new file mode 100644 index 00000000..8fcd7093 --- /dev/null +++ b/test/expression/distribution/dist_fixture_base.hpp @@ -0,0 +1,33 @@ +#pragma once +#include +#include +#include + +namespace ppl { +namespace expr { + +template +struct dist_fixture_base { +protected: + static constexpr size_t vec_size = 3; + static constexpr size_t offset_max_size = 3; + + using value_t = ValueType; + using pointer_t = value_t*; + using vec_t = std::vector; + using vec_pointer_t = std::array; + + using dv_scl_t = DataView; + using dv_vec_t = DataView; + using pv_scl_t = ParamView; + using pv_vec_t = ParamView; + using id_t = typename util::var_traits::id_t; + using index_t = typename util::param_traits::index_t; + using ad_vec_t = std::vector>; + + std::array offsets = {0}; + vec_pointer_t storage = {nullptr}; +}; + +} // namespace expr +} // namespace ppl diff --git a/test/expression/distribution/normal_unittest.cpp b/test/expression/distribution/normal_unittest.cpp index ddc52a00..bbeb6dd1 100644 --- a/test/expression/distribution/normal_unittest.cpp +++ b/test/expression/distribution/normal_unittest.cpp @@ -1,74 +1,205 @@ #include "gtest/gtest.h" -#include -#include +#include "dist_fixture_base.hpp" #include #include -#include namespace ppl { namespace expr { -struct normal_fixture : ::testing::Test { +struct normal_fixture: + dist_fixture_base, + ::testing::Test +{ protected: - using value_t = typename MockVarExpr::value_t; - static constexpr size_t sample_size = 1000; - double mean = 0.3; - double stddev = 1.3; - double tol = 1e-15; - MockVarExpr x{mean}; - MockVarExpr y{stddev}; - using norm_t = Normal; - norm_t norm = {x, y}; - std::array sample = {0.}; + // vectors must be size 3 for consistency in this fixture + value_t x_val = -0.2; + vec_t x_vec = {0., 1., 2.}; + value_t mean_val = 0.; + vec_t mean_vec = {-1., 0., 1.}; + value_t sd_val = 1.; + vec_t sd_vec = {1., 2., 3.}; }; -TEST_F(normal_fixture, ctor) +TEST_F(normal_fixture, type_check) +{ + using norm_scl_t = Normal; + static_assert(util::is_dist_expr_v); +} + +TEST_F(normal_fixture, pdf) { -#if __cplusplus <= 201703L - static_assert(util::assert_is_dist_expr_v); -#else - static_assert(util::dist_expr); -#endif + using norm_t = Normal; + dv_vec_t x(x_vec); + dv_scl_t mean(mean_val); + dv_scl_t sd(sd_val); + norm_t norm(mean, sd); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(norm.pdf(x, pvalues), + 0.005211875018288502); } -TEST_F(normal_fixture, normal_check_params) { - EXPECT_DOUBLE_EQ(norm.mean(), x.get_value(0)); - EXPECT_DOUBLE_EQ(norm.stddev(), y.get_value(0)); +TEST_F(normal_fixture, log_pdf) +{ + using norm_t = Normal; + dv_vec_t x(x_vec); + dv_scl_t mean(mean_val); + dv_scl_t sd(sd_val); + norm_t norm(mean, sd); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(norm.log_pdf(x, pvalues), + -5.2568155996140185); } -TEST_F(normal_fixture, normal_pdf) +// AD log pdf case 1, subcase 1 +TEST_F(normal_fixture, ad_log_pdf_case_11) { - EXPECT_NEAR(norm.pdf(-10.231), 1.726752595588348216742E-15, tol); - EXPECT_NEAR(norm.pdf(-5.31), 2.774166877919518907166E-5, tol); - EXPECT_DOUBLE_EQ(norm.pdf(-2.3141231), 0.04063645713784323551341); - EXPECT_DOUBLE_EQ(norm.pdf(0.), 0.2988151821496727914542); - EXPECT_DOUBLE_EQ(norm.pdf(1.31), 0.2269313951019926611687); - EXPECT_DOUBLE_EQ(norm.pdf(3.21), 0.02505560241243631472997); - EXPECT_NEAR(norm.pdf(5.24551), 2.20984513448306056291E-4, tol); - EXPECT_NEAR(norm.pdf(10.5699), 8.61135160183067521907E-15, tol); + using norm_t = Normal; + dv_scl_t x(x_val); + dv_scl_t mean(mean_val); + dv_scl_t sd(sd_val); + norm_t norm(mean, sd); + + auto expr = norm.ad_log_pdf(x, 0); // arbitrary last param + + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -0.020000000000000018); } -TEST_F(normal_fixture, normal_log_pdf) +// AD log pdf case 1, subcase 2 when x has param +TEST_F(normal_fixture, ad_log_pdf_case_12_xparam) { - EXPECT_DOUBLE_EQ(norm.log_pdf(-10.231), std::log(1.726752595588348216742E-15)); - EXPECT_DOUBLE_EQ(norm.log_pdf(-5.31), std::log(2.774166877919518907166E-5)); - EXPECT_DOUBLE_EQ(norm.log_pdf(-2.3141231), std::log(0.04063645713784323551341)); - EXPECT_DOUBLE_EQ(norm.log_pdf(0.), std::log(0.2988151821496727914542)); - EXPECT_DOUBLE_EQ(norm.log_pdf(1.31), std::log(0.2269313951019926611687)); - EXPECT_DOUBLE_EQ(norm.log_pdf(3.21), std::log(0.02505560241243631472997)); - EXPECT_DOUBLE_EQ(norm.log_pdf(5.24551), std::log(2.20984513448306056291E-4)); - EXPECT_DOUBLE_EQ(norm.log_pdf(10.5699), std::log(8.61135160183067521907E-15)); + using norm_t = Normal; + + ad_vec_t ad_vars(2); + ad_vars[0].set_value(x_val); + ad_vars[1].set_value(sd_val); + + // initialize offsets that params will view + // MUST correspond to begin indices in ad_vars + offsets[0] = 0; + offsets[1] = 1; + + pv_scl_t x(offsets[0], storage[0]); // storage not used + dv_scl_t mean(mean_val); + pv_scl_t sd(offsets[1], storage[1]); // storage not used + norm_t norm(mean, sd); + + auto expr = norm.ad_log_pdf(x, ad_vars); + + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -0.020000000000000018); } -TEST_F(normal_fixture, normal_sample) { - std::random_device rd{}; - std::mt19937 gen{rd()}; +// AD log pdf case 1, subcase 2 when mean has param +TEST_F(normal_fixture, ad_log_pdf_case_12_mparam) +{ + using norm_t = Normal; + + ad_vec_t ad_vars(2); + ad_vars[0].set_value(mean_val); + ad_vars[1].set_value(sd_val); + + offsets[0] = 0; + offsets[1] = 1; + + dv_scl_t x(x_val); + pv_scl_t mean(offsets[0], storage[0]); + pv_scl_t sd(offsets[1], storage[1]); + norm_t norm(mean, sd); + + auto expr = norm.ad_log_pdf(x, ad_vars); + + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -0.020000000000000018); +} + +// AD log pdf case 1, subcase 3 +TEST_F(normal_fixture, ad_log_pdf_case_13) +{ + using norm_t = Normal; + + ad_vec_t ad_vars(1); + ad_vars[0].set_value(sd_val); + + offsets[0] = 0; + + dv_scl_t x(x_val); + dv_scl_t mean(mean_val); + pv_scl_t sd(offsets[0], storage[0]); + norm_t norm(mean, sd); + + auto expr = norm.ad_log_pdf(x, ad_vars); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -0.020000000000000018); +} + +// AD log pdf case 2, subcase 1 +TEST_F(normal_fixture, ad_log_pdf_case_21) +{ + using norm_t = Normal; + + offsets[0] = 0; + + pv_vec_t x(offsets[0], storage, vec_size); + dv_scl_t mean(mean_val); + dv_scl_t sd(sd_val); + norm_t norm(mean, sd); + + ad_vec_t ad_vars(x_vec.size()); + std::for_each(util::counting_iterator(0), + util::counting_iterator(x_vec.size()), + [&](size_t i) { ad_vars[i].set_value(x_vec[i]); }); + + auto expr = norm.ad_log_pdf(x, ad_vars); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -2.5000000000000004); +} + +// AD log pdf case 2, subcase 2 +TEST_F(normal_fixture, ad_log_pdf_case_22) +{ + using norm_t = Normal; + + ad_vec_t ad_vars(2); + ad_vars[0].set_value(mean_val); + ad_vars[1].set_value(sd_val); + + offsets[0] = 0; + offsets[1] = 1; + + dv_vec_t x(x_vec); + pv_scl_t mean(offsets[0], storage[0]); + pv_scl_t sd(offsets[1], storage[1]); + norm_t norm(mean, sd); + + auto expr = norm.ad_log_pdf(x, ad_vars); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -2.5000000000000004); +} + +// AD log pdf case 3 +TEST_F(normal_fixture, ad_log_pdf_case_3) +{ + using norm_t = Normal; + + ad_vec_t ad_vars(vec_size + 1); + + std::for_each(util::counting_iterator<>(0), + util::counting_iterator<>(vec_size), + [&](auto i) { ad_vars[i].set_value(mean_vec[i]); }); + ad_vars[vec_size].set_value(sd_val); + + offsets[0] = 0; + offsets[1] = offsets[0] + vec_size; - for (size_t i = 0; i < sample_size; i++) { - sample[i] = norm.sample(gen); - } + dv_vec_t x(x_vec); + pv_vec_t mean(offsets[0], storage, vec_size); + pv_scl_t sd(offsets[1], storage[vec_size]); + norm_t norm(mean, sd); - plot_hist(sample); + auto expr = norm.ad_log_pdf(x, ad_vars); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -1.5000000000000004); } } // namespace expr diff --git a/test/expression/distribution/uniform_unittest.cpp b/test/expression/distribution/uniform_unittest.cpp index 7fa4658c..9202c79f 100644 --- a/test/expression/distribution/uniform_unittest.cpp +++ b/test/expression/distribution/uniform_unittest.cpp @@ -1,93 +1,191 @@ #include "gtest/gtest.h" -#include -#include +#include "dist_fixture_base.hpp" #include #include -#include namespace ppl { namespace expr { -struct uniform_fixture : ::testing::Test { +struct uniform_fixture: + dist_fixture_base, + ::testing::Test +{ protected: - using value_t = typename MockVarExpr::value_t; - static constexpr size_t sample_size = 1000; - double min = -2.3; - double max = 2.7; - MockVarExpr x{min}; - MockVarExpr y{max}; - using unif_t = Uniform; - unif_t unif = {x, y}; - std::array sample = {0.}; + // vectors must be size 3 for consistency in this fixture + value_t x_val_in = 0.; + value_t x_val_out = -1.; + vec_t x_vec_in = {0., 0.3, 1.1}; + vec_t x_vec_out = {0., 1., 2.}; // last number is changed to be at the edge of range + value_t min_val = -1.; + vec_t min_vec = {-1., 0., 1.}; + value_t max_val = 2.; + vec_t max_vec = {1., 2., 3.}; }; -TEST_F(uniform_fixture, ctor) +TEST_F(uniform_fixture, type_check) { -#if __cplusplus <= 201703L - static_assert(util::assert_is_dist_expr_v); -#else - static_assert(util::dist_expr); -#endif + using unif_scl_t = Uniform; + static_assert(util::is_dist_expr_v); } -TEST_F(uniform_fixture, uniform_check_params) { - EXPECT_DOUBLE_EQ(unif.min(), x.get_value(0)); - EXPECT_DOUBLE_EQ(unif.max(), y.get_value(0)); +//////////////////////////////////////////////////////////// +// PDF TEST +//////////////////////////////////////////////////////////// + +TEST_F(uniform_fixture, pdf_in_scl) +{ + using unif_t = Uniform; + dv_scl_t x(x_val_in); + dv_scl_t min(min_val); + dv_scl_t max(max_val); + unif_t unif(min, max); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(unif.pdf(x, pvalues), + 1./3); } -TEST_F(uniform_fixture, uniform_pdf_in_range) +TEST_F(uniform_fixture, pdf_in_vec) { - EXPECT_DOUBLE_EQ(unif.pdf(-2.2999999999), 0.2); - EXPECT_DOUBLE_EQ(unif.pdf(-2.), 0.2); - EXPECT_DOUBLE_EQ(unif.pdf(-1.423), 0.2); - EXPECT_DOUBLE_EQ(unif.pdf(0.), 0.2); - EXPECT_DOUBLE_EQ(unif.pdf(1.31), 0.2); - EXPECT_DOUBLE_EQ(unif.pdf(2.41), 0.2); - EXPECT_DOUBLE_EQ(unif.pdf(2.69999999999), 0.2); + using unif_t = Uniform; + dv_vec_t x(x_vec_in); + dv_vec_t min(min_vec); + dv_vec_t max(max_vec); + unif_t unif(min, max); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(unif.pdf(x, pvalues), + 0.125); } -TEST_F(uniform_fixture, uniform_pdf_out_of_range) +TEST_F(uniform_fixture, pdf_in_scl_vec) { - EXPECT_DOUBLE_EQ(unif.pdf(-100), 0.); - EXPECT_DOUBLE_EQ(unif.pdf(-3.41), 0.); - EXPECT_DOUBLE_EQ(unif.pdf(-2.3), 0.); - EXPECT_DOUBLE_EQ(unif.pdf(2.7), 0.); - EXPECT_DOUBLE_EQ(unif.pdf(3.5), 0.); - EXPECT_DOUBLE_EQ(unif.pdf(3214), 0.); + using unif_t = Uniform; + dv_vec_t x(x_vec_in); + dv_scl_t min(min_val); + dv_vec_t max(max_vec); + unif_t unif(min, max); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(unif.pdf(x, pvalues), + 0.5 * 1./3 * 0.25); } -TEST_F(uniform_fixture, uniform_log_pdf_in_range) +TEST_F(uniform_fixture, pdf_out) +{ + using unif_t = Uniform; + dv_scl_t x(x_val_out); + dv_scl_t min(min_val); + dv_scl_t max(max_val); + unif_t unif(min, max); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(unif.pdf(x, pvalues), + 0.0); +} + +//////////////////////////////////////////////////////////// +// Log-PDF TEST +//////////////////////////////////////////////////////////// + +TEST_F(uniform_fixture, log_pdf_in) { - EXPECT_DOUBLE_EQ(unif.log_pdf(-2.2999999999), std::log(0.2)); - EXPECT_DOUBLE_EQ(unif.log_pdf(-2.), std::log(0.2)); - EXPECT_DOUBLE_EQ(unif.log_pdf(-1.423), std::log(0.2)); - EXPECT_DOUBLE_EQ(unif.log_pdf(0.), std::log(0.2)); - EXPECT_DOUBLE_EQ(unif.log_pdf(1.31), std::log(0.2)); - EXPECT_DOUBLE_EQ(unif.log_pdf(2.41), std::log(0.2)); - EXPECT_DOUBLE_EQ(unif.log_pdf(2.69999999999), std::log(0.2)); + using unif_t = Uniform; + dv_scl_t x(x_val_in); + dv_scl_t min(min_val); + dv_scl_t max(max_val); + unif_t unif(min, max); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(unif.log_pdf(x, pvalues), + -std::log(3.)); } -TEST_F(uniform_fixture, uniform_log_pdf_out_of_range) +TEST_F(uniform_fixture, log_pdf_in_scl_vec) { - EXPECT_DOUBLE_EQ(unif.log_pdf(-100), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(unif.log_pdf(-3.41), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(unif.log_pdf(-2.3), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(unif.log_pdf(2.7), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(unif.log_pdf(3.5), std::numeric_limits::lowest()); - EXPECT_DOUBLE_EQ(unif.log_pdf(3214), std::numeric_limits::lowest()); + using unif_t = Uniform; + dv_vec_t x(x_vec_in); + dv_scl_t min(min_val); + dv_vec_t max(max_vec); + unif_t unif(min, max); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(unif.log_pdf(x, pvalues), + std::log(0.5 * 1./3 * 0.25)); } -TEST_F(uniform_fixture, uniform_sample) { - std::random_device rd{}; - std::mt19937 gen{rd()}; - for (size_t i = 0; i < sample_size; i++) { - sample[i] = unif.sample(gen); - EXPECT_GT(sample[i], min); - EXPECT_LT(sample[i], max); - } +TEST_F(uniform_fixture, log_pdf_out) +{ + using unif_t = Uniform; + dv_scl_t x(x_val_out); + dv_scl_t min(min_val); + dv_scl_t max(max_val); + unif_t unif(min, max); + vec_t pvalues; // no parameter values + EXPECT_DOUBLE_EQ(unif.log_pdf(x, pvalues), + math::neg_inf); +} + +//////////////////////////////////////////////////////////// +// ad_log_pdf TEST +//////////////////////////////////////////////////////////// + +// Case 1, Subcase 1: +TEST_F(uniform_fixture, ad_log_pdf_case11) +{ + using unif_t = Uniform; + dv_vec_t x(x_vec_in); + dv_scl_t min(min_val); + dv_scl_t max(max_val); + unif_t unif(min, max); + ad_vec_t ad_vars; + + auto expr = unif.ad_log_pdf(x, ad_vars); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -std::log(27.)); +} + +// Case 1, Subcase 2: +TEST_F(uniform_fixture, ad_log_pdf_case12) +{ + using unif_t = Uniform; + pv_vec_t x(offsets[0], storage, vec_size); + dv_scl_t min(min_val); + dv_scl_t max(max_val); + unif_t unif(min, max); + + offsets[0] = 0; + + ad_vec_t ad_vars(vec_size); + std::for_each(util::counting_iterator<>(0), + util::counting_iterator<>(vec_size), + [&](size_t i) { ad_vars[i].set_value(x_vec_in[i]); }); + + auto expr = unif.ad_log_pdf(x, ad_vars); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + -std::log(27.)); +} + +// Case 2: +TEST_F(uniform_fixture, ad_log_pdf_case2) +{ + using unif_t = Uniform; + + // storage is ignored for now + pv_vec_t x(offsets[0], storage, vec_size); + dv_scl_t min(min_val); + pv_vec_t max(offsets[1], storage, vec_size); + unif_t unif(min, max); + + offsets[0] = 0; + offsets[1] = vec_size; + + ad_vec_t ad_vars(vec_size * 2); + std::for_each(util::counting_iterator<>(0), + util::counting_iterator<>(vec_size), + [&](size_t i) { + ad_vars[i].set_value(x_vec_in[i]); + ad_vars[i+vec_size].set_value(max_vec[i]); + }); - plot_hist(sample, 0.5, min, max); + auto expr = unif.ad_log_pdf(x, ad_vars); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), + std::log(0.5 * 1./3. * 0.25)); } } // namespace expr diff --git a/test/expr_builder_unittest.cpp b/test/expression/expr_builder_unittest.cpp similarity index 82% rename from test/expr_builder_unittest.cpp rename to test/expression/expr_builder_unittest.cpp index 612cdcfd..1b17b54c 100644 --- a/test/expr_builder_unittest.cpp +++ b/test/expression/expr_builder_unittest.cpp @@ -1,5 +1,5 @@ #include "gtest/gtest.h" -#include +#include #include namespace ppl { @@ -7,9 +7,13 @@ namespace ppl { struct expr_builder_fixture : ::testing::Test { protected: + using param_t = ppl::Param; + using pview_t = ppl::ParamView< + typename util::param_traits::pointer_t, + ppl::scl>; MockVarExpr x; MockVarExpr y; - MockParam v; + param_t v; double d; long int i; }; @@ -18,20 +22,11 @@ TEST_F(expr_builder_fixture, convert_to_param_var) { using namespace details; static_assert(std::is_same_v>); -#if __cplusplus <= 201703L static_assert(util::is_var_v); -#else - static_assert(util::var); -#endif static_assert(!std::is_same_v); -#if __cplusplus <= 201703L - static_assert(!util::is_var_expr_v); -#else - static_assert(!util::var_expr); -#endif static_assert(std::is_same_v< convert_to_param_t, - expr::VariableViewer + pview_t >); } @@ -40,17 +35,9 @@ TEST_F(expr_builder_fixture, convert_to_param_raw) using namespace details; using data_t = util::cont_param_t; static_assert(std::is_same_v>); -#if __cplusplus <= 201703L static_assert(!util::is_var_v); -#else - static_assert(!util::var); -#endif static_assert(std::is_same_v); -#if __cplusplus <= 201703L static_assert(!util::is_var_expr_v); -#else - static_assert(!util::var_expr); -#endif static_assert(std::is_same_v< convert_to_param_t, expr::Constant @@ -60,17 +47,9 @@ TEST_F(expr_builder_fixture, convert_to_param_raw) TEST_F(expr_builder_fixture, convert_to_param_var_expr) { using namespace details; -#if __cplusplus <= 201703L static_assert(!util::is_var_v); -#else - static_assert(!util::var); -#endif static_assert(!std::is_same_v); -#if __cplusplus <= 201703L static_assert(util::is_var_expr_v); -#else - static_assert(util::var_expr); -#endif static_assert(std::is_same_v< convert_to_param_t, MockVarExpr& @@ -97,7 +76,7 @@ TEST_F(expr_builder_fixture, op_plus) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, @@ -109,7 +88,7 @@ TEST_F(expr_builder_fixture, op_plus) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); // double, [MockVarExpr, double, long int, MockVar] @@ -123,7 +102,7 @@ TEST_F(expr_builder_fixture, op_plus) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -135,7 +114,7 @@ TEST_F(expr_builder_fixture, op_plus) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // long int, [MockVarExpr, double, long int, MockVar] @@ -149,7 +128,7 @@ TEST_F(expr_builder_fixture, op_plus) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -161,21 +140,21 @@ TEST_F(expr_builder_fixture, op_plus) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // MockVar, [MockVarExpr, double, long int, MockVar] static_assert(std::is_same_v< - expr::BinaryOpNode, MockVarExpr>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, std::decay_t >); } @@ -192,7 +171,7 @@ TEST_F(expr_builder_fixture, op_minus) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, @@ -204,7 +183,7 @@ TEST_F(expr_builder_fixture, op_minus) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); // double, [MockVarExpr, double, long int, MockVar] @@ -218,7 +197,7 @@ TEST_F(expr_builder_fixture, op_minus) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -230,7 +209,7 @@ TEST_F(expr_builder_fixture, op_minus) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // long int, [MockVarExpr, double, long int, MockVar] @@ -244,7 +223,7 @@ TEST_F(expr_builder_fixture, op_minus) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -256,21 +235,21 @@ TEST_F(expr_builder_fixture, op_minus) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // MockVar, [MockVarExpr, double, long int, MockVar] static_assert(std::is_same_v< - expr::BinaryOpNode, MockVarExpr>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, std::decay_t >); } @@ -287,7 +266,7 @@ TEST_F(expr_builder_fixture, op_times) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, @@ -299,7 +278,7 @@ TEST_F(expr_builder_fixture, op_times) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); // double, [MockVarExpr, double, long int, MockVar] @@ -313,7 +292,7 @@ TEST_F(expr_builder_fixture, op_times) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -325,7 +304,7 @@ TEST_F(expr_builder_fixture, op_times) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // long int, [MockVarExpr, double, long int, MockVar] @@ -339,7 +318,7 @@ TEST_F(expr_builder_fixture, op_times) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -351,21 +330,21 @@ TEST_F(expr_builder_fixture, op_times) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // MockVar, [MockVarExpr, double, long int, MockVar] static_assert(std::is_same_v< - expr::BinaryOpNode, MockVarExpr>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, std::decay_t >); } @@ -382,7 +361,7 @@ TEST_F(expr_builder_fixture, op_div) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, @@ -394,7 +373,7 @@ TEST_F(expr_builder_fixture, op_div) expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode>, + expr::BinaryOpNode, std::decay_t >); // double, [MockVarExpr, double, long int, MockVar] @@ -408,7 +387,7 @@ TEST_F(expr_builder_fixture, op_div) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -420,7 +399,7 @@ TEST_F(expr_builder_fixture, op_div) double, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // long int, [MockVarExpr, double, long int, MockVar] @@ -434,7 +413,7 @@ TEST_F(expr_builder_fixture, op_div) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); static_assert(std::is_same_v< expr::BinaryOpNode, MockVarExpr>, @@ -446,21 +425,21 @@ TEST_F(expr_builder_fixture, op_div) long, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, pview_t>, std::decay_t >); // MockVar, [MockVarExpr, double, long int, MockVar] static_assert(std::is_same_v< - expr::BinaryOpNode, MockVarExpr>, + expr::BinaryOpNode, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::Constant>, + expr::BinaryOpNode>, std::decay_t >); static_assert(std::is_same_v< - expr::BinaryOpNode, expr::VariableViewer>, + expr::BinaryOpNode, std::decay_t >); } } // namespace ppl diff --git a/test/ad_integration_unittest.cpp b/test/expression/integration/ad_inttest.cpp similarity index 77% rename from test/ad_integration_unittest.cpp rename to test/expression/integration/ad_inttest.cpp index 3c0f0dcb..72a7bef4 100644 --- a/test/ad_integration_unittest.cpp +++ b/test/expression/integration/ad_inttest.cpp @@ -1,21 +1,29 @@ #include "gtest/gtest.h" #include -#include +#include namespace ppl { struct ad_integration_fixture : ::testing::Test { protected: - Data x{1., 2., 3.}, y{0., -1., 1.}; - Param theta; - std::array keys = {&theta}; - std::vector> vars; + using value_t = double; + using data_t = Data; + using param_t = Param; + using pview_t = ParamView< + typename util::param_traits::pointer_t, + ppl::scl>; + + data_t x{1., 2., 3.}, y{0., -1., 1.}; + param_t theta; + std::vector> vars; ad_integration_fixture() : theta{} , vars(1) { + pview_t theta_view = theta; + theta_view.offset() = 0; vars[0].set_value(1.); } }; @@ -23,7 +31,7 @@ struct ad_integration_fixture : ::testing::Test TEST_F(ad_integration_fixture, ad_log_pdf_data_constant_param) { auto model = (x |= normal(0., 1.)); - auto ad_expr = model.ad_log_pdf(keys, vars); + auto ad_expr = model.ad_log_pdf(vars); double value = ad::evaluate(ad_expr); EXPECT_DOUBLE_EQ(value, -0.5 * 14); value = ad::autodiff(ad_expr); // should not affect the result @@ -36,7 +44,7 @@ TEST_F(ad_integration_fixture, ad_log_pdf_data_mean_param) theta |= normal(0., 2.), x |= normal(theta, 1.) ); - auto ad_expr = model.ad_log_pdf(keys, vars); + auto ad_expr = model.ad_log_pdf(vars); double value = ad::autodiff(ad_expr); EXPECT_DOUBLE_EQ(value, -0.5 * 5 - 1./8 - std::log(2)); @@ -57,7 +65,7 @@ TEST_F(ad_integration_fixture, ad_log_pdf_data_stddev_param) x |= normal(0., theta) ); - auto ad_expr = model.ad_log_pdf(keys, vars); + auto ad_expr = model.ad_log_pdf(vars); double value = ad::autodiff(ad_expr); EXPECT_DOUBLE_EQ(value, -0.5 * 14 - 1./8 - std::log(2)); @@ -78,7 +86,7 @@ TEST_F(ad_integration_fixture, ad_log_pdf_data_param_with_data) y |= normal(theta * x, 1.) ); - auto ad_expr = model.ad_log_pdf(keys, vars); + auto ad_expr = model.ad_log_pdf(vars); double value = ad::autodiff(ad_expr); EXPECT_DOUBLE_EQ(value, -7.5); @@ -97,9 +105,9 @@ TEST_F(ad_integration_fixture, ad_log_pdf_constant_param_within_bounds) auto model = ( theta |= uniform(-1., 0.5) ); - auto expr = model.ad_log_pdf(keys, vars); + auto expr = model.ad_log_pdf(vars); double value = ad::autodiff(expr); - EXPECT_DOUBLE_EQ(value, std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(value, math::neg_inf); EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 0); } @@ -109,7 +117,7 @@ TEST_F(ad_integration_fixture, ad_log_pdf_constant_param_out_of_bounds) auto model = ( theta |= uniform(-1., 0.5) ); - auto expr = model.ad_log_pdf(keys, vars); + auto expr = model.ad_log_pdf(vars); double value = ad::autodiff(expr); EXPECT_DOUBLE_EQ(value, -std::log(1.5)); EXPECT_DOUBLE_EQ(vars[0].get_adjoint(), 0); @@ -120,9 +128,9 @@ TEST_F(ad_integration_fixture, ad_log_pdf_var_param_within_bounds) vars[0].set_value(0.42); auto model = ( theta |= normal(-1., 0.5), - x |= uniform(theta, theta + 5) + x |= uniform(theta, theta + 5.) ); - auto expr = model.ad_log_pdf(keys, vars); + auto expr = model.ad_log_pdf(vars); double value = ad::autodiff(expr); EXPECT_DOUBLE_EQ(value, -2*(1.42 * 1.42) + std::log(2) - 3*std::log(5)); } @@ -134,9 +142,9 @@ TEST_F(ad_integration_fixture, ad_log_pdf_var_param_out_of_bounds) theta |= normal(-1., 0.5), x |= uniform(theta, theta + 2) ); - auto expr = model.ad_log_pdf(keys, vars); + auto expr = model.ad_log_pdf(vars); double value = ad::autodiff(expr); - EXPECT_DOUBLE_EQ(value, std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(value, math::neg_inf); } } // namespace ppl diff --git a/test/expression/integration/dist_inttest.cpp b/test/expression/integration/dist_inttest.cpp new file mode 100644 index 00000000..437b2a3a --- /dev/null +++ b/test/expression/integration/dist_inttest.cpp @@ -0,0 +1,52 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +namespace ppl { + +struct normal_integration_fixture : ::testing::Test { +protected: + using value_t = double; + using param_t = Param; + using data_t = Data; + using pview_t = typename param_t::base_t; + + data_t v1 {0.1, 0.2, 0.3, 0.4, 0.5}; + std::array pvalues = {0.1, -0.1}; + param_t x = 2; + + value_t tol = 1e-15; + + normal_integration_fixture() + { + // manually set offset + // in real-use case, user will call an initialization function + pview_t x_view = x; + x_view.offset() = 0; + } +}; + +TEST_F(normal_integration_fixture, normal_pdfs) { + + auto dist1 = normal(0., 1.); + + EXPECT_NEAR(dist1.pdf(v1, pvalues), 0.007675723936191419, tol); + EXPECT_NEAR(dist1.log_pdf(v1, pvalues), -4.869692666023363, tol); + + auto dist2 = normal(x[0], 1.); + + pvalues[0] = 0.; + EXPECT_NEAR(dist2.pdf(v1, pvalues), 0.0076757239361914193, tol); + EXPECT_NEAR(dist2.log_pdf(v1, pvalues), -4.869692666023363, tol); + + auto dist3 = normal(x[0] + x[1], 1.); + + pvalues[0] = 0.1; + EXPECT_NEAR(dist3.pdf(v1, pvalues), 0.0076757239361914193, tol); + EXPECT_NEAR(dist3.log_pdf(v1, pvalues), -4.869692666023363, tol); +} + +} // namespace ppl diff --git a/test/expression/integration/model_inttest.cpp b/test/expression/integration/model_inttest.cpp new file mode 100644 index 00000000..f60e989d --- /dev/null +++ b/test/expression/integration/model_inttest.cpp @@ -0,0 +1,66 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include + +namespace ppl { + +struct model_integration_fixture : ::testing::Test { +protected: + using value_t = double; + using data_t = Data; + using param_t = Param; + using pview_t = typename param_t::base_t; + + value_t tol = 1e-15; + + data_t v1 {0.1, 0.2, 0.3, 0.4, 0.5}; + param_t mu, sigma, w, b; + std::array pvalues; + + data_t x{2.5, 3, 3.5, 4, 4.5, 5.}; + data_t y{3.5, 4, 4.5, 5, 5.5, 6.}; + data_t q{2.4, 3.1, 3.6, 4, 4.5, 5.}; + data_t r{3.5, 4, 4.4, 5.01, 5.46, 6.1}; + + model_integration_fixture() + { + // manually set offset + // in real-use case, user will call an initialization function + pview_t mu_view = mu; + pview_t sigma_view = sigma; + pview_t w_view = w; + pview_t b_view = b; + mu_view.offset() = 0; + sigma_view.offset() = 1; + w_view.offset() = 2; + b_view.offset() = 3; + } +}; + +TEST_F(model_integration_fixture, simple_model_pdfs) { + auto model = ( + mu |= uniform(-0.5, 2.), + v1 |= normal(mu, 1.0) + ); + + pvalues[0] = 0.0; + + EXPECT_NEAR(model.pdf(pvalues), 0.003070289574476568, tol); + EXPECT_NEAR(model.log_pdf(pvalues), -5.785983397897518, tol); +} + +TEST_F(model_integration_fixture, regression_pdfs) { + pvalues[2] = 1.0; + pvalues[3] = 1.0; + + auto model = (w |= ppl::uniform(0., 2.), + b |= ppl::uniform(0., 2.), + r |= ppl::normal(q * w + b, 0.5)); + + EXPECT_NEAR(model.pdf(pvalues), 0.055885938549306326, tol); + EXPECT_NEAR(model.log_pdf(pvalues), -2.884442476988254, tol); +} + +} // namespace ppl diff --git a/test/expression/model/model_unittest.cpp b/test/expression/model/model_unittest.cpp index 75b6825d..16233c89 100644 --- a/test/expression/model/model_unittest.cpp +++ b/test/expression/model/model_unittest.cpp @@ -1,7 +1,9 @@ #include "gtest/gtest.h" #include #include -#include +#include +#include +#include #include namespace ppl { @@ -14,58 +16,38 @@ namespace expr { /* * Fixture for testing one var with distribution. */ -struct var_dist_fixture : ::testing::Test +struct model_fixture : ::testing::Test { protected: - MockParam x; - using model_t = EqNode; - model_t model = {x, MockDistExpr()}; - double val; - - void reconfigure() - { x.set_value(val); } + using param_t = MockParam; + using value_t = typename util::param_traits::value_t; + using dist_t = MockDistExpr; + using eq_t = EqNode; }; -TEST_F(var_dist_fixture, ctor) +TEST_F(model_fixture, type_check) { -#if __cplusplus <= 201703L - static_assert(util::assert_is_model_expr_v); -#else - static_assert(util::model_expr); -#endif + static_assert(util::is_model_expr_v); } -TEST_F(var_dist_fixture, pdf_valid) +TEST_F(model_fixture, eq_pdf_valid) { - // MockDistExpr pdf is identity function - // so we may simply compare model.pdf() with val. - - val = 0.000001; - reconfigure(); - EXPECT_EQ(model.pdf(), val); - - val = 0.5; - reconfigure(); - EXPECT_EQ(model.pdf(), val); - - val = 0.999999; - reconfigure(); - EXPECT_EQ(model.pdf(), val); + param_t x(3.); + dist_t d(0.5); + eq_t model(x, d); + value_t val = 1.5; + // parameter ignored (arbitrary) + EXPECT_DOUBLE_EQ(model.pdf(0), val); } -TEST_F(var_dist_fixture, log_pdf_valid) +TEST_F(model_fixture, eq_log_pdf_valid) { - val = 0.000001; - reconfigure(); - EXPECT_EQ(model.log_pdf(), std::log(val)); - - val = 0.5; - reconfigure(); - EXPECT_EQ(model.log_pdf(), std::log(val)); - - val = 0.999999; - reconfigure(); - EXPECT_EQ(model.log_pdf(), std::log(val)); + param_t x(5.); + dist_t d(1.32); + eq_t model(x, d); + value_t val = std::log(5. * 1.32); + // parameter ignored (arbitrary) + EXPECT_DOUBLE_EQ(model.log_pdf(0), val); } ////////////////////////////////////////////////////// @@ -75,129 +57,68 @@ TEST_F(var_dist_fixture, log_pdf_valid) /* * Fixture for testing many vars with distributions. */ -struct many_var_dist_fixture : ::testing::Test +struct many_model_fixture : ::testing::Test { protected: using value_t = double; using eq_t = EqNode; - MockParam x, y, z, w; - value_t xv, yv, zv, wv; + value_t xv = 0.2; + value_t yv = 1.8; + value_t zv = 0.32; + value_t xd = 1.5; + value_t yd = 1.523; + value_t zd = 0.00132; + MockParam x = xv; + MockParam y = yv; + MockParam z = zv; using model_two_t = GlueNode; model_two_t model_two = { - {x, MockDistExpr()}, - {y, MockDistExpr()} + {x, MockDistExpr(xd)}, + {y, MockDistExpr(yd)} }; - using model_four_t = - GlueNode - > - >; + using model_three_t = + GlueNode>; - model_four_t model_four = { - {x, MockDistExpr()}, + model_three_t model_three = { + {x, MockDistExpr(xd)}, { - {y, MockDistExpr()}, - { - {z, MockDistExpr()}, - {w, MockDistExpr()} - } + {y, MockDistExpr(yd)}, + {z, MockDistExpr(zd)} } }; }; -TEST_F(many_var_dist_fixture, ctor) +TEST_F(many_model_fixture, type_check) { -#if __cplusplus <= 201703L - static_assert(util::assert_is_model_expr_v); - static_assert(util::assert_is_model_expr_v); -#else - static_assert(util::model_expr); - static_assert(util::model_expr); -#endif + static_assert(util::is_model_expr_v); + static_assert(util::is_model_expr_v); } -TEST_F(many_var_dist_fixture, two_vars_pdf) +TEST_F(many_model_fixture, two_vars_pdf) { - xv = 0.2; yv = 1.8; - - x.set_value(xv); - y.set_value(yv); - - EXPECT_EQ(model_two.pdf(), xv * yv); - EXPECT_EQ(model_two.log_pdf(), std::log(xv) + std::log(yv)); + EXPECT_DOUBLE_EQ(model_two.pdf(0), xv * xd * yv * yd); + EXPECT_DOUBLE_EQ(model_two.log_pdf(0), std::log(xv*xd) + std::log(yv*yd)); } -TEST_F(many_var_dist_fixture, four_vars_pdf) +TEST_F(many_model_fixture, three_vars_pdf) { - xv = 0.2; yv = 1.8; zv = 3.2; wv = 0.3; - - x.set_value(xv); - y.set_value(yv); - z.set_value(zv); - w.set_value(wv); - - EXPECT_EQ(model_four.pdf(), xv * yv * zv * wv); - EXPECT_EQ(model_four.log_pdf(), std::log(xv) + std::log(yv) - + std::log(zv) + std::log(wv)); + EXPECT_DOUBLE_EQ(model_three.pdf(0), xv * xd * yv * yd * zv * zd); + EXPECT_DOUBLE_EQ(model_three.log_pdf(0), + std::log(xv*xd) + std::log(yv*yd) + std::log(zv*zd)); } -TEST_F(many_var_dist_fixture, four_vars_traverse_count_params) -{ - int count = 0; - model_four.traverse([&](auto&) { - count++; - }); - EXPECT_EQ(count, 4); -} - -TEST_F(many_var_dist_fixture, four_vars_traverse_pdf) +TEST_F(many_model_fixture, three_vars_traverse_pdf) { double actual = 1.; - model_four.traverse([&](auto& model) { - auto& var = model.get_variable(); - auto& dist = model.get_distribution(); - actual *= dist.pdf(var.get_value(0)); + model_three.traverse([&](auto& eq) { + auto& var = eq.get_variable(); + auto& dist = eq.get_distribution(); + actual *= dist.pdf(var, 0); }); - EXPECT_EQ(actual, model_four.pdf()); + EXPECT_DOUBLE_EQ(actual, model_three.pdf(0)); } -//////////////////////////////////////////////////////////// -// get_n_params TESTS -//////////////////////////////////////////////////////////// - -TEST_F(many_var_dist_fixture, get_n_params_zero) -{ - using eq_node_t = EqNode; - static_assert(get_n_params_v == 0); -} - -TEST_F(many_var_dist_fixture, get_n_params_one) -{ - using eq_node_t = EqNode; - static_assert(get_n_params_v == 1); -} - -TEST_F(many_var_dist_fixture, get_n_params_one_with_data) -{ - using model_t = GlueNode< - EqNode, - EqNode - >; - static_assert(get_n_params_v == 1); -} - -TEST_F(many_var_dist_fixture, get_n_params_two) -{ - using model_t = GlueNode< - EqNode, - EqNode - >; - static_assert(get_n_params_v == 2); -} - - } // namespace expr } // namespace ppl diff --git a/test/expression/samples/dist_sample_unittest.cpp b/test/expression/samples/dist_sample_unittest.cpp deleted file mode 100644 index 248ab60d..00000000 --- a/test/expression/samples/dist_sample_unittest.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include -#include -#include -#include - -#include - -#include "gtest/gtest.h" - -namespace ppl { - -struct normal_fixture : ::testing::Test { - protected: - Data v1 {0.1, 0.2, 0.3, 0.4, 0.5}; - Param x, y; - - double tol = 1e-15; -}; - -TEST_F(normal_fixture, normal_check_pdf) { - auto dist1 = normal(0., 1.); - - EXPECT_NEAR(dist1.pdf(v1), 0.0076757239361914193, tol); - EXPECT_NEAR(dist1.log_pdf(v1), -4.869692666023363, tol); - - auto dist2 = normal(x, 1.); - - EXPECT_NEAR(dist2.pdf(v1), 0.0076757239361914193, tol); - EXPECT_NEAR(dist2.log_pdf(v1), -4.869692666023363, tol); - - x.set_value(0.1); - y.set_value(-0.1); - auto dist3 = normal(x + y, 1.); - - EXPECT_NEAR(dist3.pdf(v1), 0.0076757239361914193, tol); - EXPECT_NEAR(dist3.log_pdf(v1), -4.869692666023363, tol); -} - -} // namespace ppl diff --git a/test/expression/samples/model_sample_unittest.cpp b/test/expression/samples/model_sample_unittest.cpp deleted file mode 100644 index 3795503d..00000000 --- a/test/expression/samples/model_sample_unittest.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include -#include -#include -#include - -#include - -#include "gtest/gtest.h" - -namespace ppl { - -struct model_sample_fixture : ::testing::Test { - protected: - Data v1 {0.1, 0.2, 0.3, 0.4, 0.5}; - Param mu, sigma; - - Param w, b; - ppl::Data x{2.5, 3, 3.5, 4, 4.5, 5.}; - ppl::Data y{3.5, 4, 4.5, 5, 5.5, 6.}; - - ppl::Data q{2.4, 3.1, 3.6, 4, 4.5, 5.}; - ppl::Data r{3.5, 4, 4.4, 5.01, 5.46, 6.1}; - - double tol = 1e-10; -}; - -TEST_F(model_sample_fixture, simple_model_test) { - auto model = ( - mu |= uniform(-0.5, 2), - v1 |= normal(mu, 1.0) - ); - - mu.set_value(0.0); - - EXPECT_NEAR(model.pdf(), 0.003070289574476568, tol); - EXPECT_NEAR(model.log_pdf(), -5.785983397897518, tol); -} - -TEST_F(model_sample_fixture, test_regression_pdf) { - w.set_value(1.0); - b.set_value(1.0); - - auto model = (w |= ppl::uniform(0, 2), - b |= ppl::uniform(0, 2), - r |= ppl::normal(q * w + b, 0.5)); - - EXPECT_NEAR(model.pdf(), 0.055885938549306326, tol); - EXPECT_NEAR(model.log_pdf(), -2.884442476988254, tol); -} - -} // namespace ppl diff --git a/test/expression/variable/binop_unittest.cpp b/test/expression/variable/binop_unittest.cpp index cb21f182..bc572b0b 100644 --- a/test/expression/variable/binop_unittest.cpp +++ b/test/expression/variable/binop_unittest.cpp @@ -14,21 +14,13 @@ namespace expr { struct binop_fixture : ::testing::Test { protected: - MockVarExpr x = 0; - MockVarExpr y = 0; - - using binop_result_t = double; - - using binop_node_t = BinaryOpNode; - - void reconfigureX(double val) - { x.set_value(val); } - - void reconfigureY(double val) - { y.set_value(val); } - + using addop_node_t = BinaryOpNode; }; +////////////////////////////////////////////////////// +// Functor TESTS +////////////////////////////////////////////////////// + TEST_F(binop_fixture, add) { double val1 = 3.5; @@ -78,16 +70,34 @@ TEST_F(binop_fixture, div) EXPECT_EQ(divInt, 4); } -TEST_F(binop_fixture, binop_node) +////////////////////////////////////////////////////// +// Binop Node TESTS +////////////////////////////////////////////////////// + +TEST_F(binop_fixture, binop_node_value) { - reconfigureX(3); - reconfigureY(4); + addop_node_t node(MockVarExpr(3), MockVarExpr(4)); + // first parameter is always ignored + // second parameter is ignored because MockVarExprs are scalars + EXPECT_DOUBLE_EQ(node.value(0, 0), 7); + EXPECT_DOUBLE_EQ(node.value(0, 1), 7); +} - binop_node_t addNode = {x, y}; - double res = addNode.get_value(0); +TEST_F(binop_fixture, binop_node_size) +{ + addop_node_t node(MockVarExpr(0), MockVarExpr(1)); + EXPECT_EQ(node.size(), 1ul); - EXPECT_EQ(res, 7); + addop_node_t node2(MockVarExpr(3), MockVarExpr(1)); + EXPECT_EQ(node2.size(), 3ul); +} +TEST_F(binop_fixture, binop_node_to_ad) +{ + addop_node_t node(MockVarExpr(2), MockVarExpr(4)); + // all parameters are ignored in this case by MockVarExpr + auto expr = node.to_ad(0,0); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), 6.0); } } // namespace expr diff --git a/test/expression/variable/constant_unittest.cpp b/test/expression/variable/constant_unittest.cpp index 3357d184..2ce8a5ff 100644 --- a/test/expression/variable/constant_unittest.cpp +++ b/test/expression/variable/constant_unittest.cpp @@ -1,6 +1,6 @@ #include "gtest/gtest.h" #include -#include +#include #include namespace ppl { @@ -9,25 +9,35 @@ namespace expr { struct constant_fixture : ::testing::Test { protected: + static constexpr double defval = 0.3; using value_t = double; - value_t c = 0.3; + value_t c = defval; Constant x{c}; }; TEST_F(constant_fixture, ctor) { -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_expr_v>); -#else - static_assert(util::var_expr>); -#endif + static_assert(util::is_var_expr_v>); } -TEST_F(constant_fixture, convertible_value) +TEST_F(constant_fixture, value) { - EXPECT_EQ(x.get_value(0), 0.3); + // first parameter ignored and was chosen arbitrarily + EXPECT_DOUBLE_EQ(x.value(0), defval); c = 3.41; - EXPECT_EQ(x.get_value(0), 0.3); + EXPECT_DOUBLE_EQ(x.value(0), defval); +} + +TEST_F(constant_fixture, size) +{ + EXPECT_EQ(x.size(), 1ul); +} + +TEST_F(constant_fixture, to_ad) +{ + // Note: arbitrarily first 2 inputs (will ignore) + auto expr = x.to_ad(0,0); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), defval); } } // namespace expr diff --git a/test/expression/variable/data_unittest.cpp b/test/expression/variable/data_unittest.cpp index bab1828b..3183d176 100644 --- a/test/expression/variable/data_unittest.cpp +++ b/test/expression/variable/data_unittest.cpp @@ -1,59 +1,112 @@ -#include - #include "gtest/gtest.h" +#include +#include +#include namespace ppl { namespace expr { struct data_fixture : ::testing::Test { - protected: - Data var1 {1.0, 2.0, 3.0}; - Data var2 {1.0}; +protected: + using value_type = double; + using vec_type = std::vector; + using dview_scl_t = DataView; + using dview_vec_t = DataView; + using d_scl_t = Data; + using d_vec_t = Data; + + static constexpr value_type defval1 = 1.0; + static constexpr value_type defval2 = 2.0; + static constexpr size_t size1 = 7; + static constexpr size_t size2 = 17; + + value_type d1 = defval1; + value_type d2 = defval2; + + vec_type values1; + vec_type values2; + + data_fixture() + : values1(size1) + , values2(size2) + { + std::transform(util::counting_iterator<>(0), + util::counting_iterator<>(size1), + values1.begin(), + [=](auto i) { return i + defval1; }); - size_t expected_size; - size_t real_size; + std::transform(util::counting_iterator<>(0), + util::counting_iterator<>(size2), + values2.begin(), + [=](auto i) { return i + defval2; }); + } }; -TEST_F(data_fixture, test_multiple_value) { - expected_size = 3; - real_size = var1.size(); - - EXPECT_EQ(expected_size, real_size); - - expected_size = 1; - real_size = var2.size(); - - EXPECT_EQ(expected_size, real_size); - - EXPECT_EQ(var1.get_value(0), 1.0); - EXPECT_EQ(var1.get_value(1), 2.0); - EXPECT_EQ(var1.get_value(2), 3.0); - -#ifndef NDEBUG - EXPECT_DEATH({ - var2.get_value(1); - }, ""); - - EXPECT_DEATH({ - var2.get_value(-1); - }, ""); - - EXPECT_DEATH({ - var1.get_value(3); - }, ""); -#endif - - var1.clear(); - expected_size = 0; - real_size = var1.size(); - EXPECT_EQ(expected_size, real_size); - - var1.observe(0.1); - var1.observe(0.2); - - expected_size = 2; - real_size = var1.size(); - EXPECT_EQ(expected_size, real_size); +TEST_F(data_fixture, type_check) +{ + static_assert(util::is_data_v); + static_assert(util::is_data_v); + static_assert(util::is_data_v); + static_assert(util::is_data_v); +} + +//////////////////////////////////////// +// DataView: scl +//////////////////////////////////////// + +TEST_F(data_fixture, dview_scl_value) +{ + dview_scl_t view(d1); + + // all parameters should not matter + // this is was just to match API for variable expressions + // data already views its own values + EXPECT_DOUBLE_EQ(view.value(values1, 0), d1); + EXPECT_DOUBLE_EQ(view.value(values1, 1), d1); + EXPECT_DOUBLE_EQ(view.value(values2, 2), d1); +} + +TEST_F(data_fixture, dview_scl_size) +{ + dview_scl_t view(d1); + EXPECT_EQ(view.size(), 1ul); +} + +TEST_F(data_fixture, dview_scl_to_ad) +{ + dview_scl_t view(d1); + // both parameters are ignored + auto expr = view.to_ad(0,0); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), defval1); +} + +//////////////////////////////////////// +// DataView: vec +//////////////////////////////////////// + +TEST_F(data_fixture, dview_vec_value) +{ + dview_vec_t view(values1); + // passed in values should not matter at all + // data already views its own values + // the index matters though + EXPECT_DOUBLE_EQ(view.value(values2, 0), values1[0]); + EXPECT_DOUBLE_EQ(view.value(values2, 1), values1[1]); + EXPECT_DOUBLE_EQ(view.value(values2, 2), values1[2]); +} + +TEST_F(data_fixture, dview_vec_size) +{ + dview_vec_t view(values1); + EXPECT_EQ(view.size(), values1.size()); +} + +TEST_F(data_fixture, dview_vec_to_ad) +{ + dview_vec_t view(values1); + // only the last argument is not ignored + auto expr = view.to_ad(0,3); + EXPECT_DOUBLE_EQ(ad::evaluate(expr), values1[3]); } } // namespace expr diff --git a/test/expression/variable/param_unittest.cpp b/test/expression/variable/param_unittest.cpp index 82283c6b..f0fff56c 100644 --- a/test/expression/variable/param_unittest.cpp +++ b/test/expression/variable/param_unittest.cpp @@ -1,38 +1,175 @@ -#include - #include "gtest/gtest.h" +#include +#include +#include +#include +#include namespace ppl { namespace expr { struct param_fixture : ::testing::Test { - protected: - Param param1; - Param param2 {3.}; +protected: + using value_type = double; + using pointer_t = value_type*; + using vec_pointer_t = std::vector; + + using pview_scl_t = ParamView; + using pview_vec_t = ParamView; + using p_scl_t = Param; + using p_vec_t = Param; + + using index_t = typename util::param_traits::index_t; + + static constexpr value_type defval1 = 1.0; + static constexpr value_type defval2 = 2.0; + static constexpr size_t size1 = 7; + static constexpr size_t size2 = 17; + + // hypothetical storage: one sample for each param value + std::array storage1 = {0}; + std::array storage2 = {0}; + + // hypothetical parameter values + std::vector values1; + std::vector values2; + + // hypothetical storage ptrs for sample + vec_pointer_t storage_ptrs1; + vec_pointer_t storage_ptrs2; + + // hypothetical offsets + index_t offset = 0; + + param_fixture() + : values1(size1) + , values2(size2) + , storage_ptrs1(size1) + , storage_ptrs2(size2) + { + std::transform(util::counting_iterator<>(0), + util::counting_iterator<>(size1), + values1.begin(), + [=](auto i) { return i + defval1; }); - size_t expected_size; - size_t real_size; + std::transform(util::counting_iterator<>(0), + util::counting_iterator<>(size2), + values2.begin(), + [=](auto i) { return i + defval2; }); + + std::transform(storage1.begin(), + storage1.end(), + storage_ptrs1.begin(), + [](auto& x) { return &x; }); + + std::transform(storage2.begin(), + storage2.end(), + storage_ptrs2.begin(), + [](auto& x) { return &x; }); + } }; -TEST_F(param_fixture, test_multiple_value) { - expected_size = 1; - real_size = param1.size(); - - EXPECT_EQ(expected_size, real_size); +TEST_F(param_fixture, type_check) +{ + static_assert(util::is_param_v); + static_assert(util::is_param_v); + static_assert(util::is_param_v); + static_assert(util::is_param_v); +} + +//////////////////////////////////////// +// DataView: scl +//////////////////////////////////////// + +TEST_F(param_fixture, pview_scl_value) +{ + auto&& s1 = storage_ptrs1[0]; + pview_scl_t view(offset, s1, 1); + + // last parameter should not matter + EXPECT_DOUBLE_EQ(view.value(values1, 0), values1[1]); + EXPECT_DOUBLE_EQ(view.value(values1, 1), values1[1]); + EXPECT_DOUBLE_EQ(view.value(values1, 2), values1[1]); + + // able to view a different array of values + EXPECT_DOUBLE_EQ(view.value(values2, 0), values2[1]); + EXPECT_DOUBLE_EQ(view.value(values2, 1), values2[1]); + EXPECT_DOUBLE_EQ(view.value(values2, 2), values2[1]); +} - EXPECT_EQ(param1.get_value(0), 0.0); - param1.set_value(1.0); +TEST_F(param_fixture, pview_scl_storage) +{ + auto&& s1 = storage_ptrs1[0]; + + pview_scl_t view(offset, s1, 2); + // parameter should not matter + EXPECT_EQ(view.storage(0), s1); + EXPECT_EQ(view.storage(1), s1); + EXPECT_EQ(view.storage(2), s1); + + // relative offset should not affect storage + pview_scl_t view2(offset, s1, 13124); + EXPECT_EQ(view2.storage(0), s1); + EXPECT_EQ(view2.storage(1), s1); + EXPECT_EQ(view2.storage(2), s1); +} + +TEST_F(param_fixture, pview_scl_size) +{ + pview_scl_t view(offset, storage_ptrs1[0]); + EXPECT_EQ(view.size(), 1ul); +} + +TEST_F(param_fixture, pview_scl_to_ad) +{ + auto&& s1 = storage_ptrs1[0]; + pview_scl_t view(offset, s1); + + // simply tests if gets correct elt from passed in array + // last parameter should be ignored + const auto& elt = view.to_ad(storage_ptrs1, 0); + EXPECT_EQ(elt, s1); + + const auto& elt2 = view.to_ad(storage_ptrs1, 1); + EXPECT_EQ(elt2, s1); +} + +//////////////////////////////////////// +// DataView: vec +//////////////////////////////////////// + +TEST_F(param_fixture, pview_vec_value) +{ + pview_vec_t view(offset, storage_ptrs1, storage_ptrs1.size()); + // parameter SHOULD matter + EXPECT_DOUBLE_EQ(view.value(values1, 0), values1[0]); + EXPECT_DOUBLE_EQ(view.value(values1, 1), values1[1]); + EXPECT_DOUBLE_EQ(view.value(values1, 2), values1[2]); +} + +TEST_F(param_fixture, pview_vec_size) +{ + pview_vec_t view(offset, storage_ptrs1, storage_ptrs1.size()); + EXPECT_EQ(view.size(), size1); +} + +TEST_F(param_fixture, pview_vec_storage) +{ + pview_vec_t view(offset, storage_ptrs1, storage_ptrs1.size()); + EXPECT_EQ(view.storage(0), storage_ptrs1[0]); + EXPECT_EQ(view.storage(1), storage_ptrs1[1]); + EXPECT_EQ(view.storage(2), storage_ptrs1[2]); +} - EXPECT_EQ(param1.get_value(0), 1.0); - EXPECT_EQ(param1.get_value(10), 1.0); // all indices return the same +TEST_F(param_fixture, pview_vec_to_ad) +{ + pview_vec_t view(offset, storage_ptrs1, storage_ptrs1.size()); - EXPECT_EQ(param2.get_value(0), 3.0); // all indices return the same + auto elt = view.to_ad(storage_ptrs1, 0); + EXPECT_EQ(elt, &storage1[0]); - EXPECT_EQ(param1.get_storage(), nullptr); - - double storage[5]; - param1.set_storage(storage); - EXPECT_EQ(param1.get_storage(), storage); + elt = view.to_ad(storage_ptrs1, 3); + EXPECT_EQ(elt, &storage1[3]); } } // namespace expr diff --git a/test/expression/variable/variable_viewer_unittest.cpp b/test/expression/variable/variable_viewer_unittest.cpp deleted file mode 100644 index 6462ba67..00000000 --- a/test/expression/variable/variable_viewer_unittest.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "gtest/gtest.h" -#include -#include - -namespace ppl { -namespace expr { - -struct variable_viewer_fixture : ::testing::Test -{ -protected: - using value_t = typename MockParam::value_t; - MockParam var; - VariableViewer x = var; -}; - -TEST_F(variable_viewer_fixture, ctor) -{ -#if __cplusplus <= 201703L - static_assert(util::assert_is_var_expr_v>); -#else - static_assert(util::var_expr>); -#endif -} - -TEST_F(variable_viewer_fixture, convertible_value) -{ - var.set_value(1.); - EXPECT_EQ(x.get_value(0), 1.); - - // Tests if viewer correctly reflects any changes that happened in var. - var.set_value(-3.14); - EXPECT_EQ(x.get_value(0), -3.14); -} - -} // namespace expr -} // namespace ppl diff --git a/test/math/density_unittest.cpp b/test/math/density_unittest.cpp new file mode 100644 index 00000000..d3f62b3b --- /dev/null +++ b/test/math/density_unittest.cpp @@ -0,0 +1,158 @@ +#include "gtest/gtest.h" +#include + +namespace ppl { +namespace math { + +struct normal_fixture : ::testing::Test +{ +protected: + static constexpr double tol = 1e-15; + double mean = 0.3; + double sd = 1.3; +}; + +TEST_F(normal_fixture, pdf) +{ + EXPECT_NEAR(normal_pdf(-10.231, mean, sd), 1.726752595588348216742E-15, tol); + EXPECT_NEAR(normal_pdf(-5.31, mean, sd), 2.774166877919518907166E-5, tol); + EXPECT_DOUBLE_EQ(normal_pdf(-2.3141231, mean, sd), 0.04063645713784323551341); + EXPECT_DOUBLE_EQ(normal_pdf(0., mean, sd), 0.2988151821496727914542); + EXPECT_DOUBLE_EQ(normal_pdf(1.31, mean, sd), 0.2269313951019926611687); + EXPECT_DOUBLE_EQ(normal_pdf(3.21, mean, sd), 0.02505560241243631472997); + EXPECT_NEAR(normal_pdf(5.24551, mean, sd), 2.20984513448306056291E-4, tol); + EXPECT_NEAR(normal_pdf(10.5699, mean, sd), 8.61135160183067521907E-15, tol); +} + +TEST_F(normal_fixture, log_pdf) +{ + EXPECT_DOUBLE_EQ(normal_log_pdf(-10.231, mean, sd), std::log(1.726752595588348216742E-15)); + EXPECT_DOUBLE_EQ(normal_log_pdf(-5.31, mean, sd), std::log(2.774166877919518907166E-5)); + EXPECT_DOUBLE_EQ(normal_log_pdf(-2.3141231, mean, sd), std::log(0.04063645713784323551341)); + EXPECT_DOUBLE_EQ(normal_log_pdf(0., mean, sd), std::log(0.2988151821496727914542)); + EXPECT_DOUBLE_EQ(normal_log_pdf(1.31, mean, sd), std::log(0.2269313951019926611687)); + EXPECT_DOUBLE_EQ(normal_log_pdf(3.21, mean, sd), std::log(0.02505560241243631472997)); + EXPECT_DOUBLE_EQ(normal_log_pdf(5.24551, mean, sd), std::log(2.20984513448306056291E-4)); + EXPECT_DOUBLE_EQ(normal_log_pdf(10.5699, mean, sd), std::log(8.61135160183067521907E-15)); +} + +struct uniform_fixture : ::testing::Test +{ +protected: + double min = -2.3; + double max = 2.7; +}; + +TEST_F(uniform_fixture, uniform_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(uniform_pdf(-2.2999999999, min, max), 0.2); + EXPECT_DOUBLE_EQ(uniform_pdf(-2., min, max), 0.2); + EXPECT_DOUBLE_EQ(uniform_pdf(-1.423, min, max), 0.2); + EXPECT_DOUBLE_EQ(uniform_pdf(0., min, max), 0.2); + EXPECT_DOUBLE_EQ(uniform_pdf(1.31, min, max), 0.2); + EXPECT_DOUBLE_EQ(uniform_pdf(2.41, min, max), 0.2); + EXPECT_DOUBLE_EQ(uniform_pdf(2.69999999999, min, max), 0.2); +} + +TEST_F(uniform_fixture, uniform_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(uniform_pdf(-100., min, max), 0.); + EXPECT_DOUBLE_EQ(uniform_pdf(-3.41, min, max), 0.); + EXPECT_DOUBLE_EQ(uniform_pdf(-2.3, min, max), 0.); + EXPECT_DOUBLE_EQ(uniform_pdf(2.7, min, max), 0.); + EXPECT_DOUBLE_EQ(uniform_pdf(3.5, min, max), 0.); + EXPECT_DOUBLE_EQ(uniform_pdf(3214., min, max), 0.); +} + +TEST_F(uniform_fixture, uniform_log_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(uniform_log_pdf(-2.2999999999, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(uniform_log_pdf(-2., min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(uniform_log_pdf(-1.423, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(uniform_log_pdf(0., min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(uniform_log_pdf(1.31, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(uniform_log_pdf(2.41, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(uniform_log_pdf(2.69999999999, min, max), std::log(0.2)); +} + +TEST_F(uniform_fixture, uniform_log_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(uniform_log_pdf(-100., min, max), neg_inf); + EXPECT_DOUBLE_EQ(uniform_log_pdf(-3.41, min, max), neg_inf); + EXPECT_DOUBLE_EQ(uniform_log_pdf(-2.3, min, max), neg_inf); + EXPECT_DOUBLE_EQ(uniform_log_pdf(2.7, min, max), neg_inf); + EXPECT_DOUBLE_EQ(uniform_log_pdf(3.5, min, max), neg_inf); + EXPECT_DOUBLE_EQ(uniform_log_pdf(3214., min, max), neg_inf); +} + +struct bernoulli_fixture : ::testing::Test +{ +protected: + double p = 0.6; +}; + +TEST_F(bernoulli_fixture, bernoulli_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(bernoulli_pdf(0, p), 1-p); + EXPECT_DOUBLE_EQ(bernoulli_pdf(1, p), p); +} + +TEST_F(bernoulli_fixture, bernoulli_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(bernoulli_pdf(-100, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(-3, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(-2, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(2, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(3, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(5, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(100, p), 0.); +} + +TEST_F(bernoulli_fixture, bernoulli_pdf_always_tail) +{ + double p = 0.; + EXPECT_DOUBLE_EQ(bernoulli_pdf(0, p), 1.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(1, p), 0.); +} + +TEST_F(bernoulli_fixture, bernoulli_pdf_always_head) +{ + double p = 1.; + EXPECT_DOUBLE_EQ(bernoulli_pdf(0, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_pdf(1, p), 1.); +} + +TEST_F(bernoulli_fixture, bernoulli_log_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(0, p), std::log(1-p)); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(1, p), std::log(p)); +} + +TEST_F(bernoulli_fixture, bernoulli_log_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(-100, p), neg_inf); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(-3, p), neg_inf); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(-1, p), neg_inf); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(2, p), neg_inf); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(3, p), neg_inf); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(5, p), neg_inf); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(100, p), neg_inf); +} + +TEST_F(bernoulli_fixture, bernoulli_log_pdf_always_tail) +{ + double p = 0.; + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(0, p), 0.); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(1, p), neg_inf); +} + +TEST_F(bernoulli_fixture, bernoulli_log_pdf_always_head) +{ + double p = 1.; + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(0, p), neg_inf); + EXPECT_DOUBLE_EQ(bernoulli_log_pdf(1, p), 0.); +} + + +} // namespace math +} // namespace ppl diff --git a/test/math/math_unittest.cpp b/test/math/math_unittest.cpp new file mode 100644 index 00000000..86329f77 --- /dev/null +++ b/test/math/math_unittest.cpp @@ -0,0 +1,27 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace math { + +struct math_fixture : ::testing::Test +{ +protected: + std::array x = {0}; +}; + +TEST_F(math_fixture, min_edge_case) +{ + auto res = min(x.end(), x.begin()); + EXPECT_DOUBLE_EQ(res, inf); +} + +TEST_F(math_fixture, max_edge_case) +{ + auto res = max(x.end(), x.begin()); + EXPECT_DOUBLE_EQ(res, neg_inf); +} + +} // namespace math +} // namespace ppl diff --git a/test/mcmc/hmc/nuts/nuts_unittest.cpp b/test/mcmc/hmc/nuts/nuts_unittest.cpp index f9579dde..a4f12916 100644 --- a/test/mcmc/hmc/nuts/nuts_unittest.cpp +++ b/test/mcmc/hmc/nuts/nuts_unittest.cpp @@ -1,6 +1,6 @@ #include "gtest/gtest.h" #include -#include +#include #include #include #include @@ -164,7 +164,8 @@ TEST_F(nuts_build_tree_fixture, find_reasonable_log_epsilon) ad_vars[1] * ad_vars[1] + ad_vars[2] * ad_vars[2] ) ; - double eps = mcmc::find_reasonable_epsilon<3>(1., ad_expr, theta, theta_adj, m_handler); + double eps = mcmc::find_reasonable_epsilon( + 1., ad_expr, theta, theta_adj, m_handler); static_cast(eps); } @@ -172,12 +173,15 @@ struct nuts_fixture : nuts_tools_fixture { protected: size_t n_samples = 5000; - std::vector w_storage, b_storage; - Param w, b; - ppl::Data x {2.5, 3, 3.5, 4, 4.5, 5.}; - ppl::Data y {3.5, 4, 4.5, 5, 5.5, 6.}; - ppl::Data q{2.4, 3.1, 3.6, 4, 4.5, 5.}; - ppl::Data r{3.5, 4, 4.4, 5.01, 5.46, 6.1}; + using value_t = double; + using p_scl_t = ppl::Param; + using d_vec_t = ppl::Data; + std::vector w_storage, b_storage; + p_scl_t w, b; + d_vec_t x {2.5, 3, 3.5, 4, 4.5, 5.}; + d_vec_t y {3.5, 4, 4.5, 5, 5.5, 6.}; + d_vec_t q{2.4, 3.1, 3.6, 4, 4.5, 5.}; + d_vec_t r{3.5, 4, 4.4, 5.01, 5.46, 6.1}; NUTSConfig<> config; nuts_fixture() diff --git a/test/mcmc/hmc/var_adapter_unittest.cpp b/test/mcmc/hmc/var_adapter_unittest.cpp index c65376dd..16bcefc8 100644 --- a/test/mcmc/hmc/var_adapter_unittest.cpp +++ b/test/mcmc/hmc/var_adapter_unittest.cpp @@ -1,5 +1,5 @@ -#include #include +#include namespace ppl { namespace mcmc { @@ -7,12 +7,221 @@ namespace mcmc { struct var_adapter_fixture : ::testing::Test { protected: + using diag_adapter_t = VarAdapter; + arma::vec x = arma::zeros(1); + arma::vec var = arma::zeros(1); + + size_t n_params = 1; + + void test_case_1(size_t warmup, + size_t init_buffer, + size_t term_buffer, + size_t window_base) + { + diag_adapter_t adapter(n_params, warmup, init_buffer, + term_buffer, window_base); + + bool res; + for (size_t i = 0; i < warmup-1; ++i) { + res = adapter.adapt(x, var); + EXPECT_FALSE(res); + } + + res = adapter.adapt(x, var); + EXPECT_TRUE(res); + } + + void test_case_2(size_t warmup, + size_t init_buffer, + size_t term_buffer, + size_t window_base) + { + diag_adapter_t adapter(n_params, warmup, init_buffer, + term_buffer, window_base); + + bool res; + + size_t new_init_buffer = 0.15 * warmup; + size_t new_term_buffer = 0.1 * warmup; + size_t new_window_base = warmup - new_init_buffer - new_term_buffer; + + // init buffer always returns false + for (size_t i = 0; i < new_init_buffer; ++i) { + res = adapter.adapt(x, var); + EXPECT_FALSE(res); + } + + // first window always returns false except at the very end + for (size_t i = 0; i < new_window_base-1; ++i) { + res = adapter.adapt(x, var); + EXPECT_FALSE(res); + } + res = adapter.adapt(x, var); + EXPECT_TRUE(res); + + // termination always returns false + for (size_t i = 0; i < new_term_buffer; ++i) { + res = adapter.adapt(x, var); + EXPECT_FALSE(res); + } + } + + void test_case_3(size_t warmup, + size_t init_buffer, + size_t term_buffer, + size_t window_base) + { + diag_adapter_t adapter(n_params, warmup, init_buffer, + term_buffer, window_base); + + bool res; + + // init buffer always returns false + for (size_t i = 0; i < init_buffer; ++i) { + res = adapter.adapt(x, var); + EXPECT_FALSE(res); + } + + // Adapt for every window + for (size_t i = init_buffer; + i < warmup - term_buffer; + window_base *= 2) { + + // check if at the last window that may have just been extended to term + size_t window_end = (i + 3*window_base < warmup-term_buffer) ? + init_buffer+window_base : warmup-term_buffer; + + // within window always returns false except at the very end + for (; i < window_end - 1; ++i) { + res = adapter.adapt(x, var); + EXPECT_FALSE(res); + } + + // reached last iteration of window - check that returns true + res = adapter.adapt(x, var); + EXPECT_TRUE(res); + + if (++i == warmup - term_buffer) break; + } + + // termination always returns false + for (size_t i = 0; i < term_buffer; ++i) { + res = adapter.adapt(x, var); + EXPECT_FALSE(res); + } + } }; -TEST_F(var_adapter_fixture, diag) +// Case 1: warmup <= 20 +// Subcase 1: large term buffer +TEST_F(var_adapter_fixture, diag_ctor_case_11) +{ + size_t warmup = 10; + size_t init_buffer = 1; + size_t term_buffer = 13; + size_t window_base = 4; + test_case_1(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 1: warmup <= 20 +// Subcase 2: large init buffer +TEST_F(var_adapter_fixture, diag_ctor_case_12) +{ + size_t warmup = 10; + size_t init_buffer = 9; + size_t term_buffer = 0; + size_t window_base = 5; + test_case_1(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 1: warmup <= 20 +// Subcase 3: large window +TEST_F(var_adapter_fixture, diag_ctor_case_13) +{ + size_t warmup = 10; + size_t init_buffer = 9; + size_t term_buffer = 1; + size_t window_base = 20; + test_case_1(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 2: 20 < warmup < init + window_base + term +// Subcase 1: large init buffer +TEST_F(var_adapter_fixture, diag_ctor_case_21) +{ + size_t warmup = 100; + size_t init_buffer = 110; + size_t term_buffer = 10; + size_t window_base = 10; + test_case_2(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 2: 20 < warmup < init + window_base + term +// Subcase 2: large init buffer +TEST_F(var_adapter_fixture, diag_ctor_case_22) +{ + size_t warmup = 100; + size_t init_buffer = 10; + size_t term_buffer = 110; + size_t window_base = 10; + test_case_2(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 2: 20 < warmup < init + window_base + term +// Subcase 3: large term buffer +TEST_F(var_adapter_fixture, diag_ctor_case_23) +{ + size_t warmup = 100; + size_t init_buffer = 50; + size_t term_buffer = 10; + size_t window_base = 110; + test_case_2(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 3: warmup >= init + window_base + term +// Subcase 1: large init buffer +TEST_F(var_adapter_fixture, diag_ctor_case_31) +{ + size_t warmup = 100; + size_t init_buffer = 50; + size_t term_buffer = 10; + size_t window_base = 30; + test_case_3(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 3: warmup >= init + window_base + term +// Subcase 2: large term buffer +TEST_F(var_adapter_fixture, diag_ctor_case_32) +{ + size_t warmup = 100; + size_t init_buffer = 5; + size_t term_buffer = 80; + size_t window_base = 10; + test_case_3(warmup, init_buffer, + term_buffer, window_base); + + term_buffer = 30; + test_case_3(warmup, init_buffer, + term_buffer, window_base); +} + +// Case 3: warmup >= init + window_base + term +// Subcase 3: large window buffer +TEST_F(var_adapter_fixture, diag_ctor_case_33) { - VarAdapter adapter1(3, 3, 1, 1, 1); - VarAdapter adapter2(3, 30, 10, 20, 10); + size_t warmup = 10031; + size_t init_buffer = 63; + size_t term_buffer = 59; + size_t window_base = 1582; + test_case_3(warmup, init_buffer, + term_buffer, window_base); } } // namespace mcmc diff --git a/test/mcmc/mh_regression_unittest.cpp b/test/mcmc/mh_regression_unittest.cpp index f0ddf5cf..aeeaaad4 100644 --- a/test/mcmc/mh_regression_unittest.cpp +++ b/test/mcmc/mh_regression_unittest.cpp @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include #include @@ -12,20 +12,24 @@ namespace ppl { * Fixture for Metropolis-Hastings */ struct mh_regression_fixture : ::testing::Test { - protected: +protected: + using cont_value_t = double; + using p_cont_scl_t = Param; + using d_cont_vec_t = Data; + size_t sample_size = 50000; - double tol = 1e-8; + cont_value_t tol = 1e-8; - std::vector w_storage, b_storage; - Param w, b; + std::vector w_storage, b_storage; + p_cont_scl_t w, b; - ppl::Data x {2.5, 3, 3.5, 4, 4.5, 5.}; - ppl::Data y {3.5, 4, 4.5, 5, 5.5, 6.}; + d_cont_vec_t x {2.5, 3, 3.5, 4, 4.5, 5.}; + d_cont_vec_t y {3.5, 4, 4.5, 5, 5.5, 6.}; - ppl::Data q{2.4, 3.1, 3.6, 4, 4.5, 5.}; - ppl::Data r{3.5, 4, 4.4, 5.01, 5.46, 6.1}; + d_cont_vec_t q{2.4, 3.1, 3.6, 4, 4.5, 5.}; + d_cont_vec_t r{3.5, 4, 4.4, 5.01, 5.46, 6.1}; - size_t burn = 1000; + size_t warmup = 1000; mh_regression_fixture() : w_storage(sample_size) @@ -35,19 +39,19 @@ struct mh_regression_fixture : ::testing::Test { {} template - double sample_average(const ArrayType& storage) + cont_value_t sample_average(const ArrayType& storage) { - double sum = std::accumulate( - std::next(storage.begin(), burn), + cont_value_t sum = std::accumulate( + std::next(storage.begin(), warmup), storage.end(), 0.); - return sum / (storage.size() - burn); + return sum / (storage.size() - warmup); } }; TEST_F(mh_regression_fixture, sample_regression_dist) { - auto model = (w |= ppl::uniform(0, 2), - b |= ppl::uniform(0, 2), + auto model = (w |= ppl::uniform(0., 2.), + b |= ppl::uniform(0., 2.), y |= ppl::normal(x * w + b, 0.5) ); @@ -61,8 +65,8 @@ TEST_F(mh_regression_fixture, sample_regression_dist) { } TEST_F(mh_regression_fixture, sample_regression_fuzzy_dist) { - auto model = (w |= ppl::uniform(0, 2), - b |= ppl::uniform(0, 2), + auto model = (w |= ppl::uniform(0., 2.), + b |= ppl::uniform(0., 2.), r |= ppl::normal(q * w + b, 0.5)); ppl::mh(model, sample_size); @@ -85,4 +89,4 @@ TEST_F(mh_regression_fixture, sample_regression_normal_weight) { EXPECT_NEAR(sample_average(w_storage), 1.0, 0.1); } -} // ppl +} // namespace ppl diff --git a/test/mcmc/mh_unittest.cpp b/test/mcmc/mh_unittest.cpp index 3cae8ab4..c376ab3d 100644 --- a/test/mcmc/mh_unittest.cpp +++ b/test/mcmc/mh_unittest.cpp @@ -2,7 +2,7 @@ #include #include #include -#include +#include #include namespace ppl { @@ -13,85 +13,95 @@ namespace ppl { struct mh_fixture : ::testing::Test { protected: + using cont_value_t = double; + using disc_value_t = int; + using p_cont_scl_t = Param; + using p_cont_vec_t = Param; + using p_disc_scl_t = Param; + using d_cont_scl_t = Data; + using d_disc_scl_t = Data; + using d_cont_vec_t = Data; + size_t sample_size = 20000; - std::vector storage, storage_2; - Param theta, theta_2; - Data y {0.1, 0.2, 0.3, 0.4, 0.5}; - Data x; - Data x_discrete; - size_t burn = 1000; + size_t warmup = 1000; + std::vector cont_storage, cont_storage_2; + std::vector disc_storage, disc_storage_2; + p_cont_scl_t theta, theta_2; + d_cont_vec_t y {0.1, 0.2, 0.3, 0.4, 0.5}; mh_fixture() - : storage(sample_size) - , storage_2(sample_size) - , theta{storage.data()} - , theta_2{storage_2.data()} + : cont_storage(sample_size) + , cont_storage_2(sample_size) + , disc_storage(sample_size) + , disc_storage_2(sample_size) + , theta{cont_storage.data()} + , theta_2{cont_storage_2.data()} {} template double sample_average(const ArrayType& storage) { double sum = std::accumulate( - std::next(storage.begin(), burn), + std::next(storage.begin(), warmup), storage.end(), 0.); - return sum / (storage.size() - burn); + return sum / (storage.size() - warmup); } }; TEST_F(mh_fixture, sample_std_normal) { auto model = (theta |= normal(0., 1.)); - mh(model, sample_size, 1000, 1.0, 0.25, 0.); - plot_hist(storage); - EXPECT_NEAR(sample_average(storage), 0., 0.1); + mh(model, sample_size, warmup, 1.0, 0.25, 0); + plot_hist(cont_storage); + EXPECT_NEAR(sample_average(cont_storage), 0., 0.1); } TEST_F(mh_fixture, sample_uniform) { auto model = (theta |= uniform(0., 1.)); - mh(model, sample_size, 1000, 1.0, 0.25, 0.); - plot_hist(storage, 0.1, 0., 1.); - EXPECT_NEAR(sample_average(storage), 0.5, 0.1); + mh(model, sample_size, warmup, 1.0, 0.25, 0); + plot_hist(cont_storage, 0.1, 0., 1.); + EXPECT_NEAR(sample_average(cont_storage), 0.5, 0.1); } TEST_F(mh_fixture, sample_unif_normal_posterior_mean) { - x.observe(3.); + d_cont_scl_t x(3.); auto model = ( theta |= uniform(-20., 20.), x |= normal(theta, 1.) ); - mh(model, sample_size, 1000, 1.0, 0.25, 0.); - plot_hist(storage); - EXPECT_NEAR(sample_average(storage), 3.0, 0.1); + mh(model, sample_size, warmup, 1.0, 0.25, 0.); + plot_hist(cont_storage); + EXPECT_NEAR(sample_average(cont_storage), 3.0, 0.1); } TEST_F(mh_fixture, sample_unif_normal_posterior_stddev) { - x.observe(3.14); + d_cont_scl_t x(3.14); auto model = ( theta |= uniform(0.1, 5.), x |= normal(0., theta) ); - mh(model, sample_size, 1000, 0.5, 0.25, 0.); - plot_hist(storage, 0.2); - EXPECT_NEAR(sample_average(storage), 3.27226, 0.1); + mh(model, sample_size, warmup, 0.5, 0.25, 0.); + plot_hist(cont_storage, 0.2); + EXPECT_NEAR(sample_average(cont_storage), 3.27226, 0.1); } TEST_F(mh_fixture, sample_unif_normal_posterior_mean_stddev) { - x.observe(-0.314); + d_cont_scl_t x(-0.314); auto model = ( theta |= normal(0., 1.), theta_2 |= uniform(0.1, 5.), x |= normal(theta, theta_2) ); - mh(model, sample_size, 1000, 0.5, 0.25, 0.); - plot_hist(storage); - plot_hist(storage_2, 0.2); - EXPECT_NEAR(sample_average(storage), -0.1235305689822228, 0.1); - EXPECT_NEAR(sample_average(storage_2), 1.868814361437099766, 0.1); + mh(model, sample_size, warmup, 0.5, 0.25, 0.); + plot_hist(cont_storage); + plot_hist(cont_storage_2, 0.2); + EXPECT_NEAR(sample_average(cont_storage), -0.1235305689822228, 0.1); + EXPECT_NEAR(sample_average(cont_storage_2), 1.868814361437099766, 0.1); } TEST_F(mh_fixture, sample_unif_normal_posterior_mean_samples) { @@ -100,9 +110,9 @@ TEST_F(mh_fixture, sample_unif_normal_posterior_mean_samples) { y |= normal(theta, 1.0) // {0.1, 0.2, 0.3, 0.4, 0.5} ); - mh(model, sample_size, 1000, 0.5, 0.25, 0.); - plot_hist(storage); - EXPECT_NEAR(sample_average(storage), 0.3, 0.1); + mh(model, sample_size, warmup, 0.5, 0.25, 0.); + plot_hist(cont_storage); + EXPECT_NEAR(sample_average(cont_storage), 0.3, 0.1); } TEST_F(mh_fixture, sample_unif_normal_posterior_mean_std_samples) { @@ -112,51 +122,52 @@ TEST_F(mh_fixture, sample_unif_normal_posterior_mean_std_samples) { y |= normal(theta, theta_2) // {0.1, 0.2, 0.3, 0.4, 0.5} ); - mh(model, sample_size, 1000, 0.5, 0.25, 0.); + mh(model, sample_size, warmup, 0.5, 0.25, 0.); - plot_hist(storage, 0.5); - plot_hist(storage_2, 0.5); + plot_hist(cont_storage, 0.5); + plot_hist(cont_storage_2, 0.5); - EXPECT_NEAR(sample_average(storage), 0.29951, 0.05); // found numerical with Mathematica - EXPECT_NEAR(sample_average(storage_2), 0.241658, 0.05); + EXPECT_NEAR(sample_average(cont_storage), 0.29951, 0.05); // found numerical with Mathematica + EXPECT_NEAR(sample_average(cont_storage_2), 0.241658, 0.05); } TEST_F(mh_fixture, sample_unif_bern_posterior_observe_zero) { - x_discrete.observe(0); + d_disc_scl_t x_discrete(0); auto model = ( theta |= uniform(0., 1.), x_discrete |= bernoulli(theta) ); - mh(model, sample_size, 1000, 1.0, 0.25, 0.); - plot_hist(storage, 0.2, 0., 1.); - EXPECT_NEAR(sample_average(storage), 1./3., 0.1); + mh(model, sample_size, warmup, 1.0, 0.25, 0.); + plot_hist(cont_storage, 0.2, 0., 1.); + EXPECT_NEAR(sample_average(cont_storage), 1./3., 0.1); } TEST_F(mh_fixture, sample_unif_bern_posterior_observe_one) { - x_discrete.observe(1); + d_disc_scl_t x_discrete(1); auto model = ( theta |= uniform(0., 1.), x_discrete |= bernoulli(theta) ); - mh(model, sample_size, 1000, 1.0, 0.25, 0.); - plot_hist(storage, 0.2, 0., 1.); - EXPECT_NEAR(sample_average(storage), 2./3., 0.1); + mh(model, sample_size, warmup, 1.0, 0.25, 0.); + plot_hist(cont_storage, 0.2, 0., 1.); + EXPECT_NEAR(sample_average(cont_storage), 2./3., 0.1); } -TEST_F(mh_fixture, sample_bern_normal_posterior) -{ - std::vector storage(sample_size); - Param theta{storage.data()}; - x.observe(1.); - auto model = ( - theta |= bernoulli(0.5), - x |= normal(theta, 1.) - ); - mh(model, sample_size, 1000, 1.0, 1./3, 0.); - plot_hist(storage, 0.2, 0., 1.); - EXPECT_NEAR(sample_average(storage), 0.62245933120185456463890056, 0.1); -} +// COMPILER ERROR: good :) discrete param should not be a continuous parameter +//TEST_F(mh_fixture, sample_bern_normal_posterior) +//{ +// p_disc_scl_t theta(disc_storage.data()); +// d_cont_scl_t x(1.); +// auto model = ( +// theta |= bernoulli(0.5), +// x |= normal(theta, 1.) +// ); +// mh(model, sample_size, warmup, 1.0, 1./3, 0.); +// plot_hist(disc_storage, 0.2, 0., 1.); +// EXPECT_NEAR(sample_average(disc_storage), +// 0.62245933120185456463890056, 0.1); +//} } // namespace ppl diff --git a/test/mcmc/sampler_tools_unittest.cpp b/test/mcmc/sampler_tools_unittest.cpp index bc2dd48a..27e8174c 100644 --- a/test/mcmc/sampler_tools_unittest.cpp +++ b/test/mcmc/sampler_tools_unittest.cpp @@ -1,7 +1,7 @@ #include "gtest/gtest.h" #include -#include -#include +#include +#include #include namespace ppl { @@ -10,15 +10,76 @@ namespace mcmc { struct sampler_tools_fixture : ::testing::Test { protected: - using var_t = Param; + using cont_value_t = double; + using disc_value_t = int; + using cont_param_t = Param; + using disc_param_t = Param; - static constexpr size_t n_params = 10; - std::array, n_params> thetas; - Data x; + static constexpr size_t size = 3; + + std::array disc_values; + std::array cont_values; + std::array cont_one_samples; + cont_param_t cw = size; + disc_param_t dw = size; + + std::mt19937 gen; sampler_tools_fixture() - {} + : disc_values{{0, 1, 1}} + , cont_values{{-3., 0.2, 13.23}} + , cont_one_samples{{0,0,0}} + { + for (size_t i = 0; i < size; ++i) { + cw.set_storage(&cont_one_samples[i], i); + } + } }; +TEST_F(sampler_tools_fixture, init_param_disc) +{ + auto model = (dw |= bernoulli(0.5)); + activate(model); + init_params(model, gen, disc_values); + for (size_t i = 0; i < size; ++i) { + EXPECT_LE(0, disc_values[i]); + EXPECT_LE(disc_values[i], 1); + } +} + +TEST_F(sampler_tools_fixture, init_param_cont_unbounded) +{ + auto model = (cw |= normal(0., 1.)); + activate(model); + init_params(model, gen, cont_values); + for (size_t i = 0; i < size; ++i) { + EXPECT_LT(math::neg_inf, cont_values[i]); + EXPECT_LT(cont_values[i], math::inf); + } +} + +TEST_F(sampler_tools_fixture, init_param_cont_bounded) +{ + cont_value_t min = 0.; + cont_value_t max = 0.000001; + auto model = (cw |= uniform(min, max)); + activate(model); + init_params(model, gen, cont_values); + for (size_t i = 0; i < size; ++i) { + EXPECT_LE(min, cont_values[i]); + EXPECT_LE(cont_values[i], max); + } +} + +TEST_F(sampler_tools_fixture, store_sample) +{ + auto model = (cw |= normal(0., 1.)); + activate(model); + store_sample(model, cont_values, 0); // store first sample + for (size_t i = 0; i < size; ++i) { + EXPECT_DOUBLE_EQ(cont_one_samples[i], cont_values[i]); + } +} + } // namespace mcmc } // namespace ppl diff --git a/test/testutil/mock_types.hpp b/test/testutil/mock_types.hpp index 55a6dc87..130ef6bd 100644 --- a/test/testutil/mock_types.hpp +++ b/test/testutil/mock_types.hpp @@ -1,6 +1,8 @@ #pragma once #include +#include #include +#include #include namespace ppl { @@ -13,142 +15,198 @@ enum class MockState { parameter }; -/* - * Mock Variable class that should meet the requirements - * of is_var_v. - */ -struct MockParam : util::ParamLike { - +struct MockParam: + util::VarExprBase, + util::ParamBase +{ using value_t = double; using pointer_t = double*; using const_pointer_t = const double*; - - void set_value(value_t x) { value_ = x; } - value_t get_value(size_t) const { return value_; } - constexpr size_t size() const { return 1; } - - void set_storage(pointer_t ptr) {ptr_ = ptr;} + using shape_t = ppl::scl; + using index_t = uint32_t; + using id_t = int; + static constexpr bool has_param = true; + + template + value_t value(const PVecType&, + size_t=0, + F f = F()) const + { return f(value_); } + + constexpr size_t size() const { return 1ul; } + const pointer_t& storage(size_t=0) const { return ptr_; } + id_t id() const { return id_; } + + /* Not part of API */ + MockParam(value_t value) : value_{value} {} + MockParam() =default; private: + id_t id_ = 0; value_t value_ = 0.0; pointer_t ptr_ = nullptr; }; -struct MockData : util::DataLike +struct MockData: + util::VarExprBase, + util::DataBase { using value_t = double; - using pointer_t = double*; - using const_pointer_t = const double*; + using shape_t = ppl::scl; + using id_t = int; + static constexpr bool has_param = true; - value_t get_value(size_t) const { - return value_; - } + template + const value_t& value(const PVecType&, + size_t=0, + F = F()) const + { return value_; } - constexpr size_t size() const { return 1; } + constexpr size_t size() const { return 1ul; } + id_t id() const { return id_; } private: + id_t id_ = 0; value_t value_ = 0.0; }; - /* - * Mock variable classes that fulfill - * var_traits requirements, but do not fit the rest. + * Mock param class that fits all but the "new" conditions of param. */ -struct MockParam_no_convertible : util::Var +struct MockNotParam: + util::VarExprBase { using value_t = double; - using pointer_t = double*; - using const_pointer_t = const double*; -}; + using shape_t = ppl::scl; + static constexpr bool has_param = true; -struct MockData_no_convertible : util::Var { - using value_t = double; - using pointer_t = double*; - using const_pointer_t = const double*; + template + const value_t& value(const PVecType&, + size_t=0, + F = F()) const + { return value_; } + + constexpr size_t size() const { return 1ul; } + +private: + value_t value_ = 0.0; }; /* - * Mock Variable Expression class that should meet the requirements - * of is_var_expr_v. + * Mock data class that fits all but the "new" conditions of data. */ -struct MockVarExpr : util::VarExpr +struct MockNotData: + util::VarExprBase { using value_t = double; - value_t get_value(size_t) const { - return x_; - } + using shape_t = ppl::scl; + static constexpr bool has_param = true; - constexpr size_t size() const { return 1; } + template + const value_t& value(const PVecType&, + size_t=0, + F = F()) const + { return value_; } - /* not part of API */ - MockVarExpr(value_t x = 0.) - : x_{x} - {} - void set_value(value_t x) {x_ = x;} + constexpr size_t size() const { return 1ul; } private: - value_t x_ = 0.; + value_t value_ = 0.0; }; /* - * Mock variable expression classes that fulfill - * var_expr_traits requirements, but do not fit the rest. + * Mock variable expression class that fits all + * conditions of variable expression. */ -struct MockVarExpr_no_convertible : util::VarExpr +struct MockVarExpr: + util::VarExprBase { using value_t = double; -}; + using shape_t = ppl::scl; + static constexpr bool has_param = true; -/* - * Mock distribution expression class that should meet the requirements - * of is_dist_expr_v. - */ -struct MockDistExpr : util::DistExpr -{ - using value_t = double; + template + const value_t& value(const PVecType&, + size_t=0, + F = F()) const + { return x_; } - using base_t = util::DistExpr; - using dist_value_t = typename base_t::dist_value_t; - using base_t::pdf; - using base_t::log_pdf; + size_t size() const { return x_; } - dist_value_t pdf(value_t x, size_t=0) const { return x; } + template + auto to_ad(const T&, const U&, size_t=0) const { + return ad::constant(x_); + } - dist_value_t log_pdf(value_t x, size_t=0) const { return std::log(x); } + /* not part of API */ + MockVarExpr(value_t x = 0.) + : x_{x} + {} - value_t min() const { return 0.; } - value_t max() const { return 1.; } +private: + value_t x_; }; /* - * Mock distribution expression classes that fulfill - * dist_expr_traits requirements, but do not fit the rest. + * Mock variable expression class that fits all but the "new" + * conditions of variable expression. */ -struct MockDistExpr_no_pdf : - util::DistExpr, - public MockDistExpr +struct MockNotVarExpr { -private: - using dist_value_t = typename MockDistExpr::dist_value_t; - using MockDistExpr::pdf; + using shape_t = ppl::scl; + constexpr size_t size() const { return 1ul; } }; -struct MockDistExpr_no_log_pdf : public MockDistExpr +/* + * Mock shaped class that fits all conditions of shape. + */ +struct MockScalar { -private: - using MockDistExpr::log_pdf; + using shape_t = ppl::scl; + constexpr size_t size() const { return 1ul; } }; /* - * Mock binary operation node for testing purposes. + * Mock distribution expression class that fits all + * conditions of is_dist_expr_v. */ -struct MockBinaryOp +struct MockDistExpr: util::DistExprBase { - // mock operation -- returns the sum - static double evaluate(double x, double y) { - return x + y; - } +private: + using base_t = util::DistExprBase; +public: + using value_t = double; + using dist_value_t = typename base_t::dist_value_t; + + value_t min() const { return 0.; } + value_t max() const { return 1.; } + + /* Not part of API */ + MockDistExpr(value_t p=0) : p_{p} {} + + template + value_t pdf(const VarType& x, + const PVecType& pvalues, + F f = F()) const + { return x.value(pvalues, 0, f) * p_; } + + template + value_t log_pdf(const VarType& x, + const PVecType& pvalues, + F f = F()) const + { return std::log(this->pdf(x, pvalues, f)); } + +private: + value_t p_; }; /* diff --git a/test/util/dist_expr_traits_unittest.cpp b/test/util/dist_expr_traits_unittest.cpp deleted file mode 100644 index f4f8350e..00000000 --- a/test/util/dist_expr_traits_unittest.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "gtest/gtest.h" -#include -#include - -namespace ppl { -namespace util { - -struct dist_expr_traits_fixture : ::testing::Test -{ -protected: -}; - -TEST_F(dist_expr_traits_fixture, is_dist_expr_v_true) -{ -#if __cplusplus <= 201703L - static_assert(assert_is_dist_expr_v); -#else - static_assert(dist_expr); -#endif -} - -TEST_F(dist_expr_traits_fixture, is_dist_expr_v_false) -{ -#if __cplusplus <= 201703L - static_assert(!is_dist_expr_v); - static_assert(!is_dist_expr_v); -#else - static_assert(!dist_expr); - static_assert(!dist_expr); -#endif -} - -} // namespace util -} // namespace ppl diff --git a/test/util/iterator/counting_iterator_unittest.cpp b/test/util/iterator/counting_iterator_unittest.cpp new file mode 100644 index 00000000..22af9e06 --- /dev/null +++ b/test/util/iterator/counting_iterator_unittest.cpp @@ -0,0 +1,49 @@ +#include "gtest/gtest.h" +#include + +namespace ppl { +namespace util { + +struct counting_iterator_fixture : ::testing::Test +{ +protected: + size_t val = 2; + counting_iterator it; + counting_iterator_fixture() + : it(val) + {} +}; + +TEST_F(counting_iterator_fixture, op_star) +{ + EXPECT_EQ(*it, val); +} + +TEST_F(counting_iterator_fixture, op_plus_plus) +{ + EXPECT_EQ(*(++it), val + 1); + EXPECT_EQ(*it++, val + 1); + EXPECT_EQ(*it, val + 2); +} + +TEST_F(counting_iterator_fixture, op_minus_minus) +{ + EXPECT_EQ(*(--it), val - 1); + EXPECT_EQ(*it--, val - 1); + EXPECT_EQ(*it, val - 2); +} + +TEST_F(counting_iterator_fixture, op_eq) +{ + EXPECT_EQ(counting_iterator(val), it); +} + +TEST_F(counting_iterator_fixture, op_neq) +{ + EXPECT_NE(counting_iterator(0), it); + EXPECT_NE(counting_iterator(1), it); + EXPECT_NE(counting_iterator(3), it); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/iterator/range_unittest.cpp b/test/util/iterator/range_unittest.cpp new file mode 100644 index 00000000..60619929 --- /dev/null +++ b/test/util/iterator/range_unittest.cpp @@ -0,0 +1,85 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct range_fixture : ::testing::Test +{ +protected: + static constexpr size_t size = 5; + static constexpr int defval = 0; + static constexpr size_t special_idx = 2; + static constexpr int special_val = 10; + + using vector_t = std::vector; + using array_t = std::array; + using raw_array_t = int[size]; + vector_t v1; + array_t v2; + raw_array_t v3; + range_fixture() + : v1(size, defval) + , v2{defval} + , v3{defval} + { + v1[2] = special_val; + v2[2] = special_val; + v3[2] = special_val; + } + + template + void test_size(const Container& c) + { + if constexpr (std::is_array_v) { + auto r = range(c, c + size); + EXPECT_EQ(r.size(), size); + } else { + auto r = range(c.begin(), c.end()); + EXPECT_EQ(r.size(), size); + } + } + + template + void test_op_paren(const Container& c) + { + if constexpr (std::is_array_v) { + auto r = range(c, c + size); + EXPECT_EQ(r(special_idx), special_val); + for (size_t i = 0; i < size; ++i) { + if (i != special_idx) { EXPECT_EQ(r(i), defval); } + } + } else { + auto r = range(c.begin(), c.end()); + EXPECT_EQ(r(special_idx), special_val); + for (size_t i = 0; i < size; ++i) { + if (i != special_idx) { EXPECT_EQ(r(i), defval); } + } + } + } +}; + +TEST_F(range_fixture, size) +{ + test_size(v1); + test_size(v2); + test_size(v3); +} + +TEST_F(range_fixture, op_paren) +{ + test_op_paren(v1); + test_op_paren(v2); + test_op_paren(v3); +} + +TEST_F(range_fixture, subrange) +{ + auto r = range(std::next(v1.begin(), 2), v1.end()); + EXPECT_EQ(r.size(), size - 2ul); + EXPECT_EQ(r(special_idx - 2ul), special_val); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/concept_unittest.cpp b/test/util/traits/concept_unittest.cpp similarity index 97% rename from test/util/concept_unittest.cpp rename to test/util/traits/concept_unittest.cpp index 7fce787b..1005381d 100644 --- a/test/util/concept_unittest.cpp +++ b/test/util/traits/concept_unittest.cpp @@ -1,5 +1,5 @@ #include "gtest/gtest.h" -#include +#include namespace ppl { namespace util { diff --git a/test/util/traits/dist_expr_traits_unittest.cpp b/test/util/traits/dist_expr_traits_unittest.cpp new file mode 100644 index 00000000..3f21d882 --- /dev/null +++ b/test/util/traits/dist_expr_traits_unittest.cpp @@ -0,0 +1,19 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct dist_expr_traits_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(dist_expr_traits_fixture, is_dist_expr_v_true) +{ + static_assert(is_dist_expr_v); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/traits/shape_traits_unittest.cpp b/test/util/traits/shape_traits_unittest.cpp new file mode 100644 index 00000000..481b9ac1 --- /dev/null +++ b/test/util/traits/shape_traits_unittest.cpp @@ -0,0 +1,20 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct shape_traits_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(shape_traits_fixture, is_shape_v_true) +{ + static_assert(assert_is_shape_v); + static_assert(assert_is_scl_v); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/traits/var_expr_traits_unittest.cpp b/test/util/traits/var_expr_traits_unittest.cpp new file mode 100644 index 00000000..188908c7 --- /dev/null +++ b/test/util/traits/var_expr_traits_unittest.cpp @@ -0,0 +1,28 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct var_expr_traits_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(var_expr_traits_fixture, is_var_expr_v_true) +{ + static_assert(is_var_expr_v); +} + +TEST_F(var_expr_traits_fixture, is_var_expr_v_false) +{ + static_assert(!is_var_expr_v); + static_assert(is_shape_v); + static_assert(!var_expr_is_base_of_v); + static_assert(!has_type_value_t_v); + static_assert(!has_func_value_v); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/traits/var_traits_unittest.cpp b/test/util/traits/var_traits_unittest.cpp new file mode 100644 index 00000000..3357ebe1 --- /dev/null +++ b/test/util/traits/var_traits_unittest.cpp @@ -0,0 +1,41 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct var_traits_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(var_traits_fixture, is_var_v_true) +{ + static_assert(is_var_v); + static_assert(is_param_v); + static_assert(is_var_v); + static_assert(is_data_v); +} + +TEST_F(var_traits_fixture, is_var_v_false) +{ + static_assert(!is_param_v); + static_assert(!is_var_v); + static_assert(is_var_expr_v); + static_assert(!param_is_base_of_v); + static_assert(!has_type_id_t_v); + static_assert(!has_type_pointer_t_v); + static_assert(!has_type_const_pointer_t_v); + static_assert(!has_func_id_v); + + static_assert(!is_data_v); + static_assert(!is_var_v); + static_assert(is_var_expr_v); + static_assert(!data_is_base_of_v); + static_assert(!has_type_id_t_v); + static_assert(!has_func_id_v); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/var_expr_traits_unittest.cpp b/test/util/var_expr_traits_unittest.cpp deleted file mode 100644 index 6ca210a2..00000000 --- a/test/util/var_expr_traits_unittest.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "gtest/gtest.h" -#include -#include - -namespace ppl { -namespace util { - -struct var_expr_traits_fixture : ::testing::Test -{ -protected: -}; - -TEST_F(var_expr_traits_fixture, is_var_expr_v_true) -{ -#if __cplusplus <= 201703L - static_assert(assert_is_var_expr_v); -#else - static_assert(var_expr); -#endif -} - -TEST_F(var_expr_traits_fixture, is_var_expr_v_false) -{ -#if __cplusplus <= 201703L - static_assert(!is_var_expr_v); -#else - static_assert(!var_expr); -#endif -} - -} // namespace util -} // namespace ppl diff --git a/test/util/var_traits_unittest.cpp b/test/util/var_traits_unittest.cpp deleted file mode 100644 index fa79d63d..00000000 --- a/test/util/var_traits_unittest.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "gtest/gtest.h" -#include -#include - -namespace ppl { -namespace util { - -struct var_traits_fixture : ::testing::Test -{ -protected: -}; - -TEST_F(var_traits_fixture, is_var_v_true) -{ -#if __cplusplus <= 201703L - static_assert(assert_is_var_v); -#else - static_assert(param); - static_assert(var); -#endif -} - -TEST_F(var_traits_fixture, is_var_v_false) -{ -#if __cplusplus <= 201703L - static_assert(!is_var_v); -#else - static_assert(!param); - static_assert(!var); -#endif -} - -} // namespace util -} // namespace ppl