Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/SamplingUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ struct SeqIterWRSampler{R}
n::Int
end

@inline function Base.iterate(s::SeqIterWRSampler)
curmax = -log(Float64(s.N)) + randexp(s.rng)/s.n
return (s.N - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax))
end
@inline function Base.iterate(s::SeqIterWRSampler, state)
@inline function Base.iterate(s::SeqIterWRSampler, state = (s.n, -log(Float64(s.N))))
state[1] == 0 && return nothing
curmax = state[2] + randexp(s.rng)/state[1]
return (s.N - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax))
Expand Down Expand Up @@ -159,4 +155,4 @@ function ordmemory(n)
ord = Memory{Int}(undef, n)
for i in eachindex(ord) ord[i] = i end
ord
end
end
1 change: 0 additions & 1 deletion src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,3 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler)
end
end


26 changes: 21 additions & 5 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,18 @@ end
s = @inline update_state!(s, w)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
@inbounds s.weights[s.seen_k] = w
@inbounds s.weights[s.seen_k] = s.state
if s.seen_k == n
s.value .= sample(s.rng, s.value, Weights(s.weights, s.state), n;
ordered = is_ordered(s))
j, curx = 1, 0.0
newvalues = similar(s.value)
@inbounds for i in n:-1:1
curx += (1-exp(-randexp(s.rng)/i))*(1-curx)
while s.weights[j] < curx * s.state
j += 1
end
newvalues[i] = s.value[j]
end
s.value .= newvalues
s = @inline recompute_skip!(s, n)
end
return s
Expand Down Expand Up @@ -302,8 +310,12 @@ function OnlineStatsBase.value(s::Union{MultiAlgAResSampler, MultiAlgAExpJSample
end
end
function OnlineStatsBase.value(s::MultiAlgWRSWRSKIPSampler)
nobs(s) == 0 && return s.value[1:0]
if nobs(s) < length(s.value)
return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value))
weightsnew = Vector{Float64}(undef, nobs(s))
weightsnew[1] = s.weights[1]
for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end
return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value))
else
return s.value
end
Expand All @@ -318,8 +330,12 @@ function ordvalue(s::Union{MultiOrdAlgAResSampler, MultiOrdAlgAExpJSampler})
return first.(vals[sortperm(map(x -> x[2], vals))])
end
function ordvalue(s::MultiOrdAlgWRSWRSKIPSampler)
nobs(s) == 0 && return s.value[1:0]
if nobs(s) < length(s.value)
return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value); ordered=true)
weightsnew = Vector{Float64}(undef, nobs(s))
weightsnew[1] = s.weights[1]
for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end
return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value); ordered=true)
else
return s.value[sortperm(s.ord)]
end
Expand Down
Loading