Skip to content

gradient() fails on array mutation for mean(f, x; dims) #1128

@staticfloat

Description

@staticfloat

If you provide both an element-wise function f and a dimension specification, mean() apparently causes array mutation, which breaks Zygote's ability to differentiate:

julia> using Zygote, Statistics
       x = randn(3, 3)
       Zygote.gradient(Params([x])) do
           sum(mean(abs2, x, dims=1))
       end
ERROR: Mutating arrays is not supported -- called copyto!(::Matrix{Float64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#441#442"{Matrix{Float64}})(#unused#::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/array.jl:74
  [3] (::Zygote.var"#2330#back#443"{Zygote.var"#441#442"{Matrix{Float64}}})(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./broadcast.jl:894 [inlined]
  [5] Pullback
    @ ./broadcast.jl:891 [inlined]
  [6] Pullback
    @ ./broadcast.jl:887 [inlined]
  [7] (::typeof((materialize!)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
  [8] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:181 [inlined]
  [9] (::typeof((_mean)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [10] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
 [11] (::typeof((#mean#1)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [12] Pullback
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Statistics/src/Statistics.jl:104 [inlined]
 [13] (::typeof((mean##kw)))(Δ::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [14] Pullback
    @ ./REPL[14]:4 [inlined]
 [15] (::typeof((#17)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#89#90"{Params, typeof((#17)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:356
 [17] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
 [18] top-level scope
    @ REPL[14]:3

Looking through the adjoints for mean() defined in lib/array.jl, I would guess that the fact that I'm passing abs2 in for f causes Zygote's implementation to be skipped altogether, and then the dims kwarg causes us to go down a bad path that involves array mutation. I was going to submit a PR to create a new @adjoint definition for one that includes f, but I don't know how to get the adjoint of a user-provided function.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions