Skip to content

Commit

Permalink
updated label handling in PSpecData.add()
Browse files Browse the repository at this point in the history
  • Loading branch information
nkern committed Jul 7, 2018
1 parent 084c1d8 commit 4751b92
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
30 changes: 18 additions & 12 deletions hera_pspec/pspecdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, dsets=[], wgts=[], dsets_std=None, labels=None, beam=None):

# Store the input UVData objects if specified
if len(dsets) > 0:
self.add(dsets, wgts,dsets_std=dsets_std, labels=labels)
self.add(dsets, wgts, dsets_std=dsets_std, labels=labels)

# Store a primary beam
self.primary_beam = beam
Expand Down Expand Up @@ -96,7 +96,6 @@ def add(self, dsets, wgts, labels=None, dsets_std=None):
standard deviations (real and imaginary) of data to add to the
collection. If dsets is a dict, will assume dsets_std is a dict
and if dsets is a list, will assume dsets_std is a list.
"""
# Check for dicts and unpack into an ordered list if found
if isinstance(dsets, dict):
Expand All @@ -119,7 +118,6 @@ def add(self, dsets, wgts, labels=None, dsets_std=None):
_dsets_std = [dsets_std[key] for key in labels]
dsets_std = _dsets_std


# Unpack dsets and wgts dicts
labels = dsets.keys()
_dsets = [dsets[key] for key in labels]
Expand All @@ -144,7 +142,6 @@ def add(self, dsets, wgts, labels=None, dsets_std=None):
raise TypeError("dsets, dsets_std, and wgts must be UVData"
"or lists of UVData")


# Make sure enough weights were specified
assert(len(dsets) == len(wgts))
assert(len(dsets_std) == len(dsets))
Expand All @@ -161,16 +158,28 @@ def add(self, dsets, wgts, labels=None, dsets_std=None):
raise TypeError("Only UVData objects (or None) can be used as "
"error sets")

# Store labels (if they were set)
if self.labels is None:
self.labels = []
if labels is None:
labels = ["dset{:d}".format(i) for i in range(len(self.dsets), len(dsets)+len(self.dsets))]
self.labels += labels

# Append to list
self.dsets += dsets
self.wgts += wgts
self.dsets_std += dsets_std

# Store labels (if they were set)
if labels is None:
self.labels = [None for d in dsets]
else:
self.labels += labels
# Check for repeated labels, and make them unique
for i, l in enumerate(self.labels):
ext = 1
while True:
if l in self.labels[:i]:
l = self.labels[i] + "_{:d}".format(ext)
ext += 1
else:
self.labels[i] = l
break

# Store no. frequencies and no. times
self.Nfreqs = self.dsets[0].Nfreqs
Expand Down Expand Up @@ -2611,9 +2620,6 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None,

if dset_labels is None:
dset_labels = ["dset{}".format(i) for i in range(Ndsets)]
else:
# enforce unique dset labels
assert len(set(dset_labels)) == len(dset_labels), "Found repeated dest labels: each one must be unique"

# load data if fed as filepaths
if isinstance(dsets[0], (str, np.str)):
Expand Down
13 changes: 12 additions & 1 deletion hera_pspec/tests/test_pspecdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ def test_init(self):
key = (0, (24, 25), 'xx')
nt.assert_true(np.all(np.isclose(ds.x(key), ds.w(key))))

# Test labels when adding dsets
uvd = self.uvd
ds = pspecdata.PSpecData()
nt.assert_equal(len(ds.labels), 0)
ds.add([uvd, uvd], [None, None])
nt.assert_equal(len(ds.labels), 2)
ds.add(uvd, None, labels='foo')
nt.assert_equal(len(ds.dsets), len(ds.labels), 3)
nt.assert_equal(ds.labels, ['dset0', 'dset1', 'foo'])
ds.add(uvd, None)
nt.assert_equal(ds.labels, ['dset0', 'dset1', 'foo', 'dset3'])

# Test some exceptions
ds = pspecdata.PSpecData()
nt.assert_raises(ValueError, ds.get_G, key, key)
Expand Down Expand Up @@ -1312,7 +1324,6 @@ def test_pspec_run():
nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", blpairs=(1, 2), verbose=False)
nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", blpairs=[1, 2], verbose=False)
nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", beam=1, verbose=False)
nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", dset_labels=['hi', 'hi'], verbose=False)

if os.path.exists("./out.hdf5"):
os.remove("./out.hdf5")
Expand Down
2 changes: 0 additions & 2 deletions hera_pspec/uvpspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,14 +1535,12 @@ def combine_uvpspec(uvps, verbose=True):
u.freq_array.extend(spw_freqs)
u.dly_array.extend(spw_dlys)


# Convert to numpy arrays
u.spw_array = np.array(u.spw_array)
u.freq_array = np.array(u.freq_array)
u.dly_array = np.array(u.dly_array)
u.pol_array = np.array(new_pols)


# Number of spectral windows, delays etc.
u.Nspws = Nspws
u.Nblpairts = Nblpairts
Expand Down

0 comments on commit 4751b92

Please sign in to comment.