Skip to content

Commit

Permalink
Merge pull request #176 from devmotion/weighted_sampling_no_replacement
Browse files Browse the repository at this point in the history
Weighted sampling without replacement
  • Loading branch information
andreasnoack committed Dec 15, 2016
2 parents 8181351 + c682416 commit af2c127
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/StatsBase.jl
Expand Up @@ -8,6 +8,7 @@ module StatsBase
import Base: rand, rand!
import Base.LinAlg: BlasReal, BlasFloat
import Base.Cartesian: @nloops, @nref, @nextract
import Base.Collections: heapify!, heappop!, percolate_down!

## tackle compatibility issues

Expand Down
125 changes: 124 additions & 1 deletion src/sampling.jl
Expand Up @@ -470,6 +470,127 @@ function naive_wsample_norep!(a::AbstractArray, wv::WeightVec, x::AbstractArray)
return x
end

# Weighted sampling without replacement
#
# Algorithm A from:
# Efraimidis PS, Spirakis PG (2006). "Weighted random sampling with a reservoir."
# Information Processing Letters, 97 (5), 181-185. ISSN 0020-0190.
# doi:10.1016/j.ipl.2005.11.003.
# URL http://www.sciencedirect.com/science/article/pii/S002001900500298X
#
# Instead of keys u^(1/w) where u = random(0,1) keys w/v where v = randexp(1) are used.
function efraimidis_a_wsample_norep!(a::AbstractArray, wv::WeightVec, x::AbstractArray)
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)

# calculate keys for all items
keys = randexp(n)
for i in 1:n
@inbounds keys[i] = wv.values[i]/keys[i]
end

# return items with largest keys
index = sortperm(keys; alg = PartialQuickSort(k), rev = true)
for i in 1:k
@inbounds x[i] = a[index[i]]
end
return x
end

# Weighted sampling without replacement
#
# Algorithm A-Res from:
# Efraimidis PS, Spirakis PG (2006). "Weighted random sampling with a reservoir."
# Information Processing Letters, 97 (5), 181-185. ISSN 0020-0190.
# doi:10.1016/j.ipl.2005.11.003.
# URL http://www.sciencedirect.com/science/article/pii/S002001900500298X
#
# Instead of keys u^(1/w) where u = random(0,1) keys w/v where v = randexp(1) are used.
function efraimidis_ares_wsample_norep!(a::AbstractArray, wv::WeightVec, x::AbstractArray)
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
k > 0 || return x

# initialize priority queue
pq = Array{Pair{Float64,Int}}(k)
@inbounds for i in 1:k
pq[i] = (wv.values[i]/randexp() => i)
end
heapify!(pq)

# set threshold
@inbounds threshold = pq[1].first

@inbounds for i in k+1:n
key = wv.values[i]/randexp()

# if key is larger than the threshold
if key > threshold
# update priority queue
pq[1] = (key => i)
percolate_down!(pq, 1)

# update threshold
threshold = pq[1].first
end
end

# fill output array with items in descending order
@inbounds for i in k:-1:1
x[i] = a[heappop!(pq).second]
end
return x
end

# Weighted sampling without replacement
#
# Algorithm A-ExpJ from:
# Efraimidis PS, Spirakis PG (2006). "Weighted random sampling with a reservoir."
# Information Processing Letters, 97 (5), 181-185. ISSN 0020-0190.
# doi:10.1016/j.ipl.2005.11.003.
# URL http://www.sciencedirect.com/science/article/pii/S002001900500298X
#
# Instead of keys u^(1/w) where u = random(0,1) keys w/v where v = randexp(1) are used.
function efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::WeightVec, x::AbstractArray)
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
k > 0 || return x

# initialize priority queue
pq = Array{Pair{Float64,Int}}(k)
@inbounds for i in 1:k
pq[i] = (wv.values[i]/randexp() => i)
end
heapify!(pq)

# set threshold
@inbounds threshold = pq[1].first
X = threshold*randexp()

@inbounds for i in k+1:n
w = wv.values[i]
X -= w
X <= 0 || continue

# update priority queue
t = exp(-w/threshold)
pq[1] = (-w/log(t+rand()*(1-t)) => i)
percolate_down!(pq, 1)

# update threshold
threshold = pq[1].first
X = threshold * randexp()
end

# fill output array with items in descending order
@inbounds for i in k:-1:1
x[i] = a[heappop!(pq).second]
end
return x
end

function sample!(a::AbstractArray, wv::WeightVec, x::AbstractArray;
replace::Bool=true, ordered::Bool=false)
Expand All @@ -492,7 +613,9 @@ function sample!(a::AbstractArray, wv::WeightVec, x::AbstractArray;
end
end
else
naive_wsample_norep!(a, wv, x)
k <= n || error("Cannot draw $n samples from $k samples without replacement.")

efraimidis_aexpj_wsample_norep!(a, wv, x)
if ordered
sort!(x)
end
Expand Down
27 changes: 26 additions & 1 deletion test/wsampling.jl
Expand Up @@ -74,7 +74,8 @@ function check_wsample_norep(a::AbstractArray, vrgn, wv::WeightVec, ptol::Real;
end
end

import StatsBase: naive_wsample_norep!
import StatsBase: naive_wsample_norep!, efraimidis_a_wsample_norep!,
efraimidis_ares_wsample_norep!, efraimidis_aexpj_wsample_norep!

n = 10^5
wv = weights([0.2, 0.8, 0.4, 0.6])
Expand All @@ -84,3 +85,27 @@ for j = 1:n
naive_wsample_norep!(4:7, wv, view(a,:,j))
end
check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false)

a = zeros(Int, 3, n)
for j = 1:n
efraimidis_a_wsample_norep!(4:7, wv, view(a,:,j))
end
check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false)

a = zeros(Int, 3, n)
for j = 1:n
efraimidis_ares_wsample_norep!(4:7, wv, view(a,:,j))
end
check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false)

a = zeros(Int, 3, n)
for j = 1:n
efraimidis_aexpj_wsample_norep!(4:7, wv, view(a,:,j))
end
check_wsample_norep(a, (4, 7), wv, 5.0e-3; ordered=false)

a = sample(4:7, wv, 3; replace=false, ordered=false)
check_wsample_norep(a, (4, 7), wv, -1; ordered=false)

a = sample(4:7, wv, 3; replace=false, ordered=true)
check_wsample_norep(a, (4, 7), wv, -1; ordered=true)

0 comments on commit af2c127

Please sign in to comment.