Skip to content

Commit

Permalink
Symbolic parameter support in TDM compiler (#625)
Browse files Browse the repository at this point in the history
* fix

* update strictness

* switch empty dict convention

* add test
  • Loading branch information
thisac committed Sep 22, 2021
1 parent 4254e59 commit c80cb33
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
15 changes: 11 additions & 4 deletions strawberryfields/tdm/tdmprogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def reshape_samples(all_samples, modes, N, timebins):
idx_tracker = {i: 0 for i in mode_order}

# iterate backwards through all_samples and add them into the correct mode
new_samples = dict()
new_samples = {}
timebin_idx = 0
for i, mode in enumerate(mode_order):
mode_idx = modes[i % len(N)]
Expand Down Expand Up @@ -513,7 +513,7 @@ def compile(self, *, device=None, compiler=None):
program_param = self.rolled_circuit[i].op.p[k]

# make sure that hardcoded parameters in the device layout are correct
if not isinstance(param_name, str):
if not isinstance(param_name, str) and not par_is_symbolic(param_name):
if not program_param == param_name:
raise CircuitError(
"Program cannot be used with the device '{}' "
Expand All @@ -525,7 +525,14 @@ def compile(self, *, device=None, compiler=None):
continue

# Obtain the relevant parameter range from the device
param_range = device.gate_parameters[param_name]
param_range = device.gate_parameters.get(str(param_name))
if param_range is None:
raise CircuitError(
"Program cannot be used with the device '{}' "
"due to parameter '{}' not found in device specification.".format(
device.target, param_name
)
)
if par_is_symbolic(program_param):
# If it is a symbolic value go and lookup its corresponding list in self.tdm_params
local_p_vals = self.parameters.get(program_param.name, [])
Expand Down Expand Up @@ -677,7 +684,7 @@ def _unroll_program(self, shots):
for _ in range(shots):
# save previous mode index of a command to be able to check when modes
# are looped back to the start (not allowed when space-unrolling)
previous_mode_index = dict()
previous_mode_index = {}

for cmd in self.rolled_circuit:
previous_mode_index[cmd] = 0
Expand Down
14 changes: 14 additions & 0 deletions tests/frontend/test_tdmprogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,20 @@ def test_tdm_wrong_parameters_symbolic(self):
with pytest.raises(CircuitError, match="due to incompatible parameter."):
prog.compile(device=device, compiler="TD2")

def test_tdm_parameters_not_in_devicespec(self):
"""Test the correct error is raised when the tdm circuit symbolic parameters are not found
in the device specification"""
spec = copy.deepcopy(device_spec)
# "p1" removed from device spec, but is still used in layout
del spec["gate_parameters"]["p1"]

c = 2
prog = singleloop_program(
0.5643, [np.pi / 4, 0] * c, [0, np.pi / 2] * c, [0, 0, np.pi / 2, np.pi / 2]
)
with pytest.raises(CircuitError, match="not found in device specification"):
prog.compile(device=DeviceSpec("TDM", spec, connection=None), compiler="TDM")

def test_tdm_inconsistent_temporal_modes(self):
"""Test the correct error is raised when the tdm circuit has too many temporal modes"""
sq_r = 0.5643
Expand Down

0 comments on commit c80cb33

Please sign in to comment.