Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in select for single string antenna_name, flatten higher-dimensional arrays in select #390

Merged
merged 3 commits into from
Jul 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
94 changes: 94 additions & 0 deletions pyuvdata/tests/test_uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@ def test_select_blts():

nt.assert_true(np.all(selected_data == uv_object2.data_array))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(blt_inds=blt_inds[np.newaxis, :])
nt.assert_equal(len(blt_inds), uv_object2.Nblts)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific baseline-times using pyuvdata.',
uv_object2.history))
nt.assert_true(np.all(selected_data == uv_object2.data_array))

# check for errors associated with out of bounds indices
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(-10, -5))
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(
Expand All @@ -428,6 +438,22 @@ def test_select_antennas():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(antenna_nums=ants_to_keep)

nt.assert_equal(len(ants_to_keep), uv_object2.Nants_data)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for ant in ants_to_keep:
nt.assert_true(
ant in uv_object2.ant_1_array or ant in uv_object2.ant_2_array)
for ant in np.unique(uv_object2.ant_1_array.tolist() + uv_object2.ant_2_array.tolist()):
nt.assert_true(ant in ants_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific antennas using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(antenna_nums=ants_to_keep[np.newaxis, :])

nt.assert_equal(len(ants_to_keep), uv_object2.Nants_data)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for ant in ants_to_keep:
Expand All @@ -452,6 +478,18 @@ def test_select_antennas():

nt.assert_equal(uv_object2, uv_object3)

# check that it also works with higher dimension array
uv_object3 = copy.deepcopy(uv_object)
ants_to_keep = np.array(sorted(list(ants_to_keep)))
ant_names = []
for a in ants_to_keep:
ind = np.where(uv_object3.antenna_numbers == a)[0][0]
ant_names.append(uv_object3.antenna_names[ind])

uv_object3.select(antenna_names=[ant_names])

nt.assert_equal(uv_object2, uv_object3)

# check for errors associated with antennas not included in data, bad names or providing numbers and names
nt.assert_raises(ValueError, uv_object.select,
antenna_nums=np.max(unique_ants) + np.arange(1, 3))
Expand Down Expand Up @@ -600,6 +638,20 @@ def test_select_times():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(times=times_to_keep)

nt.assert_equal(len(times_to_keep), uv_object2.Ntimes)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for t in times_to_keep:
nt.assert_true(t in uv_object2.time_array)
for t in np.unique(uv_object2.time_array):
nt.assert_true(t in times_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific times using pyuvdata.',
uv_object2.history))
# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(times=times_to_keep[np.newaxis, :])

nt.assert_equal(len(times_to_keep), uv_object2.Ntimes)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for t in times_to_keep:
Expand Down Expand Up @@ -628,6 +680,20 @@ def test_select_frequencies():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(frequencies=freqs_to_keep)

nt.assert_equal(len(freqs_to_keep), uv_object2.Nfreqs)
for f in freqs_to_keep:
nt.assert_true(f in uv_object2.freq_array)
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in freqs_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific frequencies using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(frequencies=freqs_to_keep[np.newaxis, :])

nt.assert_equal(len(freqs_to_keep), uv_object2.Nfreqs)
for f in freqs_to_keep:
nt.assert_true(f in uv_object2.freq_array)
Expand Down Expand Up @@ -682,6 +748,20 @@ def test_select_freq_chans():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(freq_chans=chans_to_keep)

nt.assert_equal(len(chans_to_keep), uv_object2.Nfreqs)
for chan in chans_to_keep:
nt.assert_true(uv_object.freq_array[0, chan] in uv_object2.freq_array)
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in uv_object.freq_array[0, chans_to_keep])

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific frequencies using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(freq_chans=chans_to_keep[np.newaxis, :])

nt.assert_equal(len(chans_to_keep), uv_object2.Nfreqs)
for chan in chans_to_keep:
nt.assert_true(uv_object.freq_array[0, chan] in uv_object2.freq_array)
Expand Down Expand Up @@ -719,6 +799,20 @@ def test_select_polarizations():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(polarizations=pols_to_keep)

nt.assert_equal(len(pols_to_keep), uv_object2.Npols)
for p in pols_to_keep:
nt.assert_true(p in uv_object2.polarization_array)
for p in np.unique(uv_object2.polarization_array):
nt.assert_true(p in pols_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific polarizations using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(polarizations=[pols_to_keep])

nt.assert_equal(len(pols_to_keep), uv_object2.Npols)
for p in pols_to_keep:
nt.assert_true(p in uv_object2.polarization_array)
Expand Down
20 changes: 18 additions & 2 deletions pyuvdata/uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
# test for blt_inds presence before adding inds from antennas & times
if blt_inds is not None:
blt_inds = uvutils.get_iterable(blt_inds)
if np.array(blt_inds).ndim > 1:
blt_inds = np.array(blt_inds).flatten()
history_update_string += 'baseline-times'
n_selects += 1

Expand All @@ -1025,17 +1027,22 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
raise ValueError(
'Only one of antenna_nums and antenna_names can be provided.')

antenna_names = uvutils.get_iterable(antenna_names)
if not isinstance(antenna_names, (list, tuple, np.ndarray)):
antenna_names = (antenna_names,)
if np.array(antenna_names).ndim > 1:
antenna_names = np.array(antenna_names).flatten()
antenna_nums = []
for s in antenna_names:
if s not in self.antenna_names:
raise ValueError(
'Antenna name {a} is not present in the antenna_names array'.format(a=s))
antenna_nums.append(self.antenna_numbers[np.where(
np.array(self.antenna_names) == s)[0]])
np.array(self.antenna_names) == s)][0])

if antenna_nums is not None:
antenna_nums = uvutils.get_iterable(antenna_nums)
if np.array(antenna_nums).ndim > 1:
antenna_nums = np.array(antenna_nums).flatten()
if n_selects > 0:
history_update_string += ', antennas'
else:
Expand Down Expand Up @@ -1128,6 +1135,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if times is not None:
times = uvutils.get_iterable(times)
if np.array(times).ndim > 1:
times = np.array(times).flatten()
if n_selects > 0:
history_update_string += ', times'
else:
Expand Down Expand Up @@ -1159,12 +1168,15 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
raise ValueError(
'blt_inds contains indices that are too large')
if min(blt_inds) < 0:
print(blt_inds)
raise ValueError('blt_inds contains indices that are negative')

blt_inds = list(sorted(set(list(blt_inds))))

if freq_chans is not None:
freq_chans = uvutils.get_iterable(freq_chans)
if np.array(freq_chans).ndim > 1:
freq_chans = np.array(freq_chans).flatten()
if frequencies is None:
frequencies = self.freq_array[0, freq_chans]
else:
Expand All @@ -1174,6 +1186,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if frequencies is not None:
frequencies = uvutils.get_iterable(frequencies)
if np.array(frequencies).ndim > 1:
frequencies = np.array(frequencies).flatten()
if n_selects > 0:
history_update_string += ', frequencies'
else:
Expand Down Expand Up @@ -1208,6 +1222,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if polarizations is not None:
polarizations = uvutils.get_iterable(polarizations)
if np.array(polarizations).ndim > 1:
polarizations = np.array(polarizations).flatten()
if n_selects > 0:
history_update_string += ', polarizations'
else:
Expand Down