Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rule for mean(f, x; dims) #85

Open
nickrobinson251 opened this issue Aug 16, 2019 · 3 comments · Fixed by #615
Open

Add rule for mean(f, x; dims) #85

nickrobinson251 opened this issue Aug 16, 2019 · 3 comments · Fixed by #615

Comments

@nickrobinson251
Copy link
Contributor

We have rules for mean, except for mean(f, x; dims) which is new as of Julia v1.3

@nickrobinson251 nickrobinson251 changed the title Add rue for mean(f, x; dims) Add rule for mean(f, x; dims) Aug 16, 2019
YingboMa added a commit that referenced this issue Dec 24, 2019
```julia
julia> using NaNMath, SpecialFunctions
┌ Warning: Error requiring NaNMath from ChainRules:
│ LoadError: UndefVarError: SpecialFunctions not defined
│ Stacktrace:
│  [1] include at ./boot.jl:328 [inlined]
│  [2] include_relative(::Module, ::String) at ./loading.jl:1105
│  [3] include at ./Base.jl:31 [inlined]
│  [4] include(::String) at /home/scheme/.julia/dev/ChainRules/src/ChainRules.jl:1
│  [5] top-level scope at /home/scheme/.julia/dev/ChainRules/src/ChainRules.jl:45
│  [6] eval at ./boot.jl:330 [inlined]
│  [7] eval at /home/scheme/.julia/dev/ChainRules/src/ChainRules.jl:1 [inlined]
│  [8] (::ChainRules.var"#867#873")() at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:67
│  [9] err(::ChainRules.var"#867#873", ::Module, ::String) at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:38
│  [10] #866 at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:66 [inlined]
│  [11] withpath(::ChainRules.var"#866#872", ::String) at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:28
│  [12] (::ChainRules.var"#865#871")() at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:65
│  [13] #invokelatest#1 at ./essentials.jl:709 [inlined]
│  [14] invokelatest at ./essentials.jl:708 [inlined]
│  [15] #3 at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:19 [inlined]
│  [16] iterate at ./generator.jl:47 [inlined]
│  [17] _collect(::Array{Function,1}, ::Base.Generator{Array{Function,1},Requires.var"#3#4"}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:635
│  [18] map at ./array.jl:564 [inlined]
│  [19] loadpkg(::Base.PkgId) at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:19
│  [20] #invokelatest#1 at ./essentials.jl:709 [inlined]
│  [21] invokelatest at ./essentials.jl:708 [inlined]
│  [22] require(::Base.PkgId) at ./loading.jl:925
│  [23] require(::Module, ::Symbol) at ./loading.jl:917
│  [24] eval(::Module, ::Any) at ./boot.jl:330
│  [25] eval_user_input(::Any, ::REPL.REPLBackend) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/REPL/src/REPL.jl:86
│  [26] run_backend(::REPL.REPLBackend) at /home/scheme/.julia/packages/Revise/S7mrl/src/Revise.jl:1057
│  [27] (::Revise.var"#85#87"{REPL.REPLBackend})() at ./task.jl:333
│ in expression starting at /home/scheme/.julia/dev/ChainRules/src/rulesets/packages/NaNMath.jl:4
└ @ Requires ~/.julia/packages/Requires/9Jse8/src/require.jl:40
```
YingboMa added a commit that referenced this issue Dec 25, 2019
* Remove duplicated method definitions

This PR fixes the following warning
```julia
[ Info: Precompiling ForwardDiff2 [994df76e-a4c1-5e1f-bd5c-23b9b5303d4f]
WARNING: Method definition rrule(typeof(Base.sum), AbstractArray{#s695, N} where N where #s695<:Real) in module ChainRules at /home/scheme/.julia/dev/ChainRules/src/rulesets/Base/mapreduce.jl:61 overwritten at /home/scheme/.julia/dev/ChainRules/src/rulesets/Base/mapreduce.jl:76.
  ** incremental compilation may be fatally broken for this module **

WARNING: Method definition rrule(typeof(Base.sum), AbstractArray{#s695, N} where N where #s695<:Real) in module ChainRules at /home/scheme/.julia/dev/ChainRules/src/rulesets/Base/mapreduce.jl:61 overwritten at /home/scheme/.julia/dev/ChainRules/src/rulesets/Base/mapreduce.jl:76.
  ** incremental compilation may be fatally broken for this module **
```

