Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect sign for imaginary part of derivative of functions returning complex numbers #1308

Closed
giordano opened this issue Feb 25, 2024 · 0 comments · Fixed by #1309
Closed

Comments

@giordano
Copy link
Contributor

giordano commented Feb 25, 2024

julia> using Enzyme

julia> (f, z) = first(first(autodiff(Reverse, f, Active, Active(z))))
∂ (generic function with 1 method)

julia> v = rand(ComplexF64, 3);

julia> .(exp, v)
3-element Vector{ComplexF64}:
 2.4742094004406776 - 0.6831994346902432im
 1.1529144463383987 - 0.31931177102730973im
 1.0208330780674268 - 1.295298037192365im

julia> exp.(v)
3-element Vector{ComplexF64}:
 2.4742094004406776 + 0.6831994346902432im
 1.1529144463383987 + 0.31931177102730973im
 1.0208330780674268 + 1.295298037192365im

julia> conj.(.(exp, v))  exp.(v)
true

Note that the sign of the imaginary part is always wrong, Enzyme is basically always returning the conjugate of the correct result. The same happens with more complicate functions:

julia> f(z) = z ^ z
f (generic function with 1 method)

julia> f′(z) = f(z) * (log(z) + one(z))
f′ (generic function with 1 method)

julia> v = rand(ComplexF64, 1000);

julia> .(f, v)  f′.(v)
false

julia> conj.(.(f, v))  f′.(v)
true

BTW, I'm not sure #547 is still an issue, I could differentiate power of complex numbers (except the result is wrong because of the conjugation).

For the record:

julia> versioninfo()
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 8 × Apple M1
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
  Threads: 1 on 4 virtual cores

(jl_30Oveo) pkg> st -m Enzyme Enzyme_jll
Status `/private/var/folders/v2/hmy3kzgj4tb3xsy8qkltxd0r0000gn/T/jl_30Oveo/Manifest.toml`
  [7da242da] Enzyme v0.11.17
  [7cc45869] Enzyme_jll v0.0.102+0

Slightly related (although the underlying issue is likely different), when the active argument is Complex{T<:Integer} then the result of the derivative is always zero(Complex{T}):

julia> (f, im)
Complex(false,false)

julia> (f, complex(4))
0 + 0im

julia> (exp, im)
Complex(false,false)

julia> (exp, complex(2, 2))
0 + 0im

Results get better, but still conjugated, when converting the numbers to floating point:

julia> (f, float(im))
0.20787957635076193 - 0.3265364749474561im

julia> (f, complex(4.0))
256.0 + 0.0im

julia> (exp, float(im))
0.5403023058681398 - 0.8414709848078965im

julia> (exp, complex(2.0, 2.0))
-3.074932320639359 - 6.71884969742825im

As far as I understand the output of

julia> Enzyme.Compiler.enzyme_code_llvm(exp, Active, Tuple{Active{Complex{Int}}})
;  @ complex.jl within `diffejulia_exp_6619_inner_2wrap`
; Function Attrs: alwaysinline nofree
define [1 x [1 x [2 x i64]]] @diffejulia_exp_6619_inner_2wrap([2 x i64] %0, [2 x double] %1) #2 {
entry:
  %2 = call {}*** inttoptr (i64 7000240380 to {}*** (i64)*)(i64 261) #6
  %ptls_field.i25.i = getelementptr inbounds {}**, {}*** %2, i64 2
  %3 = bitcast {}*** %ptls_field.i25.i to i64***
  %ptls_load.i2627.i = load i64**, i64*** %3, align 8
  %4 = getelementptr inbounds i64*, i64** %ptls_load.i2627.i, i64 2
  %safepoint.i.i = load i64*, i64** %4, align 8
  fence syncscope("singlethread") seq_cst
; ┌ @ complex.jl within `exp` @ complex.jl:692
   %5 = load volatile i64, i64* %safepoint.i.i, align 8
   fence syncscope("singlethread") seq_cst
   ret [1 x [1 x [2 x i64]]] zeroinitializer
; └
}

Enzyme is just returning a tuple of zeros whatever is the input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant