Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid an extra call to f in Statistic.mean(f, A) #80

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 39 additions & 31 deletions src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,26 @@ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims)

_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)


function _promoted_sum(f, A::AbstractArray; dims) # calls f() length(x) +1 times
kagalenko-m-b marked this conversation as resolved.
Show resolved Hide resolved
x1 = f(first(A)) / 1
result = sum(x -> _mean_promote(x1, f(x)), A; dims)
end
function _promoted_sum(f, A::AbstractVector; dims) # calls f() length(x) times
kagalenko-m-b marked this conversation as resolved.
Show resolved Hide resolved
x1 = f(first(A)) / 1
result = sum(x -> _mean_promote(x1, f(x)), @view A[begin+1:end]; dims,
init = x1)
kagalenko-m-b marked this conversation as resolved.
Show resolved Hide resolved
end

# ::Dims is there to force specializing on Colon (as it is a Function)
function _mean(f, A::AbstractArray, dims::Dims=:) where Dims
isempty(A) && return sum(f, A, dims=dims)/0
isempty(A) && return sum(f, A; dims)/0
if dims === (:)
n = length(A)
else
n = mapreduce(i -> size(A, i), *, unique(dims); init=1)
end
x1 = f(first(A)) / 1
result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims)
result = _promoted_sum(f, A; dims)
if dims === (:)
return result / n
else
Expand Down Expand Up @@ -316,7 +326,7 @@ whereas the sum is scaled with `n` if `corrected` is

If `itr` is an `AbstractArray`, `dims` can be provided to compute the variance
over dimensions. In that case, `mean` must be an array with the same shape as
`mean(itr, dims=dims)` (additional trailing singleton dimensions are allowed).
`mean(itr; dims)` (additional trailing singleton dimensions are allowed).
kagalenko-m-b marked this conversation as resolved.
Show resolved Hide resolved

!!! note
If array contains `NaN` or [`missing`](@ref) values, the result is also
Expand All @@ -327,7 +337,7 @@ over dimensions. In that case, `mean` must be an array with the same shape as
varm(A::AbstractArray, m::AbstractArray; corrected::Bool=true, dims=:) = _varm(A, m, corrected, dims)

_varm(A::AbstractArray{T}, m, corrected::Bool, region) where {T} =
varm!(Base.reducedim_init(t -> abs2(t)/2, +, A, region), A, m; corrected=corrected)
varm!(Base.reducedim_init(t -> abs2(t)/2, +, A, region), A, m; corrected)

varm(A::AbstractArray, m; corrected::Bool=true) = _varm(A, m, corrected, :)

Expand Down Expand Up @@ -356,7 +366,7 @@ If `itr` is an `AbstractArray`, `dims` can be provided to compute the variance
over dimensions.

