Skip to content

Commit

Permalink
Forbid absl::Uniform<absl::int128>(gen)
Browse files Browse the repository at this point in the history
std::is_signed can't be specialized, so this actually lets through non-unsigned types where the types are not language primitives (i.e. it lets absl::int128 through). However, std::numeric_limits can be specialized, and is indeed specialized, so we can use that instead.
PiperOrigin-RevId: 636983590
Change-Id: Ic993518e9cac7c453b08deaf3784b6fba49f15d0
  • Loading branch information
Quincunx271 authored and Copybara-Service committed May 24, 2024
1 parent 0ef5bc6 commit 4a7c2ec
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 16 deletions.
2 changes: 2 additions & 0 deletions absl/random/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ cc_test(
deps = [
":distributions",
":random",
"//absl/meta:type_traits",
"//absl/numeric:int128",
"//absl/random/internal:distribution_test_util",
"@com_google_googletest//:gtest",
"@com_google_googletest//:gtest_main",
Expand Down
2 changes: 2 additions & 0 deletions absl/random/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ absl_cc_test(
DEPS
absl::random_distributions
absl::random_random
absl::type_traits
absl::int128
absl::random_internal_distribution_test_util
GTest::gmock
GTest::gtest_main
Expand Down
16 changes: 8 additions & 8 deletions absl/random/distributions.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,23 @@
#ifndef ABSL_RANDOM_DISTRIBUTIONS_H_
#define ABSL_RANDOM_DISTRIBUTIONS_H_

#include <algorithm>
#include <cmath>
#include <limits>
#include <random>
#include <type_traits>

#include "absl/base/config.h"
#include "absl/base/internal/inline_variable.h"
#include "absl/meta/type_traits.h"
#include "absl/random/bernoulli_distribution.h"
#include "absl/random/beta_distribution.h"
#include "absl/random/exponential_distribution.h"
#include "absl/random/gaussian_distribution.h"
#include "absl/random/internal/distribution_caller.h" // IWYU pragma: export
#include "absl/random/internal/traits.h"
#include "absl/random/internal/uniform_helper.h" // IWYU pragma: export
#include "absl/random/log_uniform_int_distribution.h"
#include "absl/random/poisson_distribution.h"
#include "absl/random/uniform_int_distribution.h"
#include "absl/random/uniform_real_distribution.h"
#include "absl/random/uniform_int_distribution.h" // IWYU pragma: export
#include "absl/random/uniform_real_distribution.h" // IWYU pragma: export
#include "absl/random/zipf_distribution.h"

namespace absl {
Expand Down Expand Up @@ -176,7 +176,7 @@ Uniform(TagType tag,

return random_internal::DistributionCaller<gen_t>::template Call<
distribution_t>(&urbg, tag, static_cast<return_t>(lo),
static_cast<return_t>(hi));
static_cast<return_t>(hi));
}

// absl::Uniform(bitgen, lo, hi)
Expand All @@ -200,15 +200,15 @@ Uniform(URBG&& urbg, // NOLINT(runtime/references)

return random_internal::DistributionCaller<gen_t>::template Call<
distribution_t>(&urbg, static_cast<return_t>(lo),
static_cast<return_t>(hi));
static_cast<return_t>(hi));
}

// absl::Uniform<unsigned T>(bitgen)
//
// Overload of Uniform() using the minimum and maximum values of a given type
// `T` (which must be unsigned), returning a value of type `unsigned T`
template <typename R, typename URBG>
typename absl::enable_if_t<!std::is_signed<R>::value, R> //
typename absl::enable_if_t<!std::numeric_limits<R>::is_signed, R> //
Uniform(URBG&& urbg) { // NOLINT(runtime/references)
using gen_t = absl::decay_t<URBG>;
using distribution_t = random_internal::UniformDistributionWrapper<R>;
Expand Down
60 changes: 52 additions & 8 deletions absl/random/distributions_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
#include <cfloat>
#include <cmath>
#include <cstdint>
#include <random>
#include <limits>
#include <type_traits>
#include <utility>
#include <vector>

#include "gtest/gtest.h"
#include "absl/meta/type_traits.h"
#include "absl/numeric/int128.h"
#include "absl/random/internal/distribution_test_util.h"
#include "absl/random/random.h"

Expand All @@ -30,7 +34,6 @@ constexpr int kSize = 400000;

class RandomDistributionsTest : public testing::Test {};


struct Invalid {};

template <typename A, typename B>
Expand Down Expand Up @@ -93,17 +96,18 @@ void CheckArgsInferType() {
}

template <typename A, typename B, typename ExplicitRet>
auto ExplicitUniformReturnT(int) -> decltype(
absl::Uniform<ExplicitRet>(*std::declval<absl::InsecureBitGen*>(),
std::declval<A>(), std::declval<B>()));
auto ExplicitUniformReturnT(int) -> decltype(absl::Uniform<ExplicitRet>(
std::declval<absl::InsecureBitGen&>(),
std::declval<A>(), std::declval<B>()));

template <typename, typename, typename ExplicitRet>
Invalid ExplicitUniformReturnT(...);

template <typename TagType, typename A, typename B, typename ExplicitRet>
auto ExplicitTaggedUniformReturnT(int) -> decltype(absl::Uniform<ExplicitRet>(
std::declval<TagType>(), *std::declval<absl::InsecureBitGen*>(),
std::declval<A>(), std::declval<B>()));
auto ExplicitTaggedUniformReturnT(int)
-> decltype(absl::Uniform<ExplicitRet>(
std::declval<TagType>(), std::declval<absl::InsecureBitGen&>(),
std::declval<A>(), std::declval<B>()));

template <typename, typename, typename, typename ExplicitRet>
Invalid ExplicitTaggedUniformReturnT(...);
Expand Down Expand Up @@ -135,6 +139,14 @@ void CheckArgsReturnExpectedType() {
"");
}

// Takes the type of `absl::Uniform<R>(gen)` if valid or `Invalid` otherwise.
template <typename R>
auto UniformNoBoundsReturnT(int)
-> decltype(absl::Uniform<R>(std::declval<absl::InsecureBitGen&>()));

template <typename>
Invalid UniformNoBoundsReturnT(...);

TEST_F(RandomDistributionsTest, UniformTypeInference) {
// Infers common types.
CheckArgsInferType<uint16_t, uint16_t, uint16_t>();
Expand Down Expand Up @@ -221,6 +233,38 @@ TEST_F(RandomDistributionsTest, UniformNoBounds) {
absl::Uniform<uint32_t>(gen);
absl::Uniform<uint64_t>(gen);
absl::Uniform<absl::uint128>(gen);

// Compile-time validity tests.

// Allows unsigned ints.
testing::StaticAssertTypeEq<uint8_t,
decltype(UniformNoBoundsReturnT<uint8_t>(0))>();
testing::StaticAssertTypeEq<uint16_t,
decltype(UniformNoBoundsReturnT<uint16_t>(0))>();
testing::StaticAssertTypeEq<uint32_t,
decltype(UniformNoBoundsReturnT<uint32_t>(0))>();
testing::StaticAssertTypeEq<uint64_t,
decltype(UniformNoBoundsReturnT<uint64_t>(0))>();
testing::StaticAssertTypeEq<
absl::uint128, decltype(UniformNoBoundsReturnT<absl::uint128>(0))>();

// Disallows signed ints.
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int8_t>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int16_t>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int32_t>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<int64_t>(0))>();
testing::StaticAssertTypeEq<
Invalid, decltype(UniformNoBoundsReturnT<absl::int128>(0))>();

// Disallows float types.
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<float>(0))>();
testing::StaticAssertTypeEq<Invalid,
decltype(UniformNoBoundsReturnT<double>(0))>();
}

TEST_F(RandomDistributionsTest, UniformNonsenseRanges) {
Expand Down

0 comments on commit 4a7c2ec

Please sign in to comment.