Skip to content

Commit

Permalink
pspecdata test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nkern committed Jul 10, 2019
1 parent 41615b9 commit 1cbf5b9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 51 deletions.
41 changes: 18 additions & 23 deletions hera_pspec/pspecdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,34 +123,29 @@ def add(self, dsets, wgts, labels=None, dsets_std=None, cals=None, cal_flag=True
"specified.")
labels = list(dsets.keys())

if not isinstance(wgts, dict):
if wgts is None:
wgts = dict([(l, None) for l in labels])
elif not isinstance(wgts, dict):
raise TypeError("If 'dsets' is a dict, 'wgts' must also be "
"a dict")

if not isinstance(dsets_std, dict):
if dsets_std is None:
dsets_std = [None for m in range(len(dsets))]
else:
raise TypeError("If 'dsets' is a dict, 'dsets_std' must"
"also be a dict")
else:
_dsets_std = [dsets_std[key] for key in labels]
dsets_std = _dsets_std
if dsets_std is None:
dsets_std = dict([(l, None) for l in labels])
elif not isinstance(dsets_std, dict):
raise TypeError("If 'dsets' is a dict, 'dsets_std' must also be "
"a dict")

if not isinstance(cals, dict):
if cals is None:
cals = [None for m in range(len(dsets))]
else:
raise TypeError("If 'dsets' is a dict, 'cals' must"
"also be a dict")
if cals is None:
cals = dict([(l, None) for l in labels])
elif not isinstance(cals, dict):
raise TypeError("If 'cals' is a dict, 'cals' must also be "
"a dict")

# Unpack dsets and wgts dicts
_dsets = [dsets[key] for key in labels]
_wgts = [wgts[key] for key in labels]
_cals = [cals[key] for key in labels]
dsets = _dsets
wgts = _wgts
cals = _cals
dsets = [dsets[key] for key in labels]
dsets_std = [dsets_std[key] for key in labels]
wgts = [wgts[key] for key in labels]
cals = [cals[key] for key in labels]

# Convert input args to lists if possible
if isinstance(dsets, UVData): dsets = [dsets,]
Expand Down Expand Up @@ -203,7 +198,7 @@ def add(self, dsets, wgts, labels=None, dsets_std=None, cals=None, cal_flag=True
self.labels = []
if labels is None:
labels = ["dset{:d}".format(i)
for i in range(len(self.dsets), len(dsets)+len(self.dsets))]
for i in range(len(self.dsets), len(dsets) + len(self.dsets))]

# Apply calibration if provided
for dset, dset_std, cal in zip(dsets, dsets_std, cals):
Expand Down
68 changes: 40 additions & 28 deletions hera_pspec/tests/test_pspecdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,28 @@ def test_init(self):

def test_add_data(self):
"""
Test adding non UVData object.
Test PSpecData add()
"""
uv = self.d[0]
# test adding non list objects
nt.assert_raises(TypeError, self.ds.add, 1, 1)
#test TypeError if dsets is dict but dsets_std is not
nt.assert_raises(TypeError,self.ds.add,{'d':0},{'w':0},None,[0])
nt.assert_raises(TypeError,self.ds.add,{'d':0},{'w':0},None,{'e':0})
nt.assert_raises(TypeError,self.ds.add,{'d':0},[0],None,{'e':0})
# test adding non UVData objects
nt.assert_raises(TypeError, self.ds.add, [1], None)
nt.assert_raises(TypeError, self.ds.add, [uv], [1])
nt.assert_raises(TypeError, self.ds.add, [uv], None, dsets_std=[1])
# test adding non UVCal for cals
nt.assert_raises(TypeError, self.ds.add, [uv], None, cals=[1])
# test TypeError if dsets is dict but other inputs are not
nt.assert_raises(TypeError, self.ds.add, {'d':uv}, [0])
nt.assert_raises(TypeError, self.ds.add, {'d':uv}, {'d':uv}, dsets_std=[0])
nt.assert_raises(TypeError, self.ds.add, {'d':uv}, {'d':uv}, cals=[0])
# specifying labels when dsets is a dict is a ValueError
nt.assert_raises(ValueError, self.ds.add, {'d':uv}, None, labels=['d'])
# use lists, but not appropriate lengths
nt.assert_raises(AssertionError, self.ds.add, [uv], [uv, uv])
nt.assert_raises(AssertionError, self.ds.add, [uv], None, dsets_std=[uv, uv])
nt.assert_raises(AssertionError, self.ds.add, [uv], None, cals=[None, None])
nt.assert_raises(AssertionError, self.ds.add, [uv], None, labels=['foo', 'bar'])

