Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No adjoint for Base.Iterators.ProductIterator #421

Closed
NMUrban opened this issue Dec 6, 2019 · 10 comments · Fixed by #785
Closed

No adjoint for Base.Iterators.ProductIterator #421

NMUrban opened this issue Dec 6, 2019 · 10 comments · Fixed by #785

Comments

@NMUrban
Copy link

NMUrban commented Dec 6, 2019

It's not possible to differentiate through a matrix constructed from a multidimensional array comprehension, which returns a ProductIterator type with no adjoint. MWE:

julia> using Zygote
julia> X = rand(4);
julia> Σ(λ) = [exp(-((x-x′)/λ)^2) for x in X, x′ in X]
Σ (generic function with 1 method)

julia> gradient(λ -> sum(Σ(λ)), 0.2)
ERROR: Need an adjoint for constructor Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}}. Gradient is of type Array{Tuple{Float64,Float64},2}

julia> gradient(λ -> sum([x^2/λ for x in X]), 0.2) # simple generator OK
(-46.45199007830025,)

Originally posted by @mcabbott in #377 (comment)

@GodotMisogi
Copy link

GodotMisogi commented Nov 15, 2020

Is there any workaround for this? I think differentiation of such constructs is definitely necessary, especially for matrix construction as array mutation is not supported.

@mcabbott
Copy link
Member

@GodotMisogi
Copy link

GodotMisogi commented Nov 16, 2020

Thanks, but this does not seem to work with custom types for Iterators.product(), as zero() is not implemented for the custom type when calling sum(y->y[n], dy; dims=dims). I'm very new to the codebase, but I believe the implementation of zero() for custom types should not be necessary.

Consider the following example based on the documentation:

##
import Base: +, -, zero
import Base.Iterators
using StaticArrays

struct Point
  x::Float64
  y::Float64
end

width(p::Point) = p.x
height(p::Point) = p.y

a::Point + b::Point = Point(width(a) + width(b), height(a) + height(b))
a::Point - b::Point = Point(width(a) - width(b), height(a) - height(b))
dist(p::Point) = sqrt(width(p)^2 + height(p)^2)

##
using Zygote

@Zygote.adjoint (T :: Type{<:SVector})(xs :: Number ...) = T(xs...), dv -> (nothing, dv...)
@Zygote.adjoint (T :: Type{<:SVector})(x :: AbstractVector) = T(x), dv -> (nothing, dv)

@Zygote.adjoint enumerate(xs) = enumerate(xs), diys -> (map(last, diys),)

_ndims(::Base.HasShape{d}) where {d} = d
_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1

@Zygote.adjoint function Iterators.product(xs...)
                    d = 1
                    Iterators.product(xs...), dy -> ntuple(length(xs)) do n
                        nd = _ndims(xs[n])
                        dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
                        d += nd
                        func = sum(y->y[n], dy; dims=dims)
                        ax = axes(xs[n])
                        reshape(func, ax)
                    end
                end

@Zygote.adjoint width(p::Point) = p.x, x̄ -> (Point(x̄, 0),)
@Zygote.adjoint height(p::Point) = p.y, ȳ -> (Point(0, ȳ),)
@Zygote.adjoint Point(a, b) = Point(a, b), p̄ -> (p̄.x, p̄.y)
# zero(p :: Point) = Point(0, 0)

##
xs = Point.(1:5, 5:9)

function something(xs) 
    sum([ width(p1) + height(p2) for p1 in xs, p2 in xs ])
end

gradient(something, xs)

This yields the following error:

