Skip to content

Commit

Permalink
Tf engine dtype (#562)
Browse files Browse the repository at this point in the history
* first commit

* dtype as kwarg start

* fix formatting

* fix tests

* proof of concept line change

* ready for testing

* passing tests

* formatted

* starting docs changes

* lint fix

* minor cleanup

* resolve lint

* fixes based on review and cleanup

* undo cleanup for tests

* more undo for tests

* test parameterization
  • Loading branch information
Aaron-Robertson committed Apr 6, 2021
1 parent ebcca23 commit a0f8b5b
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 126 deletions.
21 changes: 21 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,29 @@
plt.contourf(xvec, pvec, wigner)
plt.show()
```

* The `fock` backend now supports the `GKP` preparation [(#553)](https://github.com/XanaduAI/strawberryfields/pull/553)

* The `tf` backend now accepts the Tensor DType as argument.
[(#562)](https://github.com/XanaduAI/strawberryfields/pull/562)

Allows high cutoff dimension to give numerically correct calculations:

```python
prog = sf.Program(2)
eng = sf.Engine("tf", backend_options={"cutoff_dim": 50, "dtype": tf.complex128})
with prog.context as q:
Sgate(0.8) | q[0]
Sgate(0.8) | q[1]
BSgate(0.5,0.5) | (q[0], q[1])
BSgate(0.5,0.5) | (q[0], q[1])
state = eng.run(prog).state
N0, N0var = state.mean_photon(0)
N1, N1var = state.mean_photon(1)
print(N0)
print(N1)
print("analytical:" ,np.sinh(0.8)**2)
```

<h3>Breaking Changes</h3>

Expand Down
3 changes: 2 additions & 1 deletion strawberryfields/backends/tfbackend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
* the number of modes :math:`n` actively being simulated
* the cutoff dimension :math:`D` for the Fock basis
* whether the circuit is operating in *batched mode* (with batch size :math:`B`)
* the Tensor DType passed to the circuit
When not operating in batched mode, the state tensor corresponds to a single multimode quantum system. If the
representation is a pure state, the state tensor has shape
Expand Down Expand Up @@ -127,6 +128,7 @@
mean_photon
batched
cutoff_dim
dtype
Code details
Expand Down Expand Up @@ -169,4 +171,3 @@ def excepthook(type, value, traceback):


from .backend import TFBackend
from .ops import def_type as tf_complex_type
16 changes: 14 additions & 2 deletions strawberryfields/backends/tfbackend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,13 @@ def begin_circuit(self, num_subsystems, **kwargs):
For each mode, the simulator can represent the Fock states :math:`\ket{0}, \ket{1}, \ldots, \ket{\text{cutoff_dim}-1}`.
pure (bool): If True (default), use a pure state representation (otherwise will use a mixed state representation).
batch_size (None or int): Size of the batch-axis dimension. If None, no batch-axis will be used.
dtype (tf.DType): Complex Tensorflow Tensor type representation, either ``tf.complex64`` (default) or ``tf.complex128``.
Note, ``tf.complex128`` will increase memory usage substantially.
"""
cutoff_dim = kwargs.get("cutoff_dim", None)
pure = kwargs.get("pure", True)
batch_size = kwargs.get("batch_size", None)
dtype = kwargs.get("dtype", tf.complex64)

if cutoff_dim is None:
raise ValueError("Argument 'cutoff_dim' must be passed to the TensorFlow backend")
Expand All @@ -96,14 +99,16 @@ def begin_circuit(self, num_subsystems, **kwargs):
raise ValueError("Argument 'cutoff_dim' must be a positive integer")
if not isinstance(pure, bool):
raise ValueError("Argument 'pure' must be either True or False")
if not dtype in (tf.complex64, tf.complex128):
raise ValueError("Argument 'dtype' must be a complex Tensorflow DType")
if batch_size == 1:
raise ValueError(
"batch_size of 1 not supported, please use different batch_size or set batch_size=None"
)

with tf.name_scope("Begin_circuit"):
self._modemap = ModeMap(num_subsystems)
circuit = Circuit(num_subsystems, cutoff_dim, pure, batch_size)
circuit = Circuit(num_subsystems, cutoff_dim, pure, batch_size, dtype)

self._init_modes = num_subsystems
self.circuit = circuit
Expand Down Expand Up @@ -231,6 +236,7 @@ def state(self, modes=None, **kwargs):
pure = self.circuit.state_is_pure
num_modes = self.circuit.num_modes
batched = self.circuit.batched
dtype = self.circuit.dtype

# reduce rho down to specified subsystems
if modes is None:
Expand Down Expand Up @@ -271,7 +277,13 @@ def state(self, modes=None, **kwargs):

modenames = ["q[{}]".format(i) for i in np.array(self.get_modes())[modes]]
state_ = FockStateTF(
s, len(modes), pure, self.circuit.cutoff_dim, batched=batched, mode_names=modenames
s,
len(modes),
pure,
self.circuit.cutoff_dim,
batched=batched,
mode_names=modenames,
dtype=dtype,
)
return state_

Expand Down

0 comments on commit a0f8b5b

Please sign in to comment.