Skip to content

Commit

Permalink
Merge 04bd73f into d4ce9ae
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobaustin123 committed Apr 18, 2020
2 parents d4ce9ae + 04bd73f commit 20d62ce
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 86 deletions.
58 changes: 58 additions & 0 deletions include/autoppl/distribution/normal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once
#include <math.h>

#include <cassert>
#include <cmath>
#include <numeric>
#include <random>

namespace ppl {

// TODO: change name to NormalDist and make class template.
// normal should be a function that creates this kind of object.

template <typename mean_type, typename var_type>
struct Normal {
using value_t = double;
using dist_value_t = double;

static_assert(std::is_convertible_v<mean_type, value_t>);
static_assert(std::is_convertible_v<var_type, value_t>);

Normal(mean_type mean, var_type var)
: mean_{mean}, var_{var} {
assert(static_cast<value_t>(var_) > 0);
};

template <class GeneratorType>
value_t sample(GeneratorType& gen) const {
value_t mean, var;
mean = static_cast<value_t>(mean_);
var = static_cast<value_t>(var_);

std::normal_distribution<value_t> dist(mean, var);
return dist(gen);
}

dist_value_t pdf(value_t x) const {
value_t mean, var;
mean = static_cast<value_t>(mean_);
var = static_cast<value_t>(var_);

return std::exp(- 0.5 * std::pow(x - mean, 2) / var) / (std::sqrt(var * 2 * M_PI));
}

dist_value_t log_pdf(value_t x) const {
value_t mean, var;
mean = static_cast<value_t>(mean_);
var = static_cast<value_t>(var_);

return (-0.5 * std::pow(x - mean, 2) / var) - 0.5 * (std::log(var) + std::log(2) + std::log(M_PI));
}

private:
mean_type mean_;
var_type var_;
};

} // namespace ppl
File renamed without changes.
36 changes: 18 additions & 18 deletions include/autoppl/expression/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@ namespace ppl {
namespace details {

template <class Iter>
struct IdentityTagFunctor
struct IdentityVarFunctor
{
using value_t = typename std::iterator_traits<Iter>::value_type;
value_t& operator()(value_t& tag)
{ return tag; }
value_t& operator()(value_t& var)
{ return var; }
};

} // namespace details

/*
* This class represents a "node" in the model expression
* that relates a tag with a distribution.
* that relates a var with a distribution.
*/
template <class TagType, class DistType>
template <class VarType, class DistType>
struct EqNode
{
using tag_t = TagType;
using var_t = VarType;
using dist_t = DistType;
using dist_value_t = typename dist_traits<dist_t>::dist_value_t;

EqNode(const tag_t& tag,
EqNode(const var_t& var,
const dist_t& dist) noexcept
: orig_tag_cref_{tag}
: orig_var_cref_{var}
, dist_{dist}
{}

Expand All @@ -39,22 +39,22 @@ struct EqNode
* Assumes that underlying value has been assigned properly.
*/
dist_value_t pdf() const
{ return dist_.pdf(orig_tag_cref_.get().get_value()); }
{ return dist_.pdf(orig_var_cref_.get().get_value()); }

/*
* Compute log-pdf of underlying distribution with underlying value.
* Assumes that underlying value has been assigned properly.
*/
dist_value_t log_pdf() const
{ return dist_.log_pdf(orig_tag_cref_.get().get_value()); }
{ return dist_.log_pdf(orig_var_cref_.get().get_value()); }

private:
using tag_cref_t = std::reference_wrapper<const tag_t>;
using opt_tag_cref_t = std::optional<tag_cref_t>;
using var_cref_t = std::reference_wrapper<const var_t>;
using opt_var_cref_t = std::optional<var_cref_t>;

tag_cref_t orig_tag_cref_; // (const) reference of the original tag since
var_cref_t orig_var_cref_; // (const) reference of the original var since
// any configuration may be changed until right before update
dist_t dist_; // distribution associated with tag
dist_t dist_; // distribution associated with var
};

/*
Expand Down Expand Up @@ -104,14 +104,14 @@ struct GlueNode
// with concepts!

/*
* Builds an EqNode to associate tag with dist.
* Builds an EqNode to associate var with dist.
* Ex. x |= uniform(0,1)
*/
template <class TagType, class DistType>
constexpr inline auto operator|=(const TagType& tag,
template <class VarType, class DistType>
constexpr inline auto operator|=(const VarType& var,
const DistType& dist)
{
return EqNode<TagType, DistType>(tag, dist);
return EqNode<VarType, DistType>(var, dist);
}

/*
Expand Down
13 changes: 8 additions & 5 deletions include/autoppl/expression/traits.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <type_traits>

namespace ppl {

Expand All @@ -8,12 +9,14 @@ namespace ppl {
* Users should rely on these classes to grab member aliases.
*/

template <class TagType>
struct tag_traits
template <class VarType>
struct var_traits
{
using value_t = typename TagType::value_t;
using pointer_t = typename TagType::pointer_t;
using state_t = typename TagType::state_t;
using value_t = typename VarType::value_t;
using pointer_t = typename VarType::pointer_t;
using state_t = typename VarType::state_t;

static_assert(std::is_convertible_v<VarType, value_t>);
};

template <class DistType>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,45 @@
namespace ppl {

/*
* The possible states for a tag.
* By default, all tags should be considered as a parameter.
* The possible states for a var.
* By default, all vars should be considered as a parameter.
* TODO: maybe move in a different file?
*/
enum class tag_state : bool {
enum class var_state : bool {
data,
parameter
};

/*
* rv_tag is a light-weight structure that represents a univariate random variable.
* Variable is a light-weight structure that represents a univariate random variable.
* It acts as an intermediate layer of communication between
* a model expression and the users, who must supply storage of values associated with this tag.
* a model expression and the users, who must supply storage of values associated with this var.
*/
template <class ValueType>
struct rv_tag
struct Variable
{
using value_t = ValueType;
using pointer_t = value_t*;
using const_pointer_t = const value_t*;
using state_t = tag_state;
using state_t = var_state;

// constructors
rv_tag(value_t value,
Variable(value_t value,
pointer_t storage_ptr) noexcept
: value_{value}
, storage_ptr_{storage_ptr}
, state_{state_t::parameter}
{}

rv_tag(pointer_t storage_ptr) noexcept
: rv_tag(0, storage_ptr)
Variable(pointer_t storage_ptr) noexcept
: Variable(0, storage_ptr)
{}

rv_tag(value_t value) noexcept
: rv_tag(value, nullptr) {}
Variable(value_t value) noexcept
: Variable(value, nullptr) {}

rv_tag() noexcept
: rv_tag(0, nullptr)
Variable() noexcept
: Variable(0, nullptr)
{}

void set_value(value_t value) { value_ = value; }
Expand All @@ -58,7 +58,7 @@ struct rv_tag

/*
* Sets underlying value to "value".
* Additionally modifies the tag to be considered as data.
* Additionally modifies the var to be considered as data.
* Equivalent to calling set_value(value) then set_state(state).
*/
void observe(value_t value)
Expand All @@ -68,14 +68,14 @@ struct rv_tag
}

private:
value_t value_; // store value associated with tag
value_t value_; // store value associated with var
pointer_t storage_ptr_; // points to beginning of storage
// storage is assumed to be contiguous
state_t state_; // state to determine if data or param
};

// Useful aliases
using cont_rv_tag = rv_tag<double>; // continuous RV tag
using disc_rv_tag = rv_tag<int>; // discrete RV tag
using cont_var = Variable<double>; // continuous RV var
using disc_var = Variable<int>; // discrete RV var

} // namespace ppl
17 changes: 16 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ endif()

add_executable(expression_unittest
${CMAKE_CURRENT_SOURCE_DIR}/expression/model_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/expression/uniform_unittest.cpp
)
target_compile_options(expression_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(expression_unittest PRIVATE ${GTEST_DIR}/include)
Expand All @@ -20,3 +19,19 @@ if (AUTOPPL_ENABLE_TEST_COVERAGE)
endif()
target_link_libraries(expression_unittest gtest_main pthread ${PROJECT_NAME})
add_test(expression_unittest expression_unittest)

######################################################
# Distribution Test
######################################################

add_executable(distribution_unittest
${CMAKE_CURRENT_SOURCE_DIR}/distribution/uniform_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distribution/normal_unittest.cpp
)
target_compile_options(distribution_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(distribution_unittest PRIVATE ${GTEST_DIR}/include)
if (AUTOPPL_ENABLE_TEST_COVERAGE)
target_link_libraries(distribution_unittest gcov)
endif()
target_link_libraries(distribution_unittest gtest_main pthread ${PROJECT_NAME})
add_test(distribution_unittest distribution_unittest)
29 changes: 29 additions & 0 deletions test/distribution/normal_unittest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <autoppl/distribution/normal.hpp>
#include <autoppl/expression/variable.hpp>

#include <cmath>
#include <array>

#include "gtest/gtest.h"

namespace ppl {

struct normal_dist_fixture : ::testing::Test {
protected:
Variable<double> mu {0.};
Variable<double> sigma {1.};
Normal<double, double> dist1 = Normal(0., 1.);
Normal<Variable<double>, Variable<double> > dist2 = Normal(mu, sigma);
};

TEST_F(normal_dist_fixture, simple_gaussian) {
EXPECT_DOUBLE_EQ(dist1.pdf(0.0), 0.3989422804014327);
EXPECT_DOUBLE_EQ(dist1.pdf(-0.5), 0.3520653267642995);
EXPECT_DOUBLE_EQ(dist1.pdf(4), 0.00013383022576488537);

EXPECT_DOUBLE_EQ(dist1.log_pdf(0.0), std::log(dist1.pdf(0.0)));
EXPECT_DOUBLE_EQ(dist1.log_pdf(-0.5), std::log(dist1.pdf(-0.5)));
EXPECT_DOUBLE_EQ(dist1.log_pdf(4), std::log(dist1.pdf(4)));
}

} // ppl
41 changes: 41 additions & 0 deletions test/distribution/uniform_unittest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include <autoppl/distribution/uniform.hpp>
#include <autoppl/expression/variable.hpp>

#include <cmath>
#include <array>

#include "gtest/gtest.h"

namespace ppl {

struct uniform_dist_fixture : ::testing::Test {
protected:
Variable<double> x {0.5};
Variable<double> y {0.1};
Uniform<double, double> dist1 = Uniform(0., 1.);
Uniform<double, Variable<double> > dist2 = Uniform(0., x);
Uniform<Variable<double>, Variable<double> > dist3 = Uniform(y, x);
};

TEST_F(uniform_dist_fixture, simple_uniform) {
EXPECT_DOUBLE_EQ(dist1.pdf(1.1), 0.0);

EXPECT_DOUBLE_EQ(dist2.pdf(1.0), 0.0);
EXPECT_DOUBLE_EQ(dist2.pdf(0.25), 2.0);

EXPECT_DOUBLE_EQ(dist3.pdf(-0.1), 0.0);
EXPECT_DOUBLE_EQ(dist3.pdf(0.25), 2.5);
}

TEST_F(uniform_dist_fixture, uniform_sampling) {
std::random_device rd{};
std::mt19937 gen{rd()};

for (int i = 0; i < 100; i++) {
double sample = dist1.sample(gen);
EXPECT_GT(sample, 0.0);
EXPECT_LT(sample, 1.0);
}
}

} // ppl
Loading

0 comments on commit 20d62ce

Please sign in to comment.