A pre-computed `mean` may be provided. When `dims` is specified, `mean` must be
an array with the same shape as `mean(itr, dims=dims)` (additional trailing
an array with the same shape as `mean(itr; dims)` (additional trailing
singleton dimensions are allowed).

!!! note
Expand All @@ -369,16 +379,16 @@ var(A::AbstractArray; corrected::Bool=true, mean=nothing, dims=:) = _var(A, corr

function _var(A::AbstractArray, corrected::Bool, mean, dims)
if mean === nothing
mean = Statistics.mean(A, dims=dims)
mean = Statistics.mean(A; dims)
end
return varm(A, mean; corrected=corrected, dims=dims)
return varm(A, mean; corrected, dims)
end

function _var(A::AbstractArray, corrected::Bool, mean, ::Colon)
if mean === nothing
mean = Statistics.mean(A)
end
return real(varm(A, mean; corrected=corrected))
return real(varm(A, mean; corrected))
end

varm(iterable, m; corrected::Bool=true) = _var(iterable, corrected, m)
Expand Down Expand Up @@ -419,8 +429,7 @@ function sqrt!(A::AbstractArray)
A
end

stdm(A::AbstractArray, m; corrected::Bool=true) =
sqrt.(varm(A, m; corrected=corrected))
stdm(A::AbstractArray, m; corrected::Bool=true) = sqrt.(varm(A, m; corrected))

"""
std(itr; corrected::Bool=true, mean=nothing[, dims])
Expand All @@ -440,7 +449,7 @@ If `itr` is an `AbstractArray`, `dims` can be provided to compute the standard d
over dimensions, and `means` may contain means for each dimension of `itr`.

A pre-computed `mean` may be provided. When `dims` is specified, `mean` must be
an array with the same shape as `mean(itr, dims=dims)` (additional trailing
an array with the same shape as `mean(itr; dims)` (additional trailing
singleton dimensions are allowed).

!!! note
Expand All @@ -452,19 +461,19 @@ singleton dimensions are allowed).
std(A::AbstractArray; corrected::Bool=true, mean=nothing, dims=:) = _std(A, corrected, mean, dims)

_std(A::AbstractArray, corrected::Bool, mean, dims) =
sqrt.(var(A; corrected=corrected, mean=mean, dims=dims))
sqrt.(var(A; corrected, mean, dims))

_std(A::AbstractArray, corrected::Bool, mean, ::Colon) =
sqrt.(var(A; corrected=corrected, mean=mean))
sqrt.(var(A; corrected, mean))

_std(A::AbstractArray{<:AbstractFloat}, corrected::Bool, mean, dims) =
sqrt!(var(A; corrected=corrected, mean=mean, dims=dims))
sqrt!(var(A; corrected, mean, dims))

_std(A::AbstractArray{<:AbstractFloat}, corrected::Bool, mean, ::Colon) =
sqrt.(var(A; corrected=corrected, mean=mean))
sqrt.(var(A; corrected, mean))

std(iterable; corrected::Bool=true, mean=nothing) =
sqrt(var(iterable, corrected=corrected, mean=mean))
sqrt(var(iterable; corrected, mean))

"""
stdm(itr, mean; corrected::Bool=true)
Expand All @@ -482,16 +491,15 @@ whereas the sum is scaled with `n` if `corrected` is

If `itr` is an `AbstractArray`, `dims` can be provided to compute the standard deviation
over dimensions. In that case, `mean` must be an array with the same shape as
`mean(itr, dims=dims)` (additional trailing singleton dimensions are allowed).
`mean(itr; dims)` (additional trailing singleton dimensions are allowed).

!!! note
If array contains `NaN` or [`missing`](@ref) values, the result is also
`NaN` or `missing` (`missing` takes precedence if array contains both).
Use the [`skipmissing`](@ref) function to omit `missing` entries and compute the
standard deviation of non-missing values.
"""
stdm(iterable, m; corrected::Bool=true) =
std(iterable, corrected=corrected, mean=m)
stdm(iterable, mean; corrected::Bool=true) = std(iterable; corrected, mean)


###### covariance ######
Expand Down Expand Up @@ -553,13 +561,13 @@ end
## Use map(t -> t - xmean, x) instead of x .- xmean to allow for Vector{Vector}
## which can't be handled by broadcast
covm(x::AbstractVector, xmean; corrected::Bool=true) =
covzm(map(t -> t - xmean, x); corrected=corrected)
covzm(map(t -> t - xmean, x); corrected)
covm(x::AbstractMatrix, xmean, vardim::Int=1; corrected::Bool=true) =
covzm(x .- xmean, vardim; corrected=corrected)
covzm(x .- xmean, vardim; corrected)
covm(x::AbstractVector, xmean, y::AbstractVector, ymean; corrected::Bool=true) =
covzm(map(t -> t - xmean, x), map(t -> t - ymean, y); corrected=corrected)
covzm(map(t -> t - xmean, x), map(t -> t - ymean, y); corrected)
covm(x::AbstractVecOrMat, xmean, y::AbstractVecOrMat, ymean, vardim::Int=1; corrected::Bool=true) =
covzm(x .- xmean, y .- ymean, vardim; corrected=corrected)
covzm(x .- xmean, y .- ymean, vardim; corrected)

# cov (API)
"""
Expand All @@ -568,7 +576,7 @@ covm(x::AbstractVecOrMat, xmean, y::AbstractVecOrMat, ymean, vardim::Int=1; corr
Compute the variance of the vector `x`. If `corrected` is `true` (the default) then the sum
is scaled with `n-1`, whereas the sum is scaled with `n` if `corrected` is `false` where `n = length(x)`.
"""
cov(x::AbstractVector; corrected::Bool=true) = covm(x, mean(x); corrected=corrected)
cov(x::AbstractVector; corrected::Bool=true) = covm(x, mean(x); corrected)

"""
cov(X::AbstractMatrix; dims::Int=1, corrected::Bool=true)
Expand All @@ -578,7 +586,7 @@ is `true` (the default) then the sum is scaled with `n-1`, whereas the sum is sc
if `corrected` is `false` where `n = size(X, dims)`.
"""
cov(X::AbstractMatrix; dims::Int=1, corrected::Bool=true) =
covm(X, _vmean(X, dims), dims; corrected=corrected)
covm(X, _vmean(X, dims), dims; corrected)

"""
cov(x::AbstractVector, y::AbstractVector; corrected::Bool=true)
Expand All @@ -589,7 +597,7 @@ default), computes ``\\frac{1}{n-1}\\sum_{i=1}^n (x_i-\\bar x) (y_i-\\bar y)^*``
`false`, computes ``\\frac{1}{n}\\sum_{i=1}^n (x_i-\\bar x) (y_i-\\bar y)^*``.
"""
cov(x::AbstractVector, y::AbstractVector; corrected::Bool=true) =
covm(x, mean(x), y, mean(y); corrected=corrected)
covm(x, mean(x), y, mean(y); corrected)

"""
cov(X::AbstractVecOrMat, Y::AbstractVecOrMat; dims::Int=1, corrected::Bool=true)
Expand All @@ -599,7 +607,7 @@ Compute the covariance between the vectors or matrices `X` and `Y` along the dim
the sum is scaled with `n` if `corrected` is `false` where `n = size(X, dims) = size(Y, dims)`.
"""
cov(X::AbstractVecOrMat, Y::AbstractVecOrMat; dims::Int=1, corrected::Bool=true) =
covm(X, _vmean(X, dims), Y, _vmean(Y, dims), dims; corrected=corrected)
covm(X, _vmean(X, dims), Y, _vmean(Y, dims), dims; corrected)

##### correlation #####

Expand Down Expand Up @@ -986,9 +994,9 @@ end
require_one_based_indexing(v)

n = length(v)

@assert n > 0 # this case should never happen here

m = alpha + p * (one(alpha) - alpha - beta)
aleph = n*p + oftype(p, m)
j = clamp(trunc(Int, aleph), 1, n-1)
Expand All @@ -1001,7 +1009,7 @@ end
a = v[j]
b = v[j + 1]
end

if isfinite(a) && isfinite(b)
return a + γ*(b-a)
else
Expand Down
11 changes: 10 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ end
≈ float(typemax(Int)))
end
let x = rand(10000) # mean should use sum's accurate pairwise algorithm
@test mean(x) == sum(x) / length(x)
@test mean(x) == sum((@view x[begin + 1:end]), init=x[1]) / length(x)
end
@test mean(Number[1, 1.5, 2+3im]) === 1.5+1im # mixed-type array
@test mean(v for v in Number[1, 1.5, 2+3im]) === 1.5+1im
Expand All @@ -162,6 +162,15 @@ end
@test (@inferred mean(Iterators.filter(x -> true, Int[]))) === 0/0
@test (@inferred mean(Iterators.filter(x -> true, Float32[]))) === 0.f0/0
@test (@inferred mean(Iterators.filter(x -> true, Float64[]))) === 0/0
# Check that mean does not call function argument an extra time
let _cnt = 0, N = 100, x = rand(Int, N)
kagalenko-m-b marked this conversation as resolved.
Show resolved Hide resolved
f(x) = begin; _cnt += 1; x; end
kagalenko-m-b marked this conversation as resolved.
Show resolved Hide resolved
@test mean(1:N) == mean(f, 1:N)
@test _cnt == N
_cnt = 0
@test mean(x) == mean(f, x)
@test _cnt == N
end
end

@testset "mean/median for ranges" begin
Expand Down