Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No adjoint for Base.Iterators.ProductIterator #421

NMUrban opened this issue Dec 6, 2019 · 10 comments · Fixed by #785

No adjoint for Base.Iterators.ProductIterator #421

NMUrban opened this issue Dec 6, 2019 · 10 comments · Fixed by #785


Copy link

NMUrban commented Dec 6, 2019

It's not possible to differentiate through a matrix constructed from a multidimensional array comprehension, which returns a ProductIterator type with no adjoint. MWE:

julia> using Zygote
julia> X = rand(4);
julia> Σ(λ) = [exp(-((x-x′)/λ)^2) for x in X, x′ in X]
Σ (generic function with 1 method)

julia> gradient(λ -> sum(Σ(λ)), 0.2)
ERROR: Need an adjoint for constructor Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}}. Gradient is of type Array{Tuple{Float64,Float64},2}

julia> gradient(λ -> sum([x^2/λ for x in X]), 0.2) # simple generator OK

Originally posted by @mcabbott in #377 (comment)

Copy link

GodotMisogi commented Nov 15, 2020

Is there any workaround for this? I think differentiation of such constructs is definitely necessary, especially for matrix construction as array mutation is not supported.

Copy link

Copy link

GodotMisogi commented Nov 16, 2020

Thanks, but this does not seem to work with custom types for Iterators.product(), as zero() is not implemented for the custom type when calling sum(y->y[n], dy; dims=dims). I'm very new to the codebase, but I believe the implementation of zero() for custom types should not be necessary.

Consider the following example based on the documentation:

import Base: +, -, zero
import Base.Iterators
using StaticArrays

struct Point

width(p::Point) = p.x
height(p::Point) = p.y

a::Point + b::Point = Point(width(a) + width(b), height(a) + height(b))
a::Point - b::Point = Point(width(a) - width(b), height(a) - height(b))
dist(p::Point) = sqrt(width(p)^2 + height(p)^2)

using Zygote

@Zygote.adjoint (T :: Type{<:SVector})(xs :: Number ...) = T(xs...), dv -> (nothing, dv...)
@Zygote.adjoint (T :: Type{<:SVector})(x :: AbstractVector) = T(x), dv -> (nothing, dv)

@Zygote.adjoint enumerate(xs) = enumerate(xs), diys -> (map(last, diys),)

_ndims(::Base.HasShape{d}) where {d} = d
_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1

@Zygote.adjoint function Iterators.product(xs...)
                    d = 1
                    Iterators.product(xs...), dy -> ntuple(length(xs)) do n
                        nd = _ndims(xs[n])
                        dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
                        d += nd
                        func = sum(y->y[n], dy; dims=dims)
                        ax = axes(xs[n])
                        reshape(func, ax)

@Zygote.adjoint width(p::Point) = p.x, x̄ -> (Point(x̄, 0),)
@Zygote.adjoint height(p::Point) = p.y, ȳ -> (Point(0, ȳ),)
@Zygote.adjoint Point(a, b) = Point(a, b), p̄ -> (p̄.x, p̄.y)
# zero(p :: Point) = Point(0, 0)

xs = Point.(1:5, 5:9)

function something(xs) 
    sum([ width(p1) + height(p2) for p1 in xs, p2 in xs ])

gradient(something, xs)

This yields the following error:

