Skip to content

Commit

Permalink
Add torch exception for sort to math's single_dispatch.py fixes #1660 (
Browse files Browse the repository at this point in the history
…#1691)

* fixed issue 1660 by adding a custom implementation of sort and then registering the new function

* Fixed issue 1660 by adding a custom implementation of sort.

* Fixed issue 1660 by adding a custom implementation of sort.

* Updated sort function unit test

* Fixed errors in previous commit

* Reversed accidental updates to Makefile and changelog

* Update doc/releases/changelog-dev.md

Co-authored-by: Olivia Di Matteo <2068515+glassnotes@users.noreply.github.com>
  • Loading branch information
RubidgeCarrie and glassnotes committed Oct 4, 2021
1 parent 6ffdba4 commit 192e810
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
6 changes: 4 additions & 2 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@
* ``qml.circuit_drawer.CircuitDrawer`` can accept a string for the ``charset`` keyword, instead of a ``CharSet`` object.
[(#1640)](https://github.com/PennyLaneAI/pennylane/pull/1640)

* ``qml.math.sort`` will now return only the sorted torch tensor and not the corresponding indices, making sort consistent across interfaces.
[(#1691)](https://github.com/PennyLaneAI/pennylane/pull/1691)

* Operations can now have gradient recipes that depend on the state of the operation.
[(#1674)](https://github.com/PennyLaneAI/pennylane/pull/1674)

Expand Down Expand Up @@ -393,6 +396,5 @@

This release contains contributions from (in alphabetical order):


Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee, Romain Moyard,
Maria Schuld, Ingrid Strandberg, Antal Száva, David Wierichs.
Carrie-Anne Rubidge, Maria Schuld, Ingrid Strandberg, Antal Száva, David Wierichs.
9 changes: 9 additions & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ def _scatter_element_add_torch(tensor, index, value):
ar.register_function("torch", "scatter_element_add", _scatter_element_add_torch)


def _sort_torch(tensor):
"""Update handling of sort to return only values not indices."""
sorted_tensor = _i("torch").sort(tensor)
return sorted_tensor.values


ar.register_function("torch", "sort", _sort_torch)


# -------------------------------- JAX --------------------------------- #


Expand Down
11 changes: 2 additions & 9 deletions pennylane/transforms/decompositions/two_qubit_unitary.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ def _compute_num_cnots(U):
# To distinguish between 1/2 CNOT cases, we need to look at the eigenvalues
evs = math.linalg.eigvals(gammaU)

if math.get_interface(u) == "torch":
sorted_evs, _ = math.sort(math.imag(evs))
else:
sorted_evs = math.sort(math.imag(evs))
sorted_evs = math.sort(math.imag(evs))

# Case: 1 CNOT, the trace is 0, and the eigenvalues of gammaU are [-1j, -1j, 1j, 1j]
# Checking the eigenvalues is needed because of some special 2-CNOT cases that yield
Expand Down Expand Up @@ -366,11 +363,7 @@ def _decomposition_2_cnots(U, wires):
# some reason this case is not handled properly with the full algorithm, so
# we treat it separately.

if math.get_interface(u) == "torch":
# Torch's sort function returns both the sorted values and the new order
sorted_evs, _ = math.sort(math.real(evs))
else:
sorted_evs = math.sort(math.real(evs))
sorted_evs = math.sort(math.real(evs))

if math.allclose(sorted_evs, [-1, -1, 1, 1]):
interior_decomp = [
Expand Down
22 changes: 22 additions & 0 deletions tests/math/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,3 +1630,25 @@ def cost_fn(*params):
grad = qml.grad(cost_fn)(*values)

assert res == {0, 1}


test_sort_data = [
([1, 3, 4, 2], [1, 2, 3, 4]),
(onp.array([1, 3, 4, 2]), onp.array([1, 2, 3, 4])),
(np.array([1, 3, 4, 2]), np.array([1, 2, 3, 4])),
(jnp.array([1, 3, 4, 2]), jnp.array([1, 2, 3, 4])),
(torch.tensor([1, 3, 4, 2]), torch.tensor([1, 2, 3, 4])),
(tf.Variable([1, 3, 4, 2]), tf.Variable([1, 2, 3, 4])),
(tf.constant([1, 3, 4, 2]), tf.constant([1, 2, 3, 4])),
]


class TestSortFunction:
"""Test the sort function works across all interfaces"""

@pytest.mark.parametrize("input, test_output", test_sort_data)
def test_sort(self, input, test_output):
"""Test the sort method is outputting only sorted values not indices"""
result = fn.sort(input)

assert all(result == test_output)

0 comments on commit 192e810

Please sign in to comment.