Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when negative weights or zero sum are used when sampling #834

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,10 @@ Optionally specify a random number generator `rng` as the first argument
function sample(rng::AbstractRNG, wv::AbstractWeights)
1 == firstindex(wv) ||
throw(ArgumentError("non 1-based arrays are not supported"))
t = rand(rng) * sum(wv)
all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed"))
s = sum(wv)
s > 0 || throw(ArgumentError("sum of weights must be greater than 0"))
t = rand(rng) * s
n = length(wv)
i = 1
cw = wv[1]
Expand Down Expand Up @@ -621,6 +624,8 @@ function direct_sample!(rng::AbstractRNG, a::AbstractArray,
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed"))
sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0"))
for i = 1:length(x)
x[i] = a[sample(rng, wv)]
end
Expand Down Expand Up @@ -710,6 +715,8 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights,
throw(ArgumentError("non 1-based arrays are not supported"))
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed"))
sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0"))

# create alias table
ap = Vector{Float64}(undef, n)
Expand Down Expand Up @@ -749,6 +756,8 @@ function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
n = length(a)
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths."))
k = length(x)
all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed"))
sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0"))

w = Vector{Float64}(undef, n)
copyto!(w, wv)
Expand Down Expand Up @@ -795,6 +804,8 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
n = length(a)
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv)))."))
k = length(x)
all(>=(0), wv) || throw(ArgumentError("negative weights are not allowed"))
sum(wv) > 0 || throw(ArgumentError("sum of weights must be greater than 0"))

# calculate keys for all items
keys = randexp(rng, n)
Expand Down Expand Up @@ -845,22 +856,22 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
@inbounds for _s in 1:n
s = _s
w = wv.values[s]
w < 0 && error("Negative weight found in weight vector at index $s")
w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $s"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these efraimidis functions were actually one of my first contributions to Julia packages. Fun times but I'm not surprisied that some things can be improved and made more consistent 😄

if w > 0
i += 1
pq[i] = (w/randexp(rng) => s)
end
i >= k && break
end
i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)"))
i < k && throw(ArgumentError("wv must have at least $k strictly positive entries (got $i)"))
heapify!(pq)

# set threshold
@inbounds threshold = pq[1].first

@inbounds for i in s+1:n
w = wv.values[i]
w < 0 && error("Negative weight found in weight vector at index $i")
w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $i"))
w > 0 || continue
key = w/randexp(rng)

Expand Down Expand Up @@ -918,14 +929,14 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,
@inbounds for _s in 1:n
s = _s
w = wv.values[s]
w < 0 && error("Negative weight found in weight vector at index $s")
w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $s"))
if w > 0
i += 1
pq[i] = (w/randexp(rng) => s)
end
i >= k && break
end
i < k && throw(DimensionMismatch("wv must have at least $k strictly positive entries (got $i)"))
i < k && throw(ArgumentError("wv must have at least $k strictly positive entries (got $i)"))
heapify!(pq)

# set threshold
Expand All @@ -934,7 +945,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray,

@inbounds for i in s+1:n
w = wv.values[i]
w < 0 && error("Negative weight found in weight vector at index $i")
w < 0 && throw(ArgumentError("Negative weight found in weight vector at index $i"))
w > 0 || continue
X -= w
X <= 0 || continue
Expand Down Expand Up @@ -991,7 +1002,7 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs
end
end
else
k <= n || error("Cannot draw $k samples from $n samples without replacement.")
k <= n || throw(ArgumentError("Cannot draw $k samples from $n samples without replacement."))
efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered)
end
return x
Expand Down
42 changes: 37 additions & 5 deletions src/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@ abstract type AbstractWeights{S<:Real, T<:Real, V<:AbstractVector{T}} <: Abstrac
@weights name

