Skip to content

Commit

Permalink
Merge 4771a1b into bad3f55
Browse files Browse the repository at this point in the history
  • Loading branch information
jenchen1398 authored Apr 21, 2020
2 parents bad3f55 + 4771a1b commit aef7469
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
48 changes: 48 additions & 0 deletions include/autoppl/distribution/discrete.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#pragma once
#include <cassert>
#include <random>
#include <cmath>
#include <numeric>

namespace ppl {

// TODO: change name to DiscreteDist and make class template.
// Discrete should be a function that creates this kind of object.

template <typename weight_type>
struct Discrete
{
using value_t = int;
using dist_value_t = double;

Discrete(std::initializer_list<weight_type> weights)
: weights_{weights} { assert(weights.size() > 0); }

template <class GeneratorType>
value_t sample(GeneratorType& gen) const
{
std::discrete_distribution dist(weights_.begin(), weights_.end());
return dist(gen);
}

dist_value_t pdf(value_t i) const
{
assert( i >= 0 && i < (int) weights_.size() );
return weights(i) / std::accumulate(weights_.begin(), weights_.end(), 0.0 );

}

dist_value_t log_pdf(value_t i) const
{
assert( i >= 0 && i < (int) weights_.size() );
return log(weights(i) / std::accumulate(weights_.begin(), weights_.end(), 0.0 ));
}

inline dist_value_t weights(value_t i) const { return static_cast<dist_value_t>(weights_[i]); }

private:
std::vector<weight_type> weights_;
};

} // namespace ppl

1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ add_executable(dist_expr_unittest
${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/density_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/normal_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/expression/distribution/uniform_unittest.cpp
${CMAKE_CURRENT_SOURCE_DIR}/distribution/discrete_unittest.cpp
)
target_compile_options(dist_expr_unittest PRIVATE -g -Wall -Werror -Wextra)
target_include_directories(dist_expr_unittest PRIVATE
Expand Down
44 changes: 44 additions & 0 deletions test/distribution/discrete_unittest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <autoppl/distribution/discrete.hpp>

#include <cmath>
#include <array>

#include "gtest/gtest.h"

namespace ppl {
namespace dist {

struct discrete_dist_fixture : ::testing::Test {
protected:
std::vector<double> weights {1.0, 2.0, 3.0, 4.0};
Discrete<double> dist1 = {1.0, 2.0, 3.0, 4.0};
};

TEST_F(discrete_dist_fixture, sanity_Discrete_test) {
EXPECT_DOUBLE_EQ(dist1.weights(0), 1.0);
EXPECT_DOUBLE_EQ(dist1.weights(1), 2.0);
EXPECT_DOUBLE_EQ(dist1.weights(2), 3.0);
EXPECT_DOUBLE_EQ(dist1.weights(3), 4.0);
}

TEST_F(discrete_dist_fixture, simple_Discrete) {
EXPECT_DOUBLE_EQ(dist1.pdf(0), dist1.weights(0) / 10.0);
EXPECT_DOUBLE_EQ(dist1.pdf(1), dist1.weights(1) / 10.0);
EXPECT_DOUBLE_EQ(dist1.pdf(2), dist1.weights(2) / 10.0);
EXPECT_DOUBLE_EQ(dist1.pdf(3), dist1.weights(3) / 10.0);
// std::accumulate(weights.begin(), weights.end())

}

TEST_F(discrete_dist_fixture, Discrete_sampling) {
std::random_device rd{};
std::mt19937 gen{rd()};

for (int i = 0; i < 100; i++) {
int sample = dist1.sample(gen);
EXPECT_TRUE(sample == 0 || sample == 1 || sample == 2 || sample == 3);
}
}

} // namespace dist
} // ppl

0 comments on commit aef7469

Please sign in to comment.