-
Notifications
You must be signed in to change notification settings - Fork 4
/
eigh_impl.py
116 lines (96 loc) · 3.74 KB
/
eigh_impl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Imported from https://gist.github.com/jackd/99e012090a56637b8dd8bb037374900e
"""Versions based on 4.60 and 4.63 of https://arxiv.org/pdf/1701.00392.pdf."""
import jax
import jax.numpy as jnp
import numpy as np
def _T(x):
return jnp.swapaxes(x, -1, -2)
def _H(x):
return jnp.conj(_T(x))
def symmetrize(x):
return (x + _H(x)) / 2
def standardize_angle(w, b):
if jnp.isrealobj(w):
return w * jnp.sign(w[0, :])
else:
# scipy does this: makes imag(b[0] @ w) = 1
assert not jnp.isrealobj(b)
bw = b[0] @ w
factor = bw / jnp.abs(bw)
w = w / factor[None, :]
sign = jnp.sign(w.real[0])
w = w * sign
return w
@jax.custom_jvp # jax.scipy.linalg.eigh doesn't support general problem i.e. b not None
def eigh2d(a, b):
"""
Compute the solution to the symmetrized generalized eigenvalue problem.
a_s @ w = b_s @ w @ np.diag(v)
where a_s = (a + a.H) / 2, b_s = (b + b.H) / 2 are the symmetrized versions of the
inputs and H is the Hermitian (conjugate transpose) operator.
For self-adjoint inputs the solution should be consistent with `scipy.linalg.eigh`
i.e.
```python
v, w = eigh(a, b)
v_sp, w_sp = scipy.linalg.eigh(a, b)
np.testing.assert_allclose(v, v_sp)
np.testing.assert_allclose(w, standardize_angle(w_sp))
```
Note this currently uses `jax.linalg.eig(jax.linalg.solve(b, a))`, which will be
slow because there is no GPU implementation of `eig` and it's just a generally
inefficient way of doing it. Future implementations should wrap cuda primitives.
This implementation is provided primarily as a means to test `eigh_jvp_rule`.
Args:
a: [n, n] float self-adjoint matrix (i.e. conj(transpose(a)) == a)
b: [n, n] float self-adjoint matrix (i.e. conj(transpose(b)) == b)
Returns:
v: eigenvalues of the generalized problem in ascending order.
w: eigenvectors of the generalized problem, normalized such that
w.H @ b @ w = I.
"""
a = symmetrize(a)
b = symmetrize(b)
b_inv_a = jax.scipy.linalg.cho_solve(jax.scipy.linalg.cho_factor(b), a)
v, w = jax.jit(jax.numpy.linalg.eig, backend="cpu")(b_inv_a)
v = v.real
# with loops.Scope() as s:
# for _ in s.cond_range(jnp.isrealobj)
if jnp.isrealobj(a) and jnp.isrealobj(b):
w = w.real
# reorder as ascending in w
order = jnp.argsort(v)
v = v.take(order, axis=0)
w = w.take(order, axis=1)
# renormalize so v.H @ b @ H == 1
norm2 = jax.vmap(lambda wi: (wi.conj() @ b @ wi).real, in_axes=1)(w)
norm = jnp.sqrt(norm2)
w = w / norm
w = standardize_angle(w, b)
return v, w
@eigh2d.defjvp
def eigh_jvp_rule(primals, tangents):
"""
Derivation based on Boedekker et al.
https://arxiv.org/pdf/1701.00392.pdf
Note diagonal entries of Winv dW/dt != 0 as they claim.
"""
a, b = primals
da, db = tangents
if not all(jnp.isrealobj(x) for x in (a, b, da, db)):
raise NotImplementedError("jvp only implemented for real inputs.")
da = symmetrize(da)
db = symmetrize(db)
v, w = eigh2d(a, b)
# compute only the diagonal entries
dv = jax.vmap(
lambda vi, wi: -wi.conj() @ db @ wi * vi + wi.conj() @ da @ wi,
in_axes=(0, 1),
)(v, w)
dv = dv.real
E = v[jnp.newaxis, :] - v[:, jnp.newaxis]
# diagonal entries: compute as column then put into diagonals
diags = jnp.diag(-0.5 * jax.vmap(lambda wi: wi.conj() @ db @ wi, in_axes=1)(w))
# off-diagonals: there will be NANs on the diagonal, but these aren't used
off_diags = jnp.reciprocal(E) * (_H(w) @ (da @ w - db @ w * v[jnp.newaxis, :]))
dw = w @ jnp.where(jnp.eye(a.shape[0], dtype=np.bool), diags, off_diags)
return (v, w), (dv, dw)