diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl index 40db81a8ed..1c9e3f9f33 100644 --- a/ext/ReactantStatisticsExt.jl +++ b/ext/ReactantStatisticsExt.jl @@ -4,18 +4,22 @@ using Reactant: AnyTracedRArray using Reactant.TracedUtils: materialize_traced_array using Statistics: Statistics -function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N} - A = materialize_traced_array(A) +function Statistics._mean(f::F, A::AnyTracedRArray{T,N}, dims) where {F,T,N} denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) - return mapreduce(identity, +, A; dims) / denom + return mapreduce(f, +, A; dims) / denom end -function Statistics.var( - A::AnyTracedRArray{T,N}; dims=:, mean=nothing, corrected=true +function Statistics._var( + A::AnyTracedRArray{T,N}, corrected::Bool, mean, ::Colon ) where {T,N} - A = materialize_traced_array(A) + mean === nothing && (mean = Statistics.mean(A)) + denom = length(A) - corrected + return mapreduce(abs2, +, A .- mean; dims=:) / denom +end + +function Statistics._var(A::AnyTracedRArray{T,N}, corrected::Bool, mean, dims) where {T,N} mean === nothing && (mean = Statistics.mean(A; dims)) - denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + denom = prod(Base.Fix1(size, A), dims) - corrected return mapreduce(abs2, +, A .- mean; dims) / denom end diff --git a/test/basic.jl b/test/basic.jl index 1783a44c36..d05d3d968a 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -176,16 +176,22 @@ end mean_fn2(x) = mean(x; dims=1) mean_fn3(x) = mean(x; dims=(1, 2)) mean_fn4(x) = mean(x; dims=(1, 3)) + mean_f1abs2(x) = mean(abs2, x) + mean_f2abs2(x) = mean(abs2, x; dims=1) mean_fn1_compiled = @compile mean_fn1(x_ca) mean_fn2_compiled = @compile mean_fn2(x_ca) mean_fn3_compiled = @compile mean_fn3(x_ca) mean_fn4_compiled = @compile mean_fn4(x_ca) + mean_f1abs2_compiled = @compile mean_f1abs2(x_ca) + mean_f2abs2_compiled = @compile mean_f2abs2(x_ca) @test mean_fn1(x) ≈ mean_fn1_compiled(x_ca) @test mean_fn2(x) ≈ mean_fn2_compiled(x_ca) @test mean_fn3(x) ≈ mean_fn3_compiled(x_ca) @test mean_fn4(x) ≈ mean_fn4_compiled(x_ca) + @test mean_f1abs2(x) ≈ mean_f1abs2_compiled(x_ca) + @test mean_f2abs2(x) ≈ mean_f2abs2_compiled(x_ca) # XXX: @jit doesn't work with `;` # @test @jit(var(x_ca)) ≈ var(x)