Skip to content

Commit

Permalink
Test catstate high cutoff case (#557)
Browse files Browse the repository at this point in the history
* parametrize cat_state test

* parametrize odd cat_state test; add comment
  • Loading branch information
antalszava committed Mar 13, 2021
1 parent 6fd7aaa commit 5947864
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tests/frontend/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,16 @@ def test_fock_state(self):
state = utils.fock_state(n, fock_dim=cutoff)
assert np.all(state == np.eye(1, cutoff, n))

def test_even_cat_state(self, tol):
@pytest.mark.parametrize("a, cutoff", [(0.212, 10), (4, 50)])
def test_even_cat_state(self, a, cutoff, tol):
"""test correct even cat state returned"""
a = 0.212
cutoff = 10
p = 0

state = utils.cat_state(a, p, fock_dim=cutoff)

# For the analytic expression, cast the integer parameter to float so
# that there's no overflow
a = float(a)
n = np.arange(cutoff)
expected = np.exp(-0.5 * np.abs(a) ** 2) * a ** n / np.sqrt(fac(n)) + np.exp(
-0.5 * np.abs(-a) ** 2
Expand All @@ -234,14 +236,16 @@ def test_even_cat_state(self, tol):

assert np.allclose(state, expected, atol=tol, rtol=0)

def test_odd_cat_state(self, tol):
"""test correct even cat state returned"""
a = 0.212
cutoff = 10
@pytest.mark.parametrize("a, cutoff", [(0.212, 10), (4, 50)])
def test_odd_cat_state(self, a, cutoff, tol):
"""test correct odd cat state returned"""
p = 1

state = utils.cat_state(a, p, fock_dim=cutoff)

# For the analytic expression, cast the integer parameter to float so
# that there's no overflow
a = float(a)
n = np.arange(cutoff)
expected = np.exp(-0.5 * np.abs(a) ** 2) * a ** n / np.sqrt(fac(n)) - np.exp(
-0.5 * np.abs(-a) ** 2
Expand Down

0 comments on commit 5947864

Please sign in to comment.