Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

getindex for real input fails with complex sensitivity #376

Closed
sethaxen opened this issue Oct 18, 2019 · 6 comments
Closed

getindex for real input fails with complex sensitivity #376

sethaxen opened this issue Oct 18, 2019 · 6 comments

Comments

@sethaxen
Copy link
Contributor

getindex's adjoint assumes the sensitivity is of the same type as the input. This causes InexactErrors when the input to getindex is real but the sensitivity is complex.

julia> using Zygote
[ Info: Precompiling Zygote [e88e6eb3-aa80-5325-afca-941959d7151f]

julia> x = randn(2)
2-element Array{Float64,1}:
 -1.5316563904446463
 -0.7974154569041462

julia> y, back = Zygote._pullback(x->x[1]*im, x)
(-0.0 - 1.5316563904446463im, (#7))

julia> back(1.0)
ERROR: InexactError: Float64(0.0 - 1.0im)
Stacktrace:
 [1] Real at ./complex.jl:37 [inlined]
 [2] convert at ./number.jl:7 [inlined]
 [3] setindex! at ./array.jl:780 [inlined]
 [4] #838 at /Users/saxen/.julia/packages/Zygote/lRotY/src/lib/array.jl:32 [inlined]
 [5] #2174#back at /Users/saxen/.julia/packages/ZygoteRules/Mmoki/src/adjoint.jl:49 [inlined]
 [6] literal_getindex at /Users/saxen/.julia/packages/Zygote/lRotY/src/lib/lib.jl:77 [inlined]
 [7] #7 at ./REPL[7]:1 [inlined]
 [8] (::typeof((#7)))(::Float64) at /Users/saxen/.julia/packages/Zygote/lRotY/src/compiler/interface2.jl:0
 [9] top-level scope at REPL[8]:1

The culprit appears to be these lines:

Δ′ = _zero(xs)
Δ′[i...] = Δ

@MikeInnes
Copy link
Member

There's no great reason to use setindex! here except that there's no convenient out-of-place setindex. This would be a good reason to implement that function.

@sethaxen
Copy link
Contributor Author

You mean to implement for arrays in Julia base or to have our own _setindex for that purpose?

@MikeInnes
Copy link
Member

Ideally, it would be in base, but that will take too long; it'd be fine to have something in Zygote and switch out to Base if it ever gets added.

@sethaxen
Copy link
Contributor Author

Okay, I guess a simple Zygote-friendly setindex default would be something like this:

function Base.setindex(A::AbstractArray, X, inds...)
    T = promote_eltypeof(A, X)
    A′ = Zygote.Buffer(A, T)
    copyto!(A′, A)
    setindex!(A′, X, inds...)
    return copy(A′)
end

This would certainly resolve the bug raised here, but I wonder if this introduces new issues. e.g., if inds... covers most of A, this double copies. Are there other obvious issues?

@MikeInnes
Copy link
Member

That looks like that right direction. To avoid the double-copy you could just loop over eachindex. If it's easier, it'd also be fine to have an explicit adjoint for setindex, rather than using Buffer.

@sethaxen
Copy link
Contributor Author

sethaxen commented May 11, 2020

This seems to have been fixed sometime:

julia> using Zygote

julia> x = randn(2)
2-element Array{Float64,1}:
 -1.4516208598004956
 -0.3941745254755234

julia> y, back = Zygote._pullback(x->x[1]*im, x)
(-0.0 - 1.4516208598004956im, (#10))

julia> back(1.0)
(nothing, Complex{Float64}[0.0 - 1.0im, 0.0 + 0.0im])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants