Skip to content

Commit

Permalink
Fix cat state function array type without doing casting (#556)
Browse files Browse the repository at this point in the history
* Fixed cat state function array type (#554)

* Fixed cat state function array type (#554)

* Update ops.py

* Update states.py

* multiply by 1.0

* Apply suggestions from code review

Co-authored-by: antalszava <antalszava@gmail.com>

Co-authored-by: tguillaume <tguillaume506@gmail.com>
Co-authored-by: antalszava <antalszava@gmail.com>
  • Loading branch information
3 people committed Mar 12, 2021
1 parent d8528d4 commit 6fd7aaa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions strawberryfields/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,8 +867,9 @@ def _apply(self, reg, backend, **kwargs):
N = temp / pf.sqrt(2 * (1 + pf.cos(phi) * temp ** 4))

# coherent states
c1 = (alpha ** l) / np.sqrt(ssp.factorial(l))
c2 = ((-alpha) ** l) / np.sqrt(ssp.factorial(l))
# Need to cast alpha to float before exponentiation to avoid overflow
c1 = ((1.0 * alpha) ** l) / np.sqrt(ssp.factorial(l))
c2 = ((-1.0 * alpha) ** l) / np.sqrt(ssp.factorial(l))
# add them up with a relative phase
ket = (c1 + pf.exp(1j * phi) * c2) * N

Expand Down
5 changes: 3 additions & 2 deletions strawberryfields/utils/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,9 @@ def cat_state(a, p=0, fock_dim=5):

# coherent states
k = np.arange(fock_dim)
c1 = (a ** k) / np.sqrt(fac(k))
c2 = ((-a) ** k) / np.sqrt(fac(k))
# Need to cast a to float before exponentiation to avoid overflow
c1 = ((1.0 * a) ** k) / np.sqrt(fac(k))
c2 = ((-1.0 * a) ** k) / np.sqrt(fac(k))

# add them up with a relative phase
ket = (c1 + np.exp(1j * phi) * c2) * N
Expand Down

0 comments on commit 6fd7aaa

Please sign in to comment.