Skip to content

Commit

Permalink
Fix validation in TdmProgram.compile (#605)
Browse files Browse the repository at this point in the history
* fix tdm compile

* changelog

* whitespace

* whitespace

* implement codeFactor's suggestion

* suggestion from codefactor

* simplify getting parameters

* add test

* remove whitespace

Co-authored-by: Theodor Isacsson <theodor@xanadu.ai>
  • Loading branch information
lneuhaus and thisac committed Jul 20, 2021
1 parent afde299 commit ff2c8ff
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@
by storing them as `blackbird.RegRefTransforms` in the resulting Blackbird program.
[(#596)](https://github.com/XanaduAI/strawberryfields/pull/596)

* Fixed a bug in the validation step of `strawberryfields.tdm.TdmProgram.compile` which almost always
used the wrong set of allowed gate parameter ranges to validate the parameters in a program.
[(#605)](https://github.com/XanaduAI/strawberryfields/pull/605)

<h3>Documentation</h3>

Expand Down
7 changes: 2 additions & 5 deletions strawberryfields/tdm/tdmprogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,6 @@ def compile(self, *, device=None, compiler=None):
"due to incompatible parameter.".format(device.target)
)
# Now we will check explicitly if the parameters in the program match
num_symbolic_param = 0 # counts the number of symbolic variables, which are labelled consecutively by the context method

for k, param_name in enumerate(param_names):
# Obtain the value of the corresponding parameter in the program
program_param = self.rolled_circuit[i].op.p[k]
Expand All @@ -509,7 +507,7 @@ def compile(self, *, device=None, compiler=None):
param_range = device.gate_parameters[param_name]
if sf.parameters.par_is_symbolic(program_param):
# If it is a symbolic value go and lookup its corresponding list in self.tdm_params
local_p_vals = self.tdm_params[num_symbolic_param]
local_p_vals = self.parameters.get(program_param.name, [])

for x in local_p_vals:
if not x in param_range:
Expand All @@ -520,7 +518,6 @@ def compile(self, *, device=None, compiler=None):
device.target, x, param_range
)
)
num_symbolic_param += 1

else:
# If it is a numerical value check directly
Expand Down Expand Up @@ -580,7 +577,7 @@ def unroll(self, shots):
q = self.register

sm = []
for i in range(len(self.N)):
for i, _ in enumerate(self.N):
start = sum(self.N[:i])
stop = sum(self.N[:i]) + self.N[i]
sm.append(slice(start, stop))
Expand Down
84 changes: 75 additions & 9 deletions tests/frontend/test_tdmprogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import copy
from collections.abc import Iterable

import inspect
from strawberryfields.program_utils import CircuitError
import pytest
import numpy as np

Expand Down Expand Up @@ -555,7 +557,7 @@ def test_tdm_wrong_layout(self):
ops.MeasureHomodyne(p[2]) | q[0]
eng = sf.Engine("gaussian")
with pytest.raises(
sf.program_utils.CircuitError,
CircuitError,
match="The gates or the order of gates used in the Program",
):
prog.compile(device=device, compiler="TD2")
Expand All @@ -575,7 +577,7 @@ def test_tdm_wrong_modes(self):
ops.MeasureHomodyne(p[2]) | q[0]
eng = sf.Engine("gaussian")
with pytest.raises(
sf.program_utils.CircuitError, match="due to incompatible mode ordering."
CircuitError, match="due to incompatible mode ordering."
):
prog.compile(device=device, compiler="TD2")

Expand All @@ -587,7 +589,7 @@ def test_tdm_wrong_parameters_explicit(self):
phi = [0, np.pi / 2] * c
theta = [0, 0] + [np.pi / 2, np.pi / 2]
prog = singleloop_program(sq_r, alpha, phi, theta)
with pytest.raises(sf.program_utils.CircuitError, match="due to incompatible parameter."):
with pytest.raises(CircuitError, match="due to incompatible parameter."):
prog.compile(device=device, compiler="TD2")

def test_tdm_wrong_parameters_explicit_in_list(self):
Expand All @@ -601,7 +603,7 @@ def test_tdm_wrong_parameters_explicit_in_list(self):
phi = [0, np.pi / 2] * c
theta = [0, 0] + [np.pi / 2, np.pi / 2]
prog = singleloop_program(sq_r, alpha, phi, theta)
with pytest.raises(sf.program_utils.CircuitError, match="due to incompatible parameter."):
with pytest.raises(CircuitError, match="due to incompatible parameter."):
prog.compile(device=device, compiler="TD2")

def test_tdm_wrong_parameter_second_argument(self):
Expand All @@ -620,7 +622,7 @@ def test_tdm_wrong_parameter_second_argument(self):
ops.Rgate(p[1]) | q[1]
ops.MeasureHomodyne(p[2]) | q[0]
eng = sf.Engine("gaussian")
with pytest.raises(sf.program_utils.CircuitError, match="due to incompatible parameter."):
with pytest.raises(CircuitError, match="due to incompatible parameter."):
prog.compile(device=device, compiler="TD2")

def test_tdm_wrong_parameters_symbolic(self):
Expand All @@ -631,7 +633,7 @@ def test_tdm_wrong_parameters_symbolic(self):
phi = [0, np.pi / 2] * c
theta = [0, 0] + [np.pi / 2, np.pi / 2]
prog = singleloop_program(sq_r, alpha, phi, theta)
with pytest.raises(sf.program_utils.CircuitError, match="due to incompatible parameter."):
with pytest.raises(CircuitError, match="due to incompatible parameter."):
prog.compile(device=device, compiler="TD2")

def test_tdm_inconsistent_temporal_modes(self):
Expand All @@ -642,7 +644,7 @@ def test_tdm_inconsistent_temporal_modes(self):
phi = [0, np.pi / 2] * c
theta = [0, 0] * c
prog = singleloop_program(sq_r, alpha, phi, theta)
with pytest.raises(sf.program_utils.CircuitError, match="temporal modes, but the device"):
with pytest.raises(CircuitError, match="temporal modes, but the device"):
prog.compile(device=device, compiler="TD2")

def test_tdm_inconsistent_concurrent_modes(self):
Expand All @@ -658,7 +660,7 @@ def test_tdm_inconsistent_concurrent_modes(self):
phi = [0, np.pi / 2] * c
theta = [0, 0] * c
prog = singleloop_program(sq_r, alpha, phi, theta)
with pytest.raises(sf.program_utils.CircuitError, match="concurrent modes, but the device"):
with pytest.raises(CircuitError, match="concurrent modes, but the device"):
prog.compile(device=device1, compiler="TD2")

def test_tdm_inconsistent_spatial_modes(self):
Expand All @@ -674,7 +676,7 @@ def test_tdm_inconsistent_spatial_modes(self):
phi = [0, np.pi / 2] * c
theta = [0, 0] * c
prog = singleloop_program(sq_r, alpha, phi, theta)
with pytest.raises(sf.program_utils.CircuitError, match="spatial modes, but the device"):
with pytest.raises(CircuitError, match="spatial modes, but the device"):
prog.compile(device=device1, compiler="TD2")

class TestTDMProgramFunctions:
Expand Down Expand Up @@ -756,3 +758,67 @@ def test_shots_passed(self):
results = eng.run(prog, shots=2)
assert results.samples.shape[0] == 2
assert prog.run_options["shots"] == 5


class TestTDMValidation:
"""Test the validation of TDMProgram against the device specs"""
@pytest.fixture(scope="class")
def device(self):
target = "TD2"
tm = 4
layout = f"""
name template_tdm
version 1.0
target {target} (shots=1)
type tdm (temporal_modes=2)
float array p0[1, {tm}] =
{{rs_array}}
float array p1[1, {tm}] =
{{r_array}}
float array p2[1, {tm}] =
{{bs_array}}
float array p3[1, {tm}] =
{{m_array}}
Sgate(p0) | 1
Rgate(p1) | 0
BSgate(p2, 0) | (0, 1)
MeasureHomodyne(p3) | 0
"""
device_spec = {
"layout": inspect.cleandoc(layout),
"modes": {"concurrent": 2, "spatial": 1, "temporal_max": 100},
"compiler": [target],
"gate_parameters": {
"p0": [-1],
"p1": [1],
"p2": [2],
"p3": [3],
},
}
return DeviceSpec("TD2", device_spec, connection=None)

@staticmethod
def compile_test_program(device, args=(-1, 1, 2, 3)):
"""Compiles a test program with the given gate arguments."""
alpha = [args[1]]
beta = [args[2]]
gamma = [args[3]]
prog = tdmprogram.TDMProgram(N=2)
with prog.context(alpha, beta, gamma) as (p, q):
ops.Sgate(args[0]) | q[1] # Note that the Sgate has a second parameter that is non-zero
ops.Rgate(p[0]) | q[0]
ops.BSgate(p[1]) | (q[0], q[1])
ops.MeasureHomodyne(p[2]) | q[0]
prog.compile(device=device, compiler=device.compiler)

def test_validation_correct_args(self, device):
"""Test that no error is raised when the tdm circuit explicit parameters within the allowed ranges"""
self.compile_test_program(device, args=(-1, 1, 2, 3))

@pytest.mark.parametrize("incorrect_index", list(range(4)))
def test_validation_incorrect_args(self, device, incorrect_index):
"""Test the correct error is raised when the tdm circuit explicit parameters are not within the allowed ranges"""
args = [-1, 1, 2, 3]
args[incorrect_index] = -999
with pytest.raises(CircuitError, match="Parameter has value '-999' while its valid range is "):
self.compile_test_program(device, args=args)

0 comments on commit ff2c8ff

Please sign in to comment.