Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in select for single string antenna_name, flatten higher-dimensional arrays in select #390

Merged
merged 3 commits into from
Jul 6, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyuvdata/tests/test_uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,8 @@ def test_select_blts():
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(-10, -5))
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(
uv_object.Nblts + 1, uv_object.Nblts + 10))
nt.assert_raises(AssertionError, uv_object.select,
blt_inds=np.reshape(np.arange(-10, -5), (1, 5)))


def test_select_antennas():
Expand Down Expand Up @@ -458,6 +460,7 @@ def test_select_antennas():
nt.assert_raises(ValueError, uv_object.select, antenna_names='test1')
nt.assert_raises(ValueError, uv_object.select,
antenna_nums=ants_to_keep, antenna_names=ant_names)
nt.assert_raises(AssertionError, uv_object.select, antenna_names=[['test1']])


def sort_bl(p):
Expand Down Expand Up @@ -614,6 +617,7 @@ def test_select_times():
# check for errors associated with times not included in data
nt.assert_raises(ValueError, uv_object.select, times=[
np.min(unique_times) - uv_object.integration_time])
nt.assert_raises(AssertionError, uv_object.select, times=times_to_keep[np.newaxis, :])


def test_select_frequencies():
Expand Down Expand Up @@ -668,6 +672,7 @@ def test_select_frequencies():
message='Selected frequencies are not contiguous')
nt.assert_raises(ValueError, uv_object2.write_uvfits, write_file_uvfits)
nt.assert_raises(ValueError, uv_object2.write_miriad, write_file_miriad)
nt.assert_raises(AssertionError, uv_object.select, times=freqs_to_keep[np.newaxis, :])


def test_select_freq_chans():
Expand Down Expand Up @@ -706,6 +711,8 @@ def test_select_freq_chans():
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in uv_object.freq_array[0, all_chans_to_keep])

nt.assert_raises(AssertionError, uv_object.select, freq_chans=chans_to_keep[np.newaxis, :])


def test_select_polarizations():
uv_object = UVData()
Expand Down Expand Up @@ -737,6 +744,7 @@ def test_select_polarizations():
message='Selected polarization values are not evenly spaced')
write_file_uvfits = os.path.join(DATA_PATH, 'test/select_test.uvfits')
nt.assert_raises(ValueError, uv_object.write_uvfits, write_file_uvfits)
nt.assert_raises(AssertionError, uv_object.select, polarizations=[pols_to_keep])


def test_select():
Expand Down
12 changes: 10 additions & 2 deletions pyuvdata/uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
# test for blt_inds presence before adding inds from antennas & times
if blt_inds is not None:
blt_inds = uvutils.get_iterable(blt_inds)
assert(np.array(blt_inds).ndim == 1)
history_update_string += 'baseline-times'
n_selects += 1

Expand All @@ -1025,17 +1026,20 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
raise ValueError(
'Only one of antenna_nums and antenna_names can be provided.')

antenna_names = uvutils.get_iterable(antenna_names)
if not isinstance(antenna_names, (list, tuple, np.ndarray)):
antenna_names = (antenna_names,)
assert(np.array(antenna_names).ndim == 1)
antenna_nums = []
for s in antenna_names:
if s not in self.antenna_names:
raise ValueError(
'Antenna name {a} is not present in the antenna_names array'.format(a=s))
antenna_nums.append(self.antenna_numbers[np.where(
np.array(self.antenna_names) == s)[0]])
np.array(self.antenna_names) == s)][0])

if antenna_nums is not None:
antenna_nums = uvutils.get_iterable(antenna_nums)
assert(np.array(antenna_nums).ndim == 1)
if n_selects > 0:
history_update_string += ', antennas'
else:
Expand Down Expand Up @@ -1128,6 +1132,7 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if times is not None:
times = uvutils.get_iterable(times)
assert(np.array(times).ndim == 1)
if n_selects > 0:
history_update_string += ', times'
else:
Expand Down Expand Up @@ -1165,6 +1170,7 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if freq_chans is not None:
freq_chans = uvutils.get_iterable(freq_chans)
assert(np.array(freq_chans).ndim == 1)
if frequencies is None:
frequencies = self.freq_array[0, freq_chans]
else:
Expand All @@ -1174,6 +1180,7 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if frequencies is not None:
frequencies = uvutils.get_iterable(frequencies)
assert(np.array(frequencies).ndim == 1)
if n_selects > 0:
history_update_string += ', frequencies'
else:
Expand Down Expand Up @@ -1208,6 +1215,7 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if polarizations is not None:
polarizations = uvutils.get_iterable(polarizations)
assert(np.array(polarizations).ndim == 1)
if n_selects > 0:
history_update_string += ', polarizations'
else:
Expand Down