def test_labels(self):
"""
Expand All @@ -208,11 +223,11 @@ def test_labels(self):
# Check specifying labels using dicts
dsdict = {'a':self.d[0], 'b':self.d[1]}
psd = pspecdata.PSpecData(dsets=dsdict, wgts=dsdict)
self.assertRaises(ValueError, pspecdata.PSpecData, dsets=dsdict,
nt.assert_raises(ValueError, pspecdata.PSpecData, dsets=dsdict,
wgts=dsdict, labels=['a', 'b'])

# Check that invalid labels raise errors
self.assertRaises(KeyError, psd.x, ('green', 24, 38))
nt.assert_raises(KeyError, psd.x, ('green', 24, 38))

def test_parse_blkey(self):
# make a double-pol UVData
Expand Down Expand Up @@ -241,10 +256,8 @@ def test_str(self):
print(ds) # print empty psd
ds.add(self.uvd, None)
print(ds) # print populated psd


def test_get_Q_alt(self):

"""
Test the Q = dC/dp function.
"""
Expand Down Expand Up @@ -534,7 +547,6 @@ def test_get_unnormed_V(self):
for j in range(self.ds.spw_Ndlys):
self.assertLessEqual(frac_non_herm[i,j], tol)


def test_get_MW(self):
n = 17
random_G = generate_pos_def_all_pos(n)
Expand Down Expand Up @@ -905,7 +917,6 @@ def test_scalar_delay_adjustment(self):
adjustment = self.ds.scalar_delay_adjustment(key1, key2, sampling=True)
self.assertAlmostEqual(adjustment, 1.0)


def test_scalar(self):
self.ds = pspecdata.PSpecData(dsets=self.d, wgts=self.w, beam=self.bm)

Expand Down Expand Up @@ -1456,6 +1467,7 @@ def test_validate_blpairs(self):
blpairs = [((24, 25), (24, 38))]
pspecdata.validate_blpairs(blpairs, uvd, uvd)


def test_pspec_run():
fnames = [os.path.join(DATA_PATH, d)
for d in ['zen.even.xx.LST.1.28828.uvOCRSA',
Expand Down Expand Up @@ -1483,23 +1495,23 @@ def test_pspec_run():
cosmo = conversions.Cosmo_Conversions(Om_L=0.0)
if os.path.exists("./out.h5"):
os.remove("./out.h5")
psc, ds = pspecdata.pspec_run(fnames, "./out.h5",
dsets_std=fnames_std,
Jy2mK=True,
beam=beamfile,
blpairs=[((37, 38), (37, 38)),
((37, 38), (52, 53))],
verbose=False,
overwrite=True,
pol_pairs=[('xx', 'xx'), ('xx', 'xx')],
dset_labels=["foo", "bar"],
dset_pairs=[(0, 0), (0, 1)],
spw_ranges=[(50, 75), (120, 140)],
n_dlys=[20, 20],
cosmo=cosmo,
trim_dset_lsts=False,
broadcast_dset_flags=False,
store_cov=True)
ds = pspecdata.pspec_run(fnames, "./out.h5",
dsets_std=fnames_std,
Jy2mK=True,
beam=beamfile,
blpairs=[((37, 38), (37, 38)),
((37, 38), (52, 53))],
verbose=False,
overwrite=True,
pol_pairs=[('xx', 'xx'), ('xx', 'xx')],
dset_labels=["foo", "bar"],
dset_pairs=[(0, 0), (0, 1)],
spw_ranges=[(50, 75), (120, 140)],
n_dlys=[20, 20],
cosmo=cosmo,
trim_dset_lsts=False,
broadcast_dset_flags=False,
store_cov=True)

# assert groupname is dset1_dset2
psc = container.PSpecContainer('./out.h5')
Expand Down

0 comments on commit 1cbf5b9

Please sign in to comment.