From e4343f3cfb503ea1e5079e1239bebbbf66ffca50 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 4 Jan 2025 23:02:37 -0500 Subject: [PATCH 1/2] feat: missing mean(f, ...) dispatches --- ext/ReactantStatisticsExt.jl | 18 ++++++++++-------- test/basic.jl | 6 ++++++ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl index 40db81a8ed..8f8b8e4f47 100644 --- a/ext/ReactantStatisticsExt.jl +++ b/ext/ReactantStatisticsExt.jl @@ -4,18 +4,20 @@ using Reactant: AnyTracedRArray using Reactant.TracedUtils: materialize_traced_array using Statistics: Statistics -function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N} - A = materialize_traced_array(A) +function Statistics._mean(f::F, A::AnyTracedRArray{T,N}, dims) where {F,T,N} denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) - return mapreduce(identity, +, A; dims) / denom + return mapreduce(f, +, A; dims) / denom end -function Statistics.var( - A::AnyTracedRArray{T,N}; dims=:, mean=nothing, corrected=true -) where {T,N} - A = materialize_traced_array(A) +function Statistics._var(A::AnyTracedRArray{T,N}, corrected::Bool, mean, ::Colon) where {T,N} + mean === nothing && (mean = Statistics.mean(A)) + denom = length(A) - corrected + return mapreduce(abs2, +, A .- mean; dims=:) / denom +end + +function Statistics._var(A::AnyTracedRArray{T,N}, corrected::Bool, mean, dims) where {T,N} mean === nothing && (mean = Statistics.mean(A; dims)) - denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + denom = prod(Base.Fix1(size, A), dims) - corrected return mapreduce(abs2, +, A .- mean; dims) / denom end diff --git a/test/basic.jl b/test/basic.jl index 1783a44c36..d05d3d968a 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -176,16 +176,22 @@ end mean_fn2(x) = mean(x; dims=1) mean_fn3(x) = mean(x; dims=(1, 2)) mean_fn4(x) = mean(x; dims=(1, 3)) + mean_f1abs2(x) = mean(abs2, x) + mean_f2abs2(x) = mean(abs2, x; dims=1) mean_fn1_compiled = @compile mean_fn1(x_ca) mean_fn2_compiled = @compile mean_fn2(x_ca) mean_fn3_compiled = @compile mean_fn3(x_ca) mean_fn4_compiled = @compile mean_fn4(x_ca) + mean_f1abs2_compiled = @compile mean_f1abs2(x_ca) + mean_f2abs2_compiled = @compile mean_f2abs2(x_ca) @test mean_fn1(x) ≈ mean_fn1_compiled(x_ca) @test mean_fn2(x) ≈ mean_fn2_compiled(x_ca) @test mean_fn3(x) ≈ mean_fn3_compiled(x_ca) @test mean_fn4(x) ≈ mean_fn4_compiled(x_ca) + @test mean_f1abs2(x) ≈ mean_f1abs2_compiled(x_ca) + @test mean_f2abs2(x) ≈ mean_f2abs2_compiled(x_ca) # XXX: @jit doesn't work with `;` # @test @jit(var(x_ca)) ≈ var(x) From a01e73aaab39f98f06a6a0972800d6a9eacd3fa1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 5 Jan 2025 10:43:31 -0500 Subject: [PATCH 2/2] Update ext/ReactantStatisticsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantStatisticsExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl index 8f8b8e4f47..1c9e3f9f33 100644 --- a/ext/ReactantStatisticsExt.jl +++ b/ext/ReactantStatisticsExt.jl @@ -9,7 +9,9 @@ function Statistics._mean(f::F, A::AnyTracedRArray{T,N}, dims) where {F,T,N} return mapreduce(f, +, A; dims) / denom end -function Statistics._var(A::AnyTracedRArray{T,N}, corrected::Bool, mean, ::Colon) where {T,N} +function Statistics._var( + A::AnyTracedRArray{T,N}, corrected::Bool, mean, ::Colon +) where {T,N} mean === nothing && (mean = Statistics.mean(A)) denom = length(A) - corrected return mapreduce(abs2, +, A .- mean; dims=:) / denom