You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The wirtinger function below is taken from the Zygote/complex.
using Zygote
function jacobi(f, x)
y, back = Zygote.pullback(f, x)
back(1)[1], back(im)[1]
end
function wirtinger(f, x)
du, dv = jacobi(f, x)
(du' + im*dv')/2, (du + im*dv)/2
end
println(wirtinger(x -> prod(x), [1+2im, 3+4im]))
# (Complex{Float64}[3.0 - 4.0im 1.0 - 2.0im], Complex{Float64}[0.0 + 0.0im, 0.0 + 0.0im])
println(wirtinger(x -> x[1] * x[2], [1+2im, 3+4im]))
# (Complex{Float64}[3.0 + 4.0im 1.0 + 2.0im], Complex{Float64}[0.0 + 0.0im, 0.0 + 0.0im])
function myprod(x)
res = one(eltype(x))
for v in x
res *= v
end
return res
end
println(wirtinger(x -> myprod(x), [1+2im, 3+4im]))
# (Complex{Float64}[3.0 + 4.0im 1.0 + 2.0im], Complex{Float64}[0.0 + 0.0im, 0.0 + 0.0im])
The text was updated successfully, but these errors were encountered:
Yeah, that's probably it. Both p and xs need to be conjugated. i.e., you need
@adjointfunctionprod(xs::AbstractArray; dims = :)
p =prod(xs; dims = dims)
p, Δ -> (conj.(p ./ xs) .* Δ,)
end
But you probably also want to restrict the type of xs to xs::AbstractArray{<:Union{Real,Complex}}, because that rule assumes multiplication is commutative, which will fail for xs::AbstractArray{<:AbstractMatrix} or xs::AbstractArray{<:Quaternion}.
If someone puts together a PR, I'm happy to review.
The
wirtinger
function below is taken from the Zygote/complex.The text was updated successfully, but these errors were encountered: