Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrule for fill! #521

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

CarloLucibello
Copy link
Contributor

related to #515

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2021

I think the difficulty with allowing this is that it will cause any other rule which has captured x to give wrong answers:

julia> Zygote.gradient([1,2,3]) do x
         y = log.(x)
         fill!(x, 0)
         sum(y .+ x)
       end[1]
3-element Vector{Float64}:
 2.0
 1.5
 1.3333333333333333

julia> ForwardDiff.gradient([1,2,3]) do x
         y = log.(x)
         fill!(x, 0)
         sum(y .+ x)
       end
3-element Vector{Float64}:
 1.0
 0.5
 0.3333333333333333

Is there some clever way we might avoid this, or at least, make this example where x is used elsewhere an error, without making fill!(similar(x), y) an error? What if the gradient with respect to x is some KillerTangent(), which explodes on contact but is quietly thrown away by the rrule for similar, etc?

@CarloLucibello
Copy link
Contributor Author

Can't we trust users to use this in a safe way? We do it already for rand!.

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2021

It's harder for me to picture rand! going wrong in the wild, but it does have the same problem.

Seems to be from #252, without discussion.

@CarloLucibello
Copy link
Contributor Author

I don't see a general solution for mutating functions violating implicit assumptions made by other rrules.
Giving up entirely on them is a bit annoying, forces users into awkward alternative paths, or into using ignore blocks or into defining their own rrules. On the other hand, if we implement rules for mutating functions and people use them too freely they are going to shoot themselves in the foot.
Taking a general stance on this requires more thought. In this specific case though, given how much the fill!(similar(x), y) pattern appears in the wild, I would be more on the permissive side and go with something like this PR.

PS
I don't understand what the test failure means

@oxinabox
Copy link
Member

oxinabox commented Sep 1, 2021

What if the gradient with respect to x is some KillerTangent(), which explodes on contact but is quietly thrown away by the rrule for similar, etc?

We do have NotImplemented which kind of does that.
It poison's everything it touches making that also return NotImplemented with the same message:
Pretty much any time you pullback a NotImplemented you get the same NotImplemented.
And what would happen for

x = similar(...)
y  = fill!(x, a)

would be we call pullback_fill! and get x̄=NotImplemented(),
which we pass to pullback_similar but that returns NoTangent() for all it's inputs since it is not differentiable anyway.

And in the

julia> Zygote.gradient([1,2,3]) do x
         y = log.(x)
         fill!(x, 0)
         sum(y .+ x)
       end[1]

case
then the broadcast_pullback will get a NotImplemented which it will then pass on to the, until at the end the user gets a NotImplemented as the output.
And it should display some nice message explaining about mutation not being supported.
(we probably also want a way to turn NotImplemented construction into errors so they can workout when in their code they called e.g fill!)

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2021

Oh right, I guess NotImplemented must have roughly the rules I imagined.

However, my original example is trickier, since the pullback never gets called.

julia> function ChainRulesCore.rrule(::typeof(fill!), A::Vector, x::Number)
            project = x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
            fill!_pullback(Ȳ) = (NoTangent(), @not_implemented("nope"), project(sum(Ȳ)))
            return fill!(A, x), fill!_pullback
        end

julia> Zygote.gradient([1,2,3], 0) do x, s
         y = fill!(similar(x), s)
         sum(y)
       end
(nothing, 3.0)

julia> Zygote.gradient([1,2,3], 0) do x, s
         y = log.(x)
         z = fill!(x, s)
         sum(y .+ z)
       end
(NotImplemented(Main, #= REPL[7]:3 =#, nope), 3.0)

julia> Zygote.gradient([1,2,3], 0) do x, s
         y = log.(x)
         fill!(x, s)  # still leads to silent errors
         sum(y .+ x)
       end
([2.0, 1.5, 1.3333333333333333], nothing)

@oxinabox
Copy link
Member

oxinabox commented Sep 1, 2021

Huh, is the difference about if it was assigned to z or not?
I didn't know Zygote reasons about variables in that way.
I wonder if that could be changed>

@mcabbott
Copy link
Member

mcabbott commented Sep 1, 2021

I don't know the internals, maybe this could be changed? It calls the rrule on the forward pass, but not the pullback, I think because what's returned by fill! isn't an input to any later function, hence there is no input for the pullback. (FWIW Diffractor fails on both of these, right now.)

@mcabbott
Copy link
Member

Here's an update.

  • Since the above, Zygote seems to have been taught to silently ignore NotImplemented, giving more ways to get wrong answers.
  • Diffractor does always call the pullback. When the return of fill! is not used, this will get Zero input. Perhaps the pullback should have a method for such cases.
  • Using NotImplemented to mark x as poisoned, without restoring its value, isn't broad enough. In the examples above, log.(x) captures the old value of x and uses this for x's gradient. But something like x * y captures both values, and uses the old value of x for y's gradient.
julia> using Zygote, ChainRulesCore, ForwardDiff, Diffractor

# New rule, with back(::Zero) method

julia> function ChainRulesCore.rrule(::typeof(fill!), A::Vector, x::Number)
         function back(dB)
           println("pullback for fill! got $dB")
           (NoTangent(), @not_implemented("arg is mutated"), sum(dB))
         end
         function back(dB::AbstractZero)
            println("pullback for fill! got $dB")
            (NoTangent(), @not_implemented("mutated"), @not_implemented("no input"))
          end
         fill!(A,x), back
       end

julia> Zygote.gradient([1,2], 3) do x, s  # easy case, works as desired
         y = fill!(similar(x), s)
         sum(abs2, y)
       end
pullback for fill! got [6, 6]
(nothing, 12.0)

# Example from above

julia> Zygote.gradient([1,2,3], 0) do x, s  # silently wrong, both args!
         y = log.(x)  # this needs x's value
         fill!(x, s)  # pullback is not called
         sum(y .+ x)
       end
([2.0, 1.5, 1.3333333333333333], nothing)

julia> ForwardDiff.gradient([1,2,3]) do x
         y = log.(x)
         fill!(x, 0)
         sum(y .+ x)
       end
3-element Vector{Float64}:
 1.0
 0.5
 0.3333333333333333

julia> Diffractor.gradient([1,2,3], 0) do x, s
         y = log.(x)  # this needs x's value
         fill!(x, s)  # poisons x, and s
         sum(y .+ x)
       end
pullback for fill! got ZeroTangent()
(NotImplemented(Main, #= REPL[22]:6 =#, mutated), NotImplemented(Main, #= REPL[22]:6 =#, no input))

# New example

julia> Zygote.gradient([1 2; 3 4], [5,6], 7) do x, y, z
         xy = x * y
         y2 = fill!(y, z)
         sum(xy .+ y2)
       end
pullback for fill! got [1.0, 1.0]
([7.0 7.0; 7.0 7.0], [4.0, 6.0], 2.0)

julia> Diffractor.gradient([1 2; 3 4], [5,6], 7) do x, y, z  # silently wrong about x
         xy = x * y       # x's gradient needs y's value, etc.
         y2 = fill!(y, z) # poisons y, but not x
         sum(xy .+ y2)
       end
pullback for fill! got [1.0, 1.0]
([7.0 7.0; 7.0 7.0], NotImplemented(Main, #= REPL[5]:4 =#, nope), 2.0)

julia> ForwardDiff.gradient([1 2; 3 4]) do x
         y, z = [5,6], 7
         xy = x * y
         y2 = fill!(y, z)
         sum(xy .+ y2)
       end
2×2 Matrix{Int64}:
 5  6
 5  6

@mcabbott
Copy link
Member

mcabbott commented Dec 26, 2022

If you overload _pullback, then this is always called, even with no return. This appears to give safe answers on the above examples, often NaN:

function Zygote._pullback(__context__::Zygote.AContext, ::typeof(fill!), x::Array, v)
  old = copy(x)  # could instead just have fill!(x, NaN) on the reverse?
  y = fill!(x, v)
  back(::Nothing) = begin
    copyto!(x, old)  # restore
    (nothing, Zygote.Fill(NaN, size(x)), NaN)  # since we didn't see the return, poison it
  end
  back(dy) = begin
    copyto!(x, old)
    (nothing, Zygote.Fill(NaN, size(x)), sum(dy))  # here we know dv
  end
  return (y, back)
end

Similar for setindex!:

function Zygote._pullback(__context__::Zygote.AContext, ::typeof(Base.setindex!), x::Array, v, ind::Integer...)
  old = x[ind...]
  y = setindex!(x, v, ind...)
  nots = map(_ -> nothing, ind)
  back(::Nothing) = begin
    x[ind...] = old
    (nothing, Zygote.Fill(NaN, size(x)), NaN, nots...)
  end
  back(dy) = begin
    x[ind...] = old
    (nothing, Zygote.Fill(NaN, size(x)), dy, nots...)  # setindex! returns the value
  end
  return (y, back)
end

Zygote.gradient([1,2,3.0], 4) do x, y
  x[1] = y^2
  sum(x .* y)
end  # should be ([0,4,4], 53), in fact all NaN

@oxinabox
Copy link
Member

Since the above, Zygote seems to have been taught to silently ignore NotImplemented, giving more ways to get wrong answers.

Damn it Zygote.
Do we have an issue open downstream to "Please don't do this"?

@ToucheSir
Copy link
Contributor

We had FluxML/Zygote.jl#1227 but it was closed, I've just re-opened it. The problem was that FluxML/Zygote.jl#1204 happened, which lead to FluxML/Zygote.jl#1205. As I mentioned in the issue, there doesn't seem to be a more incremental fix here than doing all of FluxML/Zygote.jl#603. Am I missing a better solution?

@ToucheSir
Copy link
Contributor

Thinking about it a bit more, could we get away with just switching over Zygote's zero types (i.e. nothing)? The biggest obstacle I can think of is getting rid of internal pullback calls in higher-order rules to avoid premature conversion.

@oxinabox
Copy link
Member

oxinabox commented Mar 1, 2023

@mzgubic (with my help) tried to switch over Zygote's types a few years ago.
It got hairy fast. Far more complex than you might think.
Though now that Zygote has fewer rules that might need changing it might be easier

@ToucheSir
Copy link
Contributor

In the spirit of baby steps, I've filed FluxML/Zygote.jl#1385 to provide a better base for future attempts at this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants