Skip to content

Commit

Permalink
Fix bug found by new tests!
Browse files Browse the repository at this point in the history
  • Loading branch information
philbull committed Mar 5, 2019
1 parent e100f91 commit 4efe498
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
2 changes: 1 addition & 1 deletion hera_pspec/pspecdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ def scalar(self, polpair, little_h=True, num_steps=2000, beam=None,
"""
# make sure polarizations are the same
if isinstance(polpair, int):
polpair = uvutils.polpair_int2tuple(polpair)
polpair = uvputils.polpair_int2tuple(polpair)
if polpair[0] != polpair[1]:
raise NotImplementedError(
"Polarizations don't match. Beam scalar can only be "
Expand Down
18 changes: 17 additions & 1 deletion hera_pspec/tests/test_pspecdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def test_str(self):
print(ds) # print empty psd
ds.add(self.uvd, None)
print(ds) # print populated psd


def test_get_Q_alt(self):

Expand Down Expand Up @@ -485,7 +486,11 @@ def test_get_MW(self):
# Test that the norm matrix is diagonal
M, W = self.ds.get_MW(random_G, random_H, mode=mode)
self.assertEqual(diagonal_or_not(M), True)

elif mode == 'L^-1':
# Test that Cholesky mode is disabled
nt.assert_raises(NotImplementedError,
self.ds.get_MW, random_G, random_H, mode=mode)

# Test sizes for everyone
self.assertEqual(M.shape, (n,n))
self.assertEqual(W.shape, (n,n))
Expand Down Expand Up @@ -807,7 +812,9 @@ def test_scalar(self):

# Check normal execution
scalar = self.ds.scalar(('xx','xx'))
scalar = self.ds.scalar(1515) # polpair-integer = ('xx', 'xx')
scalar = self.ds.scalar(('xx','xx'), taper_override='none')
scalar = self.ds.scalar(('xx','xx'), beam=gauss)
nt.assert_raises(NotImplementedError, self.ds.scalar, ('xx','yy'))

# Precomputed results in the following test were done "by hand"
Expand Down Expand Up @@ -861,6 +868,12 @@ def test_validate_datasets(self):

# test polarization
ds.validate_pol((0,1), ('xx', 'xx'))

# test channel widths
uvd2.channel_width *= 2.
ds2 = pspecdata.PSpecData(dsets=[uvd, uvd2], wgts=[None, None])
nt.assert_raises(ValueError, ds2.validate_datasets)


def test_rephase_to_dset(self):
# generate two uvd objects w/ different LST grids
Expand Down Expand Up @@ -961,6 +974,9 @@ def test_check_in_dset(self):
nt.assert_false(ds.check_key_in_dset((24, 26, 'yy'), 0))
# check exception
nt.assert_raises(KeyError, ds.check_key_in_dset, (1,2,3,4,5), 0)

# test dset_idx
nt.assert_raises(TypeError, ds.dset_idx, (1,2))

def test_pspec(self):
# generate ds
Expand Down

0 comments on commit 4efe498

Please sign in to comment.