diff --git a/docs/reference/solver.rst b/docs/reference/solver.rst index 6decd21b..e718cfb0 100644 --- a/docs/reference/solver.rst +++ b/docs/reference/solver.rst @@ -29,6 +29,7 @@ Methods StructuredSolver.sample_ising StructuredSolver.sample_qubo StructuredSolver.sample_bqm + StructuredSolver.reformat_parameters UnstructuredSolver.sample_ising UnstructuredSolver.sample_qubo diff --git a/dwave/cloud/solver.py b/dwave/cloud/solver.py index 4cbed3c5..79577765 100644 --- a/dwave/cloud/solver.py +++ b/dwave/cloud/solver.py @@ -28,8 +28,10 @@ """ +import copy import json import logging +import typing import warnings from collections.abc import Mapping @@ -51,6 +53,11 @@ except ImportError: _numpy = False +try: + import dimod +except ImportError: + dimod = None + __all__ = [ 'BaseSolver', 'StructuredSolver', 'BaseUnstructuredSolver', 'UnstructuredSolver', @@ -60,6 +67,19 @@ logger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + # do a bit of fiddling + try: + _Type = typing.Literal['qubo', 'ising'] + except ImportError: # Python < 3.8 + _Type = str + + try: + _Vartype = typing.Union[_Type, dimod.typing.VartypeLike] + except AttributeError: # dimod not installed or too old + _Vartype = _Type + + class BaseSolver(object): """Base class for a general D-Wave solver. @@ -1043,19 +1063,62 @@ def _sample(self, type_, linear, quadratic, offset, params, return computation + # kept for internal backwards compatibility and in case it's being + # used externally anywhere. def _format_params(self, type_, params): """Reformat some of the parameters for sapi.""" - if 'initial_state' in params: - # NB: at this moment the error raised when initial_state does not match lin/quad (in - # active qubits) is not very informative, but there is also no clean way to check here - # that they match because lin can be either a list or a dict. In the future it would be - # good to check. - initial_state = params['initial_state'] - if isinstance(initial_state, Mapping): + self.reformat_parameters(type_, params, self.properties, inplace=True) + + @staticmethod + def reformat_parameters(vartype: '_Vartype', + parameters: typing.MutableMapping[str, typing.Any], + properties: typing.Mapping[str, typing.Any], + inplace: bool = False, + ) -> typing.MutableMapping[str, typing.Any]: + """Reformat some solver parameters for SAPI. - initial_state_list = [3]*self.properties['num_qubits'] + Currently the only reformatted parameter is ``initial_state``. This + method allows ``initial_state`` to be submitted as a dictionary + mapping the qubits to their initial value. - low = -1 if type_ == 'ising' else 0 + Args: + vartype: One of ``'ising'`` or ``'qubo'``. If :mod:`dimod` is + installed, this can also be any + :class:`~dimod.typing.VartypeLike`. + parameters: The parameters to submit to ths solver. + properties: The solver's properties. Note that this will + work with either :attr:`StructuredSolver.properties` + or :attr:`dwave.systems.DWaveSampler.properties`. + + Returns: + The reformatted solver parameters. + If ``inplace`` this will be the ``parameters``, modified. + If ``not inplace`` then this will be a deep copy of ``parameters``, + with the relevant fields updated. + + """ + # whether to copy or not + parameters = parameters if inplace else copy.deepcopy(parameters) + + # handle the vartype + if vartype not in ('ising', 'qubo'): + try: + vartype = 'ising' if dimod.as_vartype(vartype) is dimod.SPIN else 'qubo' + except (TypeError, AttributeError): + msg = "expected vartype to be one of: 'ising', 'qubo'" + if dimod: + msg += ", 'BINARY', 'SPIN', dimod.BINARY, dimod.SPIN" + msg += f"; {vartype!r} provided" + raise ValueError(msg) from None + + # update the parameters + if 'initial_state' in parameters: + initial_state = parameters['initial_state'] + if isinstance(initial_state, typing.Mapping): + + initial_state_list = [3]*properties['num_qubits'] + + low = -1 if vartype == 'ising' else 0 for v, val in initial_state.items(): if val == 3: @@ -1065,9 +1128,11 @@ def _format_params(self, type_, params): else: initial_state_list[v] = 1 - params['initial_state'] = initial_state_list + parameters['initial_state'] = initial_state_list # else: support old format + return parameters + def check_problem(self, linear, quadratic): """Test if an Ising model matches the graph provided by the solver. diff --git a/releasenotes/notes/StructureSolver.reformat_parameters-a3ca02b725495496.yaml b/releasenotes/notes/StructureSolver.reformat_parameters-a3ca02b725495496.yaml new file mode 100644 index 00000000..cefdff0f --- /dev/null +++ b/releasenotes/notes/StructureSolver.reformat_parameters-a3ca02b725495496.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Add ``StructuredSolver.reformat_parameters()`` method. This method can + be used to format solver parameters for SAPI. + See `#465 `_. diff --git a/tests/test_solver_utils.py b/tests/test_solver_utils.py index 27b45bb9..be97f882 100644 --- a/tests/test_solver_utils.py +++ b/tests/test_solver_utils.py @@ -64,3 +64,51 @@ def test_legacy_format(self): bqm = dimod.generators.ran_r(1, list(self.solver.edges)) h = list(bqm.linear.values()) # h as list is still supported self.assertTrue(self.solver.check_problem(h, bqm.quadratic)) + + +class TestReformatParameters(unittest.TestCase): + def test_empty(self): + self.assertEqual(StructuredSolver.reformat_parameters('ising', {}, {}), {}) + + def test_initial_states(self): + doc = {'initial_state': {0: 0, 4: 1}} + + self.assertEqual(StructuredSolver.reformat_parameters('ising', doc, dict(num_qubits=9)), + dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3])) + self.assertEqual(StructuredSolver.reformat_parameters('qubo', doc, dict(num_qubits=9)), + dict(initial_state=[0, 3, 3, 3, 1, 3, 3, 3, 3])) + + if dimod: + self.assertEqual(StructuredSolver.reformat_parameters('SPIN', doc, dict(num_qubits=9)), + dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3])) + self.assertEqual(StructuredSolver.reformat_parameters('BINARY', doc, dict(num_qubits=9)), + dict(initial_state=[0, 3, 3, 3, 1, 3, 3, 3, 3])) + + self.assertEqual(doc, {'initial_state': {0: 0, 4: 1}}) + + def test_initial_states_inplace(self): + doc = {'initial_state': {0: 0, 4: 1}} + StructuredSolver.reformat_parameters('ising', doc, dict(num_qubits=9), inplace=True) + self.assertEqual(doc, dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3])) + + def test_initial_states_sequence(self): + doc = {'initial_state': [-1, 3, 3, 3, 1, 3, 3, 3, 3]} + self.assertEqual(StructuredSolver.reformat_parameters('ising', doc, dict(num_qubits=9)), + dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3])) + + def test_vartype_smoke(self): + for vt in StructuredSolver._handled_problem_types: + StructuredSolver.reformat_parameters(vt, {}, {}) + + with self.assertRaises(ValueError): + StructuredSolver.reformat_parameters('not a type', {}, {}) + + @unittest.skipUnless(dimod, "dimod not installed") + def test_vartype_dimod_smoke(self): + StructuredSolver.reformat_parameters('SPIN', {}, {}) + StructuredSolver.reformat_parameters('BINARY', {}, {}) + StructuredSolver.reformat_parameters(dimod.BINARY, {}, {}) + StructuredSolver.reformat_parameters(dimod.SPIN, {}, {}) + + with self.assertRaises(ValueError): + StructuredSolver.reformat_parameters(dimod.INTEGER, {}, {})