Generates a new generic weight type with specified `name`, which subtypes `AbstractWeights`
and stores the `values` (`V<:RealVector`) and `sum` (`S<:Real`).
and stores the `values` (`V<:RealVector`), the pre-computed `sum` (`S<:Real`) and
whether all values are `positive`.
"""
macro weights(name)
return quote
mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V}
values::V
sum::S
function $(esc(name)){S, T, V}(values, sum) where {S<:Real, T<:Real, V<:AbstractVector{T}}
positive::Union{Bool, Missing}
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
function $(esc(name)){S, T, V}(values, sum, positive) where {S<:Real, T<:Real, V<:AbstractVector{T}}
isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values"))
return new{S, T, V}(values, sum)
return new{S, T, V}(values, sum, positive)
end
end
$(esc(name))(values::AbstractVector{T}, sum::S) where {S<:Real, T<:Real} = $(esc(name)){S, T, typeof(values)}(values, sum)
$(esc(name))(values::AbstractVector{<:Real}) = $(esc(name))(values, sum(values))
$(esc(name))(values::AbstractVector{T},
sum::S=Base.sum(values),
positive::Union{Bool, Missing}=missing) where {S<:Real, T<:Real} =
$(esc(name)){S, T, typeof(values)}(values, sum, positive)
end
end

Expand Down Expand Up @@ -53,9 +57,34 @@ Base.getindex(wv::W, ::Colon) where {W <: AbstractWeights} = W(copy(wv.values),
isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values"))
wv.values[i] = v
wv.sum = sum
wv.positive = missing
devmotion marked this conversation as resolved.
Show resolved Hide resolved
v
end

function Base.all(f::Base.Fix2{typeof(>=)}, wv::AbstractWeights)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be overkill, but unfortunately that's the standard way of checking whether all entries in a vector in Julia, so that's the only solution if we want external code to be able to use this feature, without exporting a new ispositive function. One advantage of defining this is that if we rework the API to take a weights keyword argument, sampling functions will be able to allow any AbstractArray and code will automatically work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. It's easy to not hit this function, e.g., when using all(x -> x >= 0, wv) etc. And, of course, it covers also things like all(>=(2), wv). One might also want to check non-negativity by e.g. !any(<(0), wv).

So I think a dedicated separate function would be cleaner and less ambiguous. If one wants to support AbstractArrays one could also at some point just define a fallback

isnonneg(x::AbstractArray{<:Real})  = all(>=(0), x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What annoys me is that there's no reason why this very basic function would live in and be exported by StatsBase. And anyway if users are not aware of the fast path (be it all(>=(0), x) or nonneg(x)), they won't use it and get the slow one, so I'm not sure choosing one syntax or the other makes a difference.

We could keep this internal for now -- though defining an internal function wouldn't be better than all(>=(0), wv) and !any(<(0), wv) as we would be sure users wouldn't be able to use it. ;-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I would suggest only defining an internal function - that seems sufficient as it's only used in the argument checks internally. Something like isnnonneg seems to focus on what we actually want to check whereas defining all or any catches also other, in principle undesired, cases and hence requires a more complicated implementation with multiple checks.

if iszero(f.x)
if ismissing(wv.positive)
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
# sum is significantly faster than all when no entries are negative
wv.positive = sum(<(0), wv.values) == 0
end
return wv.positive
else
return all(f, wv.values)
end
end

function Base.any(f::Base.Fix2{typeof(<)}, wv::AbstractWeights)
if iszero(f.x)
if ismissing(wv.positive)
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
# sum is significantly faster than all when no entries are negative
wv.positive = sum(<(0), wv.values) == 0
end
return !wv.positive
else
return any(f, wv.values)
nalimilan marked this conversation as resolved.
Show resolved Hide resolved
end
end

"""
varcorrection(n::Integer, corrected=false)

Expand Down Expand Up @@ -333,6 +362,9 @@ end

Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len)

Base.all(f::Base.Fix2{typeof(>=)}, wv::UnitWeights{T}) where {T} = one(T) >= f.x
Base.any(f::Base.Fix2{typeof(<)}, wv::UnitWeights{T}) where {T} = one(T) < f.x

"""
uweights(s::Integer)
uweights(::Type{T}, s::Integer) where T<:Real
Expand Down
40 changes: 39 additions & 1 deletion test/weights.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ weight_funcs = (weights, aweights, fweights, pweights)

@test_throws ArgumentError f([0.1, Inf])
@test_throws ArgumentError f([0.1, NaN])

end

@testset "$f, setindex!" for f in weight_funcs
Expand Down Expand Up @@ -125,6 +124,45 @@ end
@test Base.dataids(wv) == ()
end

@testset "Fast-path all(<=, wv) and any(<, wv)" begin
for f in weight_funcs
@test all(>=(0), f([1, 2]))
@test all(>=(0), f([-0.0, 0.0]))
@test !all(>=(0), f([1, -2]))
@test !all(>=(0), f([1, NaN]))
@test !any(<(0), f([1, 2]))
@test !any(<(0), f([-0.0, 0.0]))
@test any(<(0), f([1, -2]))
@test !any(<(0), f([1, NaN]))
@test any(<(0), f([-1, NaN]))

@test all(>=(1), [2, 3, 4])
@test !all(>=(1), [0, 1, 2])
@test any(<(3), [2, 3, 4])
@test !any(<(1), [1, 2, 3])
nalimilan marked this conversation as resolved.
Show resolved Hide resolved

wv = f([1.0, 2.0, 3.0])
@test all(>=(0), wv)
@test !any(<(0), wv)
wv[2] = -0.0
@test all(>=(0), wv)
@test !any(<(0), wv)
wv[2] = -1.0
@test !all(>=(0), wv)
@test any(<(0), wv)
wv[2] = 1.0
@test all(>=(0), wv)
@test !any(<(0), wv)
end

@test all(>=(0), uweights(2))
@test !any(<(0), uweights(2))
@test all(>=(1), uweights(2))
@test !any(<(1), uweights(2))
@test !all(>=(2), uweights(2))
@test any(<(2), uweights(2))
end

## wsum

@testset "wsum" begin
Expand Down
15 changes: 15 additions & 0 deletions test/wsampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,20 @@ end
# This corner case should theoretically succeed
# but it currently fails as Base.mightalias is not smart enough
@test_broken f(y, weights(view(x, 5:6)), view(x, 2:4))

# Check that negative weights are not allowed
if f === efraimidis_ares_wsample_norep! || f === efraimidis_aexpj_wsample_norep!
y[3] = -0.0
@test_throws ArgumentError f(x, weights(y), z)
else
y[3] = -0.0
f(x, weights(y), z)
end
y[3] = -1.0
@test_throws ArgumentError f(x, weights(y), z)

# Check that sum of weights cannot be zero
@test_throws ArgumentError f(x, weights(fill(0.0, 10)), z)
@test_throws ArgumentError f(x, weights(fill(-0.0, 10)), z)
end
end