Skip to content

Commit

Permalink
Add ordered version for weighted sampling with replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 17, 2024
1 parent affed9c commit c41ae92
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 53 deletions.
2 changes: 1 addition & 1 deletion src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ abstract type AbstractOrdWrReservoirSampleMulti <: AbstractWrReservoirSampleMult
abstract type AbstractWeightedReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedWorReservoirSampleMulti <: AbstractReservoirSample end
abstract type AbstractWeightedWrReservoirSampleMulti <: AbstractReservoirSample end

abstract type AbstractWeightedOrdWrReservoirSampleMulti <: AbstractReservoirSample end

abstract type ReservoirAlgorithm end

Expand Down
52 changes: 24 additions & 28 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@

mutable struct ResSampleMultiAlgR{T,R} <: AbstractWorReservoirSampleMulti
state::Int
mutable struct SampleMultiAlgR{T,R} <: AbstractWorReservoirSampleMulti
seen_k::Int
rng::R
value::Vector{T}
end

mutable struct OrdResSampleMultiAlgR{T,R} <: AbstractOrdWorReservoirSampleMulti
state::Int
mutable struct SampleMultiOrdAlgR{T,R} <: AbstractOrdWorReservoirSampleMulti
seen_k::Int
rng::R
value::Vector{T}
ord::Vector{Int}
end

mutable struct ResSampleMultiAlgL{T,R} <: AbstractWorReservoirSampleMulti
mutable struct SampleMultiAlgL{T,R} <: AbstractWorReservoirSampleMulti
state::Float64
skip_k::Int
seen_k::Int
rng::R
value::Vector{T}
end

mutable struct OrdResSampleMultiAlgL{T,R} <: AbstractOrdWorReservoirSampleMulti
mutable struct SampleMultiOrdAlgL{T,R} <: AbstractOrdWorReservoirSampleMulti
state::Float64
skip_k::Int
seen_k::Int
Expand All @@ -29,14 +29,14 @@ mutable struct OrdResSampleMultiAlgL{T,R} <: AbstractOrdWorReservoirSampleMulti
ord::Vector{Int}
end

mutable struct WrResSampleMulti{T,R} <: AbstractWrReservoirSampleMulti
mutable struct SampleMultiAlgRSWRSKIP{T,R} <: AbstractWrReservoirSampleMulti
skip_k::Int
seen_k::Int
rng::R
value::Vector{T}
end

mutable struct OrdWrResSampleMulti{T,R} <: AbstractOrdWrReservoirSampleMulti
mutable struct SampleMultiOrdAlgRSWRSKIP{T,R} <: AbstractOrdWrReservoirSampleMulti
skip_k::Int
seen_k::Int
rng::R
Expand All @@ -50,43 +50,43 @@ end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgL=algL; ordered = false)
value = Vector{T}(undef, n)
if ordered
return OrdResSampleMultiAlgL(0.0, 0, 0, rng, value, collect(1:n))
return SampleMultiOrdAlgL(0.0, 0, 0, rng, value, collect(1:n))
else
return ResSampleMultiAlgL(0.0, 0, 0, rng, value)
return SampleMultiAlgL(0.0, 0, 0, rng, value)
end
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgR; ordered = false)
value = Vector{T}(undef, n)
if ordered
return OrdResSampleMultiAlgR(0, rng, value, collect(1:n))
return SampleMultiOrdAlgR(0, rng, value, collect(1:n))
else
return ResSampleMultiAlgR(0, rng, value)
return SampleMultiAlgR(0, rng, value)
end
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgRSWRSKIP; ordered = false)
value = Vector{T}(undef, n)
if ordered
return OrdWrResSampleMulti(0, 0, rng, value, collect(1:n))
return SampleMultiOrdAlgRSWRSKIP(0, 0, rng, value, collect(1:n))
else
return WrResSampleMulti(0, 0, rng, value)
return SampleMultiAlgRSWRSKIP(0, 0, rng, value)
end
end

function update!(s::Union{ResSampleMultiAlgR, OrdResSampleMultiAlgR}, el)
function update!(s::Union{SampleMultiAlgR, SampleMultiOrdAlgR}, el)
n = length(s.value)
s.state += 1
if s.state <= n
s.value[s.state] = el
s.seen_k += 1
if s.seen_k <= n
s.value[s.seen_k] = el
else
j = rand(s.rng, 1:s.state)
j = rand(s.rng, 1:s.seen_k)
if j <= n
s.value[j] = el
update_order!(s, j)
end
end
return s
end
function update!(s::Union{ResSampleMultiAlgL, OrdResSampleMultiAlgL}, el)
function update!(s::Union{SampleMultiAlgL, SampleMultiOrdAlgL}, el)
n = length(s.value)
s.seen_k += 1
s.skip_k -= 1
Expand Down Expand Up @@ -162,13 +162,13 @@ function update_order!(s::AbstractOrdWorReservoirSampleMulti, j)
s.ord[j] = n_seen(s)
end