ERROR: LoadError: MethodError: no method matching zero(::Point)
Closest candidates are:
  zero(::Type{LibGit2.GitHash}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LibGit2\src\oid.jl:220  
  zero(::Type{Pkg.Resolve.VersionWeight}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Pkg\src\Resolve\versionweights.jl:15
  zero(::Type{Dates.DateTime}) at D:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\Dates\src\types.jl:404   
  ...
Stacktrace:
 [1] _reducedim_init(::var"#20#24"{Int64}, ::typeof(Base.add_sum), ::typeof(zero), ::typeof(sum), ::Array{Tuple{Point,Point},2}, ::Tuple{Int64}) at .\reducedim.jl:121
 [2] reducedim_init(::Function, ::typeof(Base.add_sum), ::Array{Tuple{Point,Point},2}, ::Tuple{Int64}) at .\reducedim.jl:109  
 [3] _mapreduce_dim(::Function, ::Function, ::NamedTuple{(),Tuple{}}, ::Array{Tuple{Point,Point},2}, ::Tuple{Int64}) at .\reducedim.jl:324
 [4] #mapreduce#620 at .\reducedim.jl:310 [inlined]
 [5] _sum at .\reducedim.jl:749 [inlined]
 [6] #sum#628 at .\reducedim.jl:723 [inlined]
 [7] (::var"#18#22"{Array{Tuple{Point,Point},2},Tuple{Array{Point,1},Array{Point,1}}})(::Int64) at d:\Academia\AeroMDAO\tests\zygote.jl:35
 [8] ntuple(::var"#18#22"{Array{Tuple{Point,Point},2},Tuple{Array{Point,1},Array{Point,1}}}, ::Int64) at .\ntuple.jl:18       
 [9] (::var"#17#21"{Tuple{Array{Point,1},Array{Point,1}}})(::Array{Tuple{Point,Point},2}) at d:\Academia\AeroMDAO\tests\zygote.jl:31
 [10] (::var"#187#back#25"{var"#17#21"{Tuple{Array{Point,1},Array{Point,1}}}})(::Array{Tuple{Point,Point},2}) at C:\Users\godot\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49
 [11] something at d:\Academia\AeroMDAO\tests\zygote.jl:50 [inlined]
 [12] (::typeof(∂(something)))(::Float64) at C:\Users\godot\.julia\packages\Zygote\c0awc\src\compiler\interface2.jl:0
 [13] (::Zygote.var"#41#42"{typeof(∂(something))})(::Float64) at C:\Users\godot\.julia\packages\Zygote\c0awc\src\compiler\interface.jl:45
 [14] gradient(::Function, ::Array{Point,1}) at C:\Users\godot\.julia\packages\Zygote\c0awc\src\compiler\interface.jl:54      
 [15] top-level scope at d:\Academia\AeroMDAO\tests\zygote.jl:53
 [16] include_string(::Function, ::Module, ::String, ::String) 
at .\loading.jl:1091
 [17] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at .\essentials.jl:710
 [18] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at .\essentials.jl:709
 [19] inlineeval(::Module, ::String, ::Int64, ::Int64, ::String; softscope::Bool) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:83
 [20] (::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool})() at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:45
 [21] withpath(::VSCodeServer.var"#43#45"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool}, ::String) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\repl.jl:118
 [22] (::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool})() at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:43
 [23] hideprompt(::VSCodeServer.var"#42#44"{VSCodeServer.ReplRunCodeRequestParams,String,Int64,Int64,String,Module,Bool,Bool}) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\repl.jl:36
 [24] repl_runcode_request(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.ReplRunCodeRequestParams) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\eval.jl:23
 [25] dispatch_msg(::VSCodeServer.JSONRPC.JSONRPCEndpoint{Base.PipeEndpoint,Base.PipeEndpoint}, ::VSCodeServer.JSONRPC.MsgDispatcher, ::Dict{String,Any}) at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\JSONRPC\src\typed.jl:66
 [26] macro expansion at c:\Users\godot\.vscode\extensions\julialang.language-julia-1.0.8\scripts\packages\VSCodeServer\src\VSCodeServer.jl:95 [inlined]
 [27] (::VSCodeServer.var"#61#63"{Bool,String})() at .\task.jl:356
in expression starting at d:\Academia\AeroMDAO\tests\zygote.jl:53

@mcabbott
Copy link
Member

What's the objection to zero(::Point)? With @adjoint width(p::Point) = p.x, x̄ -> (Point(x̄, 0),) you are using this type for the tangent space, which must be a vector space, so it doesn't seem unreasonable that it must have + and zero defined.

Without those definitions, the default is to use a NamedTuple, and these also don't have zero. So maybe there's another way around this.

@GodotMisogi
Copy link

You're right. I tried the zero(::Point) implementation and it does work (and makes sense in terms of the algebra). My original issue was that not every struct necessarily admits an algebraic structure even as weak as a vector space, but then I realised there's really nothing to differentiate then. Your implementation appears to be working as intended, thanks a lot!

@cossio
Copy link
Contributor

cossio commented Feb 23, 2021

Bump. Hitting this issue right now. Is there a workaround?

julia> using Zygote, Flux
julia> X = randn(5); Y = randn(5);
julia> ps = Flux.params(X,Y);
julia> gs = gradient(ps) do
       sum([sin(x*y) for x in X, y in Y])
       end

ERROR: Need an adjoint for constructor Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}}. Gradient is of type Array{Tuple{Float64,Float64},2}
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.Jnew{Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}},Nothing,false})(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/lib/lib.jl:311
[3] (::Zygote.var"#1728#back#167"{Zygote.Jnew{Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}},Nothing,false}})(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] ProductIterator at ./iterators.jl:892 [inlined]
[5] (::typeof(∂(Base.Iterators.ProductIterator{Tuple{Array{Float64,1},Array{Float64,1}}})))(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[6] ProductIterator at ./iterators.jl:892 [inlined]
[7] (::typeof(∂(Base.Iterators.ProductIterator)))(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[8] product at ./iterators.jl:910 [inlined]
[9] (::typeof(∂(product)))(::Array{Tuple{Float64,Float64},2}) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[10] #3 at ./REPL[5]:2 [inlined]
[11] (::typeof(∂(#3)))(::Float64) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
[12] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(#3))})(::Float64) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:172
[13] gradient(::Function, ::Params) at /home/cossio/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:49
[14] top-level scope at REPL[5]:1

Update:

Replacing the comprehension by a broadcast works:

gs = gradient(ps) do
       sum(sin.(X*Y'))
end

Maybe a general solution can be derived from this.

@pfarndt
Copy link

pfarndt commented Feb 25, 2021

You could borrow it from here: https://github.com/FluxML/Zygote.jl/pull/785/files#diff-a9e025ac90a30d27e7512546971c5d92ea7c3496ba759336ae6bf1cace6db4b2R240

This worked very well for me for a ProductIterator. Is there a reason why this code does not make it into Zygote? I think a lot of people will stumble into these issues.

@oschulz
Copy link

oschulz commented Sep 16, 2021

Bump, just ran into this, too.

@mcabbott
Copy link
Member

If anyone wants to help fix up #785, or pull this part out as a smaller PR, that would be great.

@DhairyaLGandhi
Copy link
Member

Ran into it recently too, so I'll try to get a working patch out with smaller changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants