diff --git a/hera_pspec/pspecdata.py b/hera_pspec/pspecdata.py index afb50769..bb28fbd5 100644 --- a/hera_pspec/pspecdata.py +++ b/hera_pspec/pspecdata.py @@ -66,7 +66,7 @@ def __init__(self, dsets=[], wgts=[], dsets_std=None, labels=None, beam=None): # Store the input UVData objects if specified if len(dsets) > 0: - self.add(dsets, wgts,dsets_std=dsets_std, labels=labels) + self.add(dsets, wgts, dsets_std=dsets_std, labels=labels) # Store a primary beam self.primary_beam = beam @@ -96,7 +96,6 @@ def add(self, dsets, wgts, labels=None, dsets_std=None): standard deviations (real and imaginary) of data to add to the collection. If dsets is a dict, will assume dsets_std is a dict and if dsets is a list, will assume dsets_std is a list. - """ # Check for dicts and unpack into an ordered list if found if isinstance(dsets, dict): @@ -119,7 +118,6 @@ def add(self, dsets, wgts, labels=None, dsets_std=None): _dsets_std = [dsets_std[key] for key in labels] dsets_std = _dsets_std - # Unpack dsets and wgts dicts labels = dsets.keys() _dsets = [dsets[key] for key in labels] @@ -144,7 +142,6 @@ def add(self, dsets, wgts, labels=None, dsets_std=None): raise TypeError("dsets, dsets_std, and wgts must be UVData" "or lists of UVData") - # Make sure enough weights were specified assert(len(dsets) == len(wgts)) assert(len(dsets_std) == len(dsets)) @@ -161,16 +158,28 @@ def add(self, dsets, wgts, labels=None, dsets_std=None): raise TypeError("Only UVData objects (or None) can be used as " "error sets") + # Store labels (if they were set) + if self.labels is None: + self.labels = [] + if labels is None: + labels = ["dset{:d}".format(i) for i in range(len(self.dsets), len(dsets)+len(self.dsets))] + self.labels += labels + # Append to list self.dsets += dsets self.wgts += wgts self.dsets_std += dsets_std - # Store labels (if they were set) - if labels is None: - self.labels = [None for d in dsets] - else: - self.labels += labels + # Check for repeated labels, and make them unique + for i, l in enumerate(self.labels): + ext = 1 + while True: + if l in self.labels[:i]: + l = self.labels[i] + "_{:d}".format(ext) + ext += 1 + else: + self.labels[i] = l + break # Store no. frequencies and no. times self.Nfreqs = self.dsets[0].Nfreqs @@ -2611,9 +2620,6 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, if dset_labels is None: dset_labels = ["dset{}".format(i) for i in range(Ndsets)] - else: - # enforce unique dset labels - assert len(set(dset_labels)) == len(dset_labels), "Found repeated dest labels: each one must be unique" # load data if fed as filepaths if isinstance(dsets[0], (str, np.str)): diff --git a/hera_pspec/tests/test_pspecdata.py b/hera_pspec/tests/test_pspecdata.py index 67427a98..13d047b2 100644 --- a/hera_pspec/tests/test_pspecdata.py +++ b/hera_pspec/tests/test_pspecdata.py @@ -159,6 +159,18 @@ def test_init(self): key = (0, (24, 25), 'xx') nt.assert_true(np.all(np.isclose(ds.x(key), ds.w(key)))) + # Test labels when adding dsets + uvd = self.uvd + ds = pspecdata.PSpecData() + nt.assert_equal(len(ds.labels), 0) + ds.add([uvd, uvd], [None, None]) + nt.assert_equal(len(ds.labels), 2) + ds.add(uvd, None, labels='foo') + nt.assert_equal(len(ds.dsets), len(ds.labels), 3) + nt.assert_equal(ds.labels, ['dset0', 'dset1', 'foo']) + ds.add(uvd, None) + nt.assert_equal(ds.labels, ['dset0', 'dset1', 'foo', 'dset3']) + # Test some exceptions ds = pspecdata.PSpecData() nt.assert_raises(ValueError, ds.get_G, key, key) @@ -1312,7 +1324,6 @@ def test_pspec_run(): nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", blpairs=(1, 2), verbose=False) nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", blpairs=[1, 2], verbose=False) nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", beam=1, verbose=False) - nt.assert_raises(AssertionError, pspecdata.pspec_run, fnames, "./out.hdf5", dset_labels=['hi', 'hi'], verbose=False) if os.path.exists("./out.hdf5"): os.remove("./out.hdf5") diff --git a/hera_pspec/uvpspec.py b/hera_pspec/uvpspec.py index f1c77b61..c6b10ccd 100644 --- a/hera_pspec/uvpspec.py +++ b/hera_pspec/uvpspec.py @@ -1535,14 +1535,12 @@ def combine_uvpspec(uvps, verbose=True): u.freq_array.extend(spw_freqs) u.dly_array.extend(spw_dlys) - # Convert to numpy arrays u.spw_array = np.array(u.spw_array) u.freq_array = np.array(u.freq_array) u.dly_array = np.array(u.dly_array) u.pol_array = np.array(new_pols) - # Number of spectral windows, delays etc. u.Nspws = Nspws u.Nblpairts = Nblpairts