Skip to content

Commit

Permalink
Simulon device spec mode validation (#545)
Browse files Browse the repository at this point in the history
* add checks

* add docstring

* remove extra check

* add alt names to measurements

* add tests

* remove Type import (not sure why is was there)

* update error text

* Update strawberryfields/program.py

Co-authored-by: antalszava <antalszava@gmail.com>

* Apply suggestions from code review

Co-authored-by: antalszava <antalszava@gmail.com>

* fix ifs

* run black

* separate errors

* add unit tests

* add docstrings to test

* remove comments

* Update tests/frontend/test_program.py

Co-authored-by: antalszava <antalszava@gmail.com>

* add test for wrong entry error

* add max to method

* remove hard-coded simulon

* update tests

* minor docstring update

Co-authored-by: antalszava <antalszava@gmail.com>
  • Loading branch information
thisac and antalszava committed Feb 26, 2021
1 parent 57c4242 commit 50ae64a
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 6 deletions.
66 changes: 62 additions & 4 deletions strawberryfields/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,13 @@ def _linked_copy(self):
return p

def assert_number_of_modes(self, device):
"""Check that the number of modes in the program is valid for the given device."""
"""Check that the number of modes in the program is valid for the given device.
Args:
device (~strawberryfields.api.DeviceSpec): Device specification object to use.
``device.modes`` must be an integer, containing the allowed number of modes
for the target.
"""
# Program subsystems may be created and destroyed during execution. The length
# of the program registers represents the total number of modes that has ever existed.
modes_total = len(self.reg_refs)
Expand All @@ -464,6 +469,53 @@ def assert_number_of_modes(self, device):
f"only supports a {device.modes}-mode program."
)

def assert_max_number_of_measurements(self, device):
"""Check that the number of measurements in the circuit doesn't exceed the number of allowed
measurements according to the device specification.
Args:
device (~strawberryfields.api.DeviceSpec): Device specification object to use.
``device.modes`` must be a dictionary, containing the maximum number of allowed
measurements for the specified target.
"""
num_pnr, num_homodyne, num_heterodyne = 0, 0, 0

try:
max_pnr = device.modes["max"]["pnr"]
max_homodyne = device.modes["max"]["homodyne"]
max_heterodyne = device.modes["max"]["heterodyne"]
except (KeyError, TypeError) as e:
raise KeyError(
"Device specification must contain an entry for the maximum allowed number "
"of measurments. Have you specified the correct target?"
) from e

for c in self.circuit:
op_name = str(c.op)
if "MeasureFock" in op_name:
num_pnr += len(c.reg)
elif "MeasureHomodyne" in op_name or "MeasureX" in op_name or "MeasureP" in op_name:
num_homodyne += len(c.reg)
elif "MeasureHeterodyne" in op_name or "MeasureHD" in op_name:
num_heterodyne += len(c.reg)

if num_pnr > max_pnr:
raise CircuitError(
f"This program contains {num_pnr} fock measurements. "
f"A maximum of {max_pnr} fock measurements are supported."
)
if num_homodyne > max_homodyne:
raise CircuitError(
f"This program contains {num_homodyne} homodyne measurements. "
f"A maximum of {max_homodyne} homodyne measurements are supported."
)
if num_heterodyne > max_heterodyne:
raise CircuitError(
f"This program contains {num_heterodyne} heterodyne measurements. "
f"A maximum of {max_heterodyne} heterodyne measurements are supported."
)

