diff --git a/.travis.yml b/.travis.yml index e7130246..54bc91a1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,7 +38,7 @@ jobs: coveralls --root ../../ --build-root ./ - --include autoppl/include + --include include --exclude lib --gcov 'gcov-7' --gcov-options '\-lp' diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d10b546..06f4fa53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,6 @@ option(AUTOPPL_ENABLE_BENCHMARK "Enable benchmarks to be built." OFF) option(AUTOPPL_ENABLE_TEST_COVERAGE "Build with test coverage (AUTOPPL_ENABLE_TEST must be ON)" OFF) # Automate the choosing of config -# if CMAKE_BUILD_TYPE not defined if (NOT CMAKE_BUILD_TYPE) # if binary directory ends with "release", use release mode if (${PROJECT_BINARY_DIR} MATCHES "release$") @@ -22,6 +21,16 @@ if (NOT CMAKE_BUILD_TYPE) endif() message(STATUS "Compiling in ${CMAKE_BUILD_TYPE} mode") +# Add this library as interface (header-only) +add_library(${PROJECT_NAME} INTERFACE) + +target_include_directories(${PROJECT_NAME} + INTERFACE $ + $) + +# Set C++17 standard for project target +target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) + # Configure tests if (AUTOPPL_ENABLE_TEST) include(CTest) # enable memcheck @@ -35,6 +44,11 @@ if (AUTOPPL_ENABLE_TEST) add_subdirectory(${PROJECT_SOURCE_DIR}/test ${PROJECT_BINARY_DIR}/test) endif() +# TODO: add src dir if needed +#set(AUTOPPL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}/src) +#file(GLOB_RECURSE AUTOPPL_SOURCE_FILES RELATIVE src LIST_DIRECTORIES false *.cpp) +#set(AUTOPPL_SOURCE_FILES ${AUTOPPL_SOURCE_DIR}/autoppl.cpp) +#set(AUTOPPL_HEADER_FILES ${AUTOPPL_INCLUDE_DIR}/autoppl.h) + # Add subdirectories -add_subdirectory(${PROJECT_SOURCE_DIR}/autoppl ${PROJECT_BINARY_DIR}/autoppl) add_subdirectory(${PROJECT_SOURCE_DIR}/lib ${PROJECT_BINARY_DIR}/lib) diff --git a/autoppl/CMakeLists.txt b/autoppl/CMakeLists.txt deleted file mode 100644 index b50a56a2..00000000 --- a/autoppl/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Add this library as interface (header-only) -add_library(${PROJECT_NAME} INTERFACE) - -target_include_directories(${PROJECT_NAME} - INTERFACE $ - $) - -# Set C++17 standard for project target -target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) - -#set(AUTOPPL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}/src) -#file(GLOB_RECURSE AUTOPPL_SOURCE_FILES RELATIVE src LIST_DIRECTORIES false *.cpp) - -#set(AUTOPPL_SOURCE_FILES ${AUTOPPL_SOURCE_DIR}/autoppl.cpp) -#set(AUTOPPL_HEADER_FILES ${AUTOPPL_INCLUDE_DIR}/autoppl.h) diff --git a/autoppl/include/autoppl.hpp b/autoppl/include/autoppl.hpp deleted file mode 100644 index 729eb861..00000000 --- a/autoppl/include/autoppl.hpp +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -namespace ppl { - -inline int fib(int n) -{ - if (n <= 1) return 1; - else return fib(n-1) + fib(n-2); -} - -} // namespace autoppl diff --git a/doc/design/model_design2.cpp b/doc/design/model_design2.cpp new file mode 100644 index 00000000..5aad3057 --- /dev/null +++ b/doc/design/model_design2.cpp @@ -0,0 +1,32 @@ +Y ~ W.x + epsilon +Y ~ N(W.x, sigma^2) + +Parameter X {4.0}; // observed +Parameter Y {5.0}; // observed +Parameter W; // hidden +​ +Model m1 = Model( // Model class defines a distribution over existing Parameters. + W |= Uniform(-10, 10), // linear regression + Y |= Normal(W * X, 3), // overload multiplication to build a graph from W * X +​); + +Model m2 = Model( + W |= Normal(0, 1), // ridge regression instead + Y |= Normal(W * X, 3), +); +​ +m1.sample(1000); + +(3*x).pdf(10) => x.pdf(10 / 3) + +X.observe(3); // observe more data + +// P(Y, W | X) = P(Y | W, X) P(W | X) which is doable for multiple samples, just need to +// assert len(Y) == len(X) and then multiply out over all pairs of (X, Y) values. + +// P(Y | X) => this is a fine distribution, but I can't talk about P(Y, X) or P(X | Y) until I put a prior on Y. +// I don't have a joint distribution yet. + +// Some issues: +// how do we do (x ** 2).pdf(5)? This is pretty damn hard for non-bijective functions, need to integrate? +// \ No newline at end of file diff --git a/doc/design/model_inttest.cpp b/doc/design/model_inttest.cpp new file mode 100644 index 00000000..64caf2ca --- /dev/null +++ b/doc/design/model_inttest.cpp @@ -0,0 +1,106 @@ +#include "gtest/gtest.h" +#include +#include +#include + +namespace ppl { + +template +struct BracketNode +{ + VectorType v; + IndexType i; +}; + +struct myvector +{ + + rv_tag operator[](rv_tag) + { + return rv + } + std::vector v; // 3 things +}; + +template +auto normal(const MuType& mu, const SigType& sig) +{ + Normal(mu, sig); +} + +TEST(dummy, dummy_test) +{ + double x_data = 2.3; // 1-sample data + + std::vector sampled_theta_1(100); + std::vector sampled_theta_2(100); + + double* ptr; + rv_tag x; + rv_tag theta_1(sampled_theta_1.data()); + rv_tag theta_2(sampled_theta_2.data()); + + std::vector> v; + std::for_each(..., ... , [](){v[i].set_sample_storage(&mat.row(i));}); + + x.observe(x_data); + + x_1.observe(...); + x_2.observe(...); + + auto model = ( + mu |= uniform(-10000, 10000), + y |= uniform({1,2,3}) // + x_1 |= normal(mu[y], 1), + x_2 |= normal(mu[y], 1), + ); + + x.observe(...); + + rv_tag var, mu, x; + auto normal_model = ( + var |= normal(0,1), + mu |= normal(1,5), + x |= normal(mu, var) + ); + + std::vector var_storage(1000); + std::vector mu_storage(1000); + + var.set_storage(var_storage.data()); + mu.set_storage(mu_storage.data()); + + metropolis_hastings(model, 1000, 400); + + auto gmm_model = ( + mu |= + ); + + std::vector> vec(model.param_num); + model.bind_storage(vec.begin(), vec.end(), ...); + model.pdf(); + + metropolis_hastings(model, 100); + + std::vector sampled_theta_1_again(1000); + std::vector sampled_theta_2_again(1000); + + theta_1.set_storage(sampled_theta_1_again.data()); + theta_2.set_storage(sampled_theta_2_again.data()); + + metropolis_hastings(model, 1000); + + + + + + + + auto model = ( + w |= normal(0,1), + y |= normal(w*x, 1) + ) + metropolis_hastings(modeli) +} + +} diff --git a/include/autoppl/algorithm/mh.hpp b/include/autoppl/algorithm/mh.hpp new file mode 100644 index 00000000..f5f17fa1 --- /dev/null +++ b/include/autoppl/algorithm/mh.hpp @@ -0,0 +1,123 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +/* + * Assumptions: + * - every variable referenced in model is of type Variable + */ + +namespace ppl { +namespace details { + +struct MHData +{ + double next; + // TODO: maybe keep an array for batch sampling? +}; + +} // namespace details + +/* + * Metropolis-Hastings algorithm to sample from posterior distribution. + * The posterior distribution is a constant multiple of model.pdf(). + * Any variables that model references which are in state "parameter" + * is sampled and in state "data" are not. + * So, model.pdf() is proportional to p(parameters... | data...). + * + * User must ensure that they allocated at least as large as n_sample + * in the storage associated with every parameter referenced in model. + */ +template +inline void mh_posterior(ModelType& model, + double n_sample, + double stddev = 1.0, + double seed = std::chrono::duration_cast< + std::chrono::milliseconds>( + std::chrono::system_clock::now().time_since_epoch() + ).count() + ) +{ + using data_t = details::MHData; + + // set-up auxiliary tools + std::mt19937 gen(seed); + std::uniform_real_distribution unif_sampler(0., 1.); + + // get number of parameters to sample + size_t n_params = 0.; + auto get_n_params = [&](auto& eq_node) { + auto& var = eq_node.get_variable(); + using var_t = std::decay_t; + using state_t = typename util::var_traits::state_t; + n_params += (var.get_state() == state_t::parameter); + }; + model.traverse(get_n_params); + + // vector of parameter-related data with candidate + std::vector params(n_params); + double curr_log_pdf = model.log_pdf(); + auto params_it = params.begin(); + + for (size_t iter = 0; iter < n_sample; ++iter) { + + double log_alpha = -curr_log_pdf; + + // generate next candidates and place them in parameter + // variables as next values; update log_alpha + // The old values are temporary stored in the params vector. + auto get_candidate = [=, &gen](auto& eq_node) mutable { + auto& var = eq_node.get_variable(); + using var_t = std::decay_t; + using state_t = typename util::var_traits::state_t; + + if (var.get_state() == state_t::parameter) { + auto curr = var.get_value(); + std::normal_distribution norm_sampler(curr, stddev); + + // sample new candidate, place old value in params, + // fill next candidate in var, and update log_alpha + auto cand = norm_sampler(gen); + params_it->next = curr; + var.set_value(cand); + + ++params_it; + } + }; + model.traverse(get_candidate); + + // compute next candidate log pdf and update log_alpha + double cand_log_pdf = model.log_pdf(); + log_alpha += cand_log_pdf; + bool accept = (std::log(unif_sampler(gen)) <= log_alpha); + + // If accept, "current" sample for next iteration is already in the variables + // so simply append to storage. + // Otherwise, "current" sample for next iteration must be moved back from + // params vector into variables. + auto add_to_storage = [params_it, iter, accept](auto& eq_node) mutable { + auto& var = eq_node.get_variable(); + using var_t = std::decay_t; + using state_t = typename util::var_traits::state_t; + if (var.get_state() == state_t::parameter) { + if (!accept) { + var.set_value(params_it->next); + ++params_it; + } + auto storage = var.get_storage(); + storage[iter] = var.get_value(); + } + }; + model.traverse(add_to_storage); + + // update current log pdf for next iteration + if (accept) curr_log_pdf = cand_log_pdf; + } +} + +} // namespace ppl diff --git a/include/autoppl/autoppl b/include/autoppl/autoppl new file mode 100644 index 00000000..1532ee7e --- /dev/null +++ b/include/autoppl/autoppl @@ -0,0 +1,2 @@ +#pragma once +// TODO: export all headers later! diff --git a/include/autoppl/expr_builder.hpp b/include/autoppl/expr_builder.hpp new file mode 100644 index 00000000..43365307 --- /dev/null +++ b/include/autoppl/expr_builder.hpp @@ -0,0 +1,163 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ppl { + +/* + * The purpose for these expression builders is to + * add extra type-safety and ease the user API. + */ + +//////////////////////////////////////////////////////// +// Distribution Expression Builder +//////////////////////////////////////////////////////// + +namespace details { + +/* + * Converter from arbitrary (decayed) type to valid continuous parameter type + * by the following mapping: + * - is_var_v true => VariableViewer + * - T is same as cont_raw_param_t => Constant + * - is_var_expr_v true => T + * Assumes each condition is non-overlapping. + */ +template +struct convert_to_cont_dist_param +{}; + +template +struct convert_to_cont_dist_param> && + !std::is_same_v, util::cont_raw_param_t> && + !util::is_var_expr_v> + >> +{ + using type = expr::VariableViewer>; +}; + +template +struct convert_to_cont_dist_param> && + std::is_same_v, util::cont_raw_param_t> && + !util::is_var_expr_v> + >> +{ + using type = expr::Constant>; +}; + +template +struct convert_to_cont_dist_param> && + !std::is_same_v, util::cont_raw_param_t> && + util::is_var_expr_v> + >> +{ + using type = T; +}; + +template +using convert_to_cont_dist_param_t = + typename convert_to_cont_dist_param::type; + +} // namespace details + +#ifndef AUTOPPL_USE_CONCEPTS +/* + * Builds a Uniform expression only when the parameters + * are both valid continuous distribution parameter types. + * See var_expr.hpp for more information. + */ +template +inline constexpr auto uniform(MinType&& min_expr, + MaxType&& max_expr) +{ + using min_t = details::convert_to_cont_dist_param_t; + using max_t = details::convert_to_cont_dist_param_t; + + min_t wrap_min_expr = std::forward(min_expr); + max_t wrap_max_expr = std::forward(max_expr); + + return expr::Uniform(wrap_min_expr, wrap_max_expr); +} +#else +#endif + +#ifndef AUTOPPL_USE_CONCEPTS +/* + * Builds a Normal expression only when the parameters + * are both valid continuous distribution parameter types. + * See var_expr.hpp for more information. + */ +template +inline constexpr auto normal(MeanType&& mean_expr, + StddevType&& stddev_expr) +{ + using mean_t = details::convert_to_cont_dist_param_t; + using stddev_t = details::convert_to_cont_dist_param_t; + + mean_t wrap_mean_expr = std::forward(mean_expr); + stddev_t wrap_stddev_expr = std::forward(stddev_expr); + + return expr::Normal(wrap_mean_expr, wrap_stddev_expr); +} + +#else +#endif + +#ifndef AUTOPPL_USE_CONCEPTS +/* + * Builds a Bernoulli expression only when the parameter + * is a valid discrete distribution parameter type. + * See var_expr.hpp for more information. + * TODO: generalize as done with uniform and normal + */ +template +inline constexpr auto bernoulli(const ProbType& p_expr) +{ + return expr::Bernoulli(p_expr); +} +#else +#endif + +//////////////////////////////////////////////////////// +// Model Expression Builder +//////////////////////////////////////////////////////// + +#ifndef AUTOPPL_USE_CONCEPTS +/* + * Builds an EqNode to associate var with dist + * only when var is a Variable and dist is a valid distribution expression. + * Ex. x |= uniform(0,1) + */ +template +inline constexpr auto operator|=(Variable& var, + DistType&& dist) +{ return expr::EqNode(var, std::forward(dist)); } +#else +#endif + +#ifndef AUTOPPL_USE_CONCEPTS +/* + * Builds a GlueNode to "glue" the left expression with the right + * only when both parameters are valid model expressions. + * Ex. (x |= uniform(0,1), y |= uniform(0, 2)) + */ +template +inline constexpr auto operator,(LHSNodeType&& lhs, + RHSNodeType&& rhs) +{ + return expr::GlueNode(std::forward(lhs), + std::forward(rhs)); +} +#else +#endif + +} // namespace ppl diff --git a/include/autoppl/expression/distribution/bernoulli.hpp b/include/autoppl/expression/distribution/bernoulli.hpp new file mode 100644 index 00000000..a6870f88 --- /dev/null +++ b/include/autoppl/expression/distribution/bernoulli.hpp @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +template +struct Bernoulli +{ + static_assert(util::is_var_expr_v); + + using value_t = util::disc_raw_param_t; + using param_value_t = typename util::var_expr_traits::value_t; + using dist_value_t = typename BernoulliBase::dist_value_t; + + Bernoulli(p_type p) + : p_{p} { assert((this -> p() >= 0) && (this -> p() <= 1)); } + + template + value_t sample(GeneratorType& gen) const + { + std::bernoulli_distribution dist(p()); + return dist(gen); + } + + dist_value_t pdf(value_t x) const + { return BernoulliBase::pdf(x, p()); } + + dist_value_t log_pdf(value_t x) const + { return BernoulliBase::log_pdf(x, p()); } + + param_value_t p() const { return static_cast(p_); } + +private: + p_type p_; +}; + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/expression/distribution/density.hpp b/include/autoppl/expression/distribution/density.hpp new file mode 100644 index 00000000..ea267e1e --- /dev/null +++ b/include/autoppl/expression/distribution/density.hpp @@ -0,0 +1,91 @@ +#pragma once +#include +#include + +namespace ppl { +namespace expr { + +/* + * The Base objects contain static member functions + * that compute pdf and log_pdf. + * These are useful stand-alone functions and the distribution objects + * such as Uniform and Normal simply wrap these functions. + */ + +/* + * Continuous Distributions + */ + +struct UniformBase +{ + using dist_value_t = double; + + template + static dist_value_t pdf(ValueType x, + ParamValueType min, + ParamValueType max) + { + return (min < x && x < max) ? 1. / (max - min) : 0; + } + + template + static dist_value_t log_pdf(ValueType x, + ParamValueType min, + ParamValueType max) + { + return (min < x && x < max) ? + -std::log(max - min) : + std::numeric_limits::lowest(); + } +}; + +struct NormalBase +{ + using dist_value_t = double; + + template + static dist_value_t pdf(ValueType x, + ParamValueType mean, + ParamValueType stddev) + { + dist_value_t z_score = (x - mean) / stddev; + return std::exp(- 0.5 * z_score * z_score) / (stddev * std::sqrt(2 * M_PI)); + } + + template + static dist_value_t log_pdf(ValueType x, + ParamValueType mean, + ParamValueType stddev) + { + dist_value_t z_score = (x - mean) / stddev; + return -0.5 * ((z_score * z_score) + std::log(stddev * stddev * 2 * M_PI)); + } +}; + +/* + * Discrete Distributions + */ + +struct BernoulliBase +{ + using dist_value_t = double; + + template + static dist_value_t pdf(ValueType x, ParamValueType p) + { + if (x == 1) return p; + else if (x == 0) return 1. - p; + else return 0.0; + } + + template + static dist_value_t log_pdf(ValueType x, ParamValueType p) + { + if (x == 1) return std::log(p); + else if (x == 0) return std::log(1. - p); + else return std::numeric_limits::lowest(); + } +}; + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/expression/distribution/normal.hpp b/include/autoppl/expression/distribution/normal.hpp new file mode 100644 index 00000000..b9970126 --- /dev/null +++ b/include/autoppl/expression/distribution/normal.hpp @@ -0,0 +1,50 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +template +struct Normal +{ + static_assert(util::is_var_expr_v); + static_assert(util::is_var_expr_v); + + using value_t = util::cont_raw_param_t; + using param_value_t = std::common_type_t< + typename util::var_expr_traits::value_t, + typename util::var_expr_traits::value_t + >; + using dist_value_t = typename NormalBase::dist_value_t; + + Normal(mean_type mean, stddev_type stddev) + : mean_{mean}, stddev_{stddev} { + assert(this -> stddev() > 0); + }; + + template + value_t sample(GeneratorType& gen) const { + std::normal_distribution dist(mean(), stddev()); + return dist(gen); + } + + dist_value_t pdf(value_t x) const + { return NormalBase::pdf(x, mean(), stddev()); } + + dist_value_t log_pdf(value_t x) const + { return NormalBase::log_pdf(x, mean(), stddev()); } + + param_value_t mean() const { return static_cast(mean_);} + param_value_t stddev() const { return static_cast(stddev_);} + +private: + mean_type mean_; + stddev_type stddev_; +}; + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/expression/distribution/uniform.hpp b/include/autoppl/expression/distribution/uniform.hpp new file mode 100644 index 00000000..82b27bf2 --- /dev/null +++ b/include/autoppl/expression/distribution/uniform.hpp @@ -0,0 +1,50 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +template +struct Uniform +{ + static_assert(util::is_var_expr_v); + static_assert(util::is_var_expr_v); + + using value_t = util::cont_raw_param_t; + using param_value_t = std::common_type_t< + typename util::var_expr_traits::value_t, + typename util::var_expr_traits::value_t + >; + using dist_value_t = typename UniformBase::dist_value_t; + + Uniform(min_type min, max_type max) + : min_{min}, max_{max} { assert(this -> min() < this -> max()); } + + // TODO: tag this class as "TriviallySamplable"? + template + value_t sample(GeneratorType& gen) const + { + std::uniform_real_distribution dist(min(), max()); + return dist(gen); + } + + dist_value_t pdf(value_t x) const + { return UniformBase::pdf(x, min(), max()); } + + dist_value_t log_pdf(value_t x) const + { return UniformBase::log_pdf(x, min(), max()); } + + param_value_t min() const { return static_cast(min_); } + param_value_t max() const { return static_cast(max_); } + +private: + min_type min_; + max_type max_; +}; + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/expression/model/model.hpp b/include/autoppl/expression/model/model.hpp new file mode 100644 index 00000000..78adb7c0 --- /dev/null +++ b/include/autoppl/expression/model/model.hpp @@ -0,0 +1,122 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +/* + * This class represents a "node" in the model expression + * that relates a var with a distribution. + */ +template +struct EqNode +{ + static_assert(util::is_var_v); + static_assert(util::is_dist_expr_v); + + using var_t = VarType; + using dist_t = DistType; + using dist_value_t = typename util::dist_expr_traits::dist_value_t; + + EqNode(var_t& var, + const dist_t& dist) noexcept + : orig_var_ref_{var} + , dist_{dist} + {} + + /* + * Generic traversal function. + * Assumes that eq_f is a function that takes in 1 parameter, + * which is simply the current object. + */ + template + void traverse(EqNodeFunc&& eq_f) + { + using this_t = EqNode; + eq_f(static_cast(*this)); + } + + /* + * Compute pdf of underlying distribution with underlying value. + * Assumes that underlying value has been assigned properly. + */ + dist_value_t pdf() const + { return dist_.pdf(orig_var_ref_.get().get_value()); } + + /* + * Compute log-pdf of underlying distribution with underlying value. + * Assumes that underlying value has been assigned properly. + */ + dist_value_t log_pdf() const + { return dist_.log_pdf(orig_var_ref_.get().get_value()); } + + auto& get_variable() { return orig_var_ref_.get(); } + const auto& get_distribution() const { return dist_; } + +private: + using var_ref_t = std::reference_wrapper; + var_ref_t orig_var_ref_; // reference of the original var since + // any configuration may be changed until right before update + dist_t dist_; // distribution associated with var +}; + +/* + * This class represents a "node" in a model expression that + * "glues" two sub-model expressions. + */ +template +struct GlueNode +{ + static_assert(util::is_model_expr_v); + static_assert(util::is_model_expr_v); + + using left_node_t = LHSNodeType; + using right_node_t = RHSNodeType; + using dist_value_t = std::common_type_t< + typename util::model_expr_traits::dist_value_t, + typename util::model_expr_traits::dist_value_t + >; + + GlueNode(const left_node_t& lhs, + const right_node_t& rhs) noexcept + : left_node_{lhs} + , right_node_{rhs} + {} + + /* + * Generic traversal function. + * Recursively traverses left then right. + */ + template + void traverse(EqNodeFunc&& eq_f) + { + left_node_.traverse(eq_f); + right_node_.traverse(eq_f); + } + + /* + * Computes left node joint pdf then right node joint pdf + * and returns the product of the two. + */ + dist_value_t pdf() const + { return left_node_.pdf() * right_node_.pdf(); } + + /* + * Computes left node joint log-pdf then right node joint log-pdf + * and returns the sum of the two. + */ + dist_value_t log_pdf() const + { return left_node_.log_pdf() + right_node_.log_pdf(); } + +private: + left_node_t left_node_; + right_node_t right_node_; +}; + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/expression/variable/constant.hpp b/include/autoppl/expression/variable/constant.hpp new file mode 100644 index 00000000..28c95856 --- /dev/null +++ b/include/autoppl/expression/variable/constant.hpp @@ -0,0 +1,20 @@ +#pragma once + +namespace ppl { +namespace expr { + +template +struct Constant +{ + using value_t = ValueType; + Constant(value_t c) + : c_{c} + {} + operator value_t() const { return c_; } + +private: + value_t c_; +}; + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/expression/variable/variable_viewer.hpp b/include/autoppl/expression/variable/variable_viewer.hpp new file mode 100644 index 00000000..9e137913 --- /dev/null +++ b/include/autoppl/expression/variable/variable_viewer.hpp @@ -0,0 +1,33 @@ +#pragma once +#include +#include + +namespace ppl { +namespace expr { + +/* + * VariableViewer is a viewer of some variable type. + * It will mainly be used to view Variable class defined in autoppl/variable.hpp. + */ +template +struct VariableViewer +{ + static_assert(util::is_var_v); + + using var_t = VariableType; + using value_t = typename util::var_traits::value_t; + + VariableViewer(var_t& var) + : var_ref_{var} + {} + + operator value_t() const + { return static_cast(var_ref_.get()); } + +private: + using var_ref_t = std::reference_wrapper; + var_ref_t var_ref_; +}; + +} // namespace expr +} // namespace ppl diff --git a/include/autoppl/util/concept.hpp b/include/autoppl/util/concept.hpp new file mode 100644 index 00000000..8ccaafe7 --- /dev/null +++ b/include/autoppl/util/concept.hpp @@ -0,0 +1,127 @@ +#pragma once +#include + +/* + * Metaprogramming tool to check if name is a (public) + * member alias of a given type T. + * All versions must be placed in this file for ease of maintenance. + * Macro definition is undefined at the end of the file. + * + * Ex. with "name" as "value_t" + * + namespace details { + template + struct has_type_value_t + { + private: + template static void impl(decltype(typename V::value_t(), int())); + template static bool impl(char); + public: + static constexpr bool value = std::is_same(0))>::value; + }; + } + template + inline constexpr bool has_type_value_t_v = + details::has_type_value_t::value; + */ + +#define DEFINE_HAS_TYPE(name) \ + namespace details { \ + template \ + struct has_type_##name \ + { \ + private: \ + template static void impl(typename V::name*); \ + template static bool impl(...); \ + public: \ + static constexpr bool value = std::is_same(0))>::value; \ + }; \ + \ + template \ + struct get_type_##name \ + { \ + using type = invalid_tag; \ + }; \ + template \ + struct get_type_##name \ + { \ + using type = typename T::name; \ + }; \ + } \ + template \ + inline constexpr bool has_type_##name##_v = \ + details::has_type_##name::value; \ + template \ + using get_type_##name##_t = \ + typename details::get_type_##name>::type; + +/* + * Metaprogramming tool to check if name is a (public) + * member function of a given type T. + * All versions must be placed in this file for ease of maintenance. + * Macro definition is undefined at the end of the file. + * + * Ex. with "name" as "pdf" + * + namespace details { + template + struct has_func_pdf + { + private: + template static void impl(decltype(&V::pdf)); + template static bool impl(...); + public: + static constexpr bool value = std::is_same(0))>::value; + }; + } + template + inline constexpr bool has_func_pdf_v = + details::has_func_pdf::value; + */ + +#define DEFINE_HAS_FUNC(name) \ + namespace details { \ + template \ + struct has_func_##name \ + { \ + private: \ + template static void impl(decltype(&V::name)); \ + template static bool impl(...); \ + public: \ + static constexpr bool value = std::is_same(0))>::value; \ + }; \ + } \ + template \ + inline constexpr bool has_func_##name##_v = \ + details::has_func_##name::value; + +namespace ppl { +namespace util { + +struct invalid_tag {}; + +DEFINE_HAS_TYPE(value_t); +DEFINE_HAS_TYPE(pointer_t); +DEFINE_HAS_TYPE(const_pointer_t); +DEFINE_HAS_TYPE(state_t); + +DEFINE_HAS_TYPE(dist_value_t); + +DEFINE_HAS_FUNC(set_value); +DEFINE_HAS_FUNC(get_value); +DEFINE_HAS_FUNC(set_storage); +DEFINE_HAS_FUNC(get_storage); +DEFINE_HAS_FUNC(set_state); +DEFINE_HAS_FUNC(get_state); + +DEFINE_HAS_FUNC(pdf); +DEFINE_HAS_FUNC(log_pdf); + +DEFINE_HAS_FUNC(get_variable); +DEFINE_HAS_FUNC(get_distribution); + +} // namespace util +} // namespace ppl + +#undef DEFINE_HAS_FUNC +#undef DEFINE_HAS_TYPE diff --git a/include/autoppl/util/dist_expr_traits.hpp b/include/autoppl/util/dist_expr_traits.hpp new file mode 100644 index 00000000..e674e848 --- /dev/null +++ b/include/autoppl/util/dist_expr_traits.hpp @@ -0,0 +1,60 @@ +#pragma once +#include +#include + +namespace ppl { +namespace util { + +/* + * TODO: Samplable distribution expression concept? + */ + +/* + * TODO: continuous/discrete distribution expression concept? + */ + +/* + * Continuous distribution expressions can be constructed with this type. + */ +using cont_raw_param_t = double; + +/* + * Discrete distribution expressions can be constructed with this type. + */ +using disc_raw_param_t = int64_t; + +/* + * Traits for Distribution Expression classes. + * value_t type of value Variable represents during computation + * dist_value_t type of pdf/log_pdf value + */ +template +struct dist_expr_traits +{ + using value_t = typename DistExprType::value_t; + using dist_value_t = typename DistExprType::dist_value_t; +}; + +/* + * A distribution expression is any class that the following: + * - dist_expr_traits must be well-defined for T + * - T must have member function pdf + * - T must have member function log_pdf + */ +template +inline constexpr bool is_dist_expr_v = + has_type_value_t_v && + has_type_dist_value_t_v && + has_func_pdf_v && + has_func_log_pdf_v + ; + +#ifdef AUTOPPL_USE_CONCEPTS +// TODO: definition should be extended with a stronger +// restriction on T with interface checking. +template +concept dist_expressable = is_dist_expr_v; +#endif + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/model_expr_traits.hpp b/include/autoppl/util/model_expr_traits.hpp new file mode 100644 index 00000000..f631b4cf --- /dev/null +++ b/include/autoppl/util/model_expr_traits.hpp @@ -0,0 +1,43 @@ +#pragma once +#include + +namespace ppl { +namespace util { + +/* + * Traits for Model Expression classes. + * dist_value_t type of value Variable represents during computation + */ +template +struct model_expr_traits +{ + using dist_value_t = typename NodeType::dist_value_t; +}; + +// TODO: +// - pdf and log_pdf remove from interface? +// - how to check if template member function exists (for traverse)? +template +inline constexpr bool is_model_expr_v = + has_type_dist_value_t_v && + has_func_pdf_v && + has_func_log_pdf_v + ; + +// TODO: not used currently +template +inline constexpr bool is_eq_node_expr_v = + is_model_expr_v && + has_func_get_variable_v && + has_func_get_distribution_v + ; + +#ifdef AUTOPPL_USE_CONCEPTS +// TODO: definition should be extended with a stronger +// restriction on T with interface checking. +template +concept model_expressable = is_model_expr_v; +#endif + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/traits.hpp b/include/autoppl/util/traits.hpp new file mode 100644 index 00000000..9c137cf2 --- /dev/null +++ b/include/autoppl/util/traits.hpp @@ -0,0 +1,12 @@ +#pragma once + +/* + * The following classes list member aliases that + * any such template parameter types should have. + * Users should rely on these classes to grab member aliases. + */ + +#include +#include +#include +#include diff --git a/include/autoppl/util/var_expr_traits.hpp b/include/autoppl/util/var_expr_traits.hpp new file mode 100644 index 00000000..8513448e --- /dev/null +++ b/include/autoppl/util/var_expr_traits.hpp @@ -0,0 +1,47 @@ +#pragma once +#include +#include + +namespace ppl { +namespace util { + +/* + * Traits for Variable Expression classes. + * value_t type of value Variable represents during computation + */ +template +struct var_expr_traits +{ + using value_t = typename VarExprType::value_t; +}; + +// Specialization: when double or int, considered "trivial" variable. +// TODO: this was a quick fix for generalizing distribution value_t. +template <> +struct var_expr_traits +{ + using value_t = double; +}; + +/* + * A variable expression is any class that the following: + * - is_var_v must be false + * - var_expr_traits must be well-defined for T + * - T must be convertible to its value_t + */ +template +inline constexpr bool is_var_expr_v = + !is_var_v && + has_type_value_t_v && + std::is_convertible_v> + ; + +#ifdef AUTOPPL_USE_CONCEPTS +// TODO: definition should be extended with a stronger +// restriction on T with interface checking. +template +concept var_expressable = is_var_expr_v; +#endif + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/util/var_traits.hpp b/include/autoppl/util/var_traits.hpp new file mode 100644 index 00000000..173d0818 --- /dev/null +++ b/include/autoppl/util/var_traits.hpp @@ -0,0 +1,43 @@ +#pragma once +#include +#include + +namespace ppl { +namespace util { + +/* + * Traits for Variable-like classes. + * value_t type of value Variable represents during computation + * pointer_t storage pointer type + * state_t type of enum class state; must have "data" and "parameter" + */ +template +struct var_traits +{ + using value_t = typename VarType::value_t; + using pointer_t = typename VarType::pointer_t; + using state_t = typename VarType::state_t; +}; + +/* + * C++17 version of concepts to check var properties. + * - var_traits must be well-defined under type T + * - T must be convertible to its value_t + * - not possible to get overloads + */ +template +inline constexpr bool is_var_v = + has_type_value_t_v && + has_type_pointer_t_v && + has_type_const_pointer_t_v && + has_type_state_t_v && + has_func_set_value_v && + has_func_get_value_v && + has_func_set_storage_v && + has_func_set_state_v && + has_func_get_state_v && + std::is_convertible_v> + ; + +} // namespace util +} // namespace ppl diff --git a/include/autoppl/variable.hpp b/include/autoppl/variable.hpp new file mode 100644 index 00000000..1217124c --- /dev/null +++ b/include/autoppl/variable.hpp @@ -0,0 +1,84 @@ +#pragma once +#include +#include + +namespace ppl { + +/* + * The possible states for a var. + * By default, all vars should be considered as a parameter. + * TODO: maybe move in a different file? + */ +enum class var_state : bool { + data, + parameter +}; + +/* + * Variable is a light-weight structure that represents a univariate random variable. + * It acts as an intermediate layer of communication between + * a model expression and the users, who must supply storage of values associated with this var. + */ +template +struct Variable +{ + using value_t = ValueType; + using pointer_t = value_t*; + using const_pointer_t = const value_t*; + using state_t = var_state; + + // constructors + Variable(value_t value, + pointer_t storage_ptr, + state_t state) noexcept + : value_{value} + , storage_ptr_{storage_ptr} + , state_{state} + {} + + Variable(pointer_t storage_ptr) noexcept + : Variable(0, storage_ptr, state_t::parameter) + {} + + Variable(value_t value) noexcept + : Variable(value, nullptr, state_t::data) {} + + Variable() noexcept + : Variable(0, nullptr, state_t::parameter) + {} + + void set_value(value_t value) { value_ = value; } + value_t get_value() const { return value_; } + + void set_storage(pointer_t storage_ptr) { storage_ptr_ = storage_ptr; } + pointer_t get_storage() { return storage_ptr_; } + const_pointer_t get_storage() const { return storage_ptr_; } + + void set_state(state_t state) { state_ = state; } + state_t get_state() const { return state_; } + + operator value_t () const { return value_; } + + /* + * Sets underlying value to "value". + * Additionally modifies the var to be considered as data. + * Equivalent to calling set_value(value) then set_state(state). + */ + void observe(value_t value) + { + set_value(value); + set_state(state_t::data); + } + +private: + value_t value_; // store value associated with var + pointer_t storage_ptr_; // points to beginning of storage + // storage is assumed to be contiguous + state_t state_; // state to determine if data or param +}; + +// Useful aliases +using cont_var = Variable; // continuous RV var +using disc_var = Variable; // discrete RV var + +} // namespace ppl diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7a21c768..a898ecdf 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -6,16 +6,116 @@ if (AUTOPPL_ENABLE_TEST_COVERAGE) endif() ###################################################### -# Dummy Test +# Util Test ###################################################### -add_executable(dummy_unittest - test1.cpp +add_executable(util_unittest + ${CMAKE_CURRENT_SOURCE_DIR}/util/concept_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/dist_expr_traits_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/var_expr_traits_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/util/var_traits_unittest.cpp + ) +target_compile_options(util_unittest PRIVATE -g -Wall -Werror -Wextra) +target_include_directories(util_unittest PRIVATE + ${GTEST_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR} + ) +if (AUTOPPL_ENABLE_TEST_COVERAGE) + target_link_libraries(util_unittest gcov) +endif() +target_link_libraries(util_unittest gtest_main pthread ${PROJECT_NAME}) +add_test(util_unittest util_unittest) + +###################################################### +# Variable Expression Test +###################################################### + +add_executable(var_expr_unittest + ${CMAKE_CURRENT_SOURCE_DIR}/expression/variable/variable_viewer_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/variable/constant_unittest.cpp + ) +target_compile_options(var_expr_unittest PRIVATE -g -Wall -Werror -Wextra) +target_include_directories(var_expr_unittest PRIVATE + ${GTEST_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR} + ) +if (AUTOPPL_ENABLE_TEST_COVERAGE) + target_link_libraries(var_expr_unittest gcov) +endif() +target_link_libraries(var_expr_unittest gtest_main pthread ${PROJECT_NAME}) +add_test(var_expr_unittest var_expr_unittest) + +###################################################### +# Distribution Expression Test +###################################################### + +add_executable(dist_expr_unittest + ${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/bernoulli_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/density_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/normal_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/uniform_unittest.cpp + ) +target_compile_options(dist_expr_unittest PRIVATE -g -Wall -Werror -Wextra) +target_include_directories(dist_expr_unittest PRIVATE + ${GTEST_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR} + ) +if (AUTOPPL_ENABLE_TEST_COVERAGE) + target_link_libraries(dist_expr_unittest gcov) +endif() +target_link_libraries(dist_expr_unittest gtest_main pthread ${PROJECT_NAME}) +add_test(dist_expr_unittest dist_expr_unittest) + +###################################################### +# Model Expression Test +###################################################### + +add_executable(model_expr_unittest + ${CMAKE_CURRENT_SOURCE_DIR}/expression/model/model_unittest.cpp + ) +target_compile_options(model_expr_unittest PRIVATE -g -Wall -Werror -Wextra) +target_include_directories(model_expr_unittest PRIVATE + ${GTEST_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR} + ) +if (AUTOPPL_ENABLE_TEST_COVERAGE) + target_link_libraries(model_expr_unittest gcov) +endif() +target_link_libraries(model_expr_unittest gtest_main pthread ${PROJECT_NAME}) +add_test(model_expr_unittest model_expr_unittest) + +###################################################### +# Algorithm Test +###################################################### + +add_executable(algorithm_unittest + ${CMAKE_CURRENT_SOURCE_DIR}/algorithm/mh_unittest.cpp + ) +target_compile_options(algorithm_unittest PRIVATE -g -Wall -Werror -Wextra) +target_include_directories(algorithm_unittest PRIVATE + ${GTEST_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR} + ) +if (AUTOPPL_ENABLE_TEST_COVERAGE) + target_link_libraries(algorithm_unittest gcov) +endif() +target_link_libraries(algorithm_unittest gtest_main pthread ${PROJECT_NAME}) +add_test(algorithm_unittest algorithm_unittest) + +###################################################### +# Expression Builder Test +###################################################### + +add_executable(expr_builder_unittest + ${CMAKE_CURRENT_SOURCE_DIR}/expr_builder_unittest.cpp + ) +target_compile_options(expr_builder_unittest PRIVATE -g -Wall -Werror -Wextra) +target_include_directories(expr_builder_unittest PRIVATE + ${GTEST_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR} ) -target_compile_options(dummy_unittest PRIVATE -g -Wall -Werror -Wextra) -target_include_directories(dummy_unittest PRIVATE ${GTEST_DIR}/include) if (AUTOPPL_ENABLE_TEST_COVERAGE) - target_link_libraries(dummy_unittest gcov) + target_link_libraries(expr_builder_unittest gcov) endif() -target_link_libraries(dummy_unittest gtest_main pthread ${PROJECT_NAME}) -add_test(dummy_unittest dummy_unittest) +target_link_libraries(expr_builder_unittest gtest_main pthread ${PROJECT_NAME}) +add_test(expr_builder_unittest expr_builder_unittest) diff --git a/test/algorithm/mh_unittest.cpp b/test/algorithm/mh_unittest.cpp new file mode 100644 index 00000000..0922112b --- /dev/null +++ b/test/algorithm/mh_unittest.cpp @@ -0,0 +1,63 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +namespace ppl { + +/* + * Fixture for Metropolis-Hastings + */ +struct mh_fixture : ::testing::Test +{ +protected: + static constexpr size_t sample_size = 20000; + std::array storage = {0.}; + Variable theta, x; + size_t burn = 1000; + + mh_fixture() + : theta{storage.data()} + {} + + double sample_average() + { + double sum = std::accumulate( + std::next(storage.begin(), burn), + storage.end(), + 0.); + return sum / (storage.size() - burn); + } +}; + +TEST_F(mh_fixture, sample_std_normal) +{ + auto model = (theta |= normal(0., 1.)); + mh_posterior(model, sample_size, 1.0, 0.); + plot_hist(storage); + EXPECT_NEAR(sample_average(), 0., 0.1); +} + +TEST_F(mh_fixture, sample_uniform) +{ + auto model = (theta |= uniform(0., 1.)); + mh_posterior(model, sample_size, 1.0, 0.); + plot_hist(storage, 0.05, 0., 1.); + EXPECT_NEAR(sample_average(), 0.5, 0.1); +} + +TEST_F(mh_fixture, sample_unif_normal_posterior) +{ + x.observe(3.); + auto model = ( + theta |= uniform(-20., 20.), + x |= normal(theta, 1.) + ); + mh_posterior(model, sample_size, 1.0, 0.); + plot_hist(storage); + EXPECT_NEAR(sample_average(), 3.0, 0.1); +} + +} // namespace ppl diff --git a/test/expr_builder_unittest.cpp b/test/expr_builder_unittest.cpp new file mode 100644 index 00000000..5e868895 --- /dev/null +++ b/test/expr_builder_unittest.cpp @@ -0,0 +1,55 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { + +struct expr_builder_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(expr_builder_fixture, convert_to_cont_dist_param_var) +{ + using namespace details; + static_assert(std::is_same_v>); + static_assert(util::is_var_v); + static_assert(!std::is_same_v); + static_assert(!util::is_var_expr_v); + static_assert(std::is_same_v< + convert_to_cont_dist_param_t, + expr::VariableViewer + >); +} + +TEST_F(expr_builder_fixture, convert_to_cont_dist_param_raw) +{ + using namespace details; + using data_t = util::cont_raw_param_t; + static_assert(std::is_same_v>); + static_assert(!util::is_var_v); + static_assert(std::is_same_v); + static_assert(!util::is_var_expr_v); + static_assert(std::is_same_v< + convert_to_cont_dist_param_t, + expr::Constant + >); +} + +TEST_F(expr_builder_fixture, convert_to_cont_dist_param_var_expr) +{ + using namespace details; + static_assert(!util::is_var_v); + static_assert(!std::is_same_v); + static_assert(util::is_var_expr_v); + static_assert(std::is_same_v< + convert_to_cont_dist_param_t, + MockVarExpr& + >); + static_assert(std::is_same_v< + convert_to_cont_dist_param_t, + MockVarExpr&& + >); +} + +} // namespace ppl diff --git a/test/expression/distribution/bernoulli_unittest.cpp b/test/expression/distribution/bernoulli_unittest.cpp new file mode 100644 index 00000000..19641b4c --- /dev/null +++ b/test/expression/distribution/bernoulli_unittest.cpp @@ -0,0 +1,65 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +struct bernoulli_fixture : ::testing::Test +{ +protected: + using value_t = typename MockVarExpr::value_t; + static constexpr size_t sample_size = 1000; + double p = 0.6; + MockVarExpr x{p}; + Bernoulli bern = {x}; + std::array sample = {0.}; +}; + +TEST_F(bernoulli_fixture, ctor) +{ + static_assert(util::is_dist_expr_v>); +} + +TEST_F(bernoulli_fixture, bernoulli_check_params) { + EXPECT_DOUBLE_EQ(bern.p(), static_cast(x)); +} + +TEST_F(bernoulli_fixture, bernoulli_pdf_delegate) { + EXPECT_DOUBLE_EQ(bern.pdf(-10), BernoulliBase::pdf(-10, p)); + EXPECT_DOUBLE_EQ(bern.pdf(-7), BernoulliBase::pdf(-7, p)); + EXPECT_DOUBLE_EQ(bern.pdf(-3), BernoulliBase::pdf(-3, p)); + EXPECT_DOUBLE_EQ(bern.pdf(0), BernoulliBase::pdf(0, p)); + EXPECT_DOUBLE_EQ(bern.pdf(1), BernoulliBase::pdf(1, p)); + EXPECT_DOUBLE_EQ(bern.pdf(3), BernoulliBase::pdf(3, p)); + EXPECT_DOUBLE_EQ(bern.pdf(6), BernoulliBase::pdf(6, p)); + EXPECT_DOUBLE_EQ(bern.pdf(16), BernoulliBase::pdf(16, p)); +} + +TEST_F(bernoulli_fixture, bernoulli_log_pdf_delegate) { + EXPECT_DOUBLE_EQ(bern.log_pdf(-10), BernoulliBase::log_pdf(-10, p)); + EXPECT_DOUBLE_EQ(bern.log_pdf(-7), BernoulliBase::log_pdf(-7, p)); + EXPECT_DOUBLE_EQ(bern.log_pdf(-3), BernoulliBase::log_pdf(-3, p)); + EXPECT_DOUBLE_EQ(bern.log_pdf(0), BernoulliBase::log_pdf(0, p)); + EXPECT_DOUBLE_EQ(bern.log_pdf(1), BernoulliBase::log_pdf(1, p)); + EXPECT_DOUBLE_EQ(bern.log_pdf(3), BernoulliBase::log_pdf(3, p)); + EXPECT_DOUBLE_EQ(bern.log_pdf(6), BernoulliBase::log_pdf(6, p)); + EXPECT_DOUBLE_EQ(bern.log_pdf(16), BernoulliBase::log_pdf(16, p)); +} + +TEST_F(bernoulli_fixture, bernoulli_sample) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + + for (size_t i = 0; i < sample_size; i++) { + sample[i] = bern.sample(gen); + } + + plot_hist(sample); +} + +} // namespace expr +} // namespace ppl diff --git a/test/expression/distribution/density_unittest.cpp b/test/expression/distribution/density_unittest.cpp new file mode 100644 index 00000000..6af92576 --- /dev/null +++ b/test/expression/distribution/density_unittest.cpp @@ -0,0 +1,157 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace expr { + +struct density_fixture : ::testing::Test +{ +protected: + double x = 0.; + double min = -2.3; + double max = 2.7; + double mean = 0.3; + double stddev = 1.3; + double tol = 1e-15; + double p = 0.41; +}; + +/* + * Continuous distribution + */ + +TEST_F(density_fixture, uniform_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(UniformBase::pdf(-2.2999999999, min, max), 0.2); + EXPECT_DOUBLE_EQ(UniformBase::pdf(-2., min, max), 0.2); + EXPECT_DOUBLE_EQ(UniformBase::pdf(-1.423, min, max), 0.2); + EXPECT_DOUBLE_EQ(UniformBase::pdf(0., min, max), 0.2); + EXPECT_DOUBLE_EQ(UniformBase::pdf(1.31, min, max), 0.2); + EXPECT_DOUBLE_EQ(UniformBase::pdf(2.41, min, max), 0.2); + EXPECT_DOUBLE_EQ(UniformBase::pdf(2.69999999999, min, max), 0.2); +} + +TEST_F(density_fixture, uniform_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(UniformBase::pdf(-100, min, max), 0.); + EXPECT_DOUBLE_EQ(UniformBase::pdf(-3.41, min, max), 0.); + EXPECT_DOUBLE_EQ(UniformBase::pdf(-2.3, min, max), 0.); + EXPECT_DOUBLE_EQ(UniformBase::pdf(2.7, min, max), 0.); + EXPECT_DOUBLE_EQ(UniformBase::pdf(3.5, min, max), 0.); + EXPECT_DOUBLE_EQ(UniformBase::pdf(3214, min, max), 0.); +} + +TEST_F(density_fixture, uniform_log_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(-2.2999999999, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(-2., min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(-1.423, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(0., min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(1.31, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(2.41, min, max), std::log(0.2)); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(2.69999999999, min, max), std::log(0.2)); +} + +TEST_F(density_fixture, uniform_log_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(-100, min, max), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(-3.41, min, max), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(-2.3, min, max), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(2.7, min, max), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(3.5, min, max), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(UniformBase::log_pdf(3214, min, max), std::numeric_limits::lowest()); +} + +TEST_F(density_fixture, normal_pdf) +{ + EXPECT_NEAR(NormalBase::pdf(-10.231, mean, stddev), 1.726752595588348216742E-15, tol); + EXPECT_NEAR(NormalBase::pdf(-5.31, mean, stddev), 2.774166877919518907166E-5, tol); + EXPECT_DOUBLE_EQ(NormalBase::pdf(-2.3141231, mean, stddev), 0.04063645713784323551341); + EXPECT_DOUBLE_EQ(NormalBase::pdf(0., mean, stddev), 0.2988151821496727914542); + EXPECT_DOUBLE_EQ(NormalBase::pdf(1.31, mean, stddev), 0.2269313951019926611687); + EXPECT_DOUBLE_EQ(NormalBase::pdf(3.21, mean, stddev), 0.02505560241243631472997); + EXPECT_NEAR(NormalBase::pdf(5.24551, mean, stddev), 2.20984513448306056291E-4, tol); + EXPECT_NEAR(NormalBase::pdf(10.5699, mean, stddev), 8.61135160183067521907E-15, tol); +} + +TEST_F(density_fixture, normal_log_pdf) +{ + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(-10.231, mean, stddev), std::log(1.726752595588348216742E-15)); + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(-5.31, mean, stddev), std::log(2.774166877919518907166E-5)); + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(-2.3141231, mean, stddev), std::log(0.04063645713784323551341)); + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(0., mean, stddev), std::log(0.2988151821496727914542)); + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(1.31, mean, stddev), std::log(0.2269313951019926611687)); + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(3.21, mean, stddev), std::log(0.02505560241243631472997)); + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(5.24551, mean, stddev), std::log(2.20984513448306056291E-4)); + EXPECT_DOUBLE_EQ(NormalBase::log_pdf(10.5699, mean, stddev), std::log(8.61135160183067521907E-15)); +} + +/* + * Discrete distributions + */ + +TEST_F(density_fixture, bernoulli_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(0, p), 1-p); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(1, p), p); +} + +TEST_F(density_fixture, bernoulli_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(-100, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(-3.41, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(-0.00000001, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(0.00000001, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(0.99999999, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(1.00000001, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(3.1423, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(5.613, p), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(100., p), 0.); +} + +TEST_F(density_fixture, bernoulli_pdf_always_tail) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(0, 0.), 1.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(1, 0.), 0.); +} + +TEST_F(density_fixture, bernoulli_pdf_always_head) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(0, 1.), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::pdf(1, 1.), 1.); +} + +TEST_F(density_fixture, bernoulli_log_pdf_in_range) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(0, p), std::log(1-p)); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(1, p), std::log(p)); +} + +TEST_F(density_fixture, bernoulli_log_pdf_out_of_range) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(-100, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(-3.41, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(-0.00000001, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(0.00000001, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(0.99999999, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(1.00000001, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(3.1423, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(5.613, p), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(100., p), std::numeric_limits::lowest()); +} + +TEST_F(density_fixture, bernoulli_log_pdf_always_tail) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(0, 0.), 0.); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(1, 0.), std::numeric_limits::lowest()); +} + +TEST_F(density_fixture, bernoulli_log_pdf_always_head) +{ + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(0, 1.), std::numeric_limits::lowest()); + EXPECT_DOUBLE_EQ(BernoulliBase::log_pdf(1, 1.), 0.); +} + +} // namespace expr +} // namespace ppl diff --git a/test/expression/distribution/normal_unittest.cpp b/test/expression/distribution/normal_unittest.cpp new file mode 100644 index 00000000..27b12291 --- /dev/null +++ b/test/expression/distribution/normal_unittest.cpp @@ -0,0 +1,68 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +struct normal_fixture : ::testing::Test { +protected: + using value_t = typename MockVarExpr::value_t; + static constexpr size_t sample_size = 1000; + double mean = 0.1; + double stddev = 0.8; + MockVarExpr x{mean}; + MockVarExpr y{stddev}; + using norm_t = Normal; + norm_t norm = {x, y}; + std::array sample = {0.}; +}; + +TEST_F(normal_fixture, ctor) +{ + static_assert(util::is_dist_expr_v); +} + +TEST_F(normal_fixture, normal_check_params) { + EXPECT_DOUBLE_EQ(norm.mean(), static_cast(x)); + EXPECT_DOUBLE_EQ(norm.stddev(), static_cast(y)); +} + +TEST_F(normal_fixture, normal_pdf_delegate) { + EXPECT_DOUBLE_EQ(norm.pdf(-10.664), NormalBase::pdf(-10.664, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.pdf(-7.324), NormalBase::pdf(-7.324, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.pdf(-3.241), NormalBase::pdf(-3.241, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.pdf(-0.359288), NormalBase::pdf(-0.359288, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.pdf(0.12314), NormalBase::pdf(0.12314, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.pdf(3.145), NormalBase::pdf(3.145, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.pdf(6.000923), NormalBase::pdf(6.000923, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.pdf(16.423), NormalBase::pdf(16.423, mean, stddev)); +} + +TEST_F(normal_fixture, normal_log_pdf_delegate) { + EXPECT_DOUBLE_EQ(norm.log_pdf(-10.664), NormalBase::log_pdf(-10.664, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.log_pdf(-7.324), NormalBase::log_pdf(-7.324, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.log_pdf(-3.241), NormalBase::log_pdf(-3.241, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.log_pdf(-0.359288), NormalBase::log_pdf(-0.359288, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.log_pdf(0.12314), NormalBase::log_pdf(0.12314, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.log_pdf(3.145), NormalBase::log_pdf(3.145, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.log_pdf(6.000923), NormalBase::log_pdf(6.000923, mean, stddev)); + EXPECT_DOUBLE_EQ(norm.log_pdf(16.423), NormalBase::log_pdf(16.423, mean, stddev)); +} + +TEST_F(normal_fixture, normal_sample) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + + for (size_t i = 0; i < sample_size; i++) { + sample[i] = norm.sample(gen); + } + + plot_hist(sample); +} + +} // namespace expr +} // namespace ppl diff --git a/test/expression/distribution/uniform_unittest.cpp b/test/expression/distribution/uniform_unittest.cpp new file mode 100644 index 00000000..7177ffb0 --- /dev/null +++ b/test/expression/distribution/uniform_unittest.cpp @@ -0,0 +1,70 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +struct uniform_fixture : ::testing::Test { +protected: + using value_t = typename MockVarExpr::value_t; + static constexpr size_t sample_size = 1000; + double min = 0.1; + double max = 0.8; + MockVarExpr x{min}; + MockVarExpr y{max}; + using unif_t = Uniform; + unif_t unif = {x, y}; + std::array sample = {0.}; +}; + +TEST_F(uniform_fixture, ctor) +{ + static_assert(util::is_dist_expr_v); +} + +TEST_F(uniform_fixture, uniform_check_params) { + EXPECT_DOUBLE_EQ(unif.min(), static_cast(x)); + EXPECT_DOUBLE_EQ(unif.max(), static_cast(y)); +} + +TEST_F(uniform_fixture, uniform_pdf_delegate) { + EXPECT_DOUBLE_EQ(unif.pdf(-10.664), UniformBase::pdf(-10.664, min, max)); + EXPECT_DOUBLE_EQ(unif.pdf(-7.324), UniformBase::pdf(-7.324, min, max)); + EXPECT_DOUBLE_EQ(unif.pdf(-3.241), UniformBase::pdf(-3.241, min, max)); + EXPECT_DOUBLE_EQ(unif.pdf(-0.359288), UniformBase::pdf(-0.359288, min, max)); + EXPECT_DOUBLE_EQ(unif.pdf(0.12314), UniformBase::pdf(0.12314, min, max)); + EXPECT_DOUBLE_EQ(unif.pdf(3.145), UniformBase::pdf(3.145, min, max)); + EXPECT_DOUBLE_EQ(unif.pdf(6.000923), UniformBase::pdf(6.000923, min, max)); + EXPECT_DOUBLE_EQ(unif.pdf(16.423), UniformBase::pdf(16.423, min, max)); +} + +TEST_F(uniform_fixture, uniform_log_pdf_delegate) { + EXPECT_DOUBLE_EQ(unif.log_pdf(-10.664), UniformBase::log_pdf(-10.664, min, max)); + EXPECT_DOUBLE_EQ(unif.log_pdf(-7.324), UniformBase::log_pdf(-7.324, min, max)); + EXPECT_DOUBLE_EQ(unif.log_pdf(-3.241), UniformBase::log_pdf(-3.241, min, max)); + EXPECT_DOUBLE_EQ(unif.log_pdf(-0.359288), UniformBase::log_pdf(-0.359288, min, max)); + EXPECT_DOUBLE_EQ(unif.log_pdf(0.12314), UniformBase::log_pdf(0.12314, min, max)); + EXPECT_DOUBLE_EQ(unif.log_pdf(3.145), UniformBase::log_pdf(3.145, min, max)); + EXPECT_DOUBLE_EQ(unif.log_pdf(6.000923), UniformBase::log_pdf(6.000923, min, max)); + EXPECT_DOUBLE_EQ(unif.log_pdf(16.423), UniformBase::log_pdf(16.423, min, max)); +} + +TEST_F(uniform_fixture, uniform_sample) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + + for (size_t i = 0; i < sample_size; i++) { + sample[i] = unif.sample(gen); + EXPECT_GT(sample[i], min); + EXPECT_LT(sample[i], max); + } + + plot_hist(sample, 0.05, min, max); +} + +} // namespace expr +} // namespace ppl diff --git a/test/expression/model/model_unittest.cpp b/test/expression/model/model_unittest.cpp new file mode 100644 index 00000000..b8ab9625 --- /dev/null +++ b/test/expression/model/model_unittest.cpp @@ -0,0 +1,161 @@ +#include "gtest/gtest.h" +#include +#include +#include +#include + +namespace ppl { +namespace expr { + +////////////////////////////////////////////////////// +// Model with one RV TESTS +////////////////////////////////////////////////////// + +/* + * Fixture for testing one var with distribution. + */ +struct var_dist_fixture : ::testing::Test +{ +protected: + MockVar x; + using model_t = EqNode; + model_t model = {x, MockDistExpr()}; + double val; + + void reconfigure() + { x.set_value(val); } +}; + +TEST_F(var_dist_fixture, ctor) +{ + static_assert(util::is_model_expr_v); +} + +TEST_F(var_dist_fixture, pdf_valid) +{ + // MockDistExpr pdf is identity function + // so we may simply compare model.pdf() with val. + + val = 0.000001; + reconfigure(); + EXPECT_EQ(model.pdf(), val); + + val = 0.5; + reconfigure(); + EXPECT_EQ(model.pdf(), val); + + val = 0.999999; + reconfigure(); + EXPECT_EQ(model.pdf(), val); +} + +TEST_F(var_dist_fixture, log_pdf_valid) +{ + val = 0.000001; + reconfigure(); + EXPECT_EQ(model.log_pdf(), std::log(val)); + + val = 0.5; + reconfigure(); + EXPECT_EQ(model.log_pdf(), std::log(val)); + + val = 0.999999; + reconfigure(); + EXPECT_EQ(model.log_pdf(), std::log(val)); +} + +////////////////////////////////////////////////////// +// Model with many RV (no dependencies) TESTS +////////////////////////////////////////////////////// + +/* + * Fixture for testing many vars with distributions. + */ +struct many_var_dist_fixture : ::testing::Test +{ +protected: + MockVar x, y, z, w; + double xv, yv, zv, wv; + using eq_t = EqNode; + + using model_two_t = GlueNode; + model_two_t model_two = { + {x, MockDistExpr()}, + {y, MockDistExpr()} + }; + + using model_four_t = + GlueNode + > + >; + + model_four_t model_four = { + {x, MockDistExpr()}, + { + {y, MockDistExpr()}, + { + {z, MockDistExpr()}, + {w, MockDistExpr()} + } + } + }; +}; + +TEST_F(many_var_dist_fixture, ctor) +{ + static_assert(util::is_model_expr_v); + static_assert(util::is_model_expr_v); +} + +TEST_F(many_var_dist_fixture, two_vars_pdf) +{ + xv = 0.2; yv = 1.8; + + x.set_value(xv); + y.set_value(yv); + + EXPECT_EQ(model_two.pdf(), xv * yv); + EXPECT_EQ(model_two.log_pdf(), std::log(xv) + std::log(yv)); +} + +TEST_F(many_var_dist_fixture, four_vars_pdf) +{ + xv = 0.2; yv = 1.8; zv = 3.2; wv = 0.3; + + x.set_value(xv); + y.set_value(yv); + z.set_value(zv); + w.set_value(wv); + + EXPECT_EQ(model_four.pdf(), xv * yv * zv * wv); + EXPECT_EQ(model_four.log_pdf(), std::log(xv) + std::log(yv) + + std::log(zv) + std::log(wv)); +} + +TEST_F(many_var_dist_fixture, four_vars_traverse_count_params) +{ + int count = 0; + z.set_state(MockState::data); + model_four.traverse([&](auto& model) { + using var_t = std::decay_t; + using state_t = typename util::var_traits::state_t; + count += (model.get_variable().get_state() == state_t::parameter); + }); + EXPECT_EQ(count, 3); +} + +TEST_F(many_var_dist_fixture, four_vars_traverse_pdf) +{ + double actual = 1.; + model_four.traverse([&](auto& model) { + auto& var = model.get_variable(); + auto& dist = model.get_distribution(); + actual *= dist.pdf(var); + }); + EXPECT_EQ(actual, model_four.pdf()); +} + +} // namespace expr +} // namespace ppl diff --git a/test/expression/variable/constant_unittest.cpp b/test/expression/variable/constant_unittest.cpp new file mode 100644 index 00000000..6efdc71e --- /dev/null +++ b/test/expression/variable/constant_unittest.cpp @@ -0,0 +1,30 @@ +#include "gtest/gtest.h" +#include +#include +#include + +namespace ppl { +namespace expr { + +struct constant_fixture : ::testing::Test +{ +protected: + using value_t = double; + value_t c = 0.3; + Constant x{c}; +}; + +TEST_F(constant_fixture, ctor) +{ + static_assert(util::is_var_expr_v>); +} + +TEST_F(constant_fixture, convertible_value) +{ + EXPECT_EQ(static_cast(x), 0.3); + c = 3.41; + EXPECT_EQ(static_cast(x), 0.3); +} + +} // namespace expr +} // namespace ppl diff --git a/test/expression/variable/variable_viewer_unittest.cpp b/test/expression/variable/variable_viewer_unittest.cpp new file mode 100644 index 00000000..1324efed --- /dev/null +++ b/test/expression/variable/variable_viewer_unittest.cpp @@ -0,0 +1,32 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace expr { + +struct variable_viewer_fixture : ::testing::Test +{ +protected: + using value_t = typename MockVar::value_t; + MockVar var; + VariableViewer x = var; +}; + +TEST_F(variable_viewer_fixture, ctor) +{ + static_assert(util::is_var_expr_v>); +} + +TEST_F(variable_viewer_fixture, convertible_value) +{ + var.set_value(1.); + EXPECT_EQ(static_cast(x), 1.); + + // Tests if viewer correctly reflects any changes that happened in var. + var.set_value(-3.14); + EXPECT_EQ(static_cast(x), -3.14); +} + +} // namespace expr +} // namespace ppl diff --git a/test/test1.cpp b/test/test1.cpp deleted file mode 100644 index 876c350c..00000000 --- a/test/test1.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "gtest/gtest.h" -#include "autoppl.hpp" - -namespace { - -TEST(blaTest, test1) { - int n = ppl::fib(10); - EXPECT_EQ(n, 89); -} - -} diff --git a/test/testutil/mock_types.hpp b/test/testutil/mock_types.hpp new file mode 100644 index 00000000..fe4cbee1 --- /dev/null +++ b/test/testutil/mock_types.hpp @@ -0,0 +1,118 @@ +#pragma once +#include + +namespace ppl { + +/* + * Mock state class for testing purposes. + */ +enum class MockState { + data, + parameter +}; + +/* + * Mock Variable class that should meet the requirements + * of is_var_v. + */ +struct MockVar +{ + using value_t = double; + using pointer_t = double*; + using const_pointer_t = const double*; + using state_t = MockState; + + operator value_t() const {return x_;} + void set_value(value_t x) {x_ = x;} + value_t get_value() const {return x_;} + + void set_storage(pointer_t ptr) {ptr_ = ptr;} + + void set_state(state_t state) {state_ = state;} + state_t get_state() const {return state_;} + +private: + value_t x_ = 0.; + pointer_t ptr_ = nullptr; + state_t state_ = state_t::parameter; +}; + +/* + * Mock variable classes that fulfill + * var_traits requirements, but do not fit the rest. + */ +struct MockVar_no_convertible +{ + using value_t = double; + using pointer_t = double*; + using state_t = void; +}; + +/* + * Mock Variable Expression class that should meet the requirements + * of is_var_expr_v. + */ +struct MockVarExpr +{ + using value_t = double; + operator value_t() const {return x_;} + + /* not part of API */ + MockVarExpr(value_t x) + : x_{x} + {} + + void set_value(value_t x) {x_ = x;} +private: + value_t x_ = 0.; +}; + +/* + * Mock variable expression classes that fulfill + * var_expr_traits requirements, but do not fit the rest. + */ +struct MockVarExpr_no_convertible +{ + using value_t = double; +}; + +/* + * Mock distribution expression class that should meet the requirements + * of is_dist_expr_v. + */ +struct MockDistExpr +{ + using value_t = double; + using dist_value_t = double; + + dist_value_t pdf(value_t x) const + { return x; } + + dist_value_t log_pdf(value_t x) const + { return std::log(x); } +}; + +/* + * Mock distribution expression classes that fulfill + * dist_expr_traits requirements, but do not fit the rest. + */ +struct MockDistExpr_no_pdf : public MockDistExpr +{ +private: + using MockDistExpr::pdf; +}; + +struct MockDistExpr_no_log_pdf : public MockDistExpr +{ +private: + using MockDistExpr::log_pdf; +}; + +/* + * TODO: + * Mock model expression clases that should meet the + * requirements of is_model_expr_v. + * Additionally, MockEqNode should satisfy is_eq_node_expr_v. + */ + +} // namespace ppl diff --git a/test/testutil/sample_tools.hpp b/test/testutil/sample_tools.hpp new file mode 100644 index 00000000..68537d5c --- /dev/null +++ b/test/testutil/sample_tools.hpp @@ -0,0 +1,53 @@ +#pragma once +#include "gtest/gtest.h" +#include +#include +#include +#include +#include +#include + +namespace ppl { + +// Plotting utility to visualize histogram of samples. +template +inline void plot_hist(const ArrayType& arr, + double step_size = .5, + double min = std::numeric_limits::lowest(), + double max = std::numeric_limits::max() + ) +{ + constexpr size_t nstars = 100; // maximum number of stars to distribute + + min = (min == std::numeric_limits::lowest()) ? + *std::min_element(arr.begin(), arr.end()) : + min; + max = (max == std::numeric_limits::max()) ? + *std::max_element(arr.begin(), arr.end()) : + max; + const int64_t nearest_min = std::floor(min); + const int64_t nearest_max = std::floor(max) + 1; + const uint64_t range = nearest_max - nearest_min; + const uint64_t n_hist = std::floor(range/step_size); + + // keeps count for each histogram bar + std::vector counter(n_hist, 0); + + for (auto x : arr) { + if (nearest_min <= x && x <= nearest_max) { + ++counter[std::floor((x - nearest_min) / step_size)]; + } + } + + if ((min == *std::min_element(arr.begin(), arr.end())) && + (max == *std::max_element(arr.begin(), arr.end()))) { + EXPECT_EQ(std::accumulate(counter.begin(), counter.end(), 0), (int) arr.size()); + } + + for (size_t i = 0; i < n_hist; ++i) { + std::cout << i << "-" << (i+1) << ": " << '\t'; + std::cout << std::string(counter[i] * nstars/arr.size(), '*') << std::endl; + } +} + +} // namespace ppl diff --git a/test/util/concept_unittest.cpp b/test/util/concept_unittest.cpp new file mode 100644 index 00000000..d2fe4899 --- /dev/null +++ b/test/util/concept_unittest.cpp @@ -0,0 +1,91 @@ +#include "gtest/gtest.h" +#include + +namespace ppl { +namespace util { + +struct MockType +{}; + +struct MockType2 +{ + using value_t = double; + using pointer_t = double*; + using state_t = void; + + void pdf() {}; + void log_pdf() {}; +}; + +struct concept_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(concept_fixture, has_type_value_t_v_false) +{ + static_assert(!has_type_value_t_v); + static_assert(!has_type_value_t_v); + static_assert(!has_type_value_t_v); + static_assert(!has_type_value_t_v); +} + +TEST_F(concept_fixture, has_type_value_t_v_true) +{ + static_assert(has_type_value_t_v); +} + +TEST_F(concept_fixture, has_type_pointer_t_v_false) +{ + static_assert(!has_type_pointer_t_v); + static_assert(!has_type_pointer_t_v); + static_assert(!has_type_pointer_t_v); + static_assert(!has_type_pointer_t_v); +} + +TEST_F(concept_fixture, has_type_pointer_t_v_true) +{ + static_assert(has_type_pointer_t_v); +} + +TEST_F(concept_fixture, has_type_state_t_v_false) +{ + static_assert(!has_type_state_t_v); + static_assert(!has_type_state_t_v); + static_assert(!has_type_state_t_v); + static_assert(!has_type_state_t_v); +} + +TEST_F(concept_fixture, has_type_state_t_v_true) +{ + static_assert(has_type_state_t_v); +} + +TEST_F(concept_fixture, has_func_pdf_v_false) +{ + static_assert(!has_func_pdf_v); + static_assert(!has_func_pdf_v); + static_assert(!has_func_pdf_v); + static_assert(!has_func_pdf_v); +} + +TEST_F(concept_fixture, has_func_pdf_v_true) +{ + static_assert(has_func_pdf_v); +} + +TEST_F(concept_fixture, has_func_log_pdf_v_false) +{ + static_assert(!has_func_log_pdf_v); + static_assert(!has_func_log_pdf_v); + static_assert(!has_func_log_pdf_v); + static_assert(!has_func_log_pdf_v); +} + +TEST_F(concept_fixture, has_func_log_pdf_v_true) +{ + static_assert(has_func_log_pdf_v); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/dist_expr_traits_unittest.cpp b/test/util/dist_expr_traits_unittest.cpp new file mode 100644 index 00000000..3858405f --- /dev/null +++ b/test/util/dist_expr_traits_unittest.cpp @@ -0,0 +1,25 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct dist_expr_traits_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(dist_expr_traits_fixture, is_dist_expr_v_true) +{ + static_assert(is_dist_expr_v); +} + +TEST_F(dist_expr_traits_fixture, is_dist_expr_v_false) +{ + static_assert(!is_dist_expr_v); + static_assert(!is_dist_expr_v); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/var_expr_traits_unittest.cpp b/test/util/var_expr_traits_unittest.cpp new file mode 100644 index 00000000..9ca6d875 --- /dev/null +++ b/test/util/var_expr_traits_unittest.cpp @@ -0,0 +1,24 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct var_expr_traits_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(var_expr_traits_fixture, is_var_expr_v_true) +{ + static_assert(is_var_expr_v); +} + +TEST_F(var_expr_traits_fixture, is_var_expr_v_false) +{ + static_assert(!is_var_expr_v); +} + +} // namespace util +} // namespace ppl diff --git a/test/util/var_traits_unittest.cpp b/test/util/var_traits_unittest.cpp new file mode 100644 index 00000000..30823073 --- /dev/null +++ b/test/util/var_traits_unittest.cpp @@ -0,0 +1,24 @@ +#include "gtest/gtest.h" +#include +#include + +namespace ppl { +namespace util { + +struct var_traits_fixture : ::testing::Test +{ +protected: +}; + +TEST_F(var_traits_fixture, is_var_v_true) +{ + static_assert(is_var_v); +} + +TEST_F(var_traits_fixture, is_var_v_false) +{ + static_assert(!is_var_v); +} + +} // namespace util +} // namespace ppl