Skip to content

Commit

Permalink
Fix bug in Dgate, Coherent, and DisplacedSqueezed (#507)
Browse files Browse the repository at this point in the history
* Support TF tensors in batch form

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Add test

* Add to changelog

* Update PR in changelog

* New line

* Extend to other gates

* Update changelog

* Update

* Update test

* Update .github/CHANGELOG.md

Co-authored-by: Josh Izaac <josh146@gmail.com>

* Fix typo

* Fix

Co-authored-by: Josh Izaac <josh146@gmail.com>
  • Loading branch information
trbromley and josh146 committed Dec 21, 2020
1 parent b75b398 commit c7173f9
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
7 changes: 6 additions & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@

<h3>Bug fixes</h3>

* Fixes a bug where `Dgate`, `Coherent`, and `DisplacedSqueezed` do not support TensorFlow tensors
if the tensor has an added dimension due to the existence of batching.
[(#507)](https://github.com/XanaduAI/strawberryfields/pull/507)

* Fixed issue with `reshape_samples` where the samples were sometimes
reshaped in the wrong way.
[(#489)](https://github.com/XanaduAI/strawberryfields/pull/489)
Expand All @@ -134,7 +138,8 @@

This release contains contributions from (in alphabetical order):

Jack Brown, Theodor Isacsson, Josh Izaac, Fabian Laudenbach, Nicolas Quesada, Antal Száva.
Tom Bromley, Jack Brown, Theodor Isacsson, Josh Izaac, Fabian Laudenbach, Nicolas Quesada,
Antal Száva.

# Release 0.16.0 (current release)

Expand Down
17 changes: 9 additions & 8 deletions strawberryfields/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,10 @@ def _apply(self, reg, backend, **kwargs):
r = par_evaluate(self.p[0])
phi = par_evaluate(self.p[1])

tf_complex = any(hasattr(arg, "numpy") and np.iscomplex(arg.numpy()) for arg in [r, phi])
np_args = [arg.numpy() if hasattr(arg, "numpy") else arg for arg in [r, phi]]
is_complex = any([np.iscomplexobj(np.real_if_close(arg)) for arg in np_args])

if (np.iscomplex([r, phi])).any() or tf_complex:
if is_complex:
raise ValueError("The arguments of Coherent(r, phi) cannot be complex")

backend.prepare_coherent_state(r, phi, *reg)
Expand Down Expand Up @@ -722,11 +723,10 @@ def __init__(self, r_d=0.0, phi_d=0.0, r_s=0.0, phi_s=0.0):
def _apply(self, reg, backend, **kwargs):
p = par_evaluate(self.p)

tf_complex = any(
hasattr(arg, "numpy") and np.iscomplex(arg.numpy()) for arg in [p[0], p[1], p[2], p[3]]
)
np_args = [arg.numpy() if hasattr(arg, "numpy") else arg for arg in p]
is_complex = any([np.iscomplexobj(np.real_if_close(arg)) for arg in np_args])

if (np.iscomplex([p[0], p[1], p[2], p[3]])).any() or tf_complex:
if is_complex:
raise ValueError(
"The arguments of DisplacedSqueezed(r_d, phi_d, r_s, phi_s) cannot be complex"
)
Expand Down Expand Up @@ -1337,9 +1337,10 @@ def __init__(self, r, phi=0.0):
def _apply(self, reg, backend, **kwargs):
r, phi = par_evaluate(self.p)

tf_complex = any(hasattr(arg, "numpy") and np.iscomplex(arg.numpy()) for arg in [r, phi])
np_args = [arg.numpy() if hasattr(arg, "numpy") else arg for arg in [r, phi]]
is_complex = any([np.iscomplexobj(np.real_if_close(arg)) for arg in np_args])

if (np.iscomplex([r, phi])).any() or tf_complex:
if is_complex:
raise ValueError("The arguments of Dgate(r, phi) cannot be complex")

backend.displacement(r, phi, *reg)
Expand Down
39 changes: 39 additions & 0 deletions tests/frontend/test_ops_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,42 @@ def test_merge_measured_pars():
# gates that have different p[1] parameters
with pytest.raises(MergeFailure, match="Don't know how to merge these gates."):
assert D.merge(G)


@pytest.mark.parametrize("gate", [ops.Dgate, ops.Coherent, ops.DisplacedSqueezed])
def test_tf_batch_in_gates_previously_supporting_complex(gate):
"""Test if gates that previously accepted complex arguments support the input of TF tensors in
batch form"""
tf = pytest.importorskip("tensorflow")

batch_size = 2
prog = Program(1)
eng = Engine(backend="tf", backend_options={"cutoff_dim": 3, "batch_size": batch_size})

theta = prog.params("theta")
_theta = tf.Variable([0.1] * batch_size)

with prog.context as q:
gate(theta) | q[0]

eng.run(prog, args={"theta": _theta})


@pytest.mark.parametrize("gate", [ops.Dgate, ops.Coherent, ops.DisplacedSqueezed])
def test_tf_batch_complex_raise(gate):
"""Test if an error is raised if complex TF tensors with a batch dimension are input for gates
that previously accepted complex arguments"""
tf = pytest.importorskip("tensorflow")

batch_size = 2
prog = Program(1)
eng = Engine(backend="tf", backend_options={"cutoff_dim": 3, "batch_size": batch_size})

theta = prog.params("theta")
_theta = tf.Variable([0.1j] * batch_size)

with prog.context as q:
gate(theta) | q[0]

with pytest.raises(ValueError, match="cannot be complex"):
eng.run(prog, args={"theta": _theta})

0 comments on commit c7173f9

Please sign in to comment.