def compile(self, *, device=None, compiler=None, **kwargs):
"""Compile the program given a Strawberry Fields photonic compiler, or
hardware device specification.
Expand Down Expand Up @@ -543,9 +595,15 @@ def _get_compiler(compiler_or_name):
else:
compiler = _get_compiler(compiler)

# TODO: add validation for device specs that provide a dictionary for `device.modes`.
if device.modes is not None and isinstance(device.modes, int):
self.assert_number_of_modes(device)
if device.modes is not None:
if isinstance(device.modes, int):
# check that the number of modes is correct, if device.modes
# is provided as an integer
self.assert_number_of_modes(device)
else:
# check that the number of measurements is within the allowed
# limits for each measurement type; device.modes will be a dictionary
self.assert_max_number_of_measurements(device)

else:
compiler = _get_compiler(compiler)
Expand Down
114 changes: 112 additions & 2 deletions tests/frontend/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,69 @@ def test_bind_params(self, prog):
assert x.val == 2.0
assert y.val is None

def test_assert_number_of_modes(self):
"""Check that the correct error is raised when calling `prog.assert_number_of_modes`
with the incorrect number of modes."""
device_dict = {"modes": 2, "layout": None, "gate_parameters": None, "compiler": [None]}
spec = sf.api.DeviceSpec(target=None, connection=None, spec=device_dict)

prog = sf.Program(3)
with prog.context as q:
ops.S2gate(0.6) | [q[0], q[1]]
ops.S2gate(0.6) | [q[1], q[2]]

with pytest.raises(program.CircuitError, match="program contains 3 modes, but the device 'None' only supports a 2-mode program"):
prog.assert_number_of_modes(spec)

@pytest.mark.parametrize(
"measure_op, measure_name", [
(ops.MeasureFock(), "fock"), # MeasureFock
(ops.MeasureHomodyne(phi=0), "homodyne"), # MeasureX
(ops.MeasureHomodyne(phi=42), "homodyne"), # MeasureHomodyne
(ops.MeasureHomodyne(phi=np.pi/2), "homodyne"), # MeasureP
(ops.MeasureHeterodyne(), "heterodyne"), # MeasureHD
(ops.MeasureHeterodyne(select=0), "heterodyne"), # MeasureHeterodyne
],
)
def test_assert_max_number_of_measurements(self, measure_op, measure_name):
"""Check that the correct error is raised when calling `prog.assert_number_of_measurements`
with the incorrect number of measurements in the circuit."""
# set maximum number of measurements to 2, and measure 3 in prog below
device_dict = {
"modes": {
"max": {
"pnr": 2,
"homodyne": 2,
"heterodyne": 2
}
},
"layout": None, "gate_parameters": {}, "compiler": [None]
}
spec = sf.api.DeviceSpec(target="simulon", connection=None, spec=device_dict)

prog = sf.Program(3)
with prog.context as q:
for reg in q:
measure_op | reg

with pytest.raises(
program.CircuitError, match=f"contains 3 {measure_name} measurements"
):
prog.assert_max_number_of_measurements(spec)

def test_assert_max_number_of_measurements_wrong_entry(self):
"""Check that the correct error is raised when calling `prog.assert_number_of_measurements`
with the incorrect type of device spec mode entry."""
device_dict = {"modes": 2, "layout": None, "gate_parameters": None, "compiler": [None]}
spec = sf.api.DeviceSpec(target="simulon", connection=None, spec=device_dict)

prog = sf.Program(3)
with prog.context as q:
ops.S2gate(0.6) | [q[0], q[1]]
ops.S2gate(0.6) | [q[1], q[2]]

with pytest.raises(KeyError, match="Have you specified the correct target?"):
prog.assert_max_number_of_measurements(spec)

class TestRegRefs:
"""Testing register references."""
Expand Down Expand Up @@ -381,11 +444,13 @@ def test_disconnected_circuit(self):
with pytest.warns(UserWarning, match='The circuit consists of 2 disconnected components.'):
new_prog = prog.compile(compiler='fock')

# TODO: move this test into an integration tests folder (a similar test for the
# `prog.assert_number_of_modes` method can be found above), under `test_assert_number_of_modes`.
def test_incorrect_modes(self):
"""Test that an exception is raised if the compiler
is called with a device spec with an incorrect number of modes"""

class DummyCircuit(Compiler):
class DummyCompiler(Compiler):
"""A circuit with 2 modes"""
interactive = True
primitives = {'S2gate', 'Interferometer'}
Expand All @@ -400,7 +465,52 @@ class DummyCircuit(Compiler):
ops.S2gate(0.6) | [q[1], q[2]]

with pytest.raises(program.CircuitError, match="program contains 3 modes, but the device 'None' only supports a 2-mode program"):
new_prog = prog.compile(device=spec, compiler=DummyCircuit())
new_prog = prog.compile(device=spec, compiler=DummyCompiler())

# TODO: move this test into an integration tests folder (a similar test for the
# `prog.assert_number_of_measurements` method can be found above), named `test_assert_number_of_measurements`.
@pytest.mark.parametrize(
"measure_op, measure_name", [
(ops.MeasureFock(), "fock"), # MeasureFock
(ops.MeasureHomodyne(phi=0), "homodyne"), # MeasureX
(ops.MeasureHomodyne(phi=42), "homodyne"), # MeasureHomodyne
(ops.MeasureHomodyne(phi=np.pi/2), "homodyne"), # MeasureP
(ops.MeasureHeterodyne(), "heterodyne"), # MeasureHD
(ops.MeasureHeterodyne(select=0), "heterodyne"), # MeasureHeterodyne
],
)
def test_incorrect_number_of_measurements(self, measure_op, measure_name):
"""Test that an exception is raised if the compiler is called with a
device spec with an incorrect number of measurements"""

class DummyCompiler(Compiler):
"""A circuit with 2 modes"""
interactive = True
primitives = {'MeasureHomodyne', 'MeasureHeterodyne', 'MeasureFock'}
decompositions = set()

# set maximum number of measurements to 2, and measure 3 in prog below
device_dict = {
"modes": {
"max": {
"pnr": 2,
"homodyne": 2,
"heterodyne": 2
}
},
"layout": None, "gate_parameters": {}, "compiler": [None]
}
spec = sf.api.DeviceSpec(target="simulon", connection=None, spec=device_dict)

prog = sf.Program(3)
with prog.context as q:
for reg in q:
measure_op | reg

with pytest.raises(
program.CircuitError, match=f"contains 3 {measure_name} measurements"
):
prog.compile(device=spec, compiler=DummyCompiler())

def test_no_default_compiler(self):
"""Test that an exception is raised if the DeviceSpec has no compilers
Expand Down

0 comments on commit 50ae64a

Please sign in to comment.