Skip to content

Commit

Permalink
added normal distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobaustin123 committed Apr 18, 2020
1 parent a9f0a04 commit a371f33
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 6 deletions.
58 changes: 58 additions & 0 deletions include/autoppl/distribution/normal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma once
#include <math.h>

#include <cassert>
#include <cmath>
#include <numeric>
#include <random>

namespace ppl {

// TODO: change name to NormalDist and make class template.
// normal should be a function that creates this kind of object.

template <typename mean_type, typename var_type>
struct Normal {
using value_t = double;
using dist_value_t = double;

static_assert(std::is_convertible_v<mean_type, value_t>);
static_assert(std::is_convertible_v<var_type, value_t>);

Normal(mean_type mean, var_type var)
: mean_{mean}, var_{var} {
assert(static_cast<value_t>(var_) > 0);
};

template <class GeneratorType>
value_t sample(GeneratorType& gen) const {
value_t mean, var;
mean = static_cast<value_t>(mean_);
var = static_cast<value_t>(var_);

std::normal_distribution<value_t> dist(mean, var);
return dist(gen);
}

dist_value_t pdf(value_t x) const {
value_t mean, var;
mean = static_cast<value_t>(mean_);
var = static_cast<value_t>(var_);

return std::exp(- 0.5 * std::pow(x - mean, 2) / var) / (std::sqrt(var * 2 * M_PI));
}

dist_value_t log_pdf(value_t x) const {
value_t mean, var;
mean = static_cast<value_t>(mean_);
var = static_cast<value_t>(var_);

return (-0.5 * std::pow(x - mean, 2) / var) - 0.5 * (std::log(var) + std::log(2) + std::log(M_PI));
}

private:
mean_type mean_;
var_type var_;
};

} // namespace ppl
3 changes: 2 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ add_test(expression_unittest expression_unittest)

add_executable(distribution_unittest
${CMAKE_CURRENT_SOURCE_DIR}/distribution/uniform_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distribution/normal_unittest.cpp
)
target_compile_options(distribution_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(distribution_unittest PRIVATE ${GTEST_DIR}/include)
if (AUTOPPL_ENABLE_TEST_COVERAGE)
target_link_libraries(distribution_unittest gcov)
endif()
target_link_libraries(distribution_unittest gtest_main pthread ${PROJECT_NAME})
add_test(distribution_unittest distribution_unittest)
add_test(distribution_unittest distribution_unittest)
29 changes: 29 additions & 0 deletions test/distribution/normal_unittest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include <autoppl/distribution/normal.hpp>
#include <autoppl/expression/variable.hpp>

#include <cmath>
#include <array>

#include "gtest/gtest.h"

namespace ppl {

struct normal_dist_fixture : ::testing::Test {
protected:
Variable<double> mu {0.};
Variable<double> sigma {1.};
Normal<double, double> dist1 = Normal(0., 1.);
Normal<Variable<double>, Variable<double> > dist2 = Normal(mu, sigma);
};

TEST_F(normal_dist_fixture, simple_gaussian) {
EXPECT_DOUBLE_EQ(dist1.pdf(0.0), 0.3989422804014327);
EXPECT_DOUBLE_EQ(dist1.pdf(-0.5), 0.3520653267642995);
EXPECT_DOUBLE_EQ(dist1.pdf(4), 0.00013383022576488537);

EXPECT_DOUBLE_EQ(dist1.log_pdf(0.0), std::log(dist1.pdf(0.0)));
EXPECT_DOUBLE_EQ(dist1.log_pdf(-0.5), std::log(dist1.pdf(-0.5)));
EXPECT_DOUBLE_EQ(dist1.log_pdf(4), std::log(dist1.pdf(4)));
}

} // ppl
10 changes: 5 additions & 5 deletions test/distribution/uniform_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ struct uniform_dist_fixture : ::testing::Test {
};

TEST_F(uniform_dist_fixture, simple_uniform) {
EXPECT_EQ(dist1.pdf(1.1), 0.0);
EXPECT_DOUBLE_EQ(dist1.pdf(1.1), 0.0);

EXPECT_EQ(dist2.pdf(1.0), 0.0);
EXPECT_EQ(dist2.pdf(0.25), 2.0);
EXPECT_DOUBLE_EQ(dist2.pdf(1.0), 0.0);
EXPECT_DOUBLE_EQ(dist2.pdf(0.25), 2.0);

EXPECT_EQ(dist3.pdf(-0.1), 0.0);
EXPECT_EQ(dist3.pdf(0.25), 2.5);
EXPECT_DOUBLE_EQ(dist3.pdf(-0.1), 0.0);
EXPECT_DOUBLE_EQ(dist3.pdf(0.25), 2.5);
}

} // ppl

0 comments on commit a371f33

Please sign in to comment.