* Fix NaNMath and SpecialFunctions warnings

```julia
julia> using NaNMath, SpecialFunctions
┌ Warning: Error requiring NaNMath from ChainRules:
│ LoadError: UndefVarError: SpecialFunctions not defined
│ Stacktrace:
│  [1] include at ./boot.jl:328 [inlined]
│  [2] include_relative(::Module, ::String) at ./loading.jl:1105
│  [3] include at ./Base.jl:31 [inlined]
│  [4] include(::String) at /home/scheme/.julia/dev/ChainRules/src/ChainRules.jl:1
│  [5] top-level scope at /home/scheme/.julia/dev/ChainRules/src/ChainRules.jl:45
│  [6] eval at ./boot.jl:330 [inlined]
│  [7] eval at /home/scheme/.julia/dev/ChainRules/src/ChainRules.jl:1 [inlined]
│  [8] (::ChainRules.var"#867#873")() at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:67
│  [9] err(::ChainRules.var"#867#873", ::Module, ::String) at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:38
│  [10] #866 at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:66 [inlined]
│  [11] withpath(::ChainRules.var"#866#872", ::String) at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:28
│  [12] (::ChainRules.var"#865#871")() at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:65
│  [13] #invokelatest#1 at ./essentials.jl:709 [inlined]
│  [14] invokelatest at ./essentials.jl:708 [inlined]
│  [15] #3 at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:19 [inlined]
│  [16] iterate at ./generator.jl:47 [inlined]
│  [17] _collect(::Array{Function,1}, ::Base.Generator{Array{Function,1},Requires.var"#3#4"}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:635
│  [18] map at ./array.jl:564 [inlined]
│  [19] loadpkg(::Base.PkgId) at /home/scheme/.julia/packages/Requires/9Jse8/src/require.jl:19
│  [20] #invokelatest#1 at ./essentials.jl:709 [inlined]
│  [21] invokelatest at ./essentials.jl:708 [inlined]
│  [22] require(::Base.PkgId) at ./loading.jl:925
│  [23] require(::Module, ::Symbol) at ./loading.jl:917
│  [24] eval(::Module, ::Any) at ./boot.jl:330
│  [25] eval_user_input(::Any, ::REPL.REPLBackend) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.3/REPL/src/REPL.jl:86
│  [26] run_backend(::REPL.REPLBackend) at /home/scheme/.julia/packages/Revise/S7mrl/src/Revise.jl:1057
│  [27] (::Revise.var"#85#87"{REPL.REPLBackend})() at ./task.jl:333
│ in expression starting at /home/scheme/.julia/dev/ChainRules/src/rulesets/packages/NaNMath.jl:4
└ @ Requires ~/.julia/packages/Requires/9Jse8/src/require.jl:40
```

* New patch release

* Add NaNMath.tan rule back

* Revert "Fix NaNMath and SpecialFunctions warnings"

This reverts commit 40382c6.

* Remove glue modules
@mzgubic
Copy link
Member

mzgubic commented Dec 1, 2021

The implementation could look a bit like sum(f, xs):

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(sum), f, xs::AbstractArray; dims=:
)
fx_and_pullbacks = map(x->rrule_via_ad(config, f, x), xs)
y = sum(first, fx_and_pullbacks; dims=dims)
pullbacks = last.(fx_and_pullbacks)
project = ProjectTo(xs)
function sum_pullback(ȳ)
call(f, x) = f(x)
# if dims is :, then need only left-handed only broadcast
broadcast_ȳ = dims isa Colon ? (ȳ,) :
f̄_and_x̄s = call.(pullbacks, broadcast_ȳ)
# no point thunking as most of work is in f̄_and_x̄s which we need to compute for both
= if fieldcount(typeof(f)) === 0 # Then don't need to worry about derivative wrt f
NoTangent()
else
sum(first, f̄_and_x̄s)
end
x̄s = map(unthunk last, f̄_and_x̄s) # project does not support receiving InplaceableThunks
return NoTangent(), f̄, project(x̄s)
end
return y, sum_pullback
end

@mcabbott
Copy link
Member

mcabbott commented Dec 1, 2021

Ideally it would probably share code, have a function which for mean gets scale=1/size(...) or something.

Xref #529 which is trying to re-work that rule.

@oxinabox
Copy link
Member

reopened as i had to revert the fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants