diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 6d399f78ed5..fb43671bced 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -303,6 +303,9 @@ * ``qml.circuit_drawer.CircuitDrawer`` can accept a string for the ``charset`` keyword, instead of a ``CharSet`` object. [(#1640)](https://github.com/PennyLaneAI/pennylane/pull/1640) +* ``qml.math.sort`` will now return only the sorted torch tensor and not the corresponding indices, making sort consistent across interfaces. + [(#1691)](https://github.com/PennyLaneAI/pennylane/pull/1691) + * Operations can now have gradient recipes that depend on the state of the operation. [(#1674)](https://github.com/PennyLaneAI/pennylane/pull/1674) @@ -393,6 +396,5 @@ This release contains contributions from (in alphabetical order): - Utkarsh Azad, Olivia Di Matteo, Andrew Gardhouse, Josh Izaac, Christina Lee, Romain Moyard, -Maria Schuld, Ingrid Strandberg, Antal Száva, David Wierichs. +Carrie-Anne Rubidge, Maria Schuld, Ingrid Strandberg, Antal Száva, David Wierichs. \ No newline at end of file diff --git a/pennylane/math/single_dispatch.py b/pennylane/math/single_dispatch.py index ae0d6736c38..c660e9393d6 100644 --- a/pennylane/math/single_dispatch.py +++ b/pennylane/math/single_dispatch.py @@ -357,6 +357,15 @@ def _scatter_element_add_torch(tensor, index, value): ar.register_function("torch", "scatter_element_add", _scatter_element_add_torch) +def _sort_torch(tensor): + """Update handling of sort to return only values not indices.""" + sorted_tensor = _i("torch").sort(tensor) + return sorted_tensor.values + + +ar.register_function("torch", "sort", _sort_torch) + + # -------------------------------- JAX --------------------------------- # diff --git a/pennylane/transforms/decompositions/two_qubit_unitary.py b/pennylane/transforms/decompositions/two_qubit_unitary.py index 52a22b7754d..c641a53a120 100644 --- a/pennylane/transforms/decompositions/two_qubit_unitary.py +++ b/pennylane/transforms/decompositions/two_qubit_unitary.py @@ -132,10 +132,7 @@ def _compute_num_cnots(U): # To distinguish between 1/2 CNOT cases, we need to look at the eigenvalues evs = math.linalg.eigvals(gammaU) - if math.get_interface(u) == "torch": - sorted_evs, _ = math.sort(math.imag(evs)) - else: - sorted_evs = math.sort(math.imag(evs)) + sorted_evs = math.sort(math.imag(evs)) # Case: 1 CNOT, the trace is 0, and the eigenvalues of gammaU are [-1j, -1j, 1j, 1j] # Checking the eigenvalues is needed because of some special 2-CNOT cases that yield @@ -366,11 +363,7 @@ def _decomposition_2_cnots(U, wires): # some reason this case is not handled properly with the full algorithm, so # we treat it separately. - if math.get_interface(u) == "torch": - # Torch's sort function returns both the sorted values and the new order - sorted_evs, _ = math.sort(math.real(evs)) - else: - sorted_evs = math.sort(math.real(evs)) + sorted_evs = math.sort(math.real(evs)) if math.allclose(sorted_evs, [-1, -1, 1, 1]): interior_decomp = [ diff --git a/tests/math/test_functions.py b/tests/math/test_functions.py index bf8286508c7..c1b9d14f793 100644 --- a/tests/math/test_functions.py +++ b/tests/math/test_functions.py @@ -1630,3 +1630,25 @@ def cost_fn(*params): grad = qml.grad(cost_fn)(*values) assert res == {0, 1} + + +test_sort_data = [ + ([1, 3, 4, 2], [1, 2, 3, 4]), + (onp.array([1, 3, 4, 2]), onp.array([1, 2, 3, 4])), + (np.array([1, 3, 4, 2]), np.array([1, 2, 3, 4])), + (jnp.array([1, 3, 4, 2]), jnp.array([1, 2, 3, 4])), + (torch.tensor([1, 3, 4, 2]), torch.tensor([1, 2, 3, 4])), + (tf.Variable([1, 3, 4, 2]), tf.Variable([1, 2, 3, 4])), + (tf.constant([1, 3, 4, 2]), tf.constant([1, 2, 3, 4])), +] + + +class TestSortFunction: + """Test the sort function works across all interfaces""" + + @pytest.mark.parametrize("input, test_output", test_sort_data) + def test_sort(self, input, test_output): + """Test the sort method is outputting only sorted values not indices""" + result = fn.sort(input) + + assert all(result == test_output)