ERROR: LoadError: MethodError: no method matching zero(::Point)
Closest candidates are:
  zero(::Type{LibGit2.GitHash}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LibGit2\src\oid.jl:220  
  zero(::Type{Pkg.Resolve.VersionWeight}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Resolve\versionweights.jl:15
  zero(::Type{Dates.DateTime}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Dates\src\types.jl:404   
 [1] _reducedim_init(::var"#20#24"{Int64}, ::typeof(Base.add_sum), ::typeof(zero), ::typeof(sum), ::Array{Tuple{Point,Point},2}, ::Tuple{Int64}) at .\reducedim.jl:121
 [2] reducedim_init(::Function, ::typeof(Base.add_sum), ::Array{Tuple{Point,Point},2}, ::Tuple{Int64}) at .\reducedim.jl:109  
 [3] _mapreduce_dim(::Function, ::Function, ::NamedTuple{(),Tuple{}}, ::Array{Tuple{Point,Point},2}, ::Tuple{Int64}) at .\reducedim.jl:324
 [4] #mapreduce#620 at .\reducedim.jl:310 [inlined]
 [5] _sum at .\reducedim.jl:749 [inlined]
 [6] #sum#628 at .\reducedim.jl:723 [inlined]
 [7] (::var"#18#22"{Array{Tuple{Point,Point},2},Tuple{Array{Point,1},Array{Point,1}}})(::Int64) at d:\Academia\AeroMDAO\tests\zygote.jl:35
 [8] ntuple(::var"#18#22"{Array{Tuple{Point,Point},2},Tuple{Array{Point,1},Array{Point,1}}}, ::Int64) at .\ntuple.jl:18       
 [9] (::var"#17#21"{Tuple{Array{Point,1},Array{Point,1}}})(::Array{Tuple{Point,Point},2}) at d:\Academia\AeroMDAO\tests\zygote.jl:31
 [10] (::var"#187#back#25"{var"#17#21"{Tuple{Array{Point,1},Array{Point,1}}}})(::Array{Tuple{Point,Point},2}) at C:\Users\godot\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [11] something at d:\Academia\AeroMDAO\tests\zygote.jl:50 [inlined]
 [12] (::typeof(∂(something)))(::Float64) at C:\Users\godot\.julia\packages\Zygote\c0awc\src\compiler\interface2.jl:0
 [13] (::Zygote.var"#41#42"{typeof(∂(something))})(::Float64) at C:\Users\godot\.julia\packages\Zygote\c0awc\src\compiler\interface.jl:45
 [14] gradient(::Function, ::Array{Point,1}) at C:\Users\godot\.julia\packages\Zygote\c0awc\src\compiler\interface.jl:54      
 [15] top-level scope at d:\Academia\AeroMDAO\tests\zygote.jl:53
 [16] include_string(::Function, ::Module, ::String, ::String) 
at .\loading.jl:1091
 [17] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at .\essentials.jl:710
 [18] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at .\essentials.jl:709
 [19] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:83
 [20] (::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool})() at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:45
 [21] withpath(::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool}, ::String) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\repl.jl:118
 [22] (::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool})() at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:43
 [23] hideprompt(::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool}) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\repl.jl:36
 [24] repl_runcode_request(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.ReplRunCodeRequestParams) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:23
 [25] dispatch_msg(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.JSONRPC.MsgDispatcher, ::Dict{String,Any}) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\JSONRPC\src\typed.jl:66
 [26] macro expansion at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\VSCodeServer.jl:95 [inlined]
 [27] (::VSCodeServer.var"#61#63"{Bool,String})() at .\task.jl:356
in expression starting at d:\Academia\AeroMDAO\tests\zygote.jl:53

Copy link

What's the objection to zero(::Point)? With @adjoint width(p::Point) = p.x, x̄ -> (Point(x̄, 0),) you are using this type for the tangent space, which must be a vector space, so it doesn't seem unreasonable that it must have + and zero defined.

Without those definitions, the default is to use a NamedTuple, and these also don't have zero. So maybe there's another way around this.

Copy link

You're right. I tried the zero(::Point) implementation and it does work (and makes sense in terms of the algebra). My original issue was that not every struct necessarily admits an algebraic structure even as weak as a vector space, but then I realised there's really nothing to differentiate then. Your implementation appears to be working as intended, thanks a lot!

Copy link

cossio commented Feb 23, 2021

Bump. Hitting this issue right now. Is there a workaround?

julia> using Zygote, Flux
julia> X = randn(5); Y = randn(5);
julia> ps = Flux.params(X,Y);
julia> gs = gradient(ps) do
       sum([sin(x*y) for x in X, y in Y])

ERROR: Need an adjoint for constructor Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}}. Gradient is of type Array{Tuple{Float64,Float64},2}
[1] error(::String) at ./error.jl:33
[2] (::Zygote.Jnew{Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}},Nothing,false})(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/lib/lib.jl:311
[3] (::Zygote.var"#1728#back#167"{Zygote.Jnew{Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}},Nothing,false}})(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] ProductIterator at ./iterators.jl:892 [inlined]
[5] (::typeof(∂(Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}})))(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[6] ProductIterator at ./iterators.jl:892 [inlined]
[7] (::typeof(∂(Base.Iterators.ProductIterator)))(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[8] product at ./iterators.jl:910 [inlined]
[9] (::typeof(∂(product)))(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[10] #3 at ./REPL[5]:2 [inlined]
[11] (::typeof(∂(#3)))(::Float64) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[12] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(#3))})(::Float64) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:172
[13] gradient(::Function, ::Params) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:49
[14] top-level scope at REPL[5]:1


Replacing the comprehension by a broadcast works:

gs = gradient(ps) do

Maybe a general solution can be derived from this.

Copy link

pfarndt commented Feb 25, 2021

You could borrow it from here:

This worked very well for me for a ProductIterator. Is there a reason why this code does not make it into Zygote? I think a lot of people will stumble into these issues.

Copy link

oschulz commented Sep 16, 2021

Bump, just ran into this, too.

Copy link

If anyone wants to help fix up #785, or pull this part out as a smaller PR, that would be great.

Copy link

Ran into it recently too, so I'll try to get a working patch out with smaller changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
None yet
None yet

Successfully merging a pull request may close this issue.

7 participants