update_order_single!(s::WrResSampleMulti, r) = nothing
function update_order_single!(s::OrdWrResSampleMulti, r)
update_order_single!(s::SampleMultiAlgRSWRSKIP, r) = nothing
function update_order_single!(s::SampleMultiOrdAlgRSWRSKIP, r)
s.ord[r] = n_seen(s)
end

update_order_multi!(s::WrResSampleMulti, r, j) = nothing
function update_order_multi!(s::OrdWrResSampleMulti, r, j)
update_order_multi!(s::SampleMultiAlgRSWRSKIP, r, j) = nothing
function update_order_multi!(s::SampleMultiOrdAlgRSWRSKIP, r, j)
s.ord[r], s.ord[j] = s.ord[j], n_seen(s)
end

Expand Down Expand Up @@ -205,10 +205,6 @@ function ordered_value(s::AbstractOrdWrReservoirSampleMulti)
end
end

n_seen(s::Union{ResSampleMultiAlgR, OrdResSampleMultiAlgR}) = s.state
n_seen(s::Union{ResSampleMultiAlgL, OrdResSampleMultiAlgL}) = s.seen_k
n_seen(s::Union{OrdWrResSampleMulti, WrResSampleMulti}) = s.seen_k

function itsample(iter, n::Int, method::ReservoirAlgorithm = algL; ordered = false)
return itsample(Random.default_rng(), iter, n, method; ordered)
end
Expand Down
20 changes: 10 additions & 10 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@

mutable struct ResSampleSingleAlgL{T,R} <: AbstractReservoirSample
mutable struct SampleSingleAlgL{T,R} <: AbstractReservoirSample
state::Float64
skip_k::Int
rng::R
value::T
ResSampleSingleAlgL{T,R}(state, skip_k, rng) where {T,R} = new{T,R}(state, skip_k, rng)
SampleSingleAlgL{T,R}(state, skip_k, rng) where {T,R} = new{T,R}(state, skip_k, rng)
end

mutable struct ResSampleSingleAlgR{T,R} <: AbstractReservoirSample
mutable struct SampleSingleAlgR{T,R} <: AbstractReservoirSample
state::Int
rng::R
value::T
ResSampleSingleAlgR{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
SampleSingleAlgR{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
end

function value(s::ResSampleSingleAlgL)
function value(s::SampleSingleAlgL)
s.state === 1.0 && return nothing
return s.value
end
function value(s::ResSampleSingleAlgR)
function value(s::SampleSingleAlgR)
s.state === 0 && return nothing
return s.value
end
Expand All @@ -27,21 +27,21 @@ function ReservoirSample(T, method::ReservoirAlgorithm = algL)
return ReservoirSample(Random.default_rng(), T, method)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgL = algL)
return ResSampleSingleAlgL{T, typeof(rng)}(1.0, 0, rng)
return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, rng)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgR)
return ResSampleSingleAlgR{T, typeof(rng)}(0, rng)
return SampleSingleAlgR{T, typeof(rng)}(0, rng)
end

function update!(s::ResSampleSingleAlgR, el)
function update!(s::SampleSingleAlgR, el)
s.state += 1
if rand(s.rng) <= 1/s.state
s.value = el
end
return s
end

function update!(s::ResSampleSingleAlgL, el)
function update!(s::SampleSingleAlgL, el)
if s.skip_k > 0
s.skip_k -= 1
else
Expand Down
46 changes: 37 additions & 9 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ mutable struct SampleMultiAlgWRSWRSKIP{T,R} <: AbstractWeightedWrReservoirSample
value::Vector{T}
end

mutable struct SampleMultiOrdAlgWRSWRSKIP{T,R} <: AbstractWeightedWrReservoirSampleMulti
state::Float64
skip_w::Float64
seen_k::Int
rng::R
weights::Vector{Float64}
value::Vector{T}
ord::Vector{Int}
end

function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgAExpJ; ordered = false)
value = BinaryHeap(Base.By(last), Pair{T, Float64}[])
sizehint!(value, n)
Expand All @@ -46,7 +56,8 @@ function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgWRSWRSKIP;
value = Vector{T}(undef, n)
weights = Vector{Float64}(undef, n)
if ordered
error("Not implemented yet")
ord = collect(1:n)
return SampleMultiOrdAlgWRSWRSKIP(0.0, 0.0, 0, rng, weights, value, ord)
else
return SampleMultiAlgWRSWRSKIP(0.0, 0.0, 0, rng, weights, value)
end
Expand Down Expand Up @@ -83,15 +94,15 @@ function update!(s::SampleMultiAlgAExpJ, el, w)
end
return s
end
function update!(s::SampleMultiAlgWRSWRSKIP, el, w)
function update!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el, w)
n = length(s.value)
s.seen_k += 1
s.state += w
if s.seen_k <= n
s.value[s.seen_k] = el
s.weights[s.seen_k] = w
if s.seen_k == n
s.value = sample(s.rng, s.value, weights(s.weights), n)
s.value = sample(s.rng, s.value, weights(s.weights), n; ordered = is_ordered(s))
@inline recompute_skip!(s, n)
empty!(s.weights)
end
Expand All @@ -104,11 +115,13 @@ function update!(s::SampleMultiAlgWRSWRSKIP, el, w)
if k == 1
r = rand(s.rng, 1:n)
s.value[r] = el
update_order_single!(s, r)
else
for j in 1:k
r = rand(s.rng, j:n)
s.value[r] = el
s.value[r], s.value[j] = s.value[j], s.value[r]
update_order_multi!(s, r, j)
end
end
end
Expand All @@ -126,11 +139,24 @@ function recompute_skip!(s::SampleMultiAlgAExpJ)
s.min_priority = last(first(s.value))
s.state = -randexp(s.rng)/log(s.min_priority)
end
function recompute_skip!(s::SampleMultiAlgWRSWRSKIP, n)
function recompute_skip!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, n)
q = rand(s.rng)^(1/n)
s.skip_w = s.state/q
end

update_order_single!(s::SampleMultiAlgWRSWRSKIP, r) = nothing
function update_order_single!(s::SampleMultiOrdAlgWRSWRSKIP, r)
s.ord[r] = n_seen(s)
end

update_order_multi!(s::SampleMultiAlgWRSWRSKIP, r, j) = nothing
function update_order_multi!(s::SampleMultiOrdAlgWRSWRSKIP, r, j)
s.ord[r], s.ord[j] = s.ord[j], n_seen(s)
end

is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true
is_ordered(s::SampleMultiAlgWRSWRSKIP) = false

function value(s::AbstractWeightedWorReservoirSampleMulti)
if n_seen(s) < s.n
return first.(s.value.valtree)[1:n_seen(s)]
Expand All @@ -146,13 +172,15 @@ function value(s::AbstractWeightedWrReservoirSampleMulti)
end
end

function ordered_value(s::AbstractWeightedReservoirSampleMulti)
error("Not implemented yet")
function ordered_value(s::SampleMultiOrdAlgWRSWRSKIP)
if n_seen(s) < length(s.value)
return sample(s.rng, s.value[1:n_seen(s)], weights(s.weights[1:n_seen(s)]), length(s.value); ordered=true)
else
return s.value[sortperm(s.ord)]
end
end

n_seen(s::SampleMultiAlgARes) = s.seen_k
n_seen(s::SampleMultiAlgAExpJ) = s.seen_k
n_seen(s::SampleMultiAlgWRSWRSKIP) = s.seen_k
n_seen(s::AbstractReservoirSample) = s.seen_k

function itsample(iter, wv::Function, n::Int,
method::ReservoirAlgorithm=algAExpJ; ordered = false)
Expand Down
10 changes: 5 additions & 5 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@

mutable struct WeightedResSampleSingle{T,R} <: AbstractReservoirSample
mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractReservoirSample
state::Float64
skip_w::Float64
rng::R
value::T
WeightedResSampleSingle{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
SampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
end

function ReservoirSample(rng::R, T, method::AlgAExpJ) where {R<:AbstractRNG}
return WeightedResSampleSingle{T,R}(0.0, 0.0, rng)
return SampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
end

function value(s::WeightedResSampleSingle)
function value(s::SampleSingleAlgAExpJ)
s.state === 0.0 && return nothing
return s.value
end

function update!(s::WeightedResSampleSingle, el, weight)
function update!(s::SampleSingleAlgAExpJ, el, weight)
s.state += weight
if s.skip_w <= s.state
s.value = el
Expand Down

0 comments on commit c41ae92

Please sign in to comment.