Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote compatibility does not work for Julia 1.10+ #70

Open
mrazomej opened this issue Feb 22, 2024 · 5 comments
Open

Zygote compatibility does not work for Julia 1.10+ #70

mrazomej opened this issue Feb 22, 2024 · 5 comments

Comments

@mrazomej
Copy link

I brought up this as an issue in the Zygote.jl repository, but it might belong here:

Zygote fails to use rrules defined by TaylorDiff when run with Julia 1.10+

In Julia 1.9.4:

import Zygote
import TaylorDiff

TaylorDiff.derivative(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [0.0, 1.0, 0.0], :1) # works

Zygote.withgradient([1.0, 2.0, 3.0]) do x
    TaylorDiff.derivative(x -> sum(x .^ 2), x, [0.0, 1.0, 0.0], :1)
end # works, returning (val = 4.0, grad = ([0.0, 2.0, 0.0],))

In Julia 1.10+:

import Zygote
import TaylorDiff

TaylorDiff.derivative(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [0.0, 1.0, 0.0], :1) # works

Zygote.withgradient([1.0, 2.0, 3.0]) do x
    TaylorDiff.derivative(x -> sum(x .^ 2), x, [0.0, 1.0, 0.0], :1)
end # doesn't work

The last line gives the following error:

ERROR: Need an adjoint for constructor TaylorDiff.TaylorScalar{Float64, 2}. Gradient is of type TaylorDiff.TaylorScalar{Float64, 2}
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] (::Zygote.Jnew{TaylorDiff.TaylorScalar{Float64, 2}, Nothing, false})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:330
  [3] (::Zygote.var"#2210#back#313"{Zygote.Jnew{}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [4] TaylorScalar
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/scalar.jl:17 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [6] TaylorScalar
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/scalar.jl:22 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/primitive.jl:143 [inlined]
  [8] ^
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/primitive.jl:128 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [10] literal_pow
    @ ./intfuncs.jl:351 [inlined]
 [11] (::Zygote.var"#1368#1374")(::Tuple{…}, ȳ₁::TaylorDiff.TaylorScalar{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:218
 [12] #4
    @ ./generator.jl:36 [inlined]
 [13] iterate(g::Base.Generator, s::Vararg{Any})
    @ Base ./generator.jl:47 [inlined]
 [14] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Base.var"#4#5"{Zygote.var"#1368#1374"}})
    @ Base ./array.jl:834
 [15] map
    @ ./abstractarray.jl:3406 [inlined]
 [16] (::Zygote.var"#∇broadcasted#1373"{})(ȳ::FillArrays.Fill{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:218
 [17] #4117#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [18] #291
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206 [inlined]
 [19] #2169#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [20] broadcasted
    @ ./broadcast.jl:1347 [inlined]
 [21] #8
    @ ./REPL[4]:2 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::TaylorDiff.TaylorScalar{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [23] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:37 [inlined]
 [24] derivative
    @ ~/.julia/packages/TaylorDiff/zNnz2/src/derivative.jl:23 [inlined]
 [25] #7
    @ ./REPL[4]:2 [inlined]
 [26] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [27] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [28] withgradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:0
 [29] top-level scope
    @ REPL[4]:1
@mBarreau
Copy link
Contributor

I confirm this issue. Have you found a workaround?

@tansongchen
Copy link
Member

Sorry haven't got a chance to look at the breaking changes of v1.10. It looks strange since rrules are just some function overloading, shouldn't depend too much on language core...

@YichengDWu
Copy link
Contributor

@ToucheSir Could you take a look?

@ToucheSir
Copy link

@mrazomej when cross-posting issues, please link back to the original so that readers have some context. In this case, there's plenty of background in FluxML/Zygote.jl#1502. We also have a Slack discussion in the #autodiff channel about ideas to fix this. It may be that changes made in the Julia compiler need to be reverted.

@mBarreau
Copy link
Contributor

Hi all,

If we want to solve this issue, we will need to put more effort into it.
From the slack discussion, it appears that we need to provide the simplest MWE possible. If we use Zygote.gradients on Taylor diff.derivative, we indeed get the error. How can we delete the dependency on TaylorDiff? I don't understand well enough the package to do this. @tansongchen , can you do that?

Then we can post an issue in Zygote.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants