Skip to content

Commit

Permalink
Extend sample view to sample from random distributions (#296)
Browse files Browse the repository at this point in the history
This saves lines of code by extending the sample view implementation to
convert any random distribution (a callable that takes a URNG as an
input argument) into a range.

It also makes the implementation easier to follow by removing the use of
default arguments (which play poorly with overloads).

Related to #279.

Signed-off-by: Nahuel Espinosa <nespinosa@ekumenlabs.com>
  • Loading branch information
nahueespinosa committed Jan 27, 2024
1 parent f138eb5 commit a819c82
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 52 deletions.
143 changes: 95 additions & 48 deletions beluga/include/beluga/views/sample.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <range/v3/utility/random.hpp>
#include <range/v3/view/common.hpp>
#include <range/v3/view/generate.hpp>

#include <beluga/type_traits/particle_traits.hpp>
#include <beluga/views/particles.hpp>
Expand Down Expand Up @@ -112,69 +113,112 @@ struct sample_view : public ranges::view_facade<sample_view<Range, Distribution,
URNG* engine_;
};

/// Implementation detail for a sample range adaptor object.
struct sample_fn {
/// Overload that implements the sample algorithm for weighted ranges.
/**
* It uses std::discrete_distribution to sample from the range.
*/
template <
class Range,
class Weights,
class URNG = typename ranges::detail::default_random_engine,
std::enable_if_t<ranges::range<Range> && ranges::range<Weights>, int> = 0>
constexpr auto operator()(Range&& range, Weights&& weights, URNG& engine = ranges::detail::get_random_engine())
const {
/// \cond

template <class T, class Enable = void>
struct is_random_distribution : public std::false_type {};

template <class T>
struct is_random_distribution<T, std::void_t<decltype(std::declval<T&>()(std::declval<std::mt19937&>()))>>
: std::true_type {};

template <class T>
inline constexpr bool is_random_distribution_v = is_random_distribution<T>::value;

/// \endcond

/// Implementation detail for a sample algorithm.
struct sample_base_fn {
protected:
/// Sample from weighted ranges.
template <class Range, class Weights, class URNG>
constexpr auto sample_from_range(Range&& range, Weights&& weights, URNG& engine) const {
static_assert(ranges::sized_range<Range>);
static_assert(ranges::random_access_range<Range>);
static_assert(ranges::input_range<Weights>);
using result_type = ranges::range_difference_t<Range>;
auto w = ranges::views::common(weights);
auto distribution = std::discrete_distribution<result_type>{ranges::begin(w), ranges::end(w)};
return sample_view{ranges::views::all(std::forward<Range>(range)), std::move(distribution), engine};
}

/// Overload that implements the sample algorithm for non-weighted ranges.
/// Sample from any range.
/**
* It uses std::uniform_int_distribution to sample from the range.
* If the input range is a particle range, it will extract the weights and treat it as a weighted range.
* The new particles will all have a weight equal to 1, since, after resampling, the probability will be
* represented by the number of particles rather than their individual weight.
*
* If the input range is not a particle range, it will assume a uniform distribution.
*/
template <
class Range,
class URNG = typename ranges::detail::default_random_engine,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<!is_particle_range_v<Range>, int> = 0,
std::enable_if_t<!ranges::range<URNG>, int> = 0>
constexpr auto operator()(Range&& range, URNG& engine = ranges::detail::get_random_engine()) const {
using result_type = ranges::range_difference_t<Range>;
auto distribution =
std::uniform_int_distribution<result_type>{0, static_cast<result_type>(ranges::size(range) - 1)};
return sample_view{ranges::views::all(std::forward<Range>(range)), std::move(distribution), engine};
template <class Range, class URNG>
constexpr auto sample_from_range(Range&& range, URNG& engine) const {
static_assert(ranges::sized_range<Range>);
static_assert(ranges::random_access_range<Range>);
if constexpr (beluga::is_particle_range_v<Range>) {
return sample_from_range(beluga::views::states(range), beluga::views::weights(range), engine) |
ranges::views::transform(beluga::make_from_state<ranges::range_value_t<Range>>);
} else {
using result_type = ranges::range_difference_t<Range>;
auto distribution =
std::uniform_int_distribution<result_type>{0, static_cast<result_type>(ranges::size(range) - 1)};
return sample_view{ranges::views::all(std::forward<Range>(range)), std::move(distribution), engine};
}
}

/// Overload that handles particle ranges.
/**
* The new particles will all have a weight equal to 1, since, after resampling, the probability
* will be represented by the number of particles rather than their individual weight.
*/
template <
class Range,
class URNG = typename ranges::detail::default_random_engine,
std::enable_if_t<ranges::range<Range>, int> = 0,
std::enable_if_t<is_particle_range_v<Range>, int> = 0,
std::enable_if_t<!ranges::range<URNG>, int> = 0>
constexpr auto operator()(Range&& range, URNG& engine = ranges::detail::get_random_engine()) const {
return (*this)(beluga::views::states(range), beluga::views::weights(range), engine) |
ranges::views::transform(beluga::make_from_state<ranges::range_value_t<Range>>);
/// Sample from random distributions.
template <class Distribution, class URNG>
constexpr auto sample_from_distribution(Distribution distribution, URNG& engine) const {
return ranges::views::generate(
[distribution = std::move(distribution), &engine]() mutable { return distribution(engine); });
}
};

/// Overload that unwraps the engine reference from a view closure.
template <class Range, class URNG, typename std::enable_if_t<ranges::range<Range>, int> = 0>
constexpr auto operator()(Range&& range, std::reference_wrapper<URNG> engine) const {
return (*this)(ranges::views::all(std::forward<Range>(range)), engine.get());
/// Implementation detail for a sample range adaptor object.
struct sample_fn : public sample_base_fn {
/// Overload that takes three arguments.
template <class T, class U, class V>
constexpr auto operator()(T&& t, U&& u, V& v) const {
static_assert(ranges::range<T>);
static_assert(ranges::range<U>);
return sample_from_range(std::forward<T>(t), std::forward<U>(u), v); // Assume V is a URNG
}

/// Overload that takes two arguments.
template <class T, class U>
constexpr auto operator()(T&& t, U&& u) const {
if constexpr (ranges::range<T> && ranges::range<U>) {
auto& engine = ranges::detail::get_random_engine();
return sample_from_range(std::forward<T>(t), std::forward<U>(u), engine);
} else if constexpr (is_random_distribution_v<T>) {
static_assert(std::is_lvalue_reference_v<U&&>); // Assume U is a URNG
return sample_from_distribution(std::forward<T>(t), u);
} else {
static_assert(ranges::range<T>);
static_assert(std::is_lvalue_reference_v<U&&>); // Assume U is a URNG
return sample_from_range(std::forward<T>(t), u);
}
}

/// Overload that returns a view closure to compose with other views.
template <class URNG, std::enable_if_t<!ranges::range<URNG>, int> = 0>
constexpr auto operator()(URNG& engine) const {
return ranges::make_view_closure(ranges::bind_back(sample_fn{}, std::ref(engine)));
/// Overload that takes one argument.
template <class T>
constexpr auto operator()(T&& t) const {
if constexpr (ranges::range<T>) {
auto& engine = ranges::detail::get_random_engine();
return sample_from_range(std::forward<T>(t), engine);
} else if constexpr (is_random_distribution_v<T>) {
auto& engine = ranges::detail::get_random_engine();
return sample_from_distribution(std::forward<T>(t), engine);
} else {
static_assert(std::is_lvalue_reference_v<T&&>); // Assume T is a URNG
return ranges::make_view_closure(ranges::bind_back(sample_fn{}, std::ref(t)));
}
}

/// Overload that unwraps the engine reference from a view closure.
template <class Range, class URNG>
constexpr auto operator()(Range&& range, std::reference_wrapper<URNG> engine) const {
static_assert(ranges::range<Range>);
return sample_from_range(std::forward<Range>(range), engine.get());
}
};

Expand All @@ -193,6 +237,9 @@ struct sample_fn {
* The core idea is to draw random indices / iterators to the input particle range
* from a [multinomial distribution](https://en.wikipedia.org/wiki/Multinomial_distribution)
* parameterized after particle weights (and assumed uniform for non-weighted particle ranges).
*
* This view can also be used to convert any random distribution (a callable that takes a URNG as an
* input argument) into an infinite view that generates values from that distribution.
*/
inline constexpr ranges::views::view_closure<detail::sample_fn> sample;

Expand Down
12 changes: 12 additions & 0 deletions beluga/test/beluga/views/test_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,16 @@ TEST(SampleView, DiscreteDistributionProbability) {
ASSERT_NEAR(static_cast<double>(buckets[4]) / size, 0.2, 0.01);
}

TEST(SampleView, FromRandomDistributionFalse) {
auto distribution = std::bernoulli_distribution{0.0};
auto output = beluga::views::sample(distribution) | ranges::views::take_exactly(10);
ASSERT_EQ(ranges::count(output, false), 10);
}

TEST(SampleView, FromRandomDistributionTrue) {
auto distribution = std::bernoulli_distribution{1.0};
auto output = beluga::views::sample(distribution) | ranges::views::take_exactly(10);
ASSERT_EQ(ranges::count(output, true), 10);
}

} // namespace
6 changes: 2 additions & 4 deletions beluga_system_tests/test/test_system_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ auto particle_filter_test(
auto hasher = beluga::spatial_hash<Sophus::SE2d>{0.1, 0.1, 0.1};

// Use the initial distribution to initialize particles.
// TODO(nahuel): We should have a view to sample from an existing distribution.
// TODO(nahuel): We should have a view to convert from Eigen to Sophus types.
/**
* auto particles = beluga::views::sample(initial_distribution) |
Expand All @@ -135,9 +134,8 @@ auto particle_filter_test(
* ranges::views::take_exactly(params.max_particles) |
* ranges::to<beluga::TupleVector>;
*/
auto particles = ranges::views::generate([initial_distribution]() mutable {
static thread_local auto engine = std::mt19937{std::random_device()()};
const auto sample = initial_distribution(engine);
auto particles = beluga::views::sample(initial_distribution) | //
ranges::views::transform([](const auto& sample) {
return Sophus::SE2d{Sophus::SO2d{sample.z()}, Eigen::Vector2d{sample.x(), sample.y()}};
}) |
ranges::views::transform(beluga::make_from_state<Particle>) | //
Expand Down

0 comments on commit a819c82

Please sign in to comment.