From 40d40ce22b7f38048976eb809f7edb3b868be1ca Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 02:09:02 +0200 Subject: [PATCH 1/8] Improve reservoir with replacement algorithms --- src/WeightedSamplingMulti.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index c1bbf31..817267d 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -130,8 +130,17 @@ end @inbounds s.value[s.seen_k] = el @inbounds s.weights[s.seen_k] = w if s.seen_k == n - s.value .= sample(s.rng, s.value, Weights(s.weights, s.state), n; - ordered = is_ordered(s)) + randexps = randexp(s.rng, n) + ratio = s.state/(sum(randexps) + randexp(s.rng)) + j, csweights, limit = 1, first(s.weights), 0.0 + for i in eachindex(s.value, s.weights, randexps) + limit += randexps[i] * ratio + while csweights < limit + j += 1 + csweights += s.weights[j] + end + s.value[i] = j + end s = @inline recompute_skip!(s, n) end return s From df2ec3830288463bc3bb6fcc62f707ae143ccb80 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 02:14:56 +0200 Subject: [PATCH 2/8] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 691d52a..3fa8ccb 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -105,7 +105,9 @@ end @inbounds s.value[s.seen_k] = el if s.seen_k === n s = @inline recompute_skip!(s, n) - s.value .= sample(s.rng, s.value, n, ordered=is_ordered(s)) + for (i, j) in enumerate(SeqIterWRSampler(s.rng, n, n)) + @inbounds s.value[i] = j + end end return s end @@ -279,3 +281,4 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) end + From c2f2f611cc8c2b9b1bd48060cc86077ff1ab296c Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 02:15:31 +0200 Subject: [PATCH 3/8] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 3fa8ccb..9f0473d 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -279,6 +279,3 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) return s.value[sortperm(s.ord)] end end - - - From 83e4dd256b488dfba37fb54656aa9f5f91695161 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 02:22:24 +0200 Subject: [PATCH 4/8] Update SamplingUtils.jl --- src/SamplingUtils.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 4b4c08f..61b0b05 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -19,11 +19,7 @@ struct SeqIterWRSampler{R} n::Int end -@inline function Base.iterate(s::SeqIterWRSampler) - curmax = -log(Float64(s.N)) + randexp(s.rng)/s.n - return (s.N - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax)) -end -@inline function Base.iterate(s::SeqIterWRSampler, state) +@inline function Base.iterate(s::SeqIterWRSampler, state = (s.n, -log(Float64(s.N)))) state[1] == 0 && return nothing curmax = state[2] + randexp(s.rng)/state[1] return (s.N - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax)) @@ -159,4 +155,4 @@ function ordmemory(n) ord = Memory{Int}(undef, n) for i in eachindex(ord) ord[i] = i end ord -end \ No newline at end of file +end From 7e7c214f61fb1f026c1b7359750877dd42e29fa5 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 02:25:27 +0200 Subject: [PATCH 5/8] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 9f0473d..259b442 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -250,7 +250,7 @@ function OnlineStatsBase.value(s::Union{MultiAlgRSampler, MultiAlgLSampler}) end end function OnlineStatsBase.value(s::MultiAlgRSWRSKIPSampler) - if nobs(s) < length(s.value) + if nobs(s) <= length(s.value) if nobs(s) == 0 return s.value[1:0] else @@ -279,3 +279,4 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) return s.value[sortperm(s.ord)] end end + From 663ac0fb987271cc664a9a95f337cf355ce89c77 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 02:26:27 +0200 Subject: [PATCH 6/8] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 259b442..a0e47b2 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -250,7 +250,7 @@ function OnlineStatsBase.value(s::Union{MultiAlgRSampler, MultiAlgLSampler}) end end function OnlineStatsBase.value(s::MultiAlgRSWRSKIPSampler) - if nobs(s) <= length(s.value) + if nobs(s) < length(s.value) if nobs(s) == 0 return s.value[1:0] else @@ -280,3 +280,4 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) end end + From 9096e7cdacabe6f1ec907ae1a991e5fddd34db0b Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 16:16:12 +0200 Subject: [PATCH 7/8] fix --- src/WeightedSamplingMulti.jl | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index 817267d..a711998 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -128,19 +128,18 @@ end s = @inline update_state!(s, w) if s.seen_k <= n @inbounds s.value[s.seen_k] = el - @inbounds s.weights[s.seen_k] = w + @inbounds s.weights[s.seen_k] = s.state if s.seen_k == n - randexps = randexp(s.rng, n) - ratio = s.state/(sum(randexps) + randexp(s.rng)) - j, csweights, limit = 1, first(s.weights), 0.0 - for i in eachindex(s.value, s.weights, randexps) - limit += randexps[i] * ratio - while csweights < limit + j, curx = 1, 0.0 + newvalues = similar(s.value) + @inbounds for i in n:-1:1 + curx += (1-exp(-randexp(s.rng)/i))*(1-curx) + while s.weights[j] < curx * s.state j += 1 - csweights += s.weights[j] end - s.value[i] = j + newvalues[i] = s.value[j] end + s.value .= newvalues s = @inline recompute_skip!(s, n) end return s @@ -311,8 +310,12 @@ function OnlineStatsBase.value(s::Union{MultiAlgAResSampler, MultiAlgAExpJSample end end function OnlineStatsBase.value(s::MultiAlgWRSWRSKIPSampler) + nobs(s) == 0 && return s.value[1:0] if nobs(s) < length(s.value) - return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value)) + weightsnew = Vector{Float64}(undef, nobs(s)) + weightsnew[1] = s.weights[1] + for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end + return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value)) else return s.value end @@ -327,8 +330,12 @@ function ordvalue(s::Union{MultiOrdAlgAResSampler, MultiOrdAlgAExpJSampler}) return first.(vals[sortperm(map(x -> x[2], vals))]) end function ordvalue(s::MultiOrdAlgWRSWRSKIPSampler) + nobs(s) == 0 && return s.value[1:0] if nobs(s) < length(s.value) - return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value); ordered=true) + weightsnew = Vector{Float64}(undef, nobs(s)) + weightsnew[1] = s.weights[1] + for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end + return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value); ordered=true) else return s.value[sortperm(s.ord)] end From 50c2ee3ae2f6df53438a67c08075de3836b1c6a4 Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sun, 17 Aug 2025 20:00:40 +0200 Subject: [PATCH 8/8] Update UnweightedSamplingMulti.jl --- src/UnweightedSamplingMulti.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index a0e47b2..aa3c173 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -105,9 +105,7 @@ end @inbounds s.value[s.seen_k] = el if s.seen_k === n s = @inline recompute_skip!(s, n) - for (i, j) in enumerate(SeqIterWRSampler(s.rng, n, n)) - @inbounds s.value[i] = j - end + s.value .= sample(s.rng, s.value, n, ordered=is_ordered(s)) end return s end @@ -280,4 +278,3 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) end end -