diff --git a/src/weights.jl b/src/weights.jl index e5df6b738..a1389c0c7 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -289,6 +289,11 @@ end UnitWeights{T}(length(i)) end +function Base.getindex(wv::UnitWeights{T}, i::AbstractArray{Bool}) where T + length(wv) == length(i) || throw(DimensionMismatch()) + UnitWeights{T}(count(i)) +end + Base.getindex(wv::UnitWeights{T}, ::Colon) where {T} = UnitWeights{T}(wv.len) """ diff --git a/test/weights.jl b/test/weights.jl index 9f071483e..7735e04f7 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -112,6 +112,7 @@ end @test isequal(wv, uweights(3)) @test wv != fweights(fill(1.0, 3)) @test wv == uweights(3) + @test wv[[true, false, false]] == uweights(Float64, 1) end ## wsum