diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 4b4c08f..61b0b05 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -19,11 +19,7 @@ struct SeqIterWRSampler{R} n::Int end -@inline function Base.iterate(s::SeqIterWRSampler) - curmax = -log(Float64(s.N)) + randexp(s.rng)/s.n - return (s.N - ceil(Int, exp(-curmax)) + 1, (s.n-1, curmax)) -end -@inline function Base.iterate(s::SeqIterWRSampler, state) +@inline function Base.iterate(s::SeqIterWRSampler, state = (s.n, -log(Float64(s.N)))) state[1] == 0 && return nothing curmax = state[2] + randexp(s.rng)/state[1] return (s.N - ceil(Int, exp(-curmax)) + 1, (state[1]-1, curmax)) @@ -159,4 +155,4 @@ function ordmemory(n) ord = Memory{Int}(undef, n) for i in eachindex(ord) ord[i] = i end ord -end \ No newline at end of file +end diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 691d52a..aa3c173 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -278,4 +278,3 @@ function ordvalue(s::MultiOrdAlgRSWRSKIPSampler) end end - diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index c1bbf31..a711998 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -128,10 +128,18 @@ end s = @inline update_state!(s, w) if s.seen_k <= n @inbounds s.value[s.seen_k] = el - @inbounds s.weights[s.seen_k] = w + @inbounds s.weights[s.seen_k] = s.state if s.seen_k == n - s.value .= sample(s.rng, s.value, Weights(s.weights, s.state), n; - ordered = is_ordered(s)) + j, curx = 1, 0.0 + newvalues = similar(s.value) + @inbounds for i in n:-1:1 + curx += (1-exp(-randexp(s.rng)/i))*(1-curx) + while s.weights[j] < curx * s.state + j += 1 + end + newvalues[i] = s.value[j] + end + s.value .= newvalues s = @inline recompute_skip!(s, n) end return s @@ -302,8 +310,12 @@ function OnlineStatsBase.value(s::Union{MultiAlgAResSampler, MultiAlgAExpJSample end end function OnlineStatsBase.value(s::MultiAlgWRSWRSKIPSampler) + nobs(s) == 0 && return s.value[1:0] if nobs(s) < length(s.value) - return nobs(s) == 0 ? s.value[1:0] : sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value)) + weightsnew = Vector{Float64}(undef, nobs(s)) + weightsnew[1] = s.weights[1] + for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end + return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value)) else return s.value end @@ -318,8 +330,12 @@ function ordvalue(s::Union{MultiOrdAlgAResSampler, MultiOrdAlgAExpJSampler}) return first.(vals[sortperm(map(x -> x[2], vals))]) end function ordvalue(s::MultiOrdAlgWRSWRSKIPSampler) + nobs(s) == 0 && return s.value[1:0] if nobs(s) < length(s.value) - return sample(s.rng, s.value[1:nobs(s)], weights(s.weights[1:nobs(s)]), length(s.value); ordered=true) + weightsnew = Vector{Float64}(undef, nobs(s)) + weightsnew[1] = s.weights[1] + for i in 2:nobs(s) weightsnew[i] = s.weights[i] - s.weights[i-1] end + return sample(s.rng, s.value[1:nobs(s)], weights(weightsnew), length(s.value); ordered=true) else return s.value[sortperm(s.ord)] end