Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff for A -> A * A' does not give hermitian result for complex A #1456

Open
simsurace opened this issue May 20, 2024 · 4 comments
Open

Comments

@simsurace
Copy link
Contributor

simsurace commented May 20, 2024

This was found while writing tests for #1307, where the function below is composed with cholesky:

square(A) = A * adjoint(A)

A = rand(ComplexF64, 5, 5)
ishermitian(square(A)) # true

dA = rand(ComplexF64, 5, 5)
S, dS = autodiff(Forward, square, Duplicated, Duplicated(A, dA))
S  square(A) # true
ishermitian(S) # false

test_forward(square, Duplicated, (A, Duplicated)) # passes

So somehow the forward mode does not produce the same result as the function for square, but this is not caught by EnzymeTestUtils.

Something similar for reverse mode

square(A) = A * adjoint(A)
square!(S, A) = (mul!(S, A, adjoint(A)); return nothing)

A = rand(ComplexF64, 5, 5)
dA = zeros(ComplexF64, 5, 5)
S = zeros(ComplexF64, 5, 5)
dS = ones(ComplexF64, 5, 5)

autodiff(Reverse, square!, Const, Duplicated(S, dS), Duplicated(A, dA))
S  square(A) # true
ishermitian(S) # false

test_reverse(square!, Const, (S, Duplicated), (A, Duplicated)) # passes
@simsurace simsurace changed the title Forward mode for A -> A * A' does not give hermitian result for complex A Autodiff for A -> A * A' does not give hermitian result for complex A May 20, 2024
simsurace added a commit to simsurace/Enzyme.jl that referenced this issue May 20, 2024
@wsmoses
Copy link
Member

wsmoses commented May 21, 2024

Looks like its hermition up to floating point-level precision.

julia> square(A)
5×5 Matrix{ComplexF64}:
 2.70575+0.0im        2.08828-0.483085im  2.06124+0.320394im   1.98721+0.052077im   2.05039+0.0833814im
 2.08828+0.483085im   2.44817+0.0im       2.11651+1.24836im     2.0555+0.458592im   2.04032+0.926832im
 2.06124-0.320394im   2.11651-1.24836im   3.83884+0.0im        2.57611-0.0681123im  2.66486-0.319575im
 1.98721-0.052077im    2.0555-0.458592im  2.57611+0.0681123im  2.99428+0.0im        2.49473+0.129998im
 2.05039-0.0833814im  2.04032-0.926832im  2.66486+0.319575im   2.49473-0.129998im   3.13973+0.0im

julia> S
5×5 Matrix{ComplexF64}:
 2.70575+3.6097e-17im  2.08828-0.483085im     2.06124+0.320394im     1.98721+0.052077im    2.05039+0.0833814im
 2.08828+0.483085im    2.44817-6.02588e-18im  2.11651+1.24836im       2.0555+0.458592im    2.04032+0.926832im
 2.06124-0.320394im    2.11651-1.24836im      3.83884+4.38384e-18im  2.57611-0.0681123im   2.66486-0.319575im
 1.98721-0.052077im     2.0555-0.458592im     2.57611+0.0681123im    2.99428+4.6523e-17im  2.49473+0.129998im
 2.05039-0.0833814im   2.04032-0.926832im     2.66486+0.319575im     2.49473-0.129998im    3.13973+1.95764e-17im

julia> square(A)-S
5×5 Matrix{ComplexF64}:
          0.0-3.6097e-17im            0.0+0.0im                   0.0-5.55112e-17im   2.22045e-16-3.46945e-17im  -4.44089e-16+1.249e-16im
          0.0+0.0im                   0.0+6.02588e-18im           0.0+0.0im          -4.44089e-16-5.55112e-17im   4.44089e-16+0.0im
          0.0+5.55112e-17im           0.0+0.0im           4.44089e-16-4.38384e-18im           0.0-1.11022e-16im  -4.44089e-16-5.55112e-17im
  2.22045e-16+3.46945e-17im  -4.44089e-16+5.55112e-17im           0.0+1.11022e-16im   4.44089e-16-4.6523e-17im            0.0-2.77556e-17im
 -4.44089e-16-1.249e-16im     4.44089e-16+0.0im          -4.44089e-16+5.55112e-17im           0.0+2.77556e-17im           0.0-1.95764e-17im

Per the warning

┌ Warning: Using fallback BLAS replacements for (["zgemm_64_", "zherk_64_"]), performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/kqxyC/src/utils.jl:59

this is known that the implementation of zgemm will be replaced by a different implementation, that apparently here results in some numeric differences within a reasonable tolerance.

@wsmoses
Copy link
Member

wsmoses commented May 21, 2024

I actually think the behavior here is reasonable, and the test utils properly only check within a tolerance.

@simsurace is there a reason why this is problematic?

@simsurace
Copy link
Contributor Author

Ah, makes sense. What is needed to not have to rely on these fallbacks? At least for forward mode, what is the reason for not just calling the same function that is being passed? I understand that there must be a fallback to generate LLVM from to then compute the derivative, but the primal could still use the non-fallback or would that lead to problems?

Of course one can work around those but the composition of this function with something relying on the intermediate result being hermitian currently fails.

@wsmoses
Copy link
Member

wsmoses commented Jun 5, 2024

So now for forward mode you don't have it for complex, but for reals it won't use the fallback now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants