Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

broadcast involving NTuple errors no method matching unbroadcast(::Tuple{… AND possible fix? #975

Closed
manuelbb-upb opened this issue May 19, 2021 · 1 comment · Fixed by #977

Comments

@manuelbb-upb
Copy link

manuelbb-upb commented May 19, 2021

Hi,
I have a mwe that looks something like

import Zygote
f = x -> prod( x .^ (1, 2) )
Zygote.gradient(f, rand(2))

i.e, the exponents for f are provided as an NTuple.
This will result in an error

MethodError: no method matching unbroadcast(::Tuple{Int64, Int64}, ::Vector{Float64})

I have taken a look into src/lib/unbroadcast.jl.
Although I don't really know what is going on there, this seems to fix the issue (?)

Zygote.unbroadcast( x :: NTuple{N, <:Any} where N, x̄ ) = (Zygote.accum_sum(x̄),)

Edit: There still seems to be an issue with the return type?
When I do

Zygote.gradient(f, [1,2])

the gradient is a Vector{Float64}. Is this expected?

@mcabbott
Copy link
Member

mcabbott commented May 19, 2021

This isn't quite right, because what it affects is the gradient with respect to that tuple, which your example is discarding. Keeping it:

julia> gradient((x,p) -> sum(x .^ p), [1, 2], [1,2])
([1.0, 4.0], [0.0, 2.772588722239781])

julia> gradient((x,p) -> sum(x .^ p), [1, 2], (1,2))
([1.0, 4.0], (2.772588722239781,))

I think what you want is something equivalent to this, but without the collect, i.e. to sum dy over dimensions it has which the tuple doesn't.

julia> @eval Zygote unbroadcast(x::Tuple, dy) = Tuple(unbroadcast(collect(x), dy))
unbroadcast (generic function with 6 methods)

julia> gradient((x,p) -> sum(x .^ p), [1, 2], (1,2))
([1.0, 4.0], (0.0, 2.772588722239781))

But yes, promotion to floats is expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants