diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ccb68122d2b..6c89bfe1c45 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -254,6 +254,9 @@ * `qml.QutritDepolarizingChannel` has been added, allowing for depolarizing noise to be simulated on the `default.qutrit.mixed` device. [(#5502)](https://github.com/PennyLaneAI/pennylane/pull/5502) +* Implement support in `assert_equal` for `Operator`, `Controlled`, `Adjoint`, `Pow`, `Exp`, `SProd`, `ControlledSequence`, `Prod`, `Sum`, `Tensor` and `Hamiltonian` + [(#5780)](https://github.com/PennyLaneAI/pennylane/pull/5780) + * `qml.QutritChannel` has been added, enabling the specification of noise using a collection of (3x3) Kraus matrices on the `default.qutrit.mixed` device. [(#5793)](https://github.com/PennyLaneAI/pennylane/issues/5793) diff --git a/pennylane/ops/functions/equal.py b/pennylane/ops/functions/equal.py index dd8fa58e3d5..dcb8f07b5c6 100644 --- a/pennylane/ops/functions/equal.py +++ b/pennylane/ops/functions/equal.py @@ -20,6 +20,7 @@ from typing import Union import pennylane as qml +from pennylane import Hermitian from pennylane.measurements import MeasurementProcess from pennylane.measurements.classical_shadow import ShadowExpvalMP from pennylane.measurements.counts import CountsMP @@ -42,6 +43,10 @@ from pennylane.tape import QuantumTape from pennylane.templates.subroutines import ControlledSequence +OPERANDS_MISMATCH_ERROR_MESSAGE = "op1 and op2 have different operands because " + +BASE_OPERATION_MISMATCH_ERROR_MESSAGE = "op1 and op2 have different base operations because " + def equal( op1: Union[Operator, MeasurementProcess, QuantumTape], @@ -157,12 +162,8 @@ def equal( True """ - # types don't have to be the same type, they just both have to be Observables - if not isinstance(op2, type(op1)) and not isinstance(op1, Observable): - return False - if isinstance(op2, (Hamiltonian, Tensor)): - return _equal(op2, op1) + op1, op2 = op2, op1 dispatch_result = _equal( op1, @@ -240,10 +241,6 @@ def assert_equal( AssertionError: The hyperparameter unitary_matrix has different interfaces for op1 and op2. Got numpy and autograd. """ - if not isinstance(op2, type(op1)) and not isinstance(op1, Observable): - raise AssertionError( - f"op1 and op2 are of different types. Got {type(op1)} and {type(op2)}." - ) dispatch_result = _equal( op1, @@ -259,7 +256,6 @@ def assert_equal( raise AssertionError(f"{op1} and {op2} are not equal for an unspecified reason.") -@singledispatch def _equal( op1, op2, @@ -267,11 +263,33 @@ def _equal( check_trainability=True, rtol=1e-5, atol=1e-9, +) -> Union[bool, str]: # pylint: disable=unused-argument + if not isinstance(op2, type(op1)) and not isinstance(op1, Observable): + return f"op1 and op2 are of different types. Got {type(op1)} and {type(op2)}." + + return _equal_dispatch( + op1, + op2, + check_interface=check_interface, + check_trainability=check_trainability, + atol=atol, + rtol=rtol, + ) + + +@singledispatch +def _equal_dispatch( + op1, + op2, + check_interface=True, + check_trainability=True, + rtol=1e-5, + atol=1e-9, ) -> Union[bool, str]: # pylint: disable=unused-argument raise NotImplementedError(f"Comparison of {type(op1)} and {type(op2)} not implemented") -@_equal.register +@_equal_dispatch.register def _equal_circuit( op1: qml.tape.QuantumScript, op2: qml.tape.QuantumScript, @@ -314,7 +332,7 @@ def _equal_circuit( return True -@_equal.register +@_equal_dispatch.register def _equal_operators( op1: Operator, op2: Operator, @@ -327,7 +345,7 @@ def _equal_operators( if not isinstance( op2, type(op1) ): # clarifies cases involving PauliX/Y/Z (Observable/Operation) - return False + return f"op1 and op2 are of different types. Got {type(op1)} and {type(op2)}" if isinstance(op1, qml.Identity): # All Identities are equivalent, independent of wires. @@ -336,37 +354,52 @@ def _equal_operators( return True if op1.arithmetic_depth != op2.arithmetic_depth: - return False + return f"op1 and op2 have different arithmetic depths. Got {op1.arithmetic_depth} and {op2.arithmetic_depth}" if op1.arithmetic_depth > 0: # Other dispatches cover cases of operations with arithmetic depth > 0. # If any new operations are added with arithmetic depth > 0, a new dispatch # should be created for them. - return False - if not all( - qml.math.allclose(d1, d2, rtol=rtol, atol=atol) for d1, d2 in zip(op1.data, op2.data) - ): - return False + return f"op1 and op2 have arithmetic depth > 0. Got arithmetic depth {op1.arithmetic_depth}" + if op1.wires != op2.wires: - return False + return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}." if op1.hyperparameters != op2.hyperparameters: - return False + return ( + "The hyperparameters are not equal for op1 and op2.\n" + f"Got {op1.hyperparameters}\n and {op2.hyperparameters}." + ) + + if not all( + qml.math.allclose(d1, d2, rtol=rtol, atol=atol) for d1, d2 in zip(op1.data, op2.data) + ): + return f"op1 and op2 have different data.\nGot {op1.data} and {op2.data}" if check_trainability: - for params_1, params_2 in zip(op1.data, op2.data): - if qml.math.requires_grad(params_1) != qml.math.requires_grad(params_2): - return False + for params1, params2 in zip(op1.data, op2.data): + params1_train = qml.math.requires_grad(params1) + params2_train = qml.math.requires_grad(params2) + if params1_train != params2_train: + return ( + "Parameters have different trainability.\n " + f"{params1} trainability is {params1_train} and {params2} trainability is {params2_train}" + ) if check_interface: - for params_1, params_2 in zip(op1.data, op2.data): - if qml.math.get_interface(params_1) != qml.math.get_interface(params_2): - return False + for params1, params2 in zip(op1.data, op2.data): + params1_interface = qml.math.get_interface(params1) + params2_interface = qml.math.get_interface(params2) + if params1_interface != params2_interface: + return ( + "Parameters have different interfaces.\n " + f"{params1} interface is {params1_interface} and {params2} interface is {params2_interface}" + ) return True -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument, protected-access def _equal_prod_and_sum(op1: CompositeOp, op2: CompositeOp, **kwargs): """Determine whether two Prod or Sum objects are equal""" @@ -374,74 +407,104 @@ def _equal_prod_and_sum(op1: CompositeOp, op2: CompositeOp, **kwargs): return True if len(op1.operands) != len(op2.operands): - return False + return f"op1 and op2 have different number of operands. Got {len(op1.operands)} and {len(op2.operands)}" # organizes by wire indicies while respecting commutation relations sorted_ops1 = op1._sort(op1.operands) sorted_ops2 = op2._sort(op2.operands) - return all(equal(o1, o2, **kwargs) for o1, o2 in zip(sorted_ops1, sorted_ops2)) + for o1, o2 in zip(sorted_ops1, sorted_ops2): + op_check = _equal(o1, o2, **kwargs) + if isinstance(op_check, str): + return OPERANDS_MISMATCH_ERROR_MESSAGE + op_check + + return True -@_equal.register +@_equal_dispatch.register def _equal_controlled(op1: Controlled, op2: Controlled, **kwargs): """Determine whether two Controlled or ControlledOp objects are equal""" - # work wires and control_wire/control_value combinations compared here + if op1.arithmetic_depth != op2.arithmetic_depth: + return f"op1 and op2 have different arithmetic depths. Got {op1.arithmetic_depth} and {op2.arithmetic_depth}" + # op.base.wires compared in return - if [ - dict(zip(op1.control_wires, op1.control_values)), - op1.work_wires, - op1.arithmetic_depth, - ] != [ - dict(zip(op2.control_wires, op2.control_values)), - op2.work_wires, - op2.arithmetic_depth, - ]: - return False + if op1.work_wires != op2.work_wires: + return f"op1 and op2 have different work wires. Got {op1.work_wires} and {op2.work_wires}" + + # work wires and control_wire/control_value combinations compared here + op1_control_dict = dict(zip(op1.control_wires, op1.control_values)) + op2_control_dict = dict(zip(op2.control_wires, op2.control_values)) + if op1_control_dict != op2_control_dict: + return f"op1 and op2 have different control dictionaries. Got {op1_control_dict} and {op2_control_dict}" - return qml.equal(op1.base, op2.base, **kwargs) + base_equal_check = _equal(op1.base, op2.base, **kwargs) + if isinstance(base_equal_check, str): + return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check + + return True -@_equal.register +@_equal_dispatch.register def _equal_controlled_sequence(op1: ControlledSequence, op2: ControlledSequence, **kwargs): """Determine whether two ControlledSequences are equal""" - if [op1.wires, op1.arithmetic_depth] != [ - op2.wires, - op2.arithmetic_depth, - ]: - return False + if op1.wires != op2.wires: + return f"op1 and op2 have different wires. Got {op1.wires} and {op2.wires}." + if op1.arithmetic_depth != op2.arithmetic_depth: + return f"op1 and op2 have different arithmetic depths. Got {op1.arithmetic_depth} and {op2.arithmetic_depth}" - return qml.equal(op1.base, op2.base, **kwargs) + base_equal_check = _equal(op1.base, op2.base, **kwargs) + if isinstance(base_equal_check, str): + return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check + + return True -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_pow(op1: Pow, op2: Pow, **kwargs): """Determine whether two Pow objects are equal""" check_interface, check_trainability = kwargs["check_interface"], kwargs["check_trainability"] if check_interface: - if qml.math.get_interface(op1.z) != qml.math.get_interface(op2.z): - return False + interface1 = qml.math.get_interface(op1.z) + interface2 = qml.math.get_interface(op2.z) + if interface1 != interface2: + return ( + "Exponent have different interfaces.\n" + f"{op1.z} interface is {interface1} and {op2.z} interface is {interface2}" + ) if check_trainability: - if qml.math.requires_grad(op1.z) != qml.math.requires_grad(op2.z): - return False + grad1 = qml.math.requires_grad(op1.z) + grad2 = qml.math.requires_grad(op2.z) + if grad1 != grad2: + return ( + "Exponent have different trainability.\n" + f"{op1.z} interface is {grad1} and {op2.z} interface is {grad2}" + ) if op1.z != op2.z: - return False + return f"Exponent are different. Got {op1.z} and {op2.z}" + + base_equal_check = _equal(op1.base, op2.base, **kwargs) + if isinstance(base_equal_check, str): + return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check - return qml.equal(op1.base, op2.base, **kwargs) + return True -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_adjoint(op1: Adjoint, op2: Adjoint, **kwargs): """Determine whether two Adjoint objects are equal""" # first line of top-level equal function already confirms both are Adjoint - only need to compare bases - return qml.equal(op1.base, op2.base, **kwargs) + base_equal_check = _equal(op1.base, op2.base, **kwargs) + if isinstance(base_equal_check, str): + return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + base_equal_check + return True -@_equal.register + +@_equal_dispatch.register def _equal_conditional(op1: Conditional, op2: Conditional, **kwargs): """Determine whether two Conditional objects are equal""" # first line of top-level equal function already confirms both are Conditionaly - only need to compare bases and meas_val @@ -450,14 +513,14 @@ def _equal_conditional(op1: Conditional, op2: Conditional, **kwargs): ) -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_measurement_value(op1: MeasurementValue, op2: MeasurementValue, **kwargs): """Determine whether two MeasurementValue objects are equal""" return op1.measurements == op2.measurements -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_exp(op1: Exp, op2: Exp, **kwargs): """Determine whether two Exp objects are equal""" @@ -470,20 +533,35 @@ def _equal_exp(op1: Exp, op2: Exp, **kwargs): if check_interface: for params1, params2 in zip(op1.data, op2.data): - if qml.math.get_interface(params1) != qml.math.get_interface(params2): - return False + params1_interface = qml.math.get_interface(params1) + params2_interface = qml.math.get_interface(params2) + if params1_interface != params2_interface: + return ( + "Parameters have different interfaces.\n" + f"{params1} interface is {params1_interface} and {params2} interface is {params2_interface}" + ) + if check_trainability: for params1, params2 in zip(op1.data, op2.data): - if qml.math.requires_grad(params1) != qml.math.requires_grad(params2): - return False + params1_trainability = qml.math.requires_grad(params1) + params2_trainability = qml.math.requires_grad(params2) + if params1_trainability != params2_trainability: + return ( + "Parameters have different trainability.\n" + f"{params1} trainability is {params1_trainability} and {params2} trainability is {params2_trainability}" + ) if not qml.math.allclose(op1.coeff, op2.coeff, rtol=rtol, atol=atol): - return False + return f"op1 and op2 have different coefficients. Got {op1.coeff} and {op2.coeff}" + + equal_check = _equal(op1.base, op2.base, **kwargs) + if isinstance(equal_check, str): + return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + equal_check - return qml.equal(op1.base, op2.base, **kwargs) + return True -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_sprod(op1: SProd, op2: SProd, **kwargs): """Determine whether two SProd objects are equal""" @@ -496,47 +574,69 @@ def _equal_sprod(op1: SProd, op2: SProd, **kwargs): if check_interface: for params1, params2 in zip(op1.data, op2.data): - if qml.math.get_interface(params1) != qml.math.get_interface(params2): - return False + params1_interface = qml.math.get_interface(params1) + params2_interface = qml.math.get_interface(params2) + if params1_interface != params2_interface: + return ( + "Parameters have different interfaces.\n " + f"{params1} interface is {params1_interface} and {params2} interface is {params2_interface}" + ) + if check_trainability: for params1, params2 in zip(op1.data, op2.data): - if qml.math.requires_grad(params1) != qml.math.requires_grad(params2): - return False + params1_train = qml.math.requires_grad(params1) + params2_train = qml.math.requires_grad(params2) + if params1_train != params2_train: + return ( + "Parameters have different trainability.\n " + f"{params1} trainability is {params1_train} and {params2} trainability is {params2_train}" + ) if op1.pauli_rep is not None and (op1.pauli_rep == op2.pauli_rep): # shortcut check return True + if not qml.math.allclose(op1.scalar, op2.scalar, rtol=rtol, atol=atol): - return False + return f"op1 and op2 have different scalars. Got {op1.scalar} and {op2.scalar}" - return qml.equal(op1.base, op2.base, **kwargs) + equal_check = _equal(op1.base, op2.base, **kwargs) + if isinstance(equal_check, str): + return BASE_OPERATION_MISMATCH_ERROR_MESSAGE + equal_check + + return True -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_tensor(op1: Tensor, op2: Observable, **kwargs): """Determine whether a Tensor object is equal to a Hamiltonian/Tensor""" if not isinstance(op2, Observable): - return False + return f"{op2} is not of type Observable" - if isinstance(op2, (Hamiltonian, LinearCombination)): - return op2.compare(op1) + if isinstance(op2, (Hamiltonian, LinearCombination, Hermitian)): + if not op2.compare(op1): + return f"'{op1}' and '{op2}' are not same" if isinstance(op2, Tensor): - return op1._obs_data() == op2._obs_data() # pylint: disable=protected-access + if not op1._obs_data() == op2._obs_data(): # pylint: disable=protected-access + return "op1 and op2 have different _obs_data outputs" - return False + return True -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_hamiltonian(op1: Hamiltonian, op2: Observable, **kwargs): """Determine whether a Hamiltonian object is equal to a Hamiltonian/Tensor objects""" if not isinstance(op2, Observable): - return False - return op1.compare(op2) + return f"{op2} is not of type Observable" + + if not op1.compare(op2): + return f"'{op1}' and '{op2}' are not same" + + return True -@_equal.register +@_equal_dispatch.register def _equal_parametrized_evolution(op1: ParametrizedEvolution, op2: ParametrizedEvolution, **kwargs): # check times match if op1.t is None or op2.t is None: @@ -546,7 +646,8 @@ def _equal_parametrized_evolution(op1: ParametrizedEvolution, op2: ParametrizedE return False # check parameters passed to operator match - if not _equal_operators(op1, op2, **kwargs): + operator_check = _equal_operators(op1, op2, **kwargs) + if isinstance(operator_check, str): return False # check H.coeffs match @@ -557,7 +658,7 @@ def _equal_parametrized_evolution(op1: ParametrizedEvolution, op2: ParametrizedE return all(equal(o1, o2, **kwargs) for o1, o2 in zip(op1.H.ops, op2.H.ops)) -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_measurements( op1: MeasurementProcess, @@ -603,7 +704,7 @@ def _equal_measurements( return False -@_equal.register +@_equal_dispatch.register def _equal_mid_measure(op1: MidMeasureMP, op2: MidMeasureMP, **_): return ( op1.wires == op2.wires @@ -613,7 +714,7 @@ def _equal_mid_measure(op1: MidMeasureMP, op2: MidMeasureMP, **_): ) -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _(op1: VnEntropyMP, op2: VnEntropyMP, **kwargs): """Determine whether two MeasurementProcess objects are equal""" @@ -622,7 +723,7 @@ def _(op1: VnEntropyMP, op2: VnEntropyMP, **kwargs): return eq_m and log_base_match -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _(op1: MutualInfoMP, op2: MutualInfoMP, **kwargs): """Determine whether two MeasurementProcess objects are equal""" @@ -631,7 +732,7 @@ def _(op1: MutualInfoMP, op2: MutualInfoMP, **kwargs): return eq_m and log_base_match -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_shadow_measurements(op1: ShadowExpvalMP, op2: ShadowExpvalMP, **_): """Determine whether two ShadowExpvalMP objects are equal""" @@ -650,12 +751,12 @@ def _equal_shadow_measurements(op1: ShadowExpvalMP, op2: ShadowExpvalMP, **_): return wires_match and H_match and k_match -@_equal.register +@_equal_dispatch.register def _equal_counts(op1: CountsMP, op2: CountsMP, **kwargs): return _equal_measurements(op1, op2, **kwargs) and op1.all_outcomes == op2.all_outcomes -@_equal.register +@_equal_dispatch.register # pylint: disable=unused-argument def _equal_basis_rotation( op1: qml.BasisRotation, @@ -688,7 +789,7 @@ def _equal_basis_rotation( return True -@_equal.register +@_equal_dispatch.register def _equal_hilbert_schmidt( op1: qml.HilbertSchmidt, op2: qml.HilbertSchmidt, diff --git a/tests/ops/functions/test_equal.py b/tests/ops/functions/test_equal.py index be7d522a535..e8df77ce659 100644 --- a/tests/ops/functions/test_equal.py +++ b/tests/ops/functions/test_equal.py @@ -16,6 +16,7 @@ Tests are divided by number of parameters and wires different operators take. """ import itertools +import re # pylint: disable=too-many-arguments, too-many-public-methods from copy import deepcopy @@ -27,8 +28,14 @@ from pennylane import numpy as npp from pennylane.measurements import ExpectationMP from pennylane.measurements.probs import ProbabilityMP +from pennylane.operation import Operator from pennylane.ops import Conditional -from pennylane.ops.functions.equal import _equal, assert_equal +from pennylane.ops.functions.equal import ( + BASE_OPERATION_MISMATCH_ERROR_MESSAGE, + OPERANDS_MISMATCH_ERROR_MESSAGE, + _equal_dispatch, + assert_equal, +) from pennylane.ops.op_math import Controlled, SymbolicOp from pennylane.templates.subroutines import ControlledSequence @@ -337,7 +344,7 @@ def __init__(self): pass # pylint: disable=unused-argument - @_equal.register + @_equal_dispatch.register def _(op1: RandomType, op2, **_): """always returns false""" return False @@ -351,35 +358,62 @@ class TestEqual: def test_equal_simple_diff_op(self, ops): """Test different operators return False""" assert not qml.equal(ops[0], ops[1], check_trainability=False, check_interface=False) + with pytest.raises(AssertionError, match="op1 and op2 are of different types"): + assert_equal(ops[0], ops[1], check_trainability=False, check_interface=False) @pytest.mark.parametrize("op1", PARAMETRIZED_OPERATIONS) def test_equal_simple_same_op(self, op1): """Test same operators return True""" assert qml.equal(op1, op1, check_trainability=False, check_interface=False) + assert_equal(op1, op1, check_trainability=False, check_interface=False) @pytest.mark.parametrize("op1", PARAMETRIZED_OPERATIONS_1P_1W) def test_equal_simple_op_1p1w(self, op1): """Test changing parameter or wire returns False""" wire = 0 param = 0.123 + test_operator = op1(param, wires=wire) assert qml.equal( - op1(param, wires=wire), - op1(param, wires=wire), + test_operator, + test_operator, + check_trainability=False, + check_interface=False, + ) + assert_equal( + test_operator, + test_operator, check_trainability=False, check_interface=False, ) + + test_operator_diff_parameter = op1(param * 2, wires=wire) assert not qml.equal( - op1(param, wires=wire), - op1(param * 2, wires=wire), + test_operator, + test_operator_diff_parameter, check_trainability=False, check_interface=False, ) + with pytest.raises(AssertionError, match="op1 and op2 have different data."): + assert_equal( + test_operator, + test_operator_diff_parameter, + check_trainability=False, + check_interface=False, + ) + test_operator_diff_wire = op1(param, wires=wire + 1) assert not qml.equal( - op1(param, wires=wire), - op1(param, wires=wire + 1), + test_operator, + test_operator_diff_wire, check_trainability=False, check_interface=False, ) + with pytest.raises(AssertionError, match="op1 and op2 have different wires."): + assert_equal( + test_operator, + test_operator_diff_wire, + check_trainability=False, + check_interface=False, + ) @pytest.mark.all_interfaces @pytest.mark.parametrize("op1", PARAMETRIZED_OPERATIONS_1P_1W) @@ -427,6 +461,14 @@ def test_equal_op_1p1w(self, op1): check_interface=False, ) + with pytest.raises(AssertionError, match="Parameters have different trainability"): + assert_equal( + op1(param_qml, wires=wire), + op1(param_qml_1, wires=wire), + check_trainability=True, + check_interface=False, + ) + @pytest.mark.all_interfaces @pytest.mark.parametrize("op1", PARAMETRIZED_OPERATIONS_1P_2W) def test_equal_op_1p2w(self, op1): @@ -1007,6 +1049,15 @@ def test_equal_simple_op_remaining(self): check_trainability=False, check_interface=False, ) + with pytest.raises( + AssertionError, match="The hyperparameters are not equal for op1 and op2." + ): + assert_equal( + op1(param, "Y", wires=wire), + op1(param, "Z", wires=wire), + check_trainability=False, + check_interface=False, + ) wire = 0 param = np.eye(2) * 1j @@ -1155,11 +1206,22 @@ def jax_assertion_func(x, other_tensor): check_interface=False, ) + with pytest.raises(AssertionError, match="Parameters have different interfaces"): + assert_equal( + op1(pl_tensor, wires=wire), + op1(torch_tensor, wires=wire), + check_trainability=True, + check_interface=True, + ) + def test_equal_with_different_arithmetic_depth(self): """Test equal method with two operators with different arithmetic depth.""" - op1 = qml.RX(0.3, wires=0) - op2 = qml.prod(op1, qml.RY(0.25, wires=1)) - assert not qml.equal(op1, op2) + op1 = Operator(wires=0) + op2 = DepthIncreaseOperator(op1) + + assert qml.equal(op1, op2) is False + with pytest.raises(AssertionError, match="op1 and op2 have different arithmetic depths"): + assert_equal(op1, op2) def test_equal_with_unsupported_nested_operators_returns_false(self): """Test that the equal method with two operators with the same arithmetic depth (>0) returns @@ -1173,6 +1235,8 @@ def test_equal_with_unsupported_nested_operators_returns_false(self): assert op1.arithmetic_depth > 0 assert not qml.equal(op1, op2) + with pytest.raises(AssertionError, match="op1 and op2 have arithmetic depth > 0"): + assert_equal(op1, op2) # Measurements test cases @pytest.mark.parametrize("ops", PARAMETRIZED_MEASUREMENTS_COMBINATIONS) @@ -1402,6 +1466,10 @@ def test_hamiltonian_equal(self, H1, H2, res): assert qml.equal(H1, H2) == qml.equal(H2, H1) assert qml.equal(H1, H2) == res + if not res: + error_message_pattern = re.compile(r"'([^']+)' and '([^']+)' are not same") + with pytest.raises(AssertionError, match=error_message_pattern): + assert_equal(H1, H2) @pytest.mark.parametrize(("T1", "T2", "res"), equal_tensors) def test_tensors_equal(self, T1, T2, res): @@ -1409,6 +1477,13 @@ def test_tensors_equal(self, T1, T2, res): assert qml.equal(T1, T2) == qml.equal(T2, T1) assert qml.equal(T1, T2) == res + def test_tensors_not_equal(self): + """Tensors are not equal because of different observable data""" + op1 = qml.operation.Tensor(qml.X(0), qml.Y(1)) + op2 = qml.operation.Tensor(qml.Y(0), qml.X(1)) + with pytest.raises(AssertionError, match="op1 and op2 have different _obs_data outputs"): + assert_equal(op1, op2) + @pytest.mark.parametrize(("H", "T", "res"), equal_hamiltonians_and_tensors) def test_hamiltonians_and_tensors_equal(self, H, T, res): """Tests that equality can be checked between a Hamiltonian and a Tensor""" @@ -1435,6 +1510,8 @@ def test_hamiltonian_and_operation_not_equal(self): op2 = qml.RX(1.2, 0) assert qml.equal(op1, op2) is False assert qml.equal(op2, op1) is False + with pytest.raises(AssertionError, match="is not of type Observable"): + assert_equal(op1, op2) def test_tensor_and_operation_not_equal(self): """Tests that comparing a Tensor with an Operator that is not an Observable returns False""" @@ -1442,6 +1519,8 @@ def test_tensor_and_operation_not_equal(self): op2 = qml.RX(1.2, 0) assert qml.equal(op1, op2) is False assert qml.equal(op2, op1) is False + with pytest.raises(AssertionError, match="is not of type Observable"): + assert_equal(op1, op2) def test_tensor_and_unsupported_observable_returns_false(self): """Tests that trying to compare a Tensor to something other than another Tensor or a Hamiltonian returns False""" @@ -1449,6 +1528,9 @@ def test_tensor_and_unsupported_observable_returns_false(self): op2 = qml.Hermitian([[0, 1], [1, 0]], 0) assert not qml.equal(op1, op2) + error_message_pattern = re.compile(r"'([^']+)' and '([^']+)' are not same") + with pytest.raises(AssertionError, match=error_message_pattern): + assert_equal(op1, op2) def test_unsupported_object_type_not_implemented(self): dev = qml.device("default.qubit", wires=1) @@ -1576,7 +1658,12 @@ def test_controlled_base_operator_comparison(self, base1, base2, res): """Test that equal compares base operators for Controlled operators""" op1 = Controlled(base1, control_wires=2) op2 = Controlled(base2, control_wires=2) - assert qml.equal(op1, op2) == res + if res: + assert qml.equal(op1, op2) + else: + assert not qml.equal(op1, op2) + with pytest.raises(AssertionError, match=BASE_OPERATION_MISMATCH_ERROR_MESSAGE): + assert_equal(op1, op2) @pytest.mark.parametrize(("base1", "base2", "res"), BASES) def test_controlled_sequence_base_operator_comparison(self, base1, base2, res): @@ -1585,6 +1672,23 @@ def test_controlled_sequence_base_operator_comparison(self, base1, base2, res): op2 = ControlledSequence(base2, control=2) assert qml.equal(op1, op2) == res + def test_controlled_sequence_with_different_base_operator(self): + """Test controlled sequence operator with different base operators""" + op1 = ControlledSequence(qml.PauliX(0), control=2) + op2 = ControlledSequence(qml.PauliY(0), control=2) + with pytest.raises(AssertionError, match=BASE_OPERATION_MISMATCH_ERROR_MESSAGE): + assert_equal(op1, op2) + + def test_controlled_sequence_with_different_arithmetic_depth(self): + """The depths of controlled sequence operators are different due to nesting""" + base = qml.MultiRZ(1.23, [0, 1]) + depth_increased_base = DepthIncreaseOperator(base) + op1 = ControlledSequence(base, control=5) + op2 = ControlledSequence(depth_increased_base, control=5) + + with pytest.raises(AssertionError, match="op1 and op2 have different arithmetic depths."): + assert_equal(op1, op2) + @pytest.mark.parametrize(("wire1", "wire2", "res"), WIRES) def test_control_wires_comparison(self, wire1, wire2, res): """Test that equal compares control_wires for Controlled operators""" @@ -1603,7 +1707,15 @@ def test_control_values_comparison(self, controls1, controls2): op1 = qml.ops.op_math.Controlled(base1, control_wires=[1, 2], control_values=controls1) op2 = qml.ops.op_math.Controlled(base2, control_wires=[1, 2], control_values=controls2) - assert qml.equal(op1, op2) == np.allclose(controls1, controls2) + if np.allclose(controls1, controls2): + assert qml.equal(op1, op2) + assert_equal(op1, op2) + else: + assert not qml.equal(op1, op2) + with pytest.raises( + AssertionError, match="op1 and op2 have different control dictionaries." + ): + assert_equal(op1, op2) @pytest.mark.parametrize( ("wires1", "controls1", "wires2", "controls2", "res"), @@ -1632,6 +1744,9 @@ def test_control_sequence_wires_comparison(self, wires1, wires2, res): op1 = ControlledSequence(base1, control=wires1) op2 = ControlledSequence(base2, control=wires2) assert qml.equal(op1, op2) == res + if not res: + with pytest.raises(AssertionError, match="op1 and op2 have different wires."): + assert_equal(op1, op2) @pytest.mark.parametrize(("wire1", "wire2", "res"), WIRES) def test_controlled_work_wires_comparison(self, wire1, wire2, res): @@ -1640,7 +1755,22 @@ def test_controlled_work_wires_comparison(self, wire1, wire2, res): base2 = qml.MultiRZ(1.23, [0, 1]) op1 = Controlled(base1, control_wires=2, work_wires=wire1) op2 = Controlled(base2, control_wires=2, work_wires=wire2) - assert qml.equal(op1, op2) == res + if res: + assert qml.equal(op1, op2) == res + assert_equal(op1, op2) + else: + assert not qml.equal(op1, op2) + with pytest.raises(AssertionError, match="op1 and op2 have different work wires."): + assert_equal(op1, op2) + + def test_controlled_arithmetic_depth(self): + """The depths of controlled operators are different due to nesting""" + base = qml.MultiRZ(1.23, [0, 1]) + op1 = Controlled(base, control_wires=5) + op2 = Controlled(op1, control_wires=6) + + with pytest.raises(AssertionError, match="op1 and op2 have different arithmetic depths."): + assert_equal(op1, op2) @pytest.mark.parametrize("base", PARAMETRIZED_OPERATIONS) def test_adjoint_comparison(self, base): @@ -1651,6 +1781,8 @@ def test_adjoint_comparison(self, base): assert qml.equal(op1, op2) assert not qml.equal(op1, op3) + with pytest.raises(AssertionError, match=BASE_OPERATION_MISMATCH_ERROR_MESSAGE): + assert_equal(op1, op3) def test_adjoint_comparison_with_tolerance(self): """Test that equal compares the parameters within a provided tolerance of the Adjoint class.""" @@ -1755,6 +1887,14 @@ def test_pow_comparison(self, bases_bases_match, params_params_match): op2 = qml.pow(base2, param2) assert qml.equal(op1, op2) == (bases_match and params_match) + def test_diff_pow_comparison(self): + """Test different exponents""" + base = qml.PauliX(0) + op1 = qml.pow(base, 0.2) + op2 = qml.pow(base, 0.3) + with pytest.raises(AssertionError, match="Exponent are different."): + assert_equal(op1, op2) + def test_pow_comparison_with_tolerance(self): """Test that equal compares the parameters within a provided tolerance of the Pow class.""" op1 = qml.pow(qml.RX(1.2, wires=0), 2) @@ -1772,6 +1912,8 @@ def test_pow_comparison_with_interface(self): assert qml.equal(op1, op2, check_interface=False, check_trainability=False) assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + with pytest.raises(AssertionError, match="Exponent have different interfaces.\n"): + assert_equal(op1, op2, check_interface=True, check_trainability=False) def test_pow_comparison_with_trainability(self): """Test that equal compares the parameters within a provided trainability of the Pow class.""" @@ -1780,6 +1922,8 @@ def test_pow_comparison_with_trainability(self): assert qml.equal(op1, op2, check_interface=False, check_trainability=False) assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + with pytest.raises(AssertionError, match="Exponent have different trainability.\n"): + assert_equal(op1, op2, check_interface=True, check_trainability=True) def test_pow_base_op_comparison_with_interface(self): """Test that equal compares the parameters within a provided interface of the base operator of Pow class.""" @@ -1805,8 +1949,25 @@ def test_exp_comparison(self, bases_bases_match, params_params_match): param1, param2, params_match = params_params_match op1 = qml.exp(base1, param1) op2 = qml.exp(base2, param2) + assert qml.equal(op1, op2) == (bases_match and params_match) + def test_exp_with_different_coeffs(self): + """Test that assert_equal fails when coeffs are different""" + op1 = qml.exp(qml.X(0), 0.5j) + op2 = qml.exp(qml.X(0), 1.0j) + + with pytest.raises(AssertionError, match="op1 and op2 have different coefficients."): + assert_equal(op1, op2) + + def test_exp_with_different_base_operator(self): + """Test that assert_equal fails when base operators are different""" + op1 = qml.exp(qml.X(0), 0.5j) + op2 = qml.exp(qml.Y(0), 0.5j) + + with pytest.raises(AssertionError, match=BASE_OPERATION_MISMATCH_ERROR_MESSAGE): + assert_equal(op1, op2) + def test_exp_comparison_with_tolerance(self): """Test that equal compares the parameters within a provided tolerance of the Exp class.""" op1 = qml.exp(qml.PauliX(0), 0.12) @@ -1825,6 +1986,10 @@ def test_exp_comparison_with_interface(self): assert qml.equal(op1, op2, check_interface=False, check_trainability=False) assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + assert_equal(op1, op2, check_interface=False, check_trainability=False) + with pytest.raises(AssertionError, match="Parameters have different interface"): + assert_equal(op1, op2, check_interface=True, check_trainability=False) + def test_exp_comparison_with_trainability(self): """Test that equal compares the parameters within a provided trainability of the Exp class.""" op1 = qml.exp(qml.PauliX(0), npp.array(1.2, requires_grad=False)) @@ -1833,6 +1998,10 @@ def test_exp_comparison_with_trainability(self): assert qml.equal(op1, op2, check_interface=False, check_trainability=False) assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + assert_equal(op1, op2, check_interface=False, check_trainability=False) + with pytest.raises(AssertionError, match="Parameters have different trainability"): + assert_equal(op1, op2, check_interface=False, check_trainability=True) + def test_exp_base_op_comparison_with_interface(self): """Test that equal compares the parameters within a provided interface of the base operator of Exp class.""" op1 = qml.exp(qml.RX(0.5, wires=0), 1.2) @@ -1841,6 +2010,10 @@ def test_exp_base_op_comparison_with_interface(self): assert qml.equal(op1, op2, check_interface=False, check_trainability=False) assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + assert_equal(op1, op2, check_interface=False, check_trainability=False) + with pytest.raises(AssertionError, match="Parameters have different interface"): + assert_equal(op1, op2, check_interface=True, check_trainability=False) + def test_exp_base_op_comparison_with_trainability(self): """Test that equal compares the parameters within a provided trainability of the base operator of Exp class.""" op1 = qml.exp(qml.RX(npp.array(0.5, requires_grad=False), wires=0), 1.2) @@ -1866,6 +2039,27 @@ def test_s_prod_comparison(self, bases_bases_match, params_params_match): op2 = qml.s_prod(param2, base2) assert qml.equal(op1, op2) == (bases_match and params_match) + def test_s_prod_comparison_different_scalar(self): + """Test that equal compares two objects of the SProd class with different scalars""" + base = qml.PauliX(0) @ qml.PauliY(1) + op1 = qml.s_prod(0.2, base) + op2 = qml.s_prod(0.3, base) + + with pytest.raises( + AssertionError, match="op1 and op2 have different scalars. Got 0.2 and 0.3" + ): + assert_equal(op1, op2) + + def test_s_prod_comparison_different_operands(self): + """Test that equal compares two objects of the SProd class with different operands""" + base1 = qml.PauliX(0) @ qml.PauliY(1) + base2 = qml.PauliX(0) @ qml.PauliY(2) + op1 = qml.s_prod(0.2, base1) + op2 = qml.s_prod(0.2, base2) + + with pytest.raises(AssertionError, match=OPERANDS_MISMATCH_ERROR_MESSAGE): + assert_equal(op1, op2) + def test_s_prod_comparison_with_tolerance(self): """Test that equal compares the parameters within a provided tolerance of the SProd class.""" op1 = qml.s_prod(0.12, qml.PauliX(0)) @@ -1883,6 +2077,8 @@ def test_s_prod_comparison_with_interface(self): assert qml.equal(op1, op2, check_interface=False, check_trainability=False) assert not qml.equal(op1, op2, check_interface=True, check_trainability=False) + with pytest.raises(AssertionError, match="Parameters have different interfaces."): + assert_equal(op1, op2) def test_s_prod_comparison_with_trainability(self): """Test that equal compares the parameters within a provided trainability of the SProd class.""" @@ -1891,6 +2087,8 @@ def test_s_prod_comparison_with_trainability(self): assert qml.equal(op1, op2, check_interface=False, check_trainability=False) assert not qml.equal(op1, op2, check_interface=False, check_trainability=True) + with pytest.raises(AssertionError, match="Parameters have different trainability."): + assert_equal(op1, op2) def test_s_prod_base_op_comparison_with_interface(self): """Test that equal compares the parameters within a provided interface of the base operator of SProd class.""" @@ -2092,6 +2290,28 @@ def test_sum_with_multi_wire_operations(self, base_list1, base_list2, res): op2 = qml.sum(*base_list2) assert qml.equal(op1, op2) == res + def test_sum_with_different_operands(self): + """Test sum equals with different operands""" + operands1 = [qml.PauliX(0), qml.PauliY(1)] + operands2 = [qml.PauliY(0), qml.PauliY(1)] + op1 = qml.sum(*operands1) + op2 = qml.sum(*operands2) + + with pytest.raises(AssertionError, match=OPERANDS_MISMATCH_ERROR_MESSAGE): + assert_equal(op1, op2) + + def test_sum_with_different_number_of_operands(self): + """Test sum equals with different number of operands""" + operands1 = [qml.PauliX(0), qml.PauliY(1)] + operands2 = [qml.PauliY(1)] + op1 = qml.sum(*operands1) + op2 = qml.sum(*operands2) + + with pytest.raises( + AssertionError, match="op1 and op2 have different number of operands. Got 2 and 1" + ): + assert_equal(op1, op2) + def test_sum_equal_order_invarient(self): """Test that the order of operations doesn't affect equality""" H1 = qml.prod(qml.PauliX(0), qml.PauliX(1)) @@ -2518,3 +2738,21 @@ def test_interface_and_trainability(self, op, other_op): assert qml.equal(op, other_op, check_interface=False) is False assert qml.equal(op, other_op, check_trainability=False) is False assert qml.equal(op, other_op, check_interface=False, check_trainability=False) is True + + +# pylint: disable=too-few-public-methods +class DepthIncreaseOperator(Operator): + """Dummy class which increases depth by one""" + + # pylint: disable=super-init-not-called + def __init__(self, op: Operator): + self._op = op + + @property + def arithmetic_depth(self) -> int: + """Arithmetic depth of the operator.""" + return 1 + self._op.arithmetic_depth + + @property + def wires(self): + return self._op.wires