From d8170ea5d9b14feffa9b8d155ad29e7f4211d7c1 Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 13:35:38 -0400 Subject: [PATCH 1/7] renamed var --- include/autoppl/expression/model.hpp | 36 +++++++++---------- include/autoppl/expression/traits.hpp | 13 ++++--- .../expression/{rv_tag.hpp => variable.hpp} | 36 +++++++++---------- test/CMakeLists.txt | 16 ++++++++- .../uniform_unittest.cpp | 13 +++---- test/expression/model_unittest.cpp | 30 ++++++++-------- 6 files changed, 81 insertions(+), 63 deletions(-) rename include/autoppl/expression/{rv_tag.hpp => variable.hpp} (66%) rename test/{expression => distribution}/uniform_unittest.cpp (63%) diff --git a/include/autoppl/expression/model.hpp b/include/autoppl/expression/model.hpp index ae236a4e..a9566e8f 100644 --- a/include/autoppl/expression/model.hpp +++ b/include/autoppl/expression/model.hpp @@ -8,29 +8,29 @@ namespace ppl { namespace details { template -struct IdentityTagFunctor +struct IdentityVarFunctor { using value_t = typename std::iterator_traits::value_type; - value_t& operator()(value_t& tag) - { return tag; } + value_t& operator()(value_t& var) + { return var; } }; } // namespace details /* * This class represents a "node" in the model expression - * that relates a tag with a distribution. + * that relates a var with a distribution. */ -template +template struct EqNode { - using tag_t = TagType; + using var_t = VarType; using dist_t = DistType; using dist_value_t = typename dist_traits::dist_value_t; - EqNode(const tag_t& tag, + EqNode(const var_t& var, const dist_t& dist) noexcept - : orig_tag_cref_{tag} + : orig_var_cref_{var} , dist_{dist} {} @@ -39,22 +39,22 @@ struct EqNode * Assumes that underlying value has been assigned properly. */ dist_value_t pdf() const - { return dist_.pdf(orig_tag_cref_.get().get_value()); } + { return dist_.pdf(orig_var_cref_.get().get_value()); } /* * Compute log-pdf of underlying distribution with underlying value. * Assumes that underlying value has been assigned properly. */ dist_value_t log_pdf() const - { return dist_.log_pdf(orig_tag_cref_.get().get_value()); } + { return dist_.log_pdf(orig_var_cref_.get().get_value()); } private: - using tag_cref_t = std::reference_wrapper; - using opt_tag_cref_t = std::optional; + using var_cref_t = std::reference_wrapper; + using opt_var_cref_t = std::optional; - tag_cref_t orig_tag_cref_; // (const) reference of the original tag since + var_cref_t orig_var_cref_; // (const) reference of the original var since // any configuration may be changed until right before update - dist_t dist_; // distribution associated with tag + dist_t dist_; // distribution associated with var }; /* @@ -104,14 +104,14 @@ struct GlueNode // with concepts! /* - * Builds an EqNode to associate tag with dist. + * Builds an EqNode to associate var with dist. * Ex. x |= uniform(0,1) */ -template -constexpr inline auto operator|=(const TagType& tag, +template +constexpr inline auto operator|=(const VarType& var, const DistType& dist) { - return EqNode(tag, dist); + return EqNode(var, dist); } /* diff --git a/include/autoppl/expression/traits.hpp b/include/autoppl/expression/traits.hpp index dce5efc6..bb4c83d2 100644 --- a/include/autoppl/expression/traits.hpp +++ b/include/autoppl/expression/traits.hpp @@ -1,4 +1,5 @@ #pragma once +#include namespace ppl { @@ -8,12 +9,14 @@ namespace ppl { * Users should rely on these classes to grab member aliases. */ -template -struct tag_traits +template +struct var_traits { - using value_t = typename TagType::value_t; - using pointer_t = typename TagType::pointer_t; - using state_t = typename TagType::state_t; + using value_t = typename VarType::value_t; + using pointer_t = typename VarType::pointer_t; + using state_t = typename VarType::state_t; + + static_assert(std::is_convertible_v); }; template diff --git a/include/autoppl/expression/rv_tag.hpp b/include/autoppl/expression/variable.hpp similarity index 66% rename from include/autoppl/expression/rv_tag.hpp rename to include/autoppl/expression/variable.hpp index 19063603..3c332e9f 100644 --- a/include/autoppl/expression/rv_tag.hpp +++ b/include/autoppl/expression/variable.hpp @@ -3,45 +3,45 @@ namespace ppl { /* - * The possible states for a tag. - * By default, all tags should be considered as a parameter. + * The possible states for a var. + * By default, all vars should be considered as a parameter. * TODO: maybe move in a different file? */ -enum class tag_state : bool { +enum class var_state : bool { data, parameter }; /* - * rv_tag is a light-weight structure that represents a univariate random variable. + * Variable is a light-weight structure that represents a univariate random variable. * It acts as an intermediate layer of communication between - * a model expression and the users, who must supply storage of values associated with this tag. + * a model expression and the users, who must supply storage of values associated with this var. */ template -struct rv_tag +struct Variable { using value_t = ValueType; using pointer_t = value_t*; using const_pointer_t = const value_t*; - using state_t = tag_state; + using state_t = var_state; // constructors - rv_tag(value_t value, + Variable(value_t value, pointer_t storage_ptr) noexcept : value_{value} , storage_ptr_{storage_ptr} , state_{state_t::parameter} {} - rv_tag(pointer_t storage_ptr) noexcept - : rv_tag(0, storage_ptr) + Variable(pointer_t storage_ptr) noexcept + : Variable(0, storage_ptr) {} - rv_tag(value_t value) noexcept - : rv_tag(value, nullptr) {} + Variable(value_t value) noexcept + : Variable(value, nullptr) {} - rv_tag() noexcept - : rv_tag(0, nullptr) + Variable() noexcept + : Variable(0, nullptr) {} void set_value(value_t value) { value_ = value; } @@ -58,7 +58,7 @@ struct rv_tag /* * Sets underlying value to "value". - * Additionally modifies the tag to be considered as data. + * Additionally modifies the var to be considered as data. * Equivalent to calling set_value(value) then set_state(state). */ void observe(value_t value) @@ -68,14 +68,14 @@ struct rv_tag } private: - value_t value_; // store value associated with tag + value_t value_; // store value associated with var pointer_t storage_ptr_; // points to beginning of storage // storage is assumed to be contiguous state_t state_; // state to determine if data or param }; // Useful aliases -using cont_rv_tag = rv_tag; // continuous RV tag -using disc_rv_tag = rv_tag; // discrete RV tag +using cont_var = Variable; // continuous RV var +using disc_var = Variable; // discrete RV var } // namespace ppl diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index be61815b..998e9576 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -11,7 +11,6 @@ endif() add_executable(expression_unittest ${CMAKE_CURRENT_SOURCE_DIR}/expression/model_unittest.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/expression/uniform_unittest.cpp ) target_compile_options(expression_unittest PRIVATE -g -Wall -Werror -Wextra) target_include_directories(expression_unittest PRIVATE ${GTEST_DIR}/include) @@ -20,3 +19,18 @@ if (AUTOPPL_ENABLE_TEST_COVERAGE) endif() target_link_libraries(expression_unittest gtest_main pthread ${PROJECT_NAME}) add_test(expression_unittest expression_unittest) + +###################################################### +# Distribution Test +###################################################### + +add_executable(distribution_unittest + ${CMAKE_CURRENT_SOURCE_DIR}/distribution/uniform_unittest.cpp + ) +target_compile_options(distribution_unittest PRIVATE -g -Wall -Werror -Wextra) +target_include_directories(distribution_unittest PRIVATE ${GTEST_DIR}/include) +if (AUTOPPL_ENABLE_TEST_COVERAGE) + target_link_libraries(distribution_unittest gcov) +endif() +target_link_libraries(distribution_unittest gtest_main pthread ${PROJECT_NAME}) +add_test(distribution_unittest distribution_unittest) \ No newline at end of file diff --git a/test/expression/uniform_unittest.cpp b/test/distribution/uniform_unittest.cpp similarity index 63% rename from test/expression/uniform_unittest.cpp rename to test/distribution/uniform_unittest.cpp index 99b9448c..f9f52a26 100644 --- a/test/expression/uniform_unittest.cpp +++ b/test/distribution/uniform_unittest.cpp @@ -1,6 +1,5 @@ #include -#include -#include +#include #include #include @@ -11,17 +10,19 @@ namespace ppl { struct uniform_dist_fixture : ::testing::Test { protected: - rv_tag x {0.5}; - rv_tag y {0.1}; + Variable x {0.5}; + Variable y {0.1}; Uniform dist1 = Uniform(0., 1.); - Uniform > dist2 = Uniform(0., x); - Uniform, rv_tag > dist3 = Uniform(y, x); + Uniform > dist2 = Uniform(0., x); + Uniform, Variable > dist3 = Uniform(y, x); }; TEST_F(uniform_dist_fixture, simple_uniform) { EXPECT_EQ(dist1.pdf(1.1), 0.0); + EXPECT_EQ(dist2.pdf(1.0), 0.0); EXPECT_EQ(dist2.pdf(0.25), 2.0); + EXPECT_EQ(dist3.pdf(-0.1), 0.0); EXPECT_EQ(dist3.pdf(0.25), 2.5); } diff --git a/test/expression/model_unittest.cpp b/test/expression/model_unittest.cpp index 7be5171f..9007b3ab 100644 --- a/test/expression/model_unittest.cpp +++ b/test/expression/model_unittest.cpp @@ -10,10 +10,10 @@ namespace ppl { ////////////////////////////////////////////////////// /* - * Mock tag object for testing purposes. - * Must meet some of the requirements of actual tag types. + * Mock var object for testing purposes. + * Must meet some of the requirements of actual var types. */ -struct MockTag +struct MockVar { using value_t = double; using pointer_t = double*; @@ -45,12 +45,12 @@ struct MockDist }; /* - * Fixture for testing one tag with distribution. + * Fixture for testing one var with distribution. */ -struct tag_dist_fixture : ::testing::Test +struct var_dist_fixture : ::testing::Test { protected: - MockTag x; + MockVar x; using model_t = std::decay_t; model_t model = (x |= MockDist()); double val; @@ -61,7 +61,7 @@ struct tag_dist_fixture : ::testing::Test } }; -TEST_F(tag_dist_fixture, pdf_valid) +TEST_F(var_dist_fixture, pdf_valid) { // MockDist pdf is identity function // so we may simply compare model.pdf() with val. @@ -79,7 +79,7 @@ TEST_F(tag_dist_fixture, pdf_valid) EXPECT_EQ(model.pdf(), val); } -TEST_F(tag_dist_fixture, pdf_invalid) +TEST_F(var_dist_fixture, pdf_invalid) { val = 0.000004123; reconfigure(val); @@ -98,7 +98,7 @@ TEST_F(tag_dist_fixture, pdf_invalid) EXPECT_EQ(model.pdf(), val); } -TEST_F(tag_dist_fixture, log_pdf_valid) +TEST_F(var_dist_fixture, log_pdf_valid) { val = 0.000001; reconfigure(val); @@ -113,7 +113,7 @@ TEST_F(tag_dist_fixture, log_pdf_valid) EXPECT_EQ(model.log_pdf(), std::log(val)); } -TEST_F(tag_dist_fixture, log_pdf_invalid) +TEST_F(var_dist_fixture, log_pdf_invalid) { val = 0.000004123; reconfigure(val); @@ -137,16 +137,16 @@ TEST_F(tag_dist_fixture, log_pdf_invalid) ////////////////////////////////////////////////////// /* - * Fixture for testing many tags with distributions. + * Fixture for testing many vars with distributions. */ -struct many_tag_dist_fixture : ::testing::Test +struct many_var_dist_fixture : ::testing::Test { protected: - MockTag x, y, z, w; + MockVar x, y, z, w; double xv, yv, zv, wv; }; -TEST_F(many_tag_dist_fixture, two_tags) +TEST_F(many_var_dist_fixture, two_vars) { auto model = ( x |= MockDist(), @@ -162,7 +162,7 @@ TEST_F(many_tag_dist_fixture, two_tags) EXPECT_EQ(model.log_pdf(), std::log(xv) + std::log(yv)); } -TEST_F(many_tag_dist_fixture, four_tags) +TEST_F(many_var_dist_fixture, four_vars) { auto model = ( x |= MockDist(), From a9f0a04a612cc07cd46b82412e8ab45f57de6c3e Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 13:38:58 -0400 Subject: [PATCH 2/7] moved distributions to a separate folder --- include/autoppl/{expression => distribution}/uniform.hpp | 0 test/distribution/uniform_unittest.cpp | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename include/autoppl/{expression => distribution}/uniform.hpp (100%) diff --git a/include/autoppl/expression/uniform.hpp b/include/autoppl/distribution/uniform.hpp similarity index 100% rename from include/autoppl/expression/uniform.hpp rename to include/autoppl/distribution/uniform.hpp diff --git a/test/distribution/uniform_unittest.cpp b/test/distribution/uniform_unittest.cpp index f9f52a26..f9b98e99 100644 --- a/test/distribution/uniform_unittest.cpp +++ b/test/distribution/uniform_unittest.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include From a371f33fbd3408213cfc8451684d487cfc43eb3a Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 14:09:26 -0400 Subject: [PATCH 3/7] added normal distributions --- include/autoppl/distribution/normal.hpp | 58 +++++++++++++++++++++++++ test/CMakeLists.txt | 3 +- test/distribution/normal_unittest.cpp | 29 +++++++++++++ test/distribution/uniform_unittest.cpp | 10 ++--- 4 files changed, 94 insertions(+), 6 deletions(-) create mode 100644 include/autoppl/distribution/normal.hpp create mode 100644 test/distribution/normal_unittest.cpp diff --git a/include/autoppl/distribution/normal.hpp b/include/autoppl/distribution/normal.hpp new file mode 100644 index 00000000..8e1e6181 --- /dev/null +++ b/include/autoppl/distribution/normal.hpp @@ -0,0 +1,58 @@ +#pragma once +#include + +#include +#include +#include +#include + +namespace ppl { + +// TODO: change name to NormalDist and make class template. +// normal should be a function that creates this kind of object. + +template +struct Normal { + using value_t = double; + using dist_value_t = double; + + static_assert(std::is_convertible_v); + static_assert(std::is_convertible_v); + + Normal(mean_type mean, var_type var) + : mean_{mean}, var_{var} { + assert(static_cast(var_) > 0); + }; + + template + value_t sample(GeneratorType& gen) const { + value_t mean, var; + mean = static_cast(mean_); + var = static_cast(var_); + + std::normal_distribution dist(mean, var); + return dist(gen); + } + + dist_value_t pdf(value_t x) const { + value_t mean, var; + mean = static_cast(mean_); + var = static_cast(var_); + + return std::exp(- 0.5 * std::pow(x - mean, 2) / var) / (std::sqrt(var * 2 * M_PI)); + } + + dist_value_t log_pdf(value_t x) const { + value_t mean, var; + mean = static_cast(mean_); + var = static_cast(var_); + + return (-0.5 * std::pow(x - mean, 2) / var) - 0.5 * (std::log(var) + std::log(2) + std::log(M_PI)); + } + + private: + mean_type mean_; + var_type var_; +}; + +} // namespace ppl diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 998e9576..7f93364b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -26,6 +26,7 @@ add_test(expression_unittest expression_unittest) add_executable(distribution_unittest ${CMAKE_CURRENT_SOURCE_DIR}/distribution/uniform_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distribution/normal_unittest.cpp ) target_compile_options(distribution_unittest PRIVATE -g -Wall -Werror -Wextra) target_include_directories(distribution_unittest PRIVATE ${GTEST_DIR}/include) @@ -33,4 +34,4 @@ if (AUTOPPL_ENABLE_TEST_COVERAGE) target_link_libraries(distribution_unittest gcov) endif() target_link_libraries(distribution_unittest gtest_main pthread ${PROJECT_NAME}) -add_test(distribution_unittest distribution_unittest) \ No newline at end of file +add_test(distribution_unittest distribution_unittest) diff --git a/test/distribution/normal_unittest.cpp b/test/distribution/normal_unittest.cpp new file mode 100644 index 00000000..7beb42b3 --- /dev/null +++ b/test/distribution/normal_unittest.cpp @@ -0,0 +1,29 @@ +#include +#include + +#include +#include + +#include "gtest/gtest.h" + +namespace ppl { + +struct normal_dist_fixture : ::testing::Test { +protected: + Variable mu {0.}; + Variable sigma {1.}; + Normal dist1 = Normal(0., 1.); + Normal, Variable > dist2 = Normal(mu, sigma); +}; + +TEST_F(normal_dist_fixture, simple_gaussian) { + EXPECT_DOUBLE_EQ(dist1.pdf(0.0), 0.3989422804014327); + EXPECT_DOUBLE_EQ(dist1.pdf(-0.5), 0.3520653267642995); + EXPECT_DOUBLE_EQ(dist1.pdf(4), 0.00013383022576488537); + + EXPECT_DOUBLE_EQ(dist1.log_pdf(0.0), std::log(dist1.pdf(0.0))); + EXPECT_DOUBLE_EQ(dist1.log_pdf(-0.5), std::log(dist1.pdf(-0.5))); + EXPECT_DOUBLE_EQ(dist1.log_pdf(4), std::log(dist1.pdf(4))); +} + +} // ppl \ No newline at end of file diff --git a/test/distribution/uniform_unittest.cpp b/test/distribution/uniform_unittest.cpp index f9b98e99..fa13c601 100644 --- a/test/distribution/uniform_unittest.cpp +++ b/test/distribution/uniform_unittest.cpp @@ -18,13 +18,13 @@ struct uniform_dist_fixture : ::testing::Test { }; TEST_F(uniform_dist_fixture, simple_uniform) { - EXPECT_EQ(dist1.pdf(1.1), 0.0); + EXPECT_DOUBLE_EQ(dist1.pdf(1.1), 0.0); - EXPECT_EQ(dist2.pdf(1.0), 0.0); - EXPECT_EQ(dist2.pdf(0.25), 2.0); + EXPECT_DOUBLE_EQ(dist2.pdf(1.0), 0.0); + EXPECT_DOUBLE_EQ(dist2.pdf(0.25), 2.0); - EXPECT_EQ(dist3.pdf(-0.1), 0.0); - EXPECT_EQ(dist3.pdf(0.25), 2.5); + EXPECT_DOUBLE_EQ(dist3.pdf(-0.1), 0.0); + EXPECT_DOUBLE_EQ(dist3.pdf(0.25), 2.5); } } // ppl \ No newline at end of file From 04bd73f61c689e7588a617d5485d5c9a7a2131ad Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 14:17:39 -0400 Subject: [PATCH 4/7] added simple sampling test --- test/distribution/uniform_unittest.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/distribution/uniform_unittest.cpp b/test/distribution/uniform_unittest.cpp index fa13c601..e7366b65 100644 --- a/test/distribution/uniform_unittest.cpp +++ b/test/distribution/uniform_unittest.cpp @@ -27,4 +27,15 @@ TEST_F(uniform_dist_fixture, simple_uniform) { EXPECT_DOUBLE_EQ(dist3.pdf(0.25), 2.5); } +TEST_F(uniform_dist_fixture, uniform_sampling) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + + for (int i = 0; i < 100; i++) { + double sample = dist1.sample(gen); + EXPECT_GT(sample, 0.0); + EXPECT_LT(sample, 1.0); + } +} + } // ppl \ No newline at end of file From bf549b54db05c389d4ed30aa05d1aa79bd5c6077 Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 14:38:16 -0400 Subject: [PATCH 5/7] improved tests and distributions --- include/autoppl/distribution/normal.hpp | 23 ++++++-------------- include/autoppl/distribution/uniform.hpp | 27 ++++++++---------------- test/distribution/normal_unittest.cpp | 8 +++++++ test/distribution/uniform_unittest.cpp | 15 +++++++++++-- 4 files changed, 37 insertions(+), 36 deletions(-) diff --git a/include/autoppl/distribution/normal.hpp b/include/autoppl/distribution/normal.hpp index 8e1e6181..ed023f19 100644 --- a/include/autoppl/distribution/normal.hpp +++ b/include/autoppl/distribution/normal.hpp @@ -21,35 +21,26 @@ struct Normal { Normal(mean_type mean, var_type var) : mean_{mean}, var_{var} { - assert(static_cast(var_) > 0); + assert(this -> var() > 0); }; template value_t sample(GeneratorType& gen) const { - value_t mean, var; - mean = static_cast(mean_); - var = static_cast(var_); - - std::normal_distribution dist(mean, var); + std::normal_distribution dist(mean(), var()); return dist(gen); } dist_value_t pdf(value_t x) const { - value_t mean, var; - mean = static_cast(mean_); - var = static_cast(var_); - - return std::exp(- 0.5 * std::pow(x - mean, 2) / var) / (std::sqrt(var * 2 * M_PI)); + return std::exp(- 0.5 * std::pow(x - mean(), 2) / var()) / (std::sqrt(var() * 2 * M_PI)); } dist_value_t log_pdf(value_t x) const { - value_t mean, var; - mean = static_cast(mean_); - var = static_cast(var_); - - return (-0.5 * std::pow(x - mean, 2) / var) - 0.5 * (std::log(var) + std::log(2) + std::log(M_PI)); + return (-0.5 * std::pow(x - mean(), 2) / var()) - 0.5 * (std::log(var()) + std::log(2) + std::log(M_PI)); } + inline value_t mean() const { return static_cast(mean_);} + inline value_t var() const { return static_cast(var_);} + private: mean_type mean_; var_type var_; diff --git a/include/autoppl/distribution/uniform.hpp b/include/autoppl/distribution/uniform.hpp index 15547ace..4fb75911 100644 --- a/include/autoppl/distribution/uniform.hpp +++ b/include/autoppl/distribution/uniform.hpp @@ -16,41 +16,32 @@ struct Uniform using dist_value_t = double; Uniform(min_type min, max_type max) - : min_{min}, max_{max} { assert(static_cast(min_) < static_cast(max_)); } + : min_{min}, max_{max} { assert(this -> min() < this -> max()); } // TODO: tag this class as "TriviallySamplable"? template value_t sample(GeneratorType& gen) const { - value_t min, max; - min = static_cast(min_); - max = static_cast(max_); - - std::uniform_real_distribution dist(min, max); + std::uniform_real_distribution dist(min(), max()); return dist(gen); } dist_value_t pdf(value_t x) const { - value_t min, max; - min = static_cast(min_); - max = static_cast(max_); - - return (min < x && x < max) ? 1. / (max - min) : 0; + return (min() < x && x < max()) ? 1. / (max() - min()) : 0; } dist_value_t log_pdf(value_t x) const { - value_t min, max; - min = static_cast(min_); - max = static_cast(max_); - - return (min < x && x < max) ? - -std::log(max - min) : + return (min() < x && x < max()) ? + -std::log(max() - min()) : std::numeric_limits::lowest(); } -private: + inline value_t min() const { return static_cast(min_); } + inline value_t max() const { return static_cast(max_); } + + private: min_type min_; max_type max_; }; diff --git a/test/distribution/normal_unittest.cpp b/test/distribution/normal_unittest.cpp index 7beb42b3..939bac7d 100644 --- a/test/distribution/normal_unittest.cpp +++ b/test/distribution/normal_unittest.cpp @@ -16,6 +16,14 @@ struct normal_dist_fixture : ::testing::Test { Normal, Variable > dist2 = Normal(mu, sigma); }; +TEST_F(normal_dist_fixture, sanity_normal_test) { + EXPECT_EQ(dist1.mean(), 0.0); + EXPECT_EQ(dist1.var(), 1.0); + + EXPECT_EQ(dist2.mean(), 0.0); + EXPECT_EQ(dist2.var(), 1.0); +} + TEST_F(normal_dist_fixture, simple_gaussian) { EXPECT_DOUBLE_EQ(dist1.pdf(0.0), 0.3989422804014327); EXPECT_DOUBLE_EQ(dist1.pdf(-0.5), 0.3520653267642995); diff --git a/test/distribution/uniform_unittest.cpp b/test/distribution/uniform_unittest.cpp index e7366b65..3007a3c1 100644 --- a/test/distribution/uniform_unittest.cpp +++ b/test/distribution/uniform_unittest.cpp @@ -17,13 +17,24 @@ struct uniform_dist_fixture : ::testing::Test { Uniform, Variable > dist3 = Uniform(y, x); }; +TEST_F(uniform_dist_fixture, sanity_uniform_test) { + EXPECT_EQ(dist1.min(), 0.0); + EXPECT_EQ(dist1.max(), 1.0); + + EXPECT_EQ(dist2.min(), 0.0); + EXPECT_EQ(dist2.max(), 0.5); + + EXPECT_EQ(dist3.min(), 0.1); + EXPECT_EQ(dist3.max(), 0.5); +} + TEST_F(uniform_dist_fixture, simple_uniform) { EXPECT_DOUBLE_EQ(dist1.pdf(1.1), 0.0); + EXPECT_DOUBLE_EQ(dist1.pdf(1.0), 0.0); - EXPECT_DOUBLE_EQ(dist2.pdf(1.0), 0.0); EXPECT_DOUBLE_EQ(dist2.pdf(0.25), 2.0); + EXPECT_DOUBLE_EQ(dist2.pdf(-0.1), 0.0); - EXPECT_DOUBLE_EQ(dist3.pdf(-0.1), 0.0); EXPECT_DOUBLE_EQ(dist3.pdf(0.25), 2.5); } From 1ee259dad93e6084ce0c895c35290516c09e3029 Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 14:56:54 -0400 Subject: [PATCH 6/7] added bernoulli distribution --- include/autoppl/distribution/bernoulli.hpp | 48 ++++++++++++++++++++++ include/autoppl/expression/model.hpp | 4 +- test/CMakeLists.txt | 1 + test/distribution/bernoulli_unittest.cpp | 42 +++++++++++++++++++ 4 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 include/autoppl/distribution/bernoulli.hpp create mode 100644 test/distribution/bernoulli_unittest.cpp diff --git a/include/autoppl/distribution/bernoulli.hpp b/include/autoppl/distribution/bernoulli.hpp new file mode 100644 index 00000000..f1bc7bd1 --- /dev/null +++ b/include/autoppl/distribution/bernoulli.hpp @@ -0,0 +1,48 @@ +#pragma once +#include +#include +#include +#include + +namespace ppl { + +// TODO: change name to BernoulliDist and make class template. +// bernoulli should be a function that creates this kind of object. + +template +struct Bernoulli +{ + using value_t = int; + using dist_value_t = double; + + Bernoulli(p_type p) + : p_{p} { assert((this -> p() >= 0) && (this -> p() <= 1)); } + + template + value_t sample(GeneratorType& gen) const + { + std::bernoulli_distribution dist(p()); + return dist(gen); + } + + dist_value_t pdf(value_t x) const + { + if (x == 1) return p(); + else if (x == 0) return 1. - p(); + else return 0.0; + } + + dist_value_t log_pdf(value_t x) const + { + if (x == 1) return std::log(p()); + else if (x == 0) return std::log(1. - p()); + else return std::numeric_limits::lowest(); + } + + inline dist_value_t p() const { return static_cast(p_); } + + private: + p_type p_; +}; + +} // namespace ppl diff --git a/include/autoppl/expression/model.hpp b/include/autoppl/expression/model.hpp index a9566e8f..776f9f22 100644 --- a/include/autoppl/expression/model.hpp +++ b/include/autoppl/expression/model.hpp @@ -49,9 +49,7 @@ struct EqNode { return dist_.log_pdf(orig_var_cref_.get().get_value()); } private: - using var_cref_t = std::reference_wrapper; - using opt_var_cref_t = std::optional; - + using var_cref_t = std::reference_wrapper; var_cref_t orig_var_cref_; // (const) reference of the original var since // any configuration may be changed until right before update dist_t dist_; // distribution associated with var diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7f93364b..670e3d9a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -27,6 +27,7 @@ add_test(expression_unittest expression_unittest) add_executable(distribution_unittest ${CMAKE_CURRENT_SOURCE_DIR}/distribution/uniform_unittest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distribution/normal_unittest.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distribution/bernoulli_unittest.cpp ) target_compile_options(distribution_unittest PRIVATE -g -Wall -Werror -Wextra) target_include_directories(distribution_unittest PRIVATE ${GTEST_DIR}/include) diff --git a/test/distribution/bernoulli_unittest.cpp b/test/distribution/bernoulli_unittest.cpp new file mode 100644 index 00000000..d8ea2978 --- /dev/null +++ b/test/distribution/bernoulli_unittest.cpp @@ -0,0 +1,42 @@ +#include +#include + +#include +#include + +#include "gtest/gtest.h" + +namespace ppl { + +struct bernoulli_dist_fixture : ::testing::Test { +protected: + Variable x {0.6}; + + Bernoulli dist1 = Bernoulli(0.6); + Bernoulli > dist2 = Bernoulli(x); +}; + +TEST_F(bernoulli_dist_fixture, sanity_bernoulli_test) { + EXPECT_EQ(dist1.p(), 0.6); + EXPECT_EQ(dist2.p(), 0.6); +} + +TEST_F(bernoulli_dist_fixture, simple_bernoulli) { + EXPECT_DOUBLE_EQ(dist1.pdf(1), dist1.p()); + EXPECT_DOUBLE_EQ(dist1.pdf(1), 0.6); + EXPECT_DOUBLE_EQ(dist1.pdf(0), 1 - dist1.p()); + EXPECT_DOUBLE_EQ(dist1.pdf(0), 0.4); + EXPECT_DOUBLE_EQ(dist1.pdf(2), 0.0); +} + +TEST_F(bernoulli_dist_fixture, bernoulli_sampling) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + + for (int i = 0; i < 100; i++) { + int sample = dist1.sample(gen); + EXPECT_TRUE(sample == 0 || sample == 1); + } +} + +} // ppl \ No newline at end of file From 67f9edfc276aab02ed930366b57c6f561448faca Mon Sep 17 00:00:00 2001 From: Jacob Austin Date: Sat, 18 Apr 2020 14:58:23 -0400 Subject: [PATCH 7/7] addressed review comments --- include/autoppl/expression/traits.hpp | 1 + test/distribution/normal_unittest.cpp | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/include/autoppl/expression/traits.hpp b/include/autoppl/expression/traits.hpp index bb4c83d2..9d39d3d9 100644 --- a/include/autoppl/expression/traits.hpp +++ b/include/autoppl/expression/traits.hpp @@ -16,6 +16,7 @@ struct var_traits using pointer_t = typename VarType::pointer_t; using state_t = typename VarType::state_t; + // TODO may have to move this to a different class for compile-time checking static_assert(std::is_convertible_v); }; diff --git a/test/distribution/normal_unittest.cpp b/test/distribution/normal_unittest.cpp index 939bac7d..9b229a8c 100644 --- a/test/distribution/normal_unittest.cpp +++ b/test/distribution/normal_unittest.cpp @@ -32,6 +32,15 @@ TEST_F(normal_dist_fixture, simple_gaussian) { EXPECT_DOUBLE_EQ(dist1.log_pdf(0.0), std::log(dist1.pdf(0.0))); EXPECT_DOUBLE_EQ(dist1.log_pdf(-0.5), std::log(dist1.pdf(-0.5))); EXPECT_DOUBLE_EQ(dist1.log_pdf(4), std::log(dist1.pdf(4))); + + + EXPECT_DOUBLE_EQ(dist2.pdf(0.0), 0.3989422804014327); + EXPECT_DOUBLE_EQ(dist2.pdf(-0.5), 0.3520653267642995); + EXPECT_DOUBLE_EQ(dist2.pdf(4), 0.00013383022576488537); + + EXPECT_DOUBLE_EQ(dist2.log_pdf(0.0), std::log(dist2.pdf(0.0))); + EXPECT_DOUBLE_EQ(dist2.log_pdf(-0.5), std::log(dist2.pdf(-0.5))); + EXPECT_DOUBLE_EQ(dist2.log_pdf(4), std::log(dist1.pdf(4))); } } // ppl \ No newline at end of file