Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further correcting grad_eigh to support hermitian matrices and the UPLO kwarg properly #527

Merged
merged 6 commits into from Jul 29, 2019

Conversation

momchilmm
Copy link
Contributor

@momchilmm momchilmm commented Jul 24, 2019

Edit: as discussed in the comments below, the issue with the complex eigenvectors is the gauge, which is arbitrary. However, this updated code should work for complex-valued matrices and functions that do not depend on the gauge. So for example, the test for the complex case uses np.abs(v).

What this update does:

  • fix the vjp computation for numpy.linalg.eigh in accordance with the behavior of the function, which always takes only the upper/lower part of the matrix
  • fix the tests to take random matrices as opposed to random symmetric matrices
  • fix the computation to work for Hermitian matrices as per this pull request, on which I've built

However:

  • the gradient for Hermitian matrices works only for the eigenvalues and not (always) for the eigenvectors
  • so I've added a test, but I take a random complex matrix and check only the eigenvalue gradient flow
  • the problem with eigenvectors probably has to do with their being complex; this has not been dealt with anywhere that I looked: (pyTorch, TensorFlow, or the original reference https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf , where real eigenvectors are also assumed to be real in the tests)

The gradient for the eigenvectors does not pass a general test. However, it works in some cases. For example, this code

import autograd.numpy as npa
from autograd import grad

def fn(a):
    # Define an array with some random operations 
    mat = npa.array([[(1+1j)*a, 2, a], 
                    [1j*a, 2 + npa.abs(a + 1j), 1], 
                    [npa.conj(a), npa.exp(a), npa.abs(a)]])
    [eigs, vs] = npa.linalg.eigh(mat)
    return npa.abs(vs[0, 0])

a = 2.1 + 1.1j # Some random test value

# Compute the numerical gradient of fn(a)
grad_num = (fn(a + 1e-5)-fn(a))/1e-5 - 1j*(fn(a + 1e-5*1j)-fn(a))/1e-5

print('Autograd gradient:  ', grad(fn)(a))
print('Numerical gradient: ', grad_num)
print('Difference:         ', npa.linalg.norm(grad(fn)(a)-grad_num))

returns a difference smaller than 1e-6 for any individual component of vs that is put in the return statement. However, it breaks for a more complicated function, e.g. return npa.abs(vs[0, 0] + vs[1, 1]).

It would be great if someone can address this further. Still, for now this PR is a significant improvement in the behavior of the linalg.eigh function.

steinbrecher and others added 2 commits January 12, 2019 20:21
The version of grad_eigh used previously only supported real symmetric inputs to eigh.

Changing v to conj(v) in two places makes this more general, allowing eigh to support
arbitrary hermitian matrices.
@GiggleLiu
Copy link

GiggleLiu commented Jul 25, 2019

I think the loss you defined npa.abs(vs[0, 0] + vs[1, 1]) is not proper.
Notice vs[:,i]*exp(1j*theta) with an arbituary theta is an eigenvector of eigs[i], so that the gradient is ill-defined for this loss is dependant on theta. npa.abs(vs[0, 0]) works because it cancels the phase (i.e. theta). This is also known as the gauge problem,

The good news is, normaly, we don't use a loss with phase dependancy, since the loss itself is not so well defined. This is only a question of how to test it, here is an example (in Julia though) of testing:
https://github.com/GiggleLiu/BackwardsLinalg.jl/blob/51869b0255052ad03bf8801088768cd1ff10df59/test/svd.jl#L24

@momchilmm
Copy link
Contributor Author

I see, I did realize that the gauge might be a problem, but didn't investigate further. But I think there's something more to that. It doesn't work even with a loss like npa.abs(vs[1, 0] + vs[1, 1]), which should be gauge-independent. More broadly, it fails the autograd grad_check() function which is used in all the tests. There's an extra problem there when trying to compare to numerically-computed derivative: the gauge could change in the two function calls in your finite-difference computation, leading to a spurious numerical derivative.

Not sure what more could be done about this, perhaps a warning to use at your own risk for complex-valued eigenvectors? That also reminds me that the gradient breaks when there are degenerate eigenvalues, perhaps a warning is warranted there too?

@GiggleLiu
Copy link

It doesn't work even with a loss like npa.abs(vs[1, 0] + vs[1, 1]), which should be gauge-independent.

v[1,0] is the second element of the first eigenvector and v[1,1] is the second element of the second eigenvector, so norm of their sum is gauge dependant? correct me if I am wrong.

@momchilmm
Copy link
Contributor Author

Oh, dang, I mixed up the indexes, my bad. I meant npa.abs(vs[0, 1] + vs[1, 1]), which actually works! So maybe indeed it works for gauge-independent functions.

However it's still not clear to me why the check_grads() test fails if I include the eigenvectors. It compares a vector-jacobian-vector product versus a numerically computed one. I'll think about this tomorrow, hopefully it's possible to write a test that passes.

@momchilmm
Copy link
Contributor Author

Ok, I just made the test function return np.abs(v) to get rid of all the phases, and then the tests pass! It looks like this should indeed work when you have gauge-independent functions of the eigenvectors, which is great.

@j-towns
Copy link
Collaborator

j-towns commented Jul 25, 2019

This looks good and I'm happy to merge, however I'd like to be clear, is there still an issue here that we don't fully understand?

The fact that the function being differentiated needs to not be gauge-dependent is fine, this is analogous to e.g. the singular value decomposition, which also has multiple valid output values. We assume that the user is knowledgable enough to know that their loss function should be invariant to this 🙂.

@momchilmm
Copy link
Contributor Author

No other issues that I'm aware of.

The last thing to be careful about is degenerate eigenvalues. However, this is once again a known problem and we can hope the user is knowledgeable enough. That said, I just pushed one final improvement. Basically the problem with degenerate eigenvalues is in the backprop of the eigenvector gradient. Previously, even if your function only depended on the eigenvalues, if there were degenerate ones you would get a "division by zero" warning, and nan gradients. I changed it so that if there's no backprop signal in the eigenvectors (anp.sum(anp.abs(vg)) < 1e-10), that contribution to the vjp is not computed, and everything works fine even in a degenerate case.

@GiggleLiu
Copy link

GiggleLiu commented Jul 25, 2019

I changed it so that if there's no backprop signal in the eigenvectors (anp.sum(anp.abs(vg)) < 1e-10)

I used a similar smearing technic in my code, which is x/(x^2+1e-40). Is truncation better or not in your code? @wangleiphy , maybe 1e-10 is too large.

@momchilmm
Copy link
Contributor Author

Isn't there a danger that a wrong gradient is returned in that case if the backprop eigenvector gradient is nonzero and there is a degeneracy? That is to say, the problem is not division by vg, but instead a term that's something like F * vg where F has infinite terms in the case of degeneracy. So I just want to set this to zero if vg is zero. If instead we use what you suggest directly for the F term, then everything becomes finite but probably wrong in the case of non-zero vg. Or is there something else that can be done?

Otherwise yeah can change the cutoff to 1e-20.

@GiggleLiu
Copy link

Well, this is a known (maybe intrinsically) hard problem. Reduce the cutoff or make the cutoff a tunable parameter is what you can do.

@momchilmm
Copy link
Contributor Author

Ok I just changed the check to simply use npa.any(vg) to be on the safe side.

@j-towns j-towns merged commit 7c09501 into HIPS:master Jul 29, 2019
@j-towns
Copy link
Collaborator

j-towns commented Jul 29, 2019

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants