We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NTuple
no method matching unbroadcast(::Tuple{…
Hi, I have a mwe that looks something like
import Zygote f = x -> prod( x .^ (1, 2) ) Zygote.gradient(f, rand(2))
i.e, the exponents for f are provided as an NTuple. This will result in an error
f
MethodError: no method matching unbroadcast(::Tuple{Int64, Int64}, ::Vector{Float64})
I have taken a look into src/lib/unbroadcast.jl. Although I don't really know what is going on there, this seems to fix the issue (?)
Zygote.unbroadcast( x :: NTuple{N, <:Any} where N, x̄ ) = (Zygote.accum_sum(x̄),)
Edit: There still seems to be an issue with the return type? When I do
Zygote.gradient(f, [1,2])
the gradient is a Vector{Float64}. Is this expected?
Vector{Float64}
The text was updated successfully, but these errors were encountered:
This isn't quite right, because what it affects is the gradient with respect to that tuple, which your example is discarding. Keeping it:
julia> gradient((x,p) -> sum(x .^ p), [1, 2], [1,2]) ([1.0, 4.0], [0.0, 2.772588722239781]) julia> gradient((x,p) -> sum(x .^ p), [1, 2], (1,2)) ([1.0, 4.0], (2.772588722239781,))
I think what you want is something equivalent to this, but without the collect, i.e. to sum dy over dimensions it has which the tuple doesn't.
collect
dy
julia> @eval Zygote unbroadcast(x::Tuple, dy) = Tuple(unbroadcast(collect(x), dy)) unbroadcast (generic function with 6 methods) julia> gradient((x,p) -> sum(x .^ p), [1, 2], (1,2)) ([1.0, 4.0], (0.0, 2.772588722239781))
But yes, promotion to floats is expected.
Sorry, something went wrong.
unbroadcast
Successfully merging a pull request may close this issue.
Hi,
I have a mwe that looks something like
i.e, the exponents for
f
are provided as anNTuple
.This will result in an error
I have taken a look into src/lib/unbroadcast.jl.
Although I don't really know what is going on there, this seems to fix the issue (?)
Edit: There still seems to be an issue with the return type?
When I do
the gradient is a
Vector{Float64}
. Is this expected?The text was updated successfully, but these errors were encountered: