Skip to content

Commit

Permalink
Added serde1 feature to WeightedIndex
Browse files Browse the repository at this point in the history
It required adding serde as an optional dependency.
Also added a test using bincode like in the other crate.
  • Loading branch information
CGMossa committed May 1, 2020
1 parent bf8b5a9 commit 9c823fa
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 32 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Expand Up @@ -24,7 +24,7 @@ appveyor = { repository = "rust-random/rand" }
# Meta-features:
default = ["std", "std_rng"]
nightly = ["simd_support"] # enables all features requiring nightly rust
serde1 = [] # does nothing, deprecated
serde1 = ["serde"]

# Option (enabled by default): without "std" rand uses libcore; this option
# enables functionality expected to be available on a standard platform.
Expand Down Expand Up @@ -58,6 +58,7 @@ members = [
rand_core = { path = "rand_core", version = "0.5.1" }
rand_pcg = { path = "rand_pcg", version = "0.2", optional = true }
log = { version = "0.4.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }

[dependencies.packed_simd]
# NOTE: so far no version works reliably due to dependence on unstable features
Expand All @@ -81,6 +82,8 @@ rand_hc = { path = "rand_hc", version = "0.2", optional = true }
rand_pcg = { path = "rand_pcg", version = "0.2" }
# Only for benches:
rand_hc = { path = "rand_hc", version = "0.2" }
# only to test serde1
bincode = {version = "1.2.1"}

[package.metadata.docs.rs]
all-features = true
46 changes: 28 additions & 18 deletions src/distributions/uniform.rs
Expand Up @@ -103,8 +103,10 @@
//! [`UniformDuration`]: crate::distributions::uniform::UniformDuration
//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow

#[cfg(not(feature = "std"))] use core::time::Duration;
#[cfg(feature = "std")] use std::time::Duration;
#[cfg(not(feature = "std"))]
use core::time::Duration;
#[cfg(feature = "std")]
use std::time::Duration;

use crate::distributions::float::IntoFloat;
use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, WideningMultiply};
Expand All @@ -115,8 +117,11 @@ use crate::Rng;
#[allow(unused_imports)] // rustc doesn't detect that this is actually used
use crate::distributions::utils::Float;

#[cfg(feature = "simd_support")]
use packed_simd::*;

#[cfg(feature = "simd_support")] use packed_simd::*;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

/// Sample values uniformly between two bounds.
///
Expand Down Expand Up @@ -291,15 +296,17 @@ pub trait SampleBorrow<Borrowed> {
fn borrow(&self) -> &Borrowed;
}
impl<Borrowed> SampleBorrow<Borrowed> for Borrowed
where Borrowed: SampleUniform
where
Borrowed: SampleUniform,
{
#[inline(always)]
fn borrow(&self) -> &Borrowed {
self
}
}
impl<'a, Borrowed> SampleBorrow<Borrowed> for &'a Borrowed
where Borrowed: SampleUniform
where
Borrowed: SampleUniform,
{
#[inline(always)]
fn borrow(&self) -> &Borrowed {
Expand All @@ -311,7 +318,6 @@ where Borrowed: SampleUniform

// What follows are all back-ends.


/// The back-end implementing [`UniformSampler`] for integer types.
///
/// Unless you are implementing [`UniformSampler`] for your own type, this type
Expand Down Expand Up @@ -347,6 +353,7 @@ where Borrowed: SampleUniform
/// multiply by `range`, the result is in the high word. Then comparing the low
/// word against `zone` makes sure our distribution is uniform.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct UniformInt<X> {
low: X,
range: X,
Expand Down Expand Up @@ -623,7 +630,6 @@ uniform_simd_int_impl! {
u8
}


/// The back-end implementing [`UniformSampler`] for floating-point types.
///
/// Unless you are implementing [`UniformSampler`] for your own type, this type
Expand Down Expand Up @@ -831,7 +837,6 @@ uniform_float_impl! { f64x4, u64x4, f64, u64, 64 - 52 }
#[cfg(feature = "simd_support")]
uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 }


/// The back-end implementing [`UniformSampler`] for `Duration`.
///
/// Unless you are implementing [`UniformSampler`] for your own types, this type
Expand Down Expand Up @@ -991,7 +996,8 @@ mod tests {
#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_integers() {
#[cfg(not(target_os = "emscripten"))] use core::{i128, u128};
#[cfg(not(target_os = "emscripten"))]
use core::{i128, u128};
use core::{i16, i32, i64, i8, isize};
use core::{u16, u32, u64, u8, usize};

Expand Down Expand Up @@ -1228,12 +1234,13 @@ mod tests {
}
}


#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_durations() {
#[cfg(not(feature = "std"))] use core::time::Duration;
#[cfg(feature = "std")] use std::time::Duration;
#[cfg(not(feature = "std"))]
use core::time::Duration;
#[cfg(feature = "std")]
use std::time::Duration;

let mut rng = crate::test::rng(253);

Expand Down Expand Up @@ -1328,7 +1335,9 @@ mod tests {
fn value_stability() {
fn test_samples<T: SampleUniform + Copy + core::fmt::Debug + PartialEq>(
lb: T, ub: T, expected_single: &[T], expected_multiple: &[T],
) where Uniform<T>: Distribution<T> {
) where
Uniform<T>: Distribution<T>,
{
let mut rng = crate::test::rng(897);
let mut buf = [lb; 3];

Expand All @@ -1350,11 +1359,12 @@ mod tests {
test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]);
test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]);

test_samples(0f32, 1e-2f32, &[0.0003070104, 0.0026630748, 0.00979833], &[
0.008194133,
0.00398172,
0.007428536,
]);
test_samples(
0f32,
1e-2f32,
&[0.0003070104, 0.0026630748, 0.00979833],
&[0.008194133, 0.00398172, 0.007428536],
);
test_samples(
-1e10f64,
1e10f64,
Expand Down
58 changes: 45 additions & 13 deletions src/distributions/weighted_index.rs
Expand Up @@ -15,7 +15,11 @@ use core::cmp::PartialOrd;
use core::fmt;

// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;
#[cfg(not(feature = "std"))]
use crate::alloc::vec::Vec;

#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

/// A distribution using weighted sampling of discrete items
///
Expand Down Expand Up @@ -73,6 +77,7 @@ use core::fmt;
/// [`Uniform<X>`]: crate::distributions::uniform::Uniform
/// [`RngCore`]: crate::RngCore
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
total_weight: X,
Expand Down Expand Up @@ -133,10 +138,12 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
///
/// In case of error, `self` is not modified.
pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
where X: for<'a> ::core::ops::AddAssign<&'a X>
where
X: for<'a> ::core::ops::AddAssign<&'a X>
+ for<'a> ::core::ops::SubAssign<&'a X>
+ Clone
+ Default {
+ Default,
{
if new_weights.is_empty() {
return Ok(());
}
Expand Down Expand Up @@ -214,7 +221,8 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
}

impl<X> Distribution<usize> for WeightedIndex<X>
where X: SampleUniform + PartialOrd
where
X: SampleUniform + PartialOrd,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
use ::core::cmp::Ordering;
Expand All @@ -236,6 +244,24 @@ where X: SampleUniform + PartialOrd
mod test {
use super::*;

#[cfg(feature = "serde1")]
#[test]
fn test_weightedindex_serde1() {
let weighted_index = WeightedIndex::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();

let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
let de_weighted_index: WeightedIndex<i32> =
bincode::deserialize(&ser_weighted_index).unwrap();

// these doesn't work because lack of PartialEq, Eq
// assert_eq!(de_weighted_index, weighted_index);
assert_eq!(
de_weighted_index.cumulative_weights,
weighted_index.cumulative_weights
);
assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_weightedindex() {
Expand Down Expand Up @@ -360,15 +386,21 @@ mod test {
}

let mut buf = [0; 10];
test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
]);
test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
]);
test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
]);
test_samples(
&[1i32, 1, 1, 1, 1, 1, 1, 1, 1],
&mut buf,
&[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
);
test_samples(
&[0.7f32, 0.1, 0.1, 0.1],
&mut buf,
&[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
);
test_samples(
&[1.0f64, 0.999, 0.998, 0.997],
&mut buf,
&[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
);
}
}

Expand Down

0 comments on commit 9c823fa

Please sign in to comment.