Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect Wirtinger for prod operation #744

Closed
ghost opened this issue Jul 27, 2020 · 2 comments
Closed

Incorrect Wirtinger for prod operation #744

ghost opened this issue Jul 27, 2020 · 2 comments
Labels
ChainRules adjoint -> rrule, and further integration

Comments

@ghost
Copy link

ghost commented Jul 27, 2020

The wirtinger function below is taken from the Zygote/complex.

using Zygote

function jacobi(f, x)
    y, back = Zygote.pullback(f, x)
    back(1)[1], back(im)[1]
end

function wirtinger(f, x)
    du, dv = jacobi(f, x)
    (du' + im*dv')/2, (du + im*dv)/2
end

println(wirtinger(x -> prod(x), [1+2im, 3+4im]))
# (Complex{Float64}[3.0 - 4.0im 1.0 - 2.0im], Complex{Float64}[0.0 + 0.0im, 0.0 + 0.0im])

println(wirtinger(x -> x[1] * x[2], [1+2im, 3+4im]))
# (Complex{Float64}[3.0 + 4.0im 1.0 + 2.0im], Complex{Float64}[0.0 + 0.0im, 0.0 + 0.0im])


function myprod(x)
   res = one(eltype(x))
   for v in x
       res *= v
   end
   return res
end

println(wirtinger(x -> myprod(x), [1+2im, 3+4im]))
# (Complex{Float64}[3.0 + 4.0im 1.0 + 2.0im], Complex{Float64}[0.0 + 0.0im, 0.0 + 0.0im])
@mcabbott
Copy link
Member

It's entirely possible that the gradient for prod needs to conjugate somewhere:

https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L253

@sethaxen
Copy link
Contributor

It's entirely possible that the gradient for prod needs to conjugate somewhere:

https://github.com/FluxML/Zygote.jl/blob/master/src/lib/array.jl#L253

Yeah, that's probably it. Both p and xs need to be conjugated. i.e., you need

@adjoint function prod(xs::AbstractArray; dims = :)
  p = prod(xs; dims = dims)
  p, Δ -> (conj.(p ./ xs) .* Δ,)
end

But you probably also want to restrict the type of xs to xs::AbstractArray{<:Union{Real,Complex}}, because that rule assumes multiplication is commutative, which will fail for xs::AbstractArray{<:AbstractMatrix} or xs::AbstractArray{<:Quaternion}.

If someone puts together a PR, I'm happy to review.

@mcabbott mcabbott added the ChainRules adjoint -> rrule, and further integration label Jul 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration
Projects
None yet
Development

No branches or pull requests

2 participants