Skip to content

Commit

Permalink
Simpler and more elegant implementation of Williamson (#366)
Browse files Browse the repository at this point in the history
* Simpler and more elegant implementation of Williamson

* Removes unnecesary import

* Reinstates necesary import

* Reinstates necesary import

* Updates CHANGELOG

* Attemp failed

* Update thewalrus/decompositions.py

Co-authored-by: Sebastián Duque Mesa <675763+sduquemesa@users.noreply.github.com>

* Even simpler logic

* Even better Williamson

---------

Co-authored-by: Nicolas Quesada <nquesada@pop-os.localdomain>
Co-authored-by: Sebastián Duque Mesa <675763+sduquemesa@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 17, 2023
1 parent 6c4de00 commit 01fb78c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
10 changes: 9 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,25 @@

### New features

* Adds the Takagi decomposition [(#363)](https://github.com/XanaduAI/thewalrus/pull/338)

### Breaking changes

### Improvements

* Simplifies the internal working of Bloch-Messiah decomposition [(#363)](https://github.com/XanaduAI/thewalrus/pull/338).

* Simplifies the internal working of Williamson decomposition [(#366)](https://github.com/XanaduAI/thewalrus/pull/338).

### Bug fixes

### Documentation

### Contributors

This release contains contributions from (in alphabetical order):
This release contains contributions from (in alphabetical order):

Nicolas Quesada

---

Expand Down
47 changes: 21 additions & 26 deletions thewalrus/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,36 +70,31 @@ def williamson(V, rtol=1e-05, atol=1e-08):
omega = sympmat(n)
vals = np.linalg.eigvalsh(V)

for val in vals:
if val <= 0:
raise ValueError("Input matrix is not positive definite")
if not np.all(vals > 0):
raise ValueError("Input matrix is not positive definite")

Mm12 = sqrtm(np.linalg.inv(V)).real
r1 = Mm12 @ omega @ Mm12
s1, K = schur(r1)
X = np.array([[0, 1], [1, 0]])
I = np.identity(2)
seq = []

# In what follows I construct a permutation matrix p so that the Schur matrix has
# In what follows a permutation matrix perm1 is constructed so that the Schur matrix has
# only positive elements above the diagonal
# Also the Schur matrix uses the x_1,p_1, ..., x_n,p_n ordering thus I permute using perm
# Also the Schur matrix uses the x_1,p_1, ..., x_n,p_n ordering thus a permutation perm2 is used
# to go to the ordering x_1, ..., x_n, p_1, ... , p_n

perm1 = np.arange(2 * n)
for i in range(n):
if s1[2 * i, 2 * i + 1] > 0:
seq.append(I)
else:
seq.append(X)
perm = np.array([2 * i for i in range(n)] + [2 * i + 1 for i in range(n)])
p = block_diag(*seq)
Kt = K @ p
Ktt = Kt[:, perm]
s1t = p @ s1 @ p
dd = [1 / s1t[2 * i, 2 * i + 1] for i in range(n)]
Db = np.diag(dd + dd)
S = Mm12 @ Ktt @ sqrtm(Db)
return Db, np.linalg.inv(S).T
if s1[2 * i, 2 * i + 1] <= 0:
(perm1[2 * i], perm1[2 * i + 1]) = (perm1[2 * i + 1], perm1[2 * i])

perm2 = np.array([perm1[2 * i] for i in range(n)] + [perm1[2 * i + 1] for i in range(n)])

Ktt = K[:, perm2]
s1t = s1[:, perm1][perm1]

dd = np.array([1 / s1t[2 * i, 2 * i + 1] for i in range(n)])
dd = np.concatenate([dd, dd])
ddsqrt = np.sqrt(dd)
S = Mm12 @ Ktt * ddsqrt
return np.diag(dd), np.linalg.inv(S).T


def symplectic_eigenvals(cov):
Expand All @@ -112,8 +107,8 @@ def symplectic_eigenvals(cov):
(array): symplectic eigenvalues
"""
M = int(len(cov) / 2)
D, _ = williamson(cov)
return np.diag(D)[:M]
Omega = sympmat(M)
return np.real_if_close(-1j * np.linalg.eigvals(Omega @ cov))[::2]


def blochmessiah(S):
Expand Down Expand Up @@ -226,7 +221,7 @@ def takagi(A, svd_order=True):
U = u @ sqrtm((v @ np.conjugate(u)).T)
# The line above could be simplifed to the line below if the product v @ np.conjugate(u) is diagonal
# Which it should be according to Caves http://info.phys.unm.edu/~caves/courses/qinfo-s17/lectures/polarsingularAutonne.pdf
# U = u * np.sqrt(0j + np.diag(v @ np.conjugate(u)))
# U = u * np.sqrt(0j + np.diag(np.conjugate(u) @ v))
# This however breaks test_degenerate
if svd_order is False:
return d[::-1], U[:, ::-1]
Expand Down

0 comments on commit 01fb78c

Please sign in to comment.