From bbc17442267835a1ffa616b1a5138e4c44b4eaf9 Mon Sep 17 00:00:00 2001 From: tortar Date: Wed, 13 Aug 2025 01:55:31 +0200 Subject: [PATCH 1/5] Add merging strategies for more samplers --- Project.toml | 2 +- .../benchmark_comparison_non_stream_WWR.jl | 27 +-- benchmark/benchmark_comparison_stream_WWR.jl | 14 +- docs/Project.toml | 1 + docs/make.jl | 6 +- docs/src/basics.md | 55 +++++ docs/src/example.md | 41 ++-- docs/src/index.md | 73 ------ docs/src/perf_tips.md | 65 +++++ src/SamplingInterface.jl | 131 +++++------ src/SamplingReduction.jl | 23 +- src/SamplingUtils.jl | 31 ++- src/SortedSamplingMulti.jl | 20 +- src/SortedSamplingSingle.jl | 2 +- src/StreamSampling.jl | 7 +- src/UnweightedSamplingMulti.jl | 187 ++++++++------- src/UnweightedSamplingSingle.jl | 49 ++-- src/WeightedSamplingMulti.jl | 222 +++++++++++------- src/WeightedSamplingSingle.jl | 52 +++- src/precompile.jl | 35 ++- test/merge_tests.jl | 23 +- test/unweighted_sampling_multi_tests.jl | 4 +- test/weighted_sampling_multi_tests.jl | 2 +- 23 files changed, 601 insertions(+), 471 deletions(-) create mode 100644 docs/src/basics.md create mode 100644 docs/src/perf_tips.md diff --git a/Project.toml b/Project.toml index 4f891207..418c66e5 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] -julia = "1.8" +julia = "1.10" Accessors = "0.1" DataStructures = "0.18" Distributions = "0.25" diff --git a/benchmark/benchmark_comparison_non_stream_WWR.jl b/benchmark/benchmark_comparison_non_stream_WWR.jl index 885e0444..6fa1e9da 100644 --- a/benchmark/benchmark_comparison_non_stream_WWR.jl +++ b/benchmark/benchmark_comparison_non_stream_WWR.jl @@ -10,7 +10,7 @@ using CairoMakie ################ function weighted_reservoir_sample(rng, a, ws, n) - return shuffle!(rng, weighted_reservoir_sample_seq(rng, a, ws, n)[1]) + return StreamSampling.fshuffle!(rng, weighted_reservoir_sample_seq(rng, a, ws, n)[1]) end function weighted_reservoir_sample_seq(rng, a, ws, n) @@ -25,12 +25,9 @@ function weighted_reservoir_sample_seq(rng, a, ws, n) w_sum += w_el if w_sum > w_skip p = w_el/w_sum - q = 1-p - z = exp((n-4)*log1p(-p)) - t = rand(rng, Uniform(z*q*q*q*q,1.0)) - k = @inline choose(n, p, q, t, z) + k = StreamSampling.choose(rng, n, p) @inbounds for j in 1:k - r = rand(rng, j:n) + r = rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) reservoir[r], reservoir[j] = reservoir[j], a[i] end w_skip = @inline skip(rng, w_sum, n) @@ -44,18 +41,6 @@ function skip(rng, w_sum::AbstractFloat, n) return w_sum/k end -function choose(n, p, q, t, z) - x = z*q*q*q*(q + n*p) - x > t && return 1 - x += n*p*(n-1)*p*z*q*q/2 - x > t && return 2 - x += n*p*(n-1)*p*(n-2)*p*z*q/6 - x > t && return 3 - x += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 - x > t && return 4 - return quantile(Binomial(n, p), t) -end - ##################### ## parallel 1 pass ## ##################### @@ -75,7 +60,7 @@ function weighted_reservoir_sample_parallel_1_pass(rngs, a, ws, n) Threads.@threads for i in 1:nt ss[i] = sample(rngs[i], ss[i], ns[i]; replace = false) end - return shuffle!(rngs[1], reduce(vcat, ss)) + return fshuffle!(rngs[1], reduce(vcat, ss)) end ##################### @@ -97,7 +82,7 @@ function weighted_reservoir_sample_parallel_2_pass(rngs, a, ws, n) s = weighted_reservoir_sample_seq(rngs[i], @view(a[inds]), @view(ws[inds]), ns[i]) ss[i] = s[1] end - return shuffle!(rngs[1], reduce(vcat, ss)) + return fshuffle!(rngs[1], reduce(vcat, ss)) end function sample_parallel_2_pass(rngs, a, ws, n) @@ -115,7 +100,7 @@ function sample_parallel_2_pass(rngs, a, ws, n) s = sample(rngs[i], @view(a[inds]), Weights(@view(ws[inds])), ns[i]; replace = true) ss[i] = s end - return shuffle!(rngs[1], reduce(vcat, ss)) + return fshuffle!(rngs[1], reduce(vcat, ss)) end ################ diff --git a/benchmark/benchmark_comparison_stream_WWR.jl b/benchmark/benchmark_comparison_stream_WWR.jl index 1b5992e6..0d3fd6a6 100644 --- a/benchmark/benchmark_comparison_stream_WWR.jl +++ b/benchmark/benchmark_comparison_stream_WWR.jl @@ -8,7 +8,7 @@ using Random struct AlgAExpJWR end -struct SampleMultiAlgAExpJWR{B, R, T} <: AbstractReservoirSampler +struct MultiAlgAExpJSamplerWR{B, R, T} <: AbstractReservoirSampler n::Int seen_k::Int w_sum::Float64 @@ -19,15 +19,15 @@ struct SampleMultiAlgAExpJWR{B, R, T} <: AbstractReservoirSampler end function StreamSampling.ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgAExpJWR, - ::StreamSampling.ImmutSample, ::StreamSampling.Unord) where T + ::StreamSampling.ImmutSampler, ::StreamSampling.Unord) where T value = BinaryHeap(Base.By(first, DataStructures.FasterForward()), Tuple{Float64,T}[]) sizehint!(value, n) v = Vector{T}(undef, n) w = Vector{Float64}(undef, n) - return SampleMultiAlgAExpJWR(n, 0, 0.0, rng, value, v, w) + return MultiAlgAExpJSamplerWR(n, 0, 0.0, rng, value, v, w) end -@inline function OnlineStatsBase._fit!(s::SampleMultiAlgAExpJWR, el, w) +@inline function OnlineStatsBase._fit!(s::MultiAlgAExpJSamplerWR, el, w) n = s.n s = @inline update_state!(s, w) if s.seen_k <= n @@ -51,14 +51,14 @@ end skip_single(rng, n) = n/rand(rng) -function update_state!(s::SampleMultiAlgAExpJWR, w) +function update_state!(s::MultiAlgAExpJSamplerWR, w) @update s.seen_k += 1 @update s.w_sum += w return s end -function OnlineStatsBase.value(s::SampleMultiAlgAExpJWR) - return shuffle!(s.rng, last.(s.value.valtree)) +function OnlineStatsBase.value(s::MultiAlgAExpJSamplerWR) + return StreamSampling.fshuffle!(s.rng, last.(s.value.valtree)) end a = Iterators.filter(x -> x != 1, 1:10^8) diff --git a/docs/Project.toml b/docs/Project.toml index 4312540c..82df54a4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" StreamSampling = "ff63dad9-3335-55d8-95ec-f8139d39e468" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/docs/make.jl b/docs/make.jl index 1ca09c2f..fac7a66b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,6 @@ using Documenter + +using BenchmarkTools using StreamSampling println("Documentation Build") @@ -7,9 +9,11 @@ makedocs( sitename = "StreamSampling.jl", pages = [ "StreamSampling.jl" => "index.md", + "Basics" => "basics.md", "An Illustrative Example" => "example.md", "API" => "api.md", - "Benchmark Comparison" => "benchmark.md" + "Performance Tips" => "perf_tips.md", + "Benchmarks" => "benchmark.md" ], warnonly = [:doctest, :missing_docs, :cross_references], ) diff --git a/docs/src/basics.md b/docs/src/basics.md new file mode 100644 index 00000000..ef6ab9b5 --- /dev/null +++ b/docs/src/basics.md @@ -0,0 +1,55 @@ + +# Overview of the functionalities + +The `itsample` function allows to consume all the stream at once and return the sample collected: + +```julia +using StreamSampling + +st = 1:100; + +itsample(st, 5) +``` + +In some cases, one needs to control the updates the `ReservoirSampler` will be subject to. In this case +you can simply use the `fit!` function to update the reservoir: + +```julia +using StreamSampling + +st = 1:100; + +rs = ReservoirSampler{Int}(5); + +for x in st + fit!(rs, x) +end + +value(rs) +``` + +If the total number of elements in the stream is known beforehand and the sampling is unweighted, it is +also possible to iterate over a `StreamSampler` like so + +```julia +using StreamSampling + +st = 1:100; + +ss = StreamSampler{Int}(st, 5, 100); + +r = Int[]; + +for x in ss + push!(r, x) +end + +r +``` + +The advantage of `StreamSampler` iterators in respect to `ReservoirSampler` is that they require `O(1)` +memory if not collected, while reservoir techniques require `O(k)` memory where `k` is the number +of elements in the sample. + +Consult the [API page](https://juliadynamics.github.io/StreamSampling.jl/stable/api) for more information +about the package interface. \ No newline at end of file diff --git a/docs/src/example.md b/docs/src/example.md index 8a829b56..b21163b4 100644 --- a/docs/src/example.md +++ b/docs/src/example.md @@ -1,3 +1,4 @@ + # An Illustrative Example Suppose to receive data about some process in the form of a stream and you want @@ -9,43 +10,41 @@ you want that to be lower than a certain threshold otherwise some malfunctioning is expected. ```julia -julia> using StreamSampling, Statistics, Random - -julia> function monitor(stream, thr) - rng = Xoshiro(42) - # we use a reservoir sample of 10^4 elements - rs = ReservoirSampler{Int}(rng, 10^4) - # we loop over the stream and fit the data in the reservoir - for (i, e) in enumerate(stream) - fit!(rs, e) - # we check the mean value every 1000 iterations - if iszero(mod(i, 1000)) && mean(value(rs)) >= thr - return rs - end - end - end +using StreamSampling, Statistics, Random + +function monitor(stream, thr) + rng = Xoshiro(42) + # we use a reservoir sample of 10^4 elements + rs = ReservoirSampler{Int}(rng, 10^4) + # we loop over the stream and fit the data in the reservoir + for (i, e) in enumerate(stream) + fit!(rs, e) + # we check the mean value every 1000 iterations + if iszero(mod(i, 1000)) && mean(value(rs)) >= thr + return rs + end + end +end ``` We use some toy data for illustration ```julia -julia> stream = 1:10^8; # the data stream - -julia> thr = 2*10^7; # the threshold for the mean monitoring +stream = 1:10^8; # the data stream +thr = 2*10^7; # the threshold for the mean monitoring ``` Then, we run the monitoring ```julia -julia> rs = monitor(stream, thr); +rs = monitor(stream, thr); ``` The number of observations until the detection is triggered is given by ```julia -julia> nobs(rs) -40009000 +nobs(rs) ``` which is very close to the true value of `4*10^7 - 1` observations. diff --git a/docs/src/index.md b/docs/src/index.md index f52b760c..7b9e9f52 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,79 +5,6 @@ StreamSampling ``` -# Overview of the functionalities - -The `itsample` function allows to consume all the stream at once and return the sample collected: - -```julia -julia> using StreamSampling - -julia> st = 1:100; - -julia> itsample(st, 5) -5-element Vector{Int64}: - 9 - 15 - 52 - 96 - 91 -``` - -In some cases, one needs to control the updates the `ReservoirSampler` will be subject to. In this case -you can simply use the `fit!` function to update the reservoir: - -```julia -julia> using StreamSampling - -julia> st = 1:100; - -julia> rs = ReservoirSampler{Int}(5); - -julia> for x in st - fit!(rs, x) - end - -julia> value(rs) -5-element Vector{Int64}: - 7 - 9 - 20 - 49 - 74 -``` - -If the total number of elements in the stream is known beforehand and the sampling is unweighted, it is -also possible to iterate over a `StreamSampler` like so - -```julia -julia> using StreamSampling - -julia> st = 1:100; - -julia> ss = StreamSampler{Int}(st, 5, 100); - -julia> r = Int[]; - -julia> for x in ss - push!(r, x) - end - -julia> r -5-element Vector{Int64}: - 10 - 22 - 26 - 35 - 75 -``` - -The advantage of `StreamSampler` iterators in respect to `ReservoirSampler` is that they require `O(1)` -memory if not collected, while reservoir techniques require `O(k)` memory where `k` is the number -of elements in the sample. - -Consult the [API page](https://juliadynamics.github.io/StreamSampling.jl/stable/api) for more information -about the package interface. - ## Reproducibility ```@raw html diff --git a/docs/src/perf_tips.md b/docs/src/perf_tips.md new file mode 100644 index 00000000..dc2d33dd --- /dev/null +++ b/docs/src/perf_tips.md @@ -0,0 +1,65 @@ + +# Use Immutable Reservoir Samplers + +By default, a `ReservoirSampler` is mutable, however, it is +also possible to use an immutable version which supports +all the basic operations. It uses `Accessors.jl` under the +hood to update the reservoir: + +```julia +using BenchmarkTools + +function fit_iter!(rs, iter) + for i in iter + rs = fit!(rs, i) # the reassignment is necessary when `rs` is immutable + end + return rs +end + +iter = 1:10^7 + +@btime fit_iter!(rs, $iter) setup=(rs = ReservoirSampler{Int}(10, AlgRSWRSKIP(); mutable = true)) +@btime fit_iter!(rs, $iter) setup=(rs = ReservoirSampler{Int}(10, AlgRSWRSKIP(); mutable = false)) +``` + +As you can see, the immutable version is 50% faster than +the mutable one. In general, more the ratio between reservoir +size and stream size is smaller, more the immutable version +will be faster than the mutable one. Be careful though, because +calling `fit!` on an immutable sampler won't modify it in-place, +but only create a new updated instance. + +# Parallel Sampling from Multiple Streams + +Let's say that you want to split the sampling of an iterator. If you can split the iterator into +different partitions then you can update in parallel a reservoir sample for each partition and then +merge them together at the end. + +Suppose for instance to have these 2 iterators + +```julia +iters = [1:100, 101:200] +``` + +then you create two reservoirs of the same type + +```julia +rs = [ReservoirSampler{Int}(10, AlgRSWRSKIP()) for i in 1:length(iters)] +``` + +and after that you can just update them in parallel like so + +```julia +Threads.@threads for i in 1:length(iters) + for e in iters[i] + fit!(rs[i], e) + end +end +``` + +then you can obtain a unique reservoir containing a summary of the union of the streams +with + +```julia +merge(rs...) +``` diff --git a/src/SamplingInterface.jl b/src/SamplingInterface.jl index 60794f72..c499ee1d 100644 --- a/src/SamplingInterface.jl +++ b/src/SamplingInterface.jl @@ -3,28 +3,31 @@ ReservoirSampler{T}([rng], method = AlgRSWRSKIP()) ReservoirSampler{T}([rng], n::Int, method = AlgL(); ordered = false) -Initializes a reservoir sample which can then be fitted with [`fit!`](@ref). +Initializes a reservoir sampler with elements of type `T`. + The first signature represents a sample where only a single element is collected. -If `ordered` is true, the reservoir sample values can be retrived in the order -they were collected with [`ordvalue`](@ref). +If `ordered` is true, the sampled values can be retrived in the order +they were collected using [`ordvalue`](@ref). Look at the [`Sampling Algorithms`](@ref) section for the supported methods. """ -struct ReservoirSampler{T} 1 === 1 end +struct ReservoirSampler{T,F} 1 === 1 end + +ReservoirSampler{T}(args::Vararg{Any, N}; kwargs...) where {T,N} = ReservoirSampler{T,Float64}(args...; kwargs...) -function ReservoirSampler{T}(method::ReservoirAlgorithm = AlgRSWRSKIP()) where T - return ReservoirSampler{T}(Random.default_rng(), method, MutSample()) +function ReservoirSampler{T,F}(method::ReservoirAlgorithm = AlgRSWRSKIP(); mutable = true) where {T,F} + return ReservoirSampler{T,F}(Random.default_rng(), method, mutable ? MutSampler() : ImmutSampler()) end -function ReservoirSampler{T}(rng::AbstractRNG, method::ReservoirAlgorithm = AlgRSWRSKIP()) where T - return ReservoirSampler{T}(rng, method, MutSample()) +function ReservoirSampler{T,F}(rng::AbstractRNG, method::ReservoirAlgorithm = AlgRSWRSKIP(); mutable = true) where {T,F} + return ReservoirSampler{T,F}(rng, method, mutable ? MutSampler() : ImmutSampler()) end -Base.@constprop :aggressive function ReservoirSampler{T}(n::Integer, method::ReservoirAlgorithm=AlgL(); - ordered = false) where T - return ReservoirSampler{T}(Random.default_rng(), n, method, MutSample(), ordered ? Ord() : Unord()) +Base.@constprop :aggressive function ReservoirSampler{T,F}(n::Integer, method::ReservoirAlgorithm=AlgL(); + ordered = false, mutable = true) where {T,F} + return ReservoirSampler{T,F}(Random.default_rng(), n, method, mutable ? MutSampler() : ImmutSampler(), ordered ? Ord() : Unord()) end -Base.@constprop :aggressive function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, - method::ReservoirAlgorithm=AlgL(); ordered = false) where T - return ReservoirSampler{T}(rng, n, method, MutSample(), ordered ? Ord() : Unord()) +Base.@constprop :aggressive function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, + method::ReservoirAlgorithm=AlgL(); ordered = false, mutable = true) where {T,F} + return ReservoirSampler{T,F}(rng, n, method, mutable ? MutSampler() : ImmutSampler(), ordered ? Ord() : Unord()) end """ @@ -36,7 +39,7 @@ If the sampling is weighted also the weight of the elements needs to be passed. """ @inline OnlineStatsBase.fit!(s::AbstractReservoirSampler, el) = OnlineStatsBase._fit!(s, el) -@inline OnlineStatsBase.fit!(s::AbstractReservoirSampler, el, w) = OnlineStatsBase._fit!(s, el, w) +@inline OnlineStatsBase.fit!(s::AbstractWeightedReservoirSampler, el, w) = OnlineStatsBase._fit!(s, el, w) """ value(rs::AbstractReservoirSampler) @@ -44,10 +47,14 @@ passed. Returns the elements collected in the sample at the current sampling stage. +If the sampler is empty, it returns `nothing` for single element +sampling. For multi-valued samplers, it always returns the sample +collected so far instead. + Note that even if the sampling respects the schema it is assigned when [`ReservoirSampler`](@ref) is instantiated, some ordering in the sample can be more probable than others. To represent each one -with the same probability call `shuffle!` over the result. +with the same probability call `fshuffle!` on the result. """ OnlineStatsBase.value(s::AbstractReservoirSampler) = error("Abstract version") @@ -57,6 +64,10 @@ OnlineStatsBase.value(s::AbstractReservoirSampler) = error("Abstract version") Returns the elements collected in the sample at the current sampling stage in the order they were collected. This applies only when `ordered = true` is passed in [`ReservoirSampler`](@ref). + +If the sampler is empty, it returns `nothing` for single element +sampling. For multi-valued samplers, it always returns the sample +collected so far instead. """ ordvalue(s::AbstractReservoirSampler) = error("Not an ordered sample") @@ -66,45 +77,54 @@ ordvalue(s::AbstractReservoirSampler) = error("Not an ordered sample") Returns the total number of elements that have been observed so far during the sampling process. """ -OnlineStatsBase.nobs(s::AbstractReservoirSampler) = s.seen_k +OnlineStatsBase.nobs(rs::AbstractReservoirSampler) = rs.seen_k """ Base.empty!(rs::AbstractReservoirSampler) Resets the reservoir sample to its initial state. -Useful to avoid allocating a new sample in some cases. +Useful to avoid allocating a new sampler in some cases. """ function Base.empty!(::AbstractReservoirSampler) error("Abstract Version") end """ - Base.merge!(rs::AbstractReservoirSampler, rs::AbstractReservoirSampler...) + Base.merge!(rs::AbstractReservoirSampler, rs_others::AbstractReservoirSampler...) -Updates the first reservoir sample by merging its value with the values -of the other samples. Currently only supported for samples with replacement. +Updates the first reservoir sampler by merging its value with the values +of the other samplers. The number of elements after merging will be +the minimum number of elements in the merged reservoirs. """ function Base.merge!(::AbstractReservoirSampler) error("Abstract Version") end """ - Base.merge(rs::AbstractReservoirSampler...) + Base.merge(rs_all::AbstractReservoirSampler...) -Creates a new reservoir sample by merging the values -of the samples passed. Currently only supported for sample -with replacement. +Creates a new reservoir sampler by merging the values +of the samplers passed. The number of elements in the new +sampler will be the minimum number of elements in the merged +reservoirs. """ function OnlineStatsBase.merge(::AbstractReservoirSampler) error("Abstract Version") end +""" + Base.size(rs::AbstractReservoirSampler) + +Returns the maximum number of elements that are stored in the reservoir. +""" +Base.size(rs::AbstractReservoirSampler) = rs.n + """ StreamSampler{T}([rng], iter, n, [N], method = AlgD()) -Initializes a stream sample, which can then be iterated over +Initializes a stream sampler, which can then be iterated over to return the sampling elements of the iterable `iter` which -is assumed to have a eltype of `T`. The methods implemented in +is assumed to have a `eltype` of `T`. The methods implemented in [`StreamSampler`](@ref) require the knowledge of the total number of elements in the stream `N`, if not provided it is assumed to be available by calling `length(iter)`. @@ -146,14 +166,6 @@ appear in the same order as in `iter`) must be collected. If the iterator has less than `n` elements, in the case of sampling without replacement, it returns a vector of those elements. - ------ - - itsample(rngs, iters, n::Int) - itsample(rngs, iters, wfuncs, n::Int) - -Parallel implementation which returns a sample with replacement of size `n` -from the multiple iterables. All the arguments except from `n` must be tuples. """ function itsample(iter, method = AlgRSWRSKIP(); iter_type = infer_eltype(iter)) return itsample(Random.default_rng(), iter, method; iter_type) @@ -171,7 +183,7 @@ end Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, method = AlgRSWRSKIP(); iter_type = infer_eltype(iter)) if Base.IteratorSize(iter) isa Base.SizeUnknown - s = ReservoirSampler{iter_type}(rng, method, ImmutSample()) + s = ReservoirSampler{iter_type,Float64}(rng, method, ImmutSampler()) return update_all!(s, iter) else return sorted_sample_single(rng, iter) @@ -180,45 +192,23 @@ end Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, n::Int, method = AlgL(); iter_type = infer_eltype(iter), ordered = false) if Base.IteratorSize(iter) isa Base.SizeUnknown - s = ReservoirSampler{iter_type}(rng, n, method, ImmutSample(), ordered ? Ord() : Unord()) + s = ReservoirSampler{iter_type,Float64}(rng, n, method, ImmutSampler(), ordered ? Ord() : Unord()) return update_all!(s, iter, ordered) else m = method isa AlgL || method isa AlgR || method isa AlgD ? AlgD() : AlgORDSWR() s = collect(StreamSampler{iter_type}(rng, iter, n, length(iter), m)) - return ordered ? s : shuffle!(rng, s) + return ordered ? s : fshuffle!(rng, s) end end function itsample(rng::AbstractRNG, iter, wv::Function, method = AlgWRSWRSKIP(); iter_type = infer_eltype(iter)) - s = ReservoirSampler{iter_type}(rng, method, ImmutSample()) + s = ReservoirSampler{iter_type,Float64}(rng, method, ImmutSampler()) return update_all!(s, iter, wv) end Base.@constprop :aggressive function itsample(rng::AbstractRNG, iter, wv::Function, n::Int, method = AlgAExpJ(); iter_type = infer_eltype(iter), ordered = false) - s = ReservoirSampler{iter_type}(rng, n, method, ImmutSample(), ordered ? Ord() : Unord()) + s = ReservoirSampler{iter_type,Float64}(rng, n, method, ImmutSampler(), ordered ? Ord() : Unord()) return update_all!(s, iter, ordered, wv) end -function itsample(rngs::Tuple, iters::Tuple, n::Int,; iter_types = infer_eltype.(iters)) - n_it = length(iters) - vs = Vector{Vector{Union{iter_types...}}}(undef, n_it) - ps = Vector{Float64}(undef, n_it) - Threads.@threads for i in 1:n_it - s = ReservoirSampler{iter_types[i]}(rngs[i], n, AlgRSWRSKIP(), ImmutSample(), Unord()) - vs[i], ps[i] = update_all_p!(s, iters[i]) - end - ps /= sum(ps) - return shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) -end -function itsample(rngs::Tuple, iters::Tuple, wfuncs::Tuple, n::Int; iter_types = infer_eltype.(iters)) - n_it = length(iters) - vs = Vector{Vector{Union{iter_types...}}}(undef, n_it) - ps = Vector{Float64}(undef, n_it) - Threads.@threads for i in 1:n_it - s = ReservoirSampler{iter_types[i]}(rngs[i], n, AlgWRSWRSKIP(), ImmutSample(), Unord()) - vs[i], ps[i] = update_all_p!(s, iters[i], wfuncs[i]) - end - ps /= sum(ps) - return shuffle!(rngs[1], reduce_samples(rngs, ps, vs)) -end function update_all!(s, iter) for x in iter @@ -236,24 +226,11 @@ function update_all!(s, iter, ordered::Bool) for x in iter s = fit!(s, x) end - return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) + return ordered ? ordvalue(s) : fshuffle!(s.rng, value(s)) end function update_all!(s, iter, ordered, wv) for x in iter s = fit!(s, x, wv(x)) end - return ordered ? ordvalue(s) : shuffle!(s.rng, value(s)) -end - -function update_all_p!(s, iter) - for x in iter - s = fit!(s, x) - end - return value(s), s.seen_k -end -function update_all_p!(s, iter, wv) - for x in iter - s = fit!(s, x, wv(x)) - end - return value(s), s.state + return ordered ? ordvalue(s) : fshuffle!(s.rng, value(s)) end diff --git a/src/SamplingReduction.jl b/src/SamplingReduction.jl index 9a22eff0..d7a1c2ad 100644 --- a/src/SamplingReduction.jl +++ b/src/SamplingReduction.jl @@ -1,11 +1,19 @@ -const SMWR = Union{SampleMultiAlgRSWRSKIP, SampleMultiAlgWRSWRSKIP} +const SMWR = Union{MultiAlgRSWRSKIPSampler, MultiAlgWRSWRSKIPSampler} +const SMWOWR = Union{MultiAlgAResSampler, MultiAlgAExpJSampler} reduce_samples(t) = error() +function reduce_samples(t, ss::T...) where {T<:SMWOWR} + nt = length(ss) + n = minimum(length.(value.(ss))) + lkeys = sort(reduce(vcat, [s.value.valtree for s in ss]), by=(x->x[end]), rev=true)[1:n] + return lkeys +end function reduce_samples(t, ss::T...) where {T<:SMWR} nt = length(ss) v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt) - ns = rand(ss[1].rng, Multinomial(length(value(ss[1])), get_ps(ss...))) + n = minimum(length.(value.(ss))) + ns = rand(ss[1].rng, Multinomial(n, get_ps(ss...))) Threads.@threads for i in 1:nt v[i] = sample(ss[i].rng, value(ss[i]), ns[i]; replace = false) end @@ -13,23 +21,24 @@ function reduce_samples(t, ss::T...) where {T<:SMWR} end function reduce_samples(rngs, ps::Vector, vs::Vector) nt = length(vs) - ns = rand(rngs[1], Multinomial(length(vs[1]), ps)) + n = minimum(length.(vs)) + ns = rand(rngs[1], Multinomial(n, ps)) Threads.@threads for i in 1:nt vs[i] = sample(rngs[i], vs[i], ns[i]; replace = false) end return reduce(vcat, vs) end -function get_ps(ss::SampleMultiAlgRSWRSKIP...) +function get_ps(ss::MultiAlgRSWRSKIPSampler...) sum_w = sum(getfield(s, :seen_k) for s in ss) return [s.seen_k/sum_w for s in ss] end -function get_ps(ss::SampleMultiAlgWRSWRSKIP...) +function get_ps(ss::MultiAlgWRSWRSKIPSampler...) sum_w = sum(getfield(s, :state) for s in ss) return [s.state/sum_w for s in ss] end -get_type_rs(::TypeS, s1::T, ss::T...) where {T<:SMWR} = eltype(value(s1)) -function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T<:SMWR} +get_type_rs(::TypeS, s1::T, ss::T...) where {T} = eltype(value(s1)) +function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T} return Union{eltype(value(s1)), Union{(eltype(value(s)) for s in ss)...}} end diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 5c8c729f..40b2622f 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -13,26 +13,26 @@ function infer_eltype(itr) ifelse(T2 !== Union{} && T2 <: T1, T2, T1) end -struct SeqSampleIterWR{R} +struct SeqIterWRSampler{R} rng::R N::Int n::Int end -@inline function Base.iterate(s::SeqSampleIterWR) +@inline function Base.iterate(s::SeqIterWRSampler) curmax = -log(Float64(s.N)) + randexp(s.rng)/s.n return (s.N - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax)) end -@inline function Base.iterate(s::SeqSampleIterWR, state) +@inline function Base.iterate(s::SeqIterWRSampler, state) state[1] == 0 && return nothing curmax = state[2] + randexp(s.rng)/state[1] return (s.N - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax)) end -Base.IteratorEltype(::SeqSampleIterWR) = Base.HasEltype() -Base.eltype(::SeqSampleIterWR) = Int -Base.IteratorSize(::SeqSampleIterWR) = Base.HasLength() -Base.length(s::SeqSampleIterWR) = s.n +Base.IteratorEltype(::SeqIterWRSampler) = Base.HasEltype() +Base.eltype(::SeqIterWRSampler) = Int +Base.IteratorSize(::SeqIterWRSampler) = Base.HasLength() +Base.length(s::SeqIterWRSampler) = s.n # courtesy of StatsBase.jl for part of the implementation struct SeqSampleIter{R} @@ -144,4 +144,19 @@ end Base.IteratorEltype(::SeqSampleIter) = Base.HasEltype() Base.eltype(::SeqSampleIter) = Int Base.IteratorSize(::SeqSampleIter) = Base.HasLength() -Base.length(s::SeqSampleIter) = s.n \ No newline at end of file +Base.length(s::SeqSampleIter) = s.n + +function fshuffle!(rng::AbstractRNG, vec::AbstractVector) + for i in 2:length(vec) + endi = (i-1) % UInt + j = @inline rand(rng, Random.Sampler(rng, UInt(0):endi, Val(1))) % Int + 1 + vec[i], vec[j] = vec[j], vec[i] + end + vec +end + +function ordmemory(n) + ord = Base.Memory{Int}(undef, n) + for i in eachindex(ord) ord[i] = i end + ord +end \ No newline at end of file diff --git a/src/SortedSamplingMulti.jl b/src/SortedSamplingMulti.jl index 421a4031..f39e4079 100644 --- a/src/SortedSamplingMulti.jl +++ b/src/SortedSamplingMulti.jl @@ -1,22 +1,22 @@ -struct SampleMultiAlgORD{T,R,I,D} <: AbstractStreamSampler +struct MultiAlgORDSampler{T,R,I,D} <: AbstractStreamSampler rng::R it::I n::Int inds::D - function SampleMultiAlgORD{T}(rng::R, it::I, n, inds::D) where {T,R,I,D} + function MultiAlgORDSampler{T}(rng::R, it::I, n, inds::D) where {T,R,I,D} return new{T,R,I,D}(rng, it, n, inds) end end function StreamSampler{T}(rng::AbstractRNG, iter, n, N, ::AlgD) where T - return SampleMultiAlgORD{T}(rng, iter, min(n, N), SeqSampleIter(rng, N, min(n, N))) + return MultiAlgORDSampler{T}(rng, iter, min(n, N), SeqSampleIter(rng, N, min(n, N))) end function StreamSampler{T}(rng::AbstractRNG, iter, n, N, ::AlgORDSWR) where T - return SampleMultiAlgORD{T}(rng, iter, n, SeqSampleIterWR(rng, N, n)) + return MultiAlgORDSampler{T}(rng, iter, n, SeqIterWRSampler(rng, N, n)) end -@inline function Base.iterate(s::SampleMultiAlgORD) +@inline function Base.iterate(s::MultiAlgORDSampler) indices, iter = s.inds, s.it curr_idx, state_idx = iterate(indices)::Tuple el, state_el = iterate(iter)::Tuple @@ -25,7 +25,7 @@ end end return (el, (el, state_el, curr_idx, state_idx)) end -@inline function Base.iterate(s::SampleMultiAlgORD, state) +@inline function Base.iterate(s::MultiAlgORDSampler, state) el, state_el, curr_idx, state_idx = state indices, iter = s.inds, s.it it_indices = iterate(indices, state_idx) @@ -37,7 +37,7 @@ end return (el, (el, state_el, next_idx, state_idx)) end -Base.IteratorEltype(::SampleMultiAlgORD) = Base.HasEltype() -Base.eltype(::SampleMultiAlgORD{T}) where T = T -Base.IteratorSize(::SampleMultiAlgORD) = Base.HasLength() -Base.length(s::SampleMultiAlgORD) = s.n +Base.IteratorEltype(::MultiAlgORDSampler) = Base.HasEltype() +Base.eltype(::MultiAlgORDSampler{T}) where T = T +Base.IteratorSize(::MultiAlgORDSampler) = Base.HasLength() +Base.length(s::MultiAlgORDSampler) = s.n diff --git a/src/SortedSamplingSingle.jl b/src/SortedSamplingSingle.jl index 26bd98ea..1bb551b9 100644 --- a/src/SortedSamplingSingle.jl +++ b/src/SortedSamplingSingle.jl @@ -1,6 +1,6 @@ function sorted_sample_single(rng, iter) - k = rand(rng, 1:length(iter)) + k = rand(rng, Random.Sampler(rng, 1:length(iter), Val(1))) for (i, el) in enumerate(iter) i == k && return el end diff --git a/src/StreamSampling.jl b/src/StreamSampling.jl index 3469f326..a7ac6b89 100644 --- a/src/StreamSampling.jl +++ b/src/StreamSampling.jl @@ -19,20 +19,21 @@ export fit!, merge!, value, ordvalue, nobs, itsample export AbstractReservoirSampler, ReservoirSampler, StreamSampler export AlgL, AlgR, AlgRSWRSKIP, AlgARes, AlgAExpJ, AlgWRSWRSKIP, AlgD, AlgORDSWR -struct ImmutSample end -struct MutSample end +struct ImmutSampler end +struct MutSampler end struct Ord end struct Unord end abstract type AbstractStreamSampler end abstract type AbstractReservoirSampler <: OnlineStat{Any} end +abstract type AbstractWeightedReservoirSampler <: AbstractReservoirSampler end abstract type StreamAlgorithm end abstract type ReservoirAlgorithm <: StreamAlgorithm end """ -Implements random sampling without replacement. To be used with [`StreamSampler`](@ref) +Implements random stream sampling without replacement. To be used with [`StreamSampler`](@ref) or [`itsample`](@ref). Adapted from algorithm D described in "An Efficient Algorithm for Sequential Random Sampling, diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 68383103..adc8d1bc 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -1,25 +1,25 @@ -@hybrid struct SampleMultiAlgR{O,T,R} <: AbstractReservoirSampler +@hybrid struct MultiAlgRSampler{O,T,R} <: AbstractReservoirSampler const n::Int seen_k::Int const rng::R const value::Vector{T} const ord::O end -const SampleMultiOrdAlgR = SampleMultiAlgR{<:Vector} +const MultiOrdAlgRSampler = MultiAlgRSampler{<:Base.Memory} -@hybrid struct SampleMultiAlgL{O,T,R} <: AbstractReservoirSampler +@hybrid struct MultiAlgLSampler{O,T,R,F} <: AbstractReservoirSampler const n::Int - state::Float64 + state::F skip_k::Int seen_k::Int const rng::R const value::Vector{T} const ord::O end -const SampleMultiOrdAlgL = SampleMultiAlgL{<:Vector} +const MultiOrdAlgLSampler = MultiAlgLSampler{<:Base.Memory} -@hybrid struct SampleMultiAlgRSWRSKIP{O,T,R} <: AbstractReservoirSampler +@hybrid struct MultiAlgRSWRSKIPSampler{O,T,R} <: AbstractReservoirSampler const n::Int skip_k::Int seen_k::Int @@ -27,60 +27,60 @@ const SampleMultiOrdAlgL = SampleMultiAlgL{<:Vector} const value::Vector{T} const ord::O end -const SampleMultiOrdAlgRSWRSKIP = SampleMultiAlgRSWRSKIP{<:Vector} +const MultiOrdAlgRSWRSKIPSampler = MultiAlgRSWRSKIPSampler{<:Base.Memory} -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::MutSample, ::Ord) where T - return SampleMultiAlgL_Mut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgL, ::MutSampler, ::Ord) where {T,F} + return MultiAlgLSampler_Mut(n, zero(F), 0, 0, rng, Vector{T}(undef, n), ordmemory(n)) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::MutSample, ::Unord) where T - return SampleMultiAlgL_Mut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgL, ::MutSampler, ::Unord) where {T,F} + return MultiAlgLSampler_Mut(n, zero(F), 0, 0, rng, Vector{T}(undef, n), nothing) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::ImmutSample, ::Ord) where T - return SampleMultiAlgL_Immut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgL, ::ImmutSampler, ::Ord) where {T,F} + return MultiAlgLSampler_Immut(n, zero(F), 0, 0, rng, Vector{T}(undef, n), ordmemory(n)) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::ImmutSample, ::Unord) where T - return SampleMultiAlgL_Immut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgL, ::ImmutSampler, ::Unord) where {T,F} + return MultiAlgLSampler_Immut(n, zero(F), 0, 0, rng, Vector{T}(undef, n), nothing) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::MutSample, ::Ord) where T - return SampleMultiAlgR_Mut(n, 0, rng, Vector{T}(undef, n), collect(1:n)) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgR, ::MutSampler, ::Ord) where {T,F} + return MultiAlgRSampler_Mut(n, 0, rng, Vector{T}(undef, n), ordmemory(n)) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::MutSample, ::Unord) where T - return SampleMultiAlgR_Mut(n, 0, rng, Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgR, ::MutSampler, ::Unord) where {T,F} + return MultiAlgRSampler_Mut(n, 0, rng, Vector{T}(undef, n), nothing) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::ImmutSample, ::Ord) where T - return SampleMultiAlgR_Immut(n, 0, rng, Vector{T}(undef, n), collect(1:n)) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgR, ::ImmutSampler, ::Ord) where {T,F} + return MultiAlgRSampler_Immut(n, 0, rng, Vector{T}(undef, n), ordmemory(n)) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::ImmutSample, ::Unord) where T - return SampleMultiAlgR_Immut(n, 0, rng, Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgR, ::ImmutSampler, ::Unord) where {T,F} + return MultiAlgRSampler_Immut(n, 0, rng, Vector{T}(undef, n), nothing) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::MutSample, ::Ord) where T - return SampleMultiAlgRSWRSKIP_Mut(n, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::MutSampler, ::Ord) where {T,F} + return MultiAlgRSWRSKIPSampler_Mut(n, 0, 0, rng, Vector{T}(undef, n), ordmemory(n)) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::MutSample, ::Unord) where T - return SampleMultiAlgRSWRSKIP_Mut(n, 0, 0, rng, Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::MutSampler, ::Unord) where {T,F} + return MultiAlgRSWRSKIPSampler_Mut(n, 0, 0, rng, Vector{T}(undef, n), nothing) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::ImmutSample, ::Ord) where T - return SampleMultiAlgRSWRSKIP_Immut(n, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::ImmutSampler, ::Ord) where {T,F} + return MultiAlgRSWRSKIPSampler_Immut(n, 0, 0, rng, Vector{T}(undef, n), ordmemory(n)) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::ImmutSample, ::Unord) where T - return SampleMultiAlgRSWRSKIP_Immut(n, 0, 0, rng, Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::ImmutSampler, ::Unord) where {T,F} + return MultiAlgRSWRSKIPSampler_Immut(n, 0, 0, rng, Vector{T}(undef, n), nothing) end -@inline function OnlineStatsBase._fit!(s::SampleMultiAlgR, el) +@inline function OnlineStatsBase._fit!(s::MultiAlgRSampler, el) n = s.n s = @inline update_state!(s) if s.seen_k <= n @inbounds s.value[s.seen_k] = el return s end - j = rand(s.rng, 1:s.seen_k) + j = rand(s.rng, Random.Sampler(s.rng, 1:s.seen_k, Val(1))) if j <= n @inbounds s.value[j] = el update_order!(s, j) end return s end -@inline function OnlineStatsBase._fit!(s::SampleMultiAlgL, el) +@inline function OnlineStatsBase._fit!(s::MultiAlgLSampler, el) n = s.n s = @inline update_state!(s) if s.seen_k <= n @@ -91,122 +91,143 @@ end return s end if s.skip_k < s.seen_k - j = rand(s.rng, 1:n) + j = rand(s.rng, Random.Sampler(s.rng, 1:n, Val(1))) @inbounds s.value[j] = el update_order!(s, j) s = @inline recompute_skip!(s, n) end return s end -@inline function OnlineStatsBase._fit!(s::SampleMultiAlgRSWRSKIP, el) +@inline function OnlineStatsBase._fit!(s::MultiAlgRSWRSKIPSampler, el) n = s.n s = @inline update_state!(s) if s.seen_k <= n @inbounds s.value[s.seen_k] = el if s.seen_k === n s = @inline recompute_skip!(s, n) - new_values = sample(s.rng, s.value, n, ordered=is_ordered(s)) - @inbounds for i in 1:n - s.value[i] = new_values[i] - end + s.value .= sample(s.rng, s.value, n, ordered=is_ordered(s)) end return s end if s.skip_k < s.seen_k p = 1/s.seen_k - z = exp((n-4)*log1p(-p)) - c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0)) - k = @inline choose(n, p, c, z) + k = @inline choose(s.rng, n, p) @inbounds for j in 1:k - r = rand(s.rng, j:n) + r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) s.value[r], s.value[j] = s.value[j], el update_order_multi!(s, r, j) - end + end s = @inline recompute_skip!(s, n) end return s end -function Base.empty!(s::SampleMultiAlgR_Mut) +function Base.empty!(s::MultiAlgRSampler_Mut) s.seen_k = 0 return s end -function Base.empty!(s::SampleMultiAlgL_Mut) +function Base.empty!(s::MultiAlgLSampler_Mut) s.state = 0.0 s.skip_k = 0 s.seen_k = 0 return s end -function Base.empty!(s::SampleMultiAlgRSWRSKIP_Mut) +function Base.empty!(s::MultiAlgRSWRSKIPSampler_Mut) s.skip_k = 0 s.seen_k = 0 return s end -function update_state!(s::SampleMultiAlgR) +function update_state!(s::MultiAlgRSampler) @update s.seen_k += 1 return s end -function update_state!(s::SampleMultiAlgL) +function update_state!(s::MultiAlgLSampler) @update s.seen_k += 1 return s end -function update_state!(s::SampleMultiAlgRSWRSKIP) +function update_state!(s::MultiAlgRSWRSKIPSampler) @update s.seen_k += 1 return s end -function recompute_skip!(s::SampleMultiAlgL, n) +function recompute_skip!(s::MultiAlgLSampler, n) @update s.state += randexp(s.rng) @update s.skip_k = s.seen_k-ceil(Int, randexp(s.rng)/log1p(-exp(-s.state/n))) return s end -function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n) - q = exp(-randexp(s.rng)/n) - @update s.skip_k = ceil(Int, s.seen_k/q)-1 +function recompute_skip!(s::MultiAlgRSWRSKIPSampler, n) + q = exp(randexp(s.rng)/n) + @update s.skip_k = ceil(Int, s.seen_k*q)-1 return s end -function choose(n, p, c, z) +@inline function choose(rng, n, p) + z = exp(n*log1p(-p)) + t = rand(rng, Uniform(z, 1.0)) + s = n*p q = 1-p - k = z*q*q*q*(q + n*p) - k > c && return 1 - k += n*p*(n-1)*p*z*q*q/2 - k > c && return 2 - k += n*p*(n-1)*p*(n-2)*p*z*q/6 - k > c && return 3 - k += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 - k > c && return 4 - b = Binomial(n, p) - return quantile(b, c) -end - -update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing -function update_order!(s::Union{SampleMultiOrdAlgR, SampleMultiOrdAlgL}, j) + x = z + z*s/q + x > t && return 1 + s *= (n-1)*p + q *= 1-p + x += (s*z/q)/2 + x > t && return 2 + s *= (n-2)*p + q *= 1-p + x += (s*z/q)/6 + x > t && return 3 + s *= (n-3)*p + q *= 1-p + x += (s*z/q)/24 + x > t && return 4 + s *= (n-4)*p + q *= 1-p + x += (s*z/q)/120 + x > t && return 5 + return quantile(Binomial(n, p), t) +end + +update_order!(s::Union{MultiAlgRSampler, MultiAlgLSampler}, j) = nothing +function update_order!(s::Union{MultiOrdAlgRSampler, MultiOrdAlgLSampler}, j) s.ord[j] = nobs(s) end -update_order_single!(s::SampleMultiAlgRSWRSKIP, r) = nothing -function update_order_single!(s::SampleMultiOrdAlgRSWRSKIP, r) +update_order_single!(s::MultiAlgRSWRSKIPSampler, r) = nothing +function update_order_single!(s::MultiOrdAlgRSWRSKIPSampler, r) s.ord[r] = nobs(s) end -update_order_multi!(s::SampleMultiAlgRSWRSKIP, r, j) = nothing -function update_order_multi!(s::SampleMultiOrdAlgRSWRSKIP, r, j) +update_order_multi!(s::MultiAlgRSWRSKIPSampler, r, j) = nothing +function update_order_multi!(s::MultiOrdAlgRSWRSKIPSampler, r, j) s.ord[r], s.ord[j] = s.ord[j], nobs(s) end -is_ordered(s::SampleMultiOrdAlgRSWRSKIP) = true -is_ordered(s::SampleMultiAlgRSWRSKIP) = false +is_ordered(s::MultiOrdAlgRSWRSKIPSampler) = true +is_ordered(s::MultiAlgRSWRSKIPSampler) = false -function Base.merge(ss::SampleMultiAlgRSWRSKIP...) +function Base.merge(ss::MultiAlgRSampler...) + error("To Be Implemented") +end +function Base.merge(ss::MultiAlgLSampler...) + error("To Be Implemented") +end +function Base.merge(ss::MultiAlgRSWRSKIPSampler...) newvalue = reduce_samples(TypeUnion(), ss...) skip_k = sum(getfield(s, :skip_k) for s in ss) seen_k = sum(getfield(s, :seen_k) for s in ss) - return SampleMultiAlgRSWRSKIP_Mut(ss[1].n, skip_k, seen_k, ss[1].rng, newvalue, nothing) + n = minimum(s.n for s in ss) + return MultiAlgRSWRSKIPSampler_Mut(n, skip_k, seen_k, ss[1].rng, newvalue, nothing) end -function Base.merge!(s1::SampleMultiAlgRSWRSKIP{<:Nothing}, ss::SampleMultiAlgRSWRSKIP...) +function Base.merge!(ss::MultiAlgRSampler...) + error("To Be Implemented") +end +function Base.merge!(ss::MultiAlgLSampler...) + error("To Be Implemented") +end +function Base.merge!(s1::MultiAlgRSWRSKIPSampler{<:Nothing}, ss::MultiAlgRSWRSKIPSampler...) + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") newvalue = reduce_samples(TypeS(), s1, ss...) for i in 1:length(newvalue) @inbounds s1.value[i] = newvalue[i] @@ -216,14 +237,14 @@ function Base.merge!(s1::SampleMultiAlgRSWRSKIP{<:Nothing}, ss::SampleMultiAlgRS return s1 end -function OnlineStatsBase.value(s::Union{SampleMultiAlgR, SampleMultiAlgL}) +function OnlineStatsBase.value(s::Union{MultiAlgRSampler, MultiAlgLSampler}) if nobs(s) < length(s.value) return s.value[1:nobs(s)] else return s.value end end -function OnlineStatsBase.value(s::SampleMultiAlgRSWRSKIP) +function OnlineStatsBase.value(s::MultiAlgRSWRSKIPSampler) if nobs(s) < length(s.value) if nobs(s) == 0 return s.value[1:0] @@ -235,14 +256,14 @@ function OnlineStatsBase.value(s::SampleMultiAlgRSWRSKIP) end end -function ordvalue(s::Union{SampleMultiOrdAlgR, SampleMultiOrdAlgL}) +function ordvalue(s::Union{MultiOrdAlgRSampler, MultiOrdAlgLSampler}) if nobs(s) < length(s.value) return s.value[1:nobs(s)] else return s.value[sortperm(s.ord)] end end -function ordvalue(s::SampleMultiOrdAlgRSWRSKIP) +function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) if nobs(s) < length(s.value) if nobs(s) == 0 return s.value[1:0] diff --git a/src/UnweightedSamplingSingle.jl b/src/UnweightedSamplingSingle.jl index 99e151ab..108e7f04 100644 --- a/src/UnweightedSamplingSingle.jl +++ b/src/UnweightedSamplingSingle.jl @@ -1,24 +1,24 @@ -@hybrid struct SampleSingleAlgRSWRSKIP{RT,R} <: AbstractReservoirSampler +@hybrid struct SingleAlgRSWRSKIPSampler{RT,R} <: AbstractReservoirSampler seen_k::Int skip_k::Int const rng::R rvalue::RT end -function ReservoirSampler{T}(rng::AbstractRNG, ::AlgRSWRSKIP, ::MutSample) where T - return SampleSingleAlgRSWRSKIP_Mut(0, 0, rng, RefVal_Immut{T}()) +function ReservoirSampler{T,F}(rng::AbstractRNG, ::AlgRSWRSKIP, ::MutSampler) where {T,F} + return SingleAlgRSWRSKIPSampler_Mut(0, 0, rng, RefVal_Immut{T}()) end -function ReservoirSampler{T}(rng::AbstractRNG, ::AlgRSWRSKIP, ::ImmutSample) where T - return SampleSingleAlgRSWRSKIP_Immut(0, 0, rng, RefVal_Mut{T}()) +function ReservoirSampler{T,F}(rng::AbstractRNG, ::AlgRSWRSKIP, ::ImmutSampler) where {T,F} + return SingleAlgRSWRSKIPSampler_Immut(0, 0, rng, RefVal_Mut{T}()) end -function OnlineStatsBase.value(s::SampleSingleAlgRSWRSKIP) +function OnlineStatsBase.value(s::SingleAlgRSWRSKIPSampler) s.seen_k === 0 && return nothing return s.rvalue.value end -@inline function OnlineStatsBase._fit!(s::SampleSingleAlgRSWRSKIP, el) +@inline function OnlineStatsBase._fit!(s::SingleAlgRSWRSKIPSampler, el) @update s.seen_k += 1 if s.skip_k <= s.seen_k @update s.skip_k = ceil(Int, s.seen_k/rand(s.rng)) @@ -27,35 +27,38 @@ end return s end -function reset_value!(s::SampleSingleAlgRSWRSKIP_Mut, el) +function reset_value!(s::SingleAlgRSWRSKIPSampler_Mut, el) s.rvalue = RefVal_Immut(el) end -function reset_value!(s::SampleSingleAlgRSWRSKIP_Immut, el) +function reset_value!(s::SingleAlgRSWRSKIPSampler_Immut, el) s.rvalue.value = el end -function Base.empty!(s::SampleSingleAlgRSWRSKIP) +function Base.empty!(s::SingleAlgRSWRSKIPSampler) s.seen_k = 0 s.skip_k = 0 return s end -function Base.merge(s1::SampleSingleAlgRSWRSKIP, s2::SampleSingleAlgRSWRSKIP) - n1, n2 = nobs(s1), nobs(s2) - n_tot = n1 + n2 - value = rand(s1.rng) < n1/n_tot ? s1.rvalue : s2.rvalue - return typeof(s1)(n_tot, s1.skip_k + s2.skip_k, s1.rng, value) +function Base.merge(ss::SingleAlgRSWRSKIPSampler...) + ns = [nobs(s) for s in ss] + n_tot = sum(ns) + ps = cumsum(ns ./ n_tot) + r = rand(s1.rng) + value = ss[findfirst(p -> r < p, ps)].value + return typeof(s1)(n_tot, sum(s.skip_k for s in ss), ss[1].rng, value) end -function Base.merge!(s1::SampleSingleAlgRSWRSKIP_Mut, s2::SampleSingleAlgRSWRSKIP_Mut) - n1, n2 = nobs(s1), nobs(s2) - n_tot = n1 + n2 +function Base.merge!(s1::SingleAlgRSWRSKIPSampler_Mut, ss::SingleAlgRSWRSKIPSampler_Mut...) + ns = [nobs(s1), [nobs(s) for s in ss]...] + n_tot = sum(ns) + ps = cumsum(ns ./ n_tot) r = rand(s1.rng) - p = n2 / n_tot - if r < p - s1.rvalue = RefVal_Immut(s2.rvalue.value) + i = findfirst(p -> r < p, ps) + if i > 1 + s1.rvalue = RefVal_Immut(ss[i-1].rvalue.value) end - s1.seen_k = n_tot - s1.skip_k += s2.skip_k + s1.seen_k += sum(s.seen_k for s in ss) + s1.skip_k += sum(s.skip_k for s in ss) return s1 end diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index d70f8f8b..e979d8fa 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -1,92 +1,93 @@ const OrdWeighted = BinaryHeap{Tuple{T, Int64, Float64}, Base.Order.By{typeof(last), DataStructures.FasterForward}} where T -@hybrid struct SampleMultiAlgARes{BH,R} <: AbstractReservoirSampler +@hybrid struct MultiAlgAResSampler{BH,R} <: AbstractWeightedReservoirSampler seen_k::Int n::Int const rng::R value::BH end -const SampleMultiOrdAlgARes = Union{SampleMultiAlgARes_Immut{<:OrdWeighted}, SampleMultiAlgARes_Mut{<:OrdWeighted}} +const MultiOrdAlgAResSampler = Union{MultiAlgAResSampler_Immut{<:OrdWeighted}, MultiAlgAResSampler_Mut{<:OrdWeighted}} -@hybrid struct SampleMultiAlgAExpJ{BH,R} <: AbstractReservoirSampler - state::Float64 - min_priority::Float64 +@hybrid struct MultiAlgAExpJSampler{BH,R,F} <: AbstractWeightedReservoirSampler + state::F + min_priority::F seen_k::Int const n::Int const rng::R value::BH end -const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}} +const MultiOrdAlgAExpJSampler = Union{MultiAlgAExpJSampler_Immut{<:OrdWeighted}, MultiAlgAExpJSampler_Mut{<:OrdWeighted}} -@hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractReservoirSampler +@hybrid struct MultiAlgWRSWRSKIPSampler{O,T,R,F} <: AbstractWeightedReservoirSampler const n::Int - state::Float64 - skip_w::Float64 + state::F + skip_w::F seen_k::Int const rng::R - const weights::Vector{Float64} + const weights::Base.Memory{F} const value::Vector{T} const ord::O end -const SampleMultiOrdAlgWRSWRSKIP = Union{SampleMultiAlgWRSWRSKIP_Immut{<:Vector}, SampleMultiAlgWRSWRSKIP_Mut{<:Vector}} +const MultiOrdAlgWRSWRSKIPSampler = Union{MultiAlgWRSWRSKIPSampler_Immut{<:Base.Memory}, MultiAlgWRSWRSKIPSampler_Mut{<:Base.Memory}} -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::MutSample, ::Ord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::MutSampler, ::Ord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, F}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value) + return MultiAlgAExpJSampler_Mut(zero(F), zero(F), 0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::MutSample, ::Unord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::MutSampler, ::Unord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, F}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Mut(0.0, 0.0, 0, n, rng, value) + return MultiAlgAExpJSampler_Mut(zero(F), zero(F), 0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Ord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::ImmutSampler, ::Ord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, F}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value) + return MultiAlgAExpJSampler_Immut(zero(F), zero(F), 0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::ImmutSample, ::Unord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::ImmutSampler, ::Unord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, F}[]) sizehint!(value, n) - return SampleMultiAlgAExpJ_Immut(0.0, 0.0, 0, n, rng, value) + return MultiAlgAExpJSampler_Immut(zero(F), zero(F), 0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgARes, ::MutSample, ::Ord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgARes, ::MutSampler, ::Ord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, F}[]) sizehint!(value, n) - return SampleMultiAlgARes_Mut(0, n, rng, value) + return MultiAlgAResSampler_Mut(0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgARes, ::MutSample, ::Unord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgARes, ::MutSampler, ::Unord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, F}[]) sizehint!(value, n) - return SampleMultiAlgARes_Mut(0, n, rng, value) + return MultiAlgAResSampler_Mut(0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgARes, ::ImmutSample, ::Ord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgARes, ::ImmutSampler, ::Ord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, F}[]) sizehint!(value, n) - return SampleMultiAlgARes_Immut(0, n, rng, value) + return MultiAlgAResSampler_Immut(0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgARes, ::ImmutSample, ::Unord) where T - value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, Float64}[]) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgARes, ::ImmutSampler, ::Unord) where {T,F} + value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Pair{T, F}[]) sizehint!(value, n) - return SampleMultiAlgARes_Immut(0, n, rng, value) + return MultiAlgAResSampler_Immut(0, n, rng, value) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Ord) where T - ord = collect(1:n) - return SampleMultiAlgWRSWRSKIP_Mut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSampler, ::Ord) where {T,F} + ord = ordmemory(n) + return MultiAlgWRSWRSKIPSampler_Mut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), ord) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Unord) where T - return SampleMultiAlgWRSWRSKIP_Mut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSampler, ::Unord) where {T,F} + return MultiAlgWRSWRSKIPSampler_Mut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), nothing) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Ord) where T - ord = collect(1:n) - return SampleMultiAlgWRSWRSKIP_Immut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSampler, ::Ord) where {T,F} + ord = ordmemory(n) + return MultiAlgWRSWRSKIPSampler_Immut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), ord) end -function ReservoirSampler{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Unord) where T - return SampleMultiAlgWRSWRSKIP_Immut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) +function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSampler, ::Unord) where {T,F} + return MultiAlgWRSWRSKIPSampler_Immut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), nothing) end -@inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el, w) +@inline function OnlineStatsBase._fit!(s::Union{MultiAlgAResSampler, MultiOrdAlgAResSampler}, el, w) + w < 0.0 && error(lazy"Passed element $(el) with weight $(w), but weights must be positive.") n = s.n s = @inline update_state!(s, w) priority = -randexp(s.rng)/w @@ -101,11 +102,12 @@ end end return s end -@inline function OnlineStatsBase._fit!(s::SampleMultiAlgAExpJ, el, w) +@inline function OnlineStatsBase._fit!(s::MultiAlgAExpJSampler, el, w) + w < 0.0 && error(lazy"Passed element $(el) with weight $(w), but weights must be positive.") n = s.n s = @inline update_state!(s, w) if s.seen_k <= n - priority = exp(-randexp(s.rng)/w) + priority = -randexp(s.rng)/w @inline push_value!(s, el, priority) if s.seen_k == n s = @inline recompute_skip!(s) @@ -120,41 +122,36 @@ end end return s end -@inline function OnlineStatsBase._fit!(s::SampleMultiAlgWRSWRSKIP, el, w) +@inline function OnlineStatsBase._fit!(s::MultiAlgWRSWRSKIPSampler, el, w) + w < 0.0 && error(lazy"Passed element $(el) with weight $(w), but weights must be positive.") n = s.n s = @inline update_state!(s, w) if s.seen_k <= n @inbounds s.value[s.seen_k] = el @inbounds s.weights[s.seen_k] = w if s.seen_k == n - new_values = sample(s.rng, s.value, Weights(s.weights, s.state), n; - ordered = is_ordered(s)) - @inbounds for i in 1:n - s.value[i] = new_values[i] - end + s.value .= sample(s.rng, s.value, Weights(s.weights, s.state), n; + ordered = is_ordered(s)) s = @inline recompute_skip!(s, n) - empty!(s.weights) end return s end if s.skip_w <= s.state p = w/s.state - z = exp((n-4)*log1p(-p)) - c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p), 1.0)) - k = @inline choose(n, p, c, z) + k = @inline choose(s.rng, n, p) @inbounds for j in 1:k - r = rand(s.rng, j:n) + r = @inline rand(s.rng, Random.Sampler(s.rng, j:n, Val(1))) s.value[r], s.value[j] = s.value[j], el update_order_multi!(s, r, j) - end + end s = @inline recompute_skip!(s, n) end return s end -function Base.empty!(s::SampleMultiAlgARes_Mut) +function Base.empty!(s::MultiAlgAResSampler_Mut) s.seen_k = 0 - if s isa SampleMultiAlgWRSWRSKIP_Mut{<:Vector} + if s isa MultiAlgWRSWRSKIPSampler_Mut{<:Vector} s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[]) else s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[]) @@ -162,11 +159,11 @@ function Base.empty!(s::SampleMultiAlgARes_Mut) sizehint!(s.value, s.n) return s end -function Base.empty!(s::SampleMultiAlgAExpJ_Mut) +function Base.empty!(s::MultiAlgAExpJSampler_Mut) s.state = 0.0 s.min_priority = 0.0 s.seen_k = 0 - if s isa SampleMultiAlgWRSWRSKIP_Mut{<:Vector} + if s isa MultiAlgWRSWRSKIPSampler_Mut{<:Vector} s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[]) else s.value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), extract_T(s.value)[]) @@ -174,7 +171,7 @@ function Base.empty!(s::SampleMultiAlgAExpJ_Mut) sizehint!(s.value, s.n) return s end -function Base.empty!(s::SampleMultiAlgWRSWRSKIP_Mut) +function Base.empty!(s::MultiAlgWRSWRSKIPSampler_Mut) s.state = 0.0 s.skip_w = 0.0 s.seen_k = 0 @@ -183,16 +180,60 @@ end extract_T(::DataStructures.BinaryHeap{T}) where T = T -function Base.merge(ss::SampleMultiAlgWRSWRSKIP...) +function Base.merge(ss::MultiAlgAResSampler...) + newvalue = reduce_samples(TypeUnion(), ss...) + newheap = BinaryHeap(Base.By(last, DataStructures.FasterForward()), newvalue) + seen_k = sum(getfield(s, :seen_k) for s in ss) + n = minimum(s.n for s in ss) + s = MultiAlgAResSampler_Mut(seen_k, n, ss[1].rng, newheap) + return s +end +function Base.merge(ss::MultiAlgAExpJSampler...) + newvalue = reduce_samples(TypeUnion(), ss...) + newheap = BinaryHeap(Base.By(last, DataStructures.FasterForward()), newvalue) + seen_k = sum(getfield(s, :seen_k) for s in ss) + state = sum(getfield(s, :state) for s in ss) + min_priority = minimum(getfield(s, :min_priority) for s in ss) + n = minimum(s.n for s in ss) + s = MultiAlgAExpJSampler_Mut(state, min_priority, seen_k, n, ss[1].rng, newheap) + return s +end +function Base.merge(ss::MultiAlgWRSWRSKIPSampler...) newvalue = reduce_samples(TypeUnion(), ss...) skip_w = sum(getfield(s, :skip_w) for s in ss) state = sum(getfield(s, :state) for s in ss) seen_k = sum(getfield(s, :seen_k) for s in ss) - s = SampleMultiAlgWRSWRSKIP_Mut(ss[1].n, state, skip_w, seen_k, ss[1].rng, Float64[], newvalue, nothing) + n = minimum(s.n for s in ss) + s = MultiAlgWRSWRSKIPSampler_Mut(n, state, skip_w, seen_k, ss[1].rng, Memory{Float64}(undef,0), newvalue, nothing) return s end -function Base.merge!(s1::SampleMultiAlgWRSWRSKIP{<:Nothing}, ss::SampleMultiAlgWRSWRSKIP...) +function Base.merge!(s1::MultiAlgAResSampler, ss::MultiAlgAResSampler...) + length(typeof(s1.value.valtree).parameters) == 3 && error("Merging ordered reservoirs is not possible") + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + empty!(s1.value) + newvalue = reduce_samples(TypeS(), s1, ss...) + for e in newvalue + push!(s1.value, e[1] => e[2]) + end + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) + return s1 +end +function Base.merge!(s1::MultiAlgAExpJSampler, ss::MultiAlgAExpJSampler...) + length(typeof(s1.value.valtree).parameters) == 3 && error("Merging ordered reservoirs is not possible") + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") + empty!(s1.value) + newvalue = reduce_samples(TypeS(), s1, ss...) + for e in newvalue + push!(s1.value, e[1] => e[2]) + end + s1.seen_k += sum(getfield(s, :seen_k) for s in ss) + s1.state += sum(getfield(s, :state) for s in ss) + s1.min_priority = min(s1.min_priority, minimum(getfield(s, :min_priority) for s in ss)) + return s1 +end +function Base.merge!(s1::MultiAlgWRSWRSKIPSampler{<:Nothing}, ss::MultiAlgWRSWRSKIPSampler...) + s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir") newvalue = reduce_samples(TypeS(), s1, ss...) for i in 1:length(newvalue) @inbounds s1.value[i] = newvalue[i] @@ -200,68 +241,67 @@ function Base.merge!(s1::SampleMultiAlgWRSWRSKIP{<:Nothing}, ss::SampleMultiAlgW s1.skip_w += sum(getfield(s, :skip_w) for s in ss) s1.state += sum(getfield(s, :state) for s in ss) s1.seen_k += sum(getfield(s, :seen_k) for s in ss) - empty!(s1.weights) return s1 end -function update_state!(s::SampleMultiAlgARes, w) +function update_state!(s::MultiAlgAResSampler, w) @update s.seen_k += 1 return s end -function update_state!(s::SampleMultiAlgAExpJ, w) +function update_state!(s::MultiAlgAExpJSampler, w) @update s.seen_k += 1 @update s.state -= w return s end -function update_state!(s::SampleMultiAlgWRSWRSKIP, w) +function update_state!(s::MultiAlgWRSWRSKIPSampler, w) @update s.seen_k += 1 @update s.state += w return s end function compute_skip_priority(s, w) - t = exp(log(s.min_priority)*w) - return exp(log(rand(s.rng, Uniform(t,1)))/w) + t = exp(s.min_priority*w) + return log(rand(s.rng, Uniform(t,1)))/w end -function recompute_skip!(s::SampleMultiAlgAExpJ) +function recompute_skip!(s::MultiAlgAExpJSampler) @update s.min_priority = last(first(s.value)) - @update s.state = -randexp(s.rng)/log(s.min_priority) + @update s.state = -randexp(s.rng)/s.min_priority return s end -function recompute_skip!(s::SampleMultiAlgWRSWRSKIP, n) - q = exp(-randexp(s.rng)/n) - @update s.skip_w = s.state/q +function recompute_skip!(s::MultiAlgWRSWRSKIPSampler, n) + q = exp(randexp(s.rng)/n) + @update s.skip_w = s.state*q return s end -function push_value!(s::Union{SampleMultiAlgARes, SampleMultiAlgAExpJ}, el, priority) +function push_value!(s::Union{MultiAlgAResSampler, MultiAlgAExpJSampler}, el, priority) push!(s.value, el => priority) end -function push_value!(s::Union{SampleMultiOrdAlgARes, SampleMultiOrdAlgAExpJ}, el, priority) +function push_value!(s::Union{MultiOrdAlgAResSampler, MultiOrdAlgAExpJSampler}, el, priority) push!(s.value, (el, s.seen_k, priority)) end -update_order_single!(s::SampleMultiAlgWRSWRSKIP, r) = nothing -function update_order_single!(s::SampleMultiOrdAlgWRSWRSKIP, r) +update_order_single!(s::MultiAlgWRSWRSKIPSampler, r) = nothing +function update_order_single!(s::MultiOrdAlgWRSWRSKIPSampler, r) s.ord[r] = nobs(s) end -update_order_multi!(s::SampleMultiAlgWRSWRSKIP, r, j) = nothing -function update_order_multi!(s::SampleMultiOrdAlgWRSWRSKIP, r, j) +update_order_multi!(s::MultiAlgWRSWRSKIPSampler, r, j) = nothing +function update_order_multi!(s::MultiOrdAlgWRSWRSKIPSampler, r, j) s.ord[r], s.ord[j] = s.ord[j], nobs(s) end -is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true -is_ordered(s::SampleMultiAlgWRSWRSKIP) = false +is_ordered(s::MultiOrdAlgWRSWRSKIPSampler) = true +is_ordered(s::MultiAlgWRSWRSKIPSampler) = false -function OnlineStatsBase.value(s::Union{SampleMultiAlgARes, SampleMultiAlgAExpJ}) +function OnlineStatsBase.value(s::Union{MultiAlgAResSampler, MultiAlgAExpJSampler}) if nobs(s) < s.n return first.(s.value.valtree[1:nobs(s)]) else return first.(s.value.valtree) end end -function OnlineStatsBase.value(s::SampleMultiAlgWRSWRSKIP) +function OnlineStatsBase.value(s::MultiAlgWRSWRSKIPSampler) if nobs(s) < length(s.value) return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value)) else @@ -269,7 +309,7 @@ function OnlineStatsBase.value(s::SampleMultiAlgWRSWRSKIP) end end -function ordvalue(s::Union{SampleMultiOrdAlgARes, SampleMultiOrdAlgAExpJ}) +function ordvalue(s::Union{MultiOrdAlgAResSampler, MultiOrdAlgAExpJSampler}) if nobs(s) < length(s.value) vals = s.value.valtree[1:nobs(s)] else @@ -277,7 +317,7 @@ function ordvalue(s::Union{SampleMultiOrdAlgARes, SampleMultiOrdAlgAExpJ}) end return first.(vals[sortperm(map(x -> x[2], vals))]) end -function ordvalue(s::SampleMultiOrdAlgWRSWRSKIP) +function ordvalue(s::MultiOrdAlgWRSWRSKIPSampler) if nobs(s) < length(s.value) return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value); ordered=true) else diff --git a/src/WeightedSamplingSingle.jl b/src/WeightedSamplingSingle.jl index 01c3480b..45799606 100644 --- a/src/WeightedSamplingSingle.jl +++ b/src/WeightedSamplingSingle.jl @@ -1,25 +1,25 @@ -@hybrid struct SampleSingleAlgWRSWRSKIP{RT,R} <: AbstractReservoirSampler +@hybrid struct SingleAlgWRSWRSKIPSampler{RT,R,F} <: AbstractWeightedReservoirSampler seen_k::Int - total_w::Float64 - skip_w::Float64 + total_w::F + skip_w::F const rng::R rvalue::RT end -function ReservoirSampler{T}(rng::AbstractRNG, ::AlgWRSWRSKIP, ::MutSample) where T - return SampleSingleAlgWRSWRSKIP_Mut(0, 0.0, 0.0, rng, RefVal_Immut{T}()) +function ReservoirSampler{T,F}(rng::AbstractRNG, ::AlgWRSWRSKIP, ::MutSampler) where {T,F} + return SingleAlgWRSWRSKIPSampler_Mut(0, zero(F), zero(F), rng, RefVal_Immut{T}()) end -function ReservoirSampler{T}(rng::AbstractRNG, ::AlgWRSWRSKIP, ::ImmutSample) where T - return SampleSingleAlgWRSWRSKIP_Immut(0, 0.0, 0.0, rng, RefVal_Mut{T}()) +function ReservoirSampler{T,F}(rng::AbstractRNG, ::AlgWRSWRSKIP, ::ImmutSampler) where {T,F} + return SingleAlgWRSWRSKIPSampler_Immut(0, zero(F), zero(F), rng, RefVal_Mut{T}()) end -function OnlineStatsBase.value(s::SampleSingleAlgWRSWRSKIP) +function OnlineStatsBase.value(s::SingleAlgWRSWRSKIPSampler) s.seen_k === 0 && return nothing return get_value(s) end -@inline function OnlineStatsBase._fit!(s::SampleSingleAlgWRSWRSKIP, el, w) +@inline function OnlineStatsBase._fit!(s::SingleAlgWRSWRSKIPSampler, el, w) @update s.seen_k += 1 @update s.total_w += w if s.skip_w <= s.total_w @@ -29,18 +29,44 @@ end return s end -function Base.empty!(s::SampleSingleAlgWRSWRSKIP_Mut) +function Base.empty!(s::SingleAlgWRSWRSKIPSampler_Mut) s.seen_k = 0 s.total_w = 0.0 s.skip_w = 0.0 return s end -get_value(s::SampleSingleAlgWRSWRSKIP) = s.rvalue.value +get_value(s::SingleAlgWRSWRSKIPSampler) = s.rvalue.value -function reset_value!(s::SampleSingleAlgWRSWRSKIP_Mut, el) +function reset_value!(s::SingleAlgWRSWRSKIPSampler_Mut, el) s.rvalue = RefVal_Immut(el) end -function reset_value!(s::SampleSingleAlgWRSWRSKIP_Immut, el) +function reset_value!(s::SingleAlgWRSWRSKIPSampler_Immut, el) s.rvalue.value = el end + +function Base.merge(ss::SingleAlgWRSWRSKIPSampler...) + ns = [s.total_w for s in ss] + n_tot = sum(ns) + ps = cumsum(ns ./ n_tot) + r = rand(s1.rng) + value = ss[findfirst(p -> r < p, ps)].value + return typeof(s1)(sum(s.seen_k for s in ss), sum(s.total_w for s in ss), sum(s.skip_w for s in ss), + ss[1].rng, value) +end + +function Base.merge!(s1::SingleAlgWRSWRSKIPSampler_Mut, ss::SingleAlgWRSWRSKIPSampler_Mut...) + ns = [s1.total_w, [s.total_w for s in ss]...] + n_tot = sum(ns) + ps = cumsum(ns ./ n_tot) + r = rand(s1.rng) + i = findfirst(p -> r < p, ps) + if i > 1 + s1.rvalue = RefVal_Immut(ss[i-1].rvalue.value) + end + s1.seen_k += sum(s.seen_k for s in ss) + s1.skip_w += sum(s.skip_w for s in ss) + s1.total_w += sum(s.total_w for s in ss) + return s1 +end + diff --git a/src/precompile.jl b/src/precompile.jl index 1a700b66..b2a2e63d 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -3,25 +3,22 @@ using PrecompileTools @setup_workload let iter = Iterators.filter(x -> x != 10, 1:20); - wv(el) = 1.0 - update_s!(rs, iter) = for x in iter fit!(rs, x) end - update_s!(rs, iter, wv) = for x in iter fit!(rs, x, wv(x)) end @compile_workload let - rs = ReservoirSampler{Int}(AlgRSWRSKIP()) - update_s!(rs, iter) - rs = ReservoirSampler{Int}(AlgWRSWRSKIP()) - update_s!(rs, iter, wv) - rs = ReservoirSampler{Int}(2, AlgR()) - update_s!(rs, iter) - rs = ReservoirSampler{Int}(2, AlgL()) - update_s!(rs, iter) - rs = ReservoirSampler{Int}(2, AlgRSWRSKIP()) - update_s!(rs, iter) - rs = ReservoirSampler{Int}(2, AlgARes()) - update_s!(rs, iter, wv) - rs = ReservoirSampler{Int}(2, AlgAExpJ()) - update_s!(rs, iter, wv) - rs = ReservoirSampler{Int}(2, AlgWRSWRSKIP()) - update_s!(rs, iter, wv) + for alg in (AlgRSWRSKIP(),) + rs = ReservoirSampler{Int}(alg) + for x in iter fit!(rs, x) end + end + for alg in (AlgR(), AlgL(), AlgRSWRSKIP()) + rs = ReservoirSampler{Int}(2, alg) + for x in iter fit!(rs, x) end + end + for alg in (AlgWRSWRSKIP(),) + rs = ReservoirSampler{Int}(alg) + for x in iter fit!(rs, x, 1.0) end + end + for alg in (AlgARes(), AlgAExpJ(), AlgWRSWRSKIP()) + rs = ReservoirSampler{Int}(2, alg) + for x in iter fit!(rs, x, 1.0) end + end end end diff --git a/test/merge_tests.jl b/test/merge_tests.jl index 86ec0634..d2fc6371 100644 --- a/test/merge_tests.jl +++ b/test/merge_tests.jl @@ -4,7 +4,10 @@ iters = (1:2, 3:10) reps = 10^5 size = 2 - for (m1, m2) in [(AlgRSWRSKIP(), AlgRSWRSKIP())] + for (m1, m2) in [(AlgRSWRSKIP(), AlgRSWRSKIP()), + (AlgWRSWRSKIP(), AlgWRSWRSKIP()), + (AlgARes(), AlgARes()), + (AlgAExpJ(), AlgAExpJ())] res = zeros(Int, 10, 10) for _ in 1:reps s1 = ReservoirSampler{Int}(rng, size, m1) @@ -12,15 +15,15 @@ s_all = (s1, s2) for (s, it) in zip(s_all, iters) for x in it - fit!(s, x) + m1 == AlgRSWRSKIP() ? fit!(s, x) : fit!(s, x, 1.0) end end s_merged = merge(s1, s2) res[shuffle!(rng, value(s_merged))...] += 1 end - cases = m1 == AlgRSWRSKIP() ? 10^size : factorial(10)/factorial(10-size) + cases = (m1 == AlgRSWRSKIP() || m1 == AlgWRSWRSKIP()) ? 10^size : factorial(10)/factorial(10-size) ps_exact = [1/cases for _ in 1:cases] - count_est = vec(res) + count_est = [x for x in vec(res) if x != 0] chisq_test = ChisqTest(count_est, ps_exact) @test pvalue(chisq_test) > 0.05 end @@ -33,9 +36,11 @@ end end @test length(value(merge!(s1, s2))) == 2 - s1 = ReservoirSampler{Int}(rng, AlgRSWRSKIP()) - s2 = ReservoirSampler{Int}(rng, AlgRSWRSKIP()) - fit!(s1, 1) - fit!(s2, 2) - @test value(merge!(s1, s2)) in (1, 2) + for m in (AlgRSWRSKIP(), AlgWRSWRSKIP()) + s1 = ReservoirSampler{Int}(rng, m) + s2 = ReservoirSampler{Int}(rng, m) + m == AlgRSWRSKIP() ? fit!(s1, 1) : fit!(s1, 1, 1.0) + m == AlgRSWRSKIP() ? fit!(s2, 2) : fit!(s2, 2, 1.0) + @test value(merge!(s1, s2)) in (1, 2) + end end diff --git a/test/unweighted_sampling_multi_tests.jl b/test/unweighted_sampling_multi_tests.jl index 07b6da44..0ea47f1c 100644 --- a/test/unweighted_sampling_multi_tests.jl +++ b/test/unweighted_sampling_multi_tests.jl @@ -49,8 +49,8 @@ @test all(x -> a <= x <= b, value(rs)) @test nobs(rs) == 10 - rngs = (StableRNG(46), StableRNG(47)) - iters = (a:b, Iterators.filter(x -> x != b + 1, a:b+1), (a:floor(Int, b/2), (floor(Int, b/2)+1):b)) + rngs = (StableRNG(47), StableRNG(48)) + iters = (a:b, Iterators.filter(x -> x != b + 1, a:b+1)) sizes = (2, 3) for it in iters for size in sizes diff --git a/test/weighted_sampling_multi_tests.jl b/test/weighted_sampling_multi_tests.jl index a3e6162a..a95f1ff5 100644 --- a/test/weighted_sampling_multi_tests.jl +++ b/test/weighted_sampling_multi_tests.jl @@ -79,7 +79,7 @@ end weight3(el) = el <= 5 ? 1.0 : 2.0 wfuncs = (weight2, weight3) rngs = (StableRNG(41), StableRNG(42)) - iters = (a:b, Iterators.filter(x -> x != b+1, a:b+1), (a:floor(Int, b/2), (floor(Int, b/2)+1):b)) + iters = (a:b, Iterators.filter(x -> x != b+1, a:b+1)) sizes = (1, 2) for it in iters for size in sizes From 64cd409223031bd1aabdb4105db0db11e8f8c87a Mon Sep 17 00:00:00 2001 From: tortar Date: Wed, 13 Aug 2025 02:03:01 +0200 Subject: [PATCH 2/5] fix --- Project.toml | 2 +- src/SamplingUtils.jl | 2 +- src/StreamSampling.jl | 2 ++ src/UnweightedSamplingMulti.jl | 6 +++--- src/WeightedSamplingMulti.jl | 12 ++++++------ 5 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 418c66e5..4f891207 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] -julia = "1.10" +julia = "1.8" Accessors = "0.1" DataStructures = "0.18" Distributions = "0.25" diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 40b2622f..4b4c08fc 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -156,7 +156,7 @@ function fshuffle!(rng::AbstractRNG, vec::AbstractVector) end function ordmemory(n) - ord = Base.Memory{Int}(undef, n) + ord = Memory{Int}(undef, n) for i in eachindex(ord) ord[i] = i end ord end \ No newline at end of file diff --git a/src/StreamSampling.jl b/src/StreamSampling.jl index a7ac6b89..38760565 100644 --- a/src/StreamSampling.jl +++ b/src/StreamSampling.jl @@ -7,6 +7,8 @@ module StreamSampling read(path, String) end StreamSampling +isdefined(@__MODULE__, :Memory) || const Memory = Vector # Compat for Julia < 1.11 + using Accessors using DataStructures using Distributions diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index adc8d1bc..29979477 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -6,7 +6,7 @@ const value::Vector{T} const ord::O end -const MultiOrdAlgRSampler = MultiAlgRSampler{<:Base.Memory} +const MultiOrdAlgRSampler = MultiAlgRSampler{<:Memory} @hybrid struct MultiAlgLSampler{O,T,R,F} <: AbstractReservoirSampler const n::Int @@ -17,7 +17,7 @@ const MultiOrdAlgRSampler = MultiAlgRSampler{<:Base.Memory} const value::Vector{T} const ord::O end -const MultiOrdAlgLSampler = MultiAlgLSampler{<:Base.Memory} +const MultiOrdAlgLSampler = MultiAlgLSampler{<:Memory} @hybrid struct MultiAlgRSWRSKIPSampler{O,T,R} <: AbstractReservoirSampler const n::Int @@ -27,7 +27,7 @@ const MultiOrdAlgLSampler = MultiAlgLSampler{<:Base.Memory} const value::Vector{T} const ord::O end -const MultiOrdAlgRSWRSKIPSampler = MultiAlgRSWRSKIPSampler{<:Base.Memory} +const MultiOrdAlgRSWRSKIPSampler = MultiAlgRSWRSKIPSampler{<:Memory} function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgL, ::MutSampler, ::Ord) where {T,F} return MultiAlgLSampler_Mut(n, zero(F), 0, 0, rng, Vector{T}(undef, n), ordmemory(n)) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index e979d8fa..a1360302 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -25,11 +25,11 @@ const MultiOrdAlgAExpJSampler = Union{MultiAlgAExpJSampler_Immut{<:OrdWeighted}, skip_w::F seen_k::Int const rng::R - const weights::Base.Memory{F} + const weights::Memory{F} const value::Vector{T} const ord::O end -const MultiOrdAlgWRSWRSKIPSampler = Union{MultiAlgWRSWRSKIPSampler_Immut{<:Base.Memory}, MultiAlgWRSWRSKIPSampler_Mut{<:Base.Memory}} +const MultiOrdAlgWRSWRSKIPSampler = Union{MultiAlgWRSWRSKIPSampler_Immut{<:Memory}, MultiAlgWRSWRSKIPSampler_Mut{<:Memory}} function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgAExpJ, ::MutSampler, ::Ord) where {T,F} value = BinaryHeap(Base.By(last, DataStructures.FasterForward()), Tuple{T, Int, F}[]) @@ -73,17 +73,17 @@ function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgARes, ::ImmutS end function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSampler, ::Ord) where {T,F} ord = ordmemory(n) - return MultiAlgWRSWRSKIPSampler_Mut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), ord) + return MultiAlgWRSWRSKIPSampler_Mut(n, zero(F), zero(F), 0, rng, Memory{F}(undef, n), Vector{T}(undef, n), ord) end function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSampler, ::Unord) where {T,F} - return MultiAlgWRSWRSKIPSampler_Mut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), nothing) + return MultiAlgWRSWRSKIPSampler_Mut(n, zero(F), zero(F), 0, rng, Memory{F}(undef, n), Vector{T}(undef, n), nothing) end function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSampler, ::Ord) where {T,F} ord = ordmemory(n) - return MultiAlgWRSWRSKIPSampler_Immut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), ord) + return MultiAlgWRSWRSKIPSampler_Immut(n, zero(F), zero(F), 0, rng, Memory{F}(undef, n), Vector{T}(undef, n), ord) end function ReservoirSampler{T,F}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSampler, ::Unord) where {T,F} - return MultiAlgWRSWRSKIPSampler_Immut(n, zero(F), zero(F), 0, rng, Base.Memory{F}(undef, n), Vector{T}(undef, n), nothing) + return MultiAlgWRSWRSKIPSampler_Immut(n, zero(F), zero(F), 0, rng, Memory{F}(undef, n), Vector{T}(undef, n), nothing) end @inline function OnlineStatsBase._fit!(s::Union{MultiAlgAResSampler, MultiOrdAlgAResSampler}, el, w) From a642408204f472f1fcfa6a3f94207f1ab94ae1d0 Mon Sep 17 00:00:00 2001 From: tortar Date: Wed, 13 Aug 2025 02:07:21 +0200 Subject: [PATCH 3/5] fix2 --- docs/Project.toml | 1 + docs/make.jl | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 82df54a4..bb9c3d31 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,7 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" StreamSampling = "ff63dad9-3335-55d8-95ec-f8139d39e468" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/docs/make.jl b/docs/make.jl index fac7a66b..4fea4301 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,13 +1,16 @@ -using Documenter -using BenchmarkTools +@info "Loading packages..." using StreamSampling -println("Documentation Build") +using BenchmarkTools +using Documenter +using Literate + +@info "Building Documentation" makedocs( - modules = [StreamSampling], - sitename = "StreamSampling.jl", - pages = [ + sitename = "BeforeIT.jl", + format = Documenter.HTML(prettyurls = false, size_threshold = 409600), + pages = [ "StreamSampling.jl" => "index.md", "Basics" => "basics.md", "An Illustrative Example" => "example.md", @@ -15,7 +18,6 @@ makedocs( "Performance Tips" => "perf_tips.md", "Benchmarks" => "benchmark.md" ], - warnonly = [:doctest, :missing_docs, :cross_references], ) @info "Deploying Documentation" @@ -28,4 +30,4 @@ if CI devbranch = "main", ) end -println("Finished boulding and deploying docs.") +println("Finished building and deploying docs.") \ No newline at end of file From 9e63a0a76118de5bea9a3c99326760224a273043 Mon Sep 17 00:00:00 2001 From: tortar Date: Wed, 13 Aug 2025 02:07:45 +0200 Subject: [PATCH 4/5] fix3 --- docs/make.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/make.jl b/docs/make.jl index 4fea4301..0bae01e0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ using Literate @info "Building Documentation" makedocs( - sitename = "BeforeIT.jl", + sitename = "StreamSampling.jl", format = Documenter.HTML(prettyurls = false, size_threshold = 409600), pages = [ "StreamSampling.jl" => "index.md", From f5c9a234ae947d496affa8c1020e58bf182be11c Mon Sep 17 00:00:00 2001 From: tortar Date: Wed, 13 Aug 2025 02:12:41 +0200 Subject: [PATCH 5/5] fix4 --- docs/make.jl | 1 + docs/src/perf_tips.md | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 0bae01e0..b53cf71f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -18,6 +18,7 @@ makedocs( "Performance Tips" => "perf_tips.md", "Benchmarks" => "benchmark.md" ], + warnonly = [:doctest, :missing_docs, :cross_references], ) @info "Deploying Documentation" diff --git a/docs/src/perf_tips.md b/docs/src/perf_tips.md index dc2d33dd..8a97ed70 100644 --- a/docs/src/perf_tips.md +++ b/docs/src/perf_tips.md @@ -4,7 +4,10 @@ By default, a `ReservoirSampler` is mutable, however, it is also possible to use an immutable version which supports all the basic operations. It uses `Accessors.jl` under the -hood to update the reservoir: +hood to update the reservoir. + +Let's compare the performance of mutable and immutable samplers +with a simple benchmark ```julia using BenchmarkTools @@ -16,9 +19,16 @@ function fit_iter!(rs, iter) return rs end -iter = 1:10^7 +iter = 1:10^7; +``` + +Running with both version we get +```julia @btime fit_iter!(rs, $iter) setup=(rs = ReservoirSampler{Int}(10, AlgRSWRSKIP(); mutable = true)) +``` + +```julia @btime fit_iter!(rs, $iter) setup=(rs = ReservoirSampler{Int}(10, AlgRSWRSKIP(); mutable = false)) ```