Skip to content

Commit

Permalink
Merge pull request #4 from JamesYang007/rename_tag
Browse files Browse the repository at this point in the history
Renamed rv_tag to Variable, added Normal and Bernoulli distributions and updated tests.
  • Loading branch information
jacobaustin123 committed Apr 18, 2020
2 parents d4ce9ae + 67f9edf commit ca008a4
Show file tree
Hide file tree
Showing 12 changed files with 322 additions and 105 deletions.
48 changes: 48 additions & 0 deletions include/autoppl/distribution/bernoulli.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once
#include <cassert>
#include <random>
#include <cmath>
#include <numeric>

namespace ppl {

// TODO: change name to BernoulliDist and make class template.
// bernoulli should be a function that creates this kind of object.

template <typename p_type>
struct Bernoulli
{
using value_t = int;
using dist_value_t = double;

Bernoulli(p_type p)
: p_{p} { assert((this -> p() >= 0) && (this -> p() <= 1)); }

template <class GeneratorType>
value_t sample(GeneratorType& gen) const
{
std::bernoulli_distribution dist(p());
return dist(gen);
}

dist_value_t pdf(value_t x) const
{
if (x == 1) return p();
else if (x == 0) return 1. - p();
else return 0.0;
}

dist_value_t log_pdf(value_t x) const
{
if (x == 1) return std::log(p());
else if (x == 0) return std::log(1. - p());
else return std::numeric_limits<dist_value_t>::lowest();
}

inline dist_value_t p() const { return static_cast<dist_value_t>(p_); }

private:
p_type p_;
};

} // namespace ppl
49 changes: 49 additions & 0 deletions include/autoppl/distribution/normal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once
#include <math.h>

#include <cassert>
#include <cmath>
#include <numeric>
#include <random>

namespace ppl {

// TODO: change name to NormalDist and make class template.
// normal should be a function that creates this kind of object.

template <typename mean_type, typename var_type>
struct Normal {
using value_t = double;
using dist_value_t = double;

static_assert(std::is_convertible_v<mean_type, value_t>);
static_assert(std::is_convertible_v<var_type, value_t>);

Normal(mean_type mean, var_type var)
: mean_{mean}, var_{var} {
assert(this -> var() > 0);
};

template <class GeneratorType>
value_t sample(GeneratorType& gen) const {
std::normal_distribution<value_t> dist(mean(), var());
return dist(gen);
}

dist_value_t pdf(value_t x) const {
return std::exp(- 0.5 * std::pow(x - mean(), 2) / var()) / (std::sqrt(var() * 2 * M_PI));
}

dist_value_t log_pdf(value_t x) const {
return (-0.5 * std::pow(x - mean(), 2) / var()) - 0.5 * (std::log(var()) + std::log(2) + std::log(M_PI));
}

inline value_t mean() const { return static_cast<value_t>(mean_);}
inline value_t var() const { return static_cast<value_t>(var_);}

private:
mean_type mean_;
var_type var_;
};

} // namespace ppl
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,32 @@ struct Uniform
using dist_value_t = double;

Uniform(min_type min, max_type max)
: min_{min}, max_{max} { assert(static_cast<value_t>(min_) < static_cast<value_t>(max_)); }
: min_{min}, max_{max} { assert(this -> min() < this -> max()); }

// TODO: tag this class as "TriviallySamplable"?
template <class GeneratorType>
value_t sample(GeneratorType& gen) const
{
value_t min, max;
min = static_cast<value_t>(min_);
max = static_cast<value_t>(max_);

std::uniform_real_distribution<value_t> dist(min, max);
std::uniform_real_distribution<value_t> dist(min(), max());
return dist(gen);
}

dist_value_t pdf(value_t x) const
{
value_t min, max;
min = static_cast<value_t>(min_);
max = static_cast<value_t>(max_);

return (min < x && x < max) ? 1. / (max - min) : 0;
return (min() < x && x < max()) ? 1. / (max() - min()) : 0;
}

dist_value_t log_pdf(value_t x) const
{
value_t min, max;
min = static_cast<value_t>(min_);
max = static_cast<value_t>(max_);

return (min < x && x < max) ?
-std::log(max - min) :
return (min() < x && x < max()) ?
-std::log(max() - min()) :
std::numeric_limits<dist_value_t>::lowest();
}

private:
inline value_t min() const { return static_cast<value_t>(min_); }
inline value_t max() const { return static_cast<value_t>(max_); }

