Skip to content

Commit

Permalink
Add StructuredSampler.reformat_parameters() method
Browse files Browse the repository at this point in the history
  • Loading branch information
arcondello committed Jan 24, 2022
1 parent 901ee64 commit 6c97803
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/reference/solver.rst
Expand Up @@ -29,6 +29,7 @@ Methods
StructuredSolver.sample_ising
StructuredSolver.sample_qubo
StructuredSolver.sample_bqm
StructuredSolver.reformat_parameters

UnstructuredSolver.sample_ising
UnstructuredSolver.sample_qubo
Expand Down
85 changes: 75 additions & 10 deletions dwave/cloud/solver.py
Expand Up @@ -28,8 +28,10 @@
"""

import copy
import json
import logging
import typing
import warnings
from collections.abc import Mapping

Expand All @@ -51,6 +53,11 @@
except ImportError:
_numpy = False

try:
import dimod
except ImportError:
dimod = None

__all__ = [
'BaseSolver', 'StructuredSolver',
'BaseUnstructuredSolver', 'UnstructuredSolver',
Expand All @@ -60,6 +67,19 @@
logger = logging.getLogger(__name__)


if typing.TYPE_CHECKING:
# do a bit of fiddling
try:
_Type = typing.Literal['qubo', 'ising']
except ImportError: # Python < 3.8
_Type = str

try:
_Vartype = typing.Union[_Type, dimod.typing.VartypeLike]
except AttributeError: # dimod not installed or too old
_Vartype = _Type


class BaseSolver(object):
"""Base class for a general D-Wave solver.
Expand Down Expand Up @@ -1043,19 +1063,62 @@ def _sample(self, type_, linear, quadratic, offset, params,

return computation

# kept for internal backwards compatibility and in case it's being
# used externally anywhere.
def _format_params(self, type_, params):
"""Reformat some of the parameters for sapi."""
if 'initial_state' in params:
# NB: at this moment the error raised when initial_state does not match lin/quad (in
# active qubits) is not very informative, but there is also no clean way to check here
# that they match because lin can be either a list or a dict. In the future it would be
# good to check.
initial_state = params['initial_state']
if isinstance(initial_state, Mapping):
self.reformat_parameters(type_, params, self.properties, inplace=True)

@staticmethod
def reformat_parameters(vartype: '_Vartype',
parameters: typing.MutableMapping[str, typing.Any],
properties: typing.Mapping[str, typing.Any],
inplace: bool = False,
) -> typing.MutableMapping[str, typing.Any]:
"""Reformat some solver parameters for SAPI.
initial_state_list = [3]*self.properties['num_qubits']
Currently the only reformatted parameter is ``initial_state``. This
method allows ``initial_state`` to be submitted as a dictionary
mapping the qubits to their initial value.
low = -1 if type_ == 'ising' else 0
Args:
vartype: One of ``'ising'`` or ``'qubo'``. If :mod:`dimod` is
installed, this can also be any
:class:`~dimod.typing.VartypeLike`.
parameters: The parameters to submit to ths solver.
properties: The solver's properties. Note that this will
work with either :attr:`StructuredSolver.properties`
or :attr:`dwave.systems.DWaveSampler.properties`.
Returns:
The reformatted solver parameters.
If ``inplace`` this will be the ``parameters``, modified.
If ``not inplace`` then this will be a deep copy of ``parameters``,
with the relevant fields updated.
"""
# whether to copy or not
parameters = parameters if inplace else copy.deepcopy(parameters)

# handle the vartype
if vartype not in ('ising', 'qubo'):
try:
vartype = 'ising' if dimod.as_vartype(vartype) is dimod.SPIN else 'qubo'
except (TypeError, AttributeError):
msg = "expected vartype to be one of: 'ising', 'qubo'"
if dimod:
msg += ", 'BINARY', 'SPIN', dimod.BINARY, dimod.SPIN"
msg += f"; {vartype!r} provided"
raise ValueError(msg) from None

# update the parameters
if 'initial_state' in parameters:
initial_state = parameters['initial_state']
if isinstance(initial_state, typing.Mapping):

initial_state_list = [3]*properties['num_qubits']

low = -1 if vartype == 'ising' else 0

for v, val in initial_state.items():
if val == 3:
Expand All @@ -1065,9 +1128,11 @@ def _format_params(self, type_, params):
else:
initial_state_list[v] = 1

params['initial_state'] = initial_state_list
parameters['initial_state'] = initial_state_list
# else: support old format

return parameters

def check_problem(self, linear, quadratic):
"""Test if an Ising model matches the graph provided by the solver.
Expand Down
@@ -0,0 +1,6 @@
---
features:
- |
Add ``StructuredSolver.reformat_parameters()`` method. This method can
be used to format solver parameters for SAPI.
See `#465 <https://github.com/dwavesystems/dwave-cloud-client/issues/465>`_.
48 changes: 48 additions & 0 deletions tests/test_solver_utils.py
Expand Up @@ -64,3 +64,51 @@ def test_legacy_format(self):
bqm = dimod.generators.ran_r(1, list(self.solver.edges))
h = list(bqm.linear.values()) # h as list is still supported
self.assertTrue(self.solver.check_problem(h, bqm.quadratic))


class TestReformatParameters(unittest.TestCase):
def test_empty(self):
self.assertEqual(StructuredSolver.reformat_parameters('ising', {}, {}), {})

def test_initial_states(self):
doc = {'initial_state': {0: 0, 4: 1}}

self.assertEqual(StructuredSolver.reformat_parameters('ising', doc, dict(num_qubits=9)),
dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3]))
self.assertEqual(StructuredSolver.reformat_parameters('qubo', doc, dict(num_qubits=9)),
dict(initial_state=[0, 3, 3, 3, 1, 3, 3, 3, 3]))

if dimod:
self.assertEqual(StructuredSolver.reformat_parameters('SPIN', doc, dict(num_qubits=9)),
dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3]))
self.assertEqual(StructuredSolver.reformat_parameters('BINARY', doc, dict(num_qubits=9)),
dict(initial_state=[0, 3, 3, 3, 1, 3, 3, 3, 3]))

self.assertEqual(doc, {'initial_state': {0: 0, 4: 1}})

def test_initial_states_inplace(self):
doc = {'initial_state': {0: 0, 4: 1}}
StructuredSolver.reformat_parameters('ising', doc, dict(num_qubits=9), inplace=True)
self.assertEqual(doc, dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3]))

def test_initial_states_sequence(self):
doc = {'initial_state': [-1, 3, 3, 3, 1, 3, 3, 3, 3]}
self.assertEqual(StructuredSolver.reformat_parameters('ising', doc, dict(num_qubits=9)),
dict(initial_state=[-1, 3, 3, 3, 1, 3, 3, 3, 3]))

def test_vartype_smoke(self):
for vt in StructuredSolver._handled_problem_types:
StructuredSolver.reformat_parameters(vt, {}, {})

with self.assertRaises(ValueError):
StructuredSolver.reformat_parameters('not a type', {}, {})

@unittest.skipUnless(dimod, "dimod not installed")
def test_vartype_dimod_smoke(self):
StructuredSolver.reformat_parameters('SPIN', {}, {})
StructuredSolver.reformat_parameters('BINARY', {}, {})
StructuredSolver.reformat_parameters(dimod.BINARY, {}, {})
StructuredSolver.reformat_parameters(dimod.SPIN, {}, {})

with self.assertRaises(ValueError):
StructuredSolver.reformat_parameters(dimod.INTEGER, {}, {})

0 comments on commit 6c97803

Please sign in to comment.