Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

summarystats improvements #858

Merged
merged 9 commits into from
Jun 24, 2023
10 changes: 7 additions & 3 deletions src/scalarstats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,7 @@ kldivergence(p::AbstractArray{<:Real}, q::AbstractArray{<:Real}, b::Real) =

struct SummaryStats{T<:Union{AbstractFloat,Missing}}
mean::T
sd::T
min::T
q25::T
median::T
Expand All @@ -871,14 +872,16 @@ end
summarystats(a)

Compute summary statistics for a real-valued array `a`. Returns a
`SummaryStats` object containing the mean, minimum, 25th percentile,
median, 75th percentile, and maxmimum.
`SummaryStats` object containing the number of observations,
number of missing observations, standard deviation, mean, minimum,
25th percentile, median, 75th percentile, and maximum.
"""
function summarystats(a::AbstractArray{T}) where T<:Union{Real,Missing}
# `mean` doesn't fail on empty input but rather returns `NaN`, so we can use the
# return type to populate the `SummaryStats` structure.
s = T >: Missing ? collect(skipmissing(a)) : a
m = mean(s)
stdev = std(s, mean=m)
R = typeof(m)
n = length(a)
ns = length(s)
Expand All @@ -889,7 +892,7 @@ function summarystats(a::AbstractArray{T}) where T<:Union{Real,Missing}
else
quantile(s, [0.00, 0.25, 0.50, 0.75, 1.00])
end
SummaryStats{R}(m, qs..., n, n - ns)
SummaryStats{R}(m, stdev, qs..., n, n - ns)
end

function Base.show(io::IO, ss::SummaryStats)
Expand All @@ -898,6 +901,7 @@ function Base.show(io::IO, ss::SummaryStats)
ss.nobs > 0 || return
@printf(io, "Missing Count: %i\n", ss.nmiss)
@printf(io, "Mean: %.6f\n", ss.mean)
@printf(io, "Std. Deviation: %.6f\n", ss.sd)
@printf(io, "Minimum: %.6f\n", ss.min)
@printf(io, "1st Quartile: %.6f\n", ss.q25)
@printf(io, "Median: %.6f\n", ss.median)
Expand Down
2 changes: 2 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ describe(io, collect(1:10))
Length: 10
Missing Count: 0
Mean: 5.500000
Std. Deviation: 3.027650
Minimum: 1.000000
1st Quartile: 3.250000
Median: 5.500000
Expand All @@ -68,6 +69,7 @@ describe(io, Union{Float32,Missing}[1.0, 4.5, missing, missing, 33.1])
Length: 5
Missing Count: 2
Mean: 12.866666
Std. Deviation: 17.609751
Minimum: 1.000000
1st Quartile: 2.750000
Median: 4.500000
Expand Down
10 changes: 10 additions & 0 deletions test/scalarstats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,30 +332,39 @@ s = summarystats(1:5)
@test isa(s, StatsBase.SummaryStats)
@test s.min == 1.0
@test s.max == 5.0
@test s.nobs == 5
@test s.nmiss == 0
@test s.mean ≈ 3.0
@test s.median ≈ 3.0
@test s.q25 ≈ 2.0
@test s.q75 ≈ 4.0
@test s.sd ≈ 1.5811388300841898

# Issue #631
s = summarystats([-2, -1, 0, 1, 2, missing])
@test isa(s, StatsBase.SummaryStats)
@test s.min == -2.0
@test s.max == 2.0
@test s.nobs == 6
@test s.nmiss == 1
@test s.mean ≈ 0.0
@test s.median ≈ 0.0
@test s.q25 ≈ -1.0
@test s.q75 ≈ +1.0
@test s.sd ≈ 1.5811388300841898

# Issue #631
s = summarystats(zeros(10))
@test isa(s, StatsBase.SummaryStats)
@test s.min == 0.0
@test s.max == 0.0
@test s.nobs == 10
@test s.nmiss == 0
@test s.mean ≈ 0.0
@test s.median ≈ 0.0
@test s.q25 ≈ 0.0
@test s.q75 ≈ 0.0
@test s.sd ≈ 0.0

# Issue #631
s = summarystats(Union{Float64,Missing}[missing, missing])
Expand All @@ -364,3 +373,4 @@ s = summarystats(Union{Float64,Missing}[missing, missing])
@test s.nmiss == 2
@test isnan(s.mean)
@test isnan(s.median)
@test isnan(s.sd)