private:
min_type min_;
max_type max_;
};
Expand Down
36 changes: 17 additions & 19 deletions include/autoppl/expression/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@ namespace ppl {
namespace details {

template <class Iter>
struct IdentityTagFunctor
struct IdentityVarFunctor
{
using value_t = typename std::iterator_traits<Iter>::value_type;
value_t& operator()(value_t& tag)
{ return tag; }
value_t& operator()(value_t& var)
{ return var; }
};

} // namespace details

/*
* This class represents a "node" in the model expression
* that relates a tag with a distribution.
* that relates a var with a distribution.
*/
template <class TagType, class DistType>
template <class VarType, class DistType>
struct EqNode
{
using tag_t = TagType;
using var_t = VarType;
using dist_t = DistType;
using dist_value_t = typename dist_traits<dist_t>::dist_value_t;

EqNode(const tag_t& tag,
EqNode(const var_t& var,
const dist_t& dist) noexcept
: orig_tag_cref_{tag}
: orig_var_cref_{var}
, dist_{dist}
{}

Expand All @@ -39,22 +39,20 @@ struct EqNode
* Assumes that underlying value has been assigned properly.
*/
dist_value_t pdf() const
{ return dist_.pdf(orig_tag_cref_.get().get_value()); }
{ return dist_.pdf(orig_var_cref_.get().get_value()); }

/*
* Compute log-pdf of underlying distribution with underlying value.
* Assumes that underlying value has been assigned properly.
*/
dist_value_t log_pdf() const
{ return dist_.log_pdf(orig_tag_cref_.get().get_value()); }
{ return dist_.log_pdf(orig_var_cref_.get().get_value()); }

private:
using tag_cref_t = std::reference_wrapper<const tag_t>;
using opt_tag_cref_t = std::optional<tag_cref_t>;

tag_cref_t orig_tag_cref_; // (const) reference of the original tag since
using var_cref_t = std::reference_wrapper<const var_t>;
var_cref_t orig_var_cref_; // (const) reference of the original var since
// any configuration may be changed until right before update
dist_t dist_; // distribution associated with tag
dist_t dist_; // distribution associated with var
};

/*
Expand Down Expand Up @@ -104,14 +102,14 @@ struct GlueNode
// with concepts!

/*
* Builds an EqNode to associate tag with dist.
* Builds an EqNode to associate var with dist.
* Ex. x |= uniform(0,1)
*/
template <class TagType, class DistType>
constexpr inline auto operator|=(const TagType& tag,
template <class VarType, class DistType>
constexpr inline auto operator|=(const VarType& var,
const DistType& dist)
{
return EqNode<TagType, DistType>(tag, dist);
return EqNode<VarType, DistType>(var, dist);
}

/*
Expand Down
14 changes: 9 additions & 5 deletions include/autoppl/expression/traits.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once
#include <type_traits>

namespace ppl {

Expand All @@ -8,12 +9,15 @@ namespace ppl {
* Users should rely on these classes to grab member aliases.
*/

template <class TagType>
struct tag_traits
template <class VarType>
struct var_traits
{
using value_t = typename TagType::value_t;
using pointer_t = typename TagType::pointer_t;
using state_t = typename TagType::state_t;
using value_t = typename VarType::value_t;
using pointer_t = typename VarType::pointer_t;
using state_t = typename VarType::state_t;

// TODO may have to move this to a different class for compile-time checking
static_assert(std::is_convertible_v<VarType, value_t>);
};

template <class DistType>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,45 @@
namespace ppl {

/*
* The possible states for a tag.
* By default, all tags should be considered as a parameter.
* The possible states for a var.
* By default, all vars should be considered as a parameter.
* TODO: maybe move in a different file?
*/
enum class tag_state : bool {
enum class var_state : bool {
data,
parameter
};

/*
* rv_tag is a light-weight structure that represents a univariate random variable.
* Variable is a light-weight structure that represents a univariate random variable.
* It acts as an intermediate layer of communication between
* a model expression and the users, who must supply storage of values associated with this tag.
* a model expression and the users, who must supply storage of values associated with this var.
*/
template <class ValueType>
struct rv_tag
struct Variable
{
using value_t = ValueType;
using pointer_t = value_t*;
using const_pointer_t = const value_t*;
using state_t = tag_state;
using state_t = var_state;

// constructors
rv_tag(value_t value,
Variable(value_t value,
pointer_t storage_ptr) noexcept
: value_{value}
, storage_ptr_{storage_ptr}
, state_{state_t::parameter}
{}

rv_tag(pointer_t storage_ptr) noexcept
: rv_tag(0, storage_ptr)
Variable(pointer_t storage_ptr) noexcept
: Variable(0, storage_ptr)
{}

rv_tag(value_t value) noexcept
: rv_tag(value, nullptr) {}
Variable(value_t value) noexcept
: Variable(value, nullptr) {}

rv_tag() noexcept
: rv_tag(0, nullptr)
Variable() noexcept
: Variable(0, nullptr)
{}

void set_value(value_t value) { value_ = value; }
Expand All @@ -58,7 +58,7 @@ struct rv_tag

/*
* Sets underlying value to "value".
* Additionally modifies the tag to be considered as data.
* Additionally modifies the var to be considered as data.
* Equivalent to calling set_value(value) then set_state(state).
*/
void observe(value_t value)
Expand All @@ -68,14 +68,14 @@ struct rv_tag
}

private:
value_t value_; // store value associated with tag
value_t value_; // store value associated with var
pointer_t storage_ptr_; // points to beginning of storage
// storage is assumed to be contiguous
state_t state_; // state to determine if data or param
};

// Useful aliases
using cont_rv_tag = rv_tag<double>; // continuous RV tag
using disc_rv_tag = rv_tag<int>; // discrete RV tag
using cont_var = Variable<double>; // continuous RV var
using disc_var = Variable<int>; // discrete RV var

} // namespace ppl
18 changes: 17 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ endif()

add_executable(expression_unittest
${CMAKE_CURRENT_SOURCE_DIR}/expression/model_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/expression/uniform_unittest.cpp
)
target_compile_options(expression_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(expression_unittest PRIVATE ${GTEST_DIR}/include)
Expand All @@ -20,3 +19,20 @@ if (AUTOPPL_ENABLE_TEST_COVERAGE)
endif()
target_link_libraries(expression_unittest gtest_main pthread ${PROJECT_NAME})
add_test(expression_unittest expression_unittest)

######################################################
# Distribution Test
######################################################

add_executable(distribution_unittest
${CMAKE_CURRENT_SOURCE_DIR}/distribution/uniform_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distribution/normal_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distribution/bernoulli_unittest.cpp
)
target_compile_options(distribution_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(distribution_unittest PRIVATE ${GTEST_DIR}/include)
if (AUTOPPL_ENABLE_TEST_COVERAGE)
target_link_libraries(distribution_unittest gcov)
endif()
target_link_libraries(distribution_unittest gtest_main pthread ${PROJECT_NAME})
add_test(distribution_unittest distribution_unittest)
Loading

0 comments on commit ca008a4

Please sign in to comment.