Skip to content

Commit

Permalink
Merge adbd138 into 94f1e81
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Jul 5, 2018
2 parents 94f1e81 + adbd138 commit 1db49e6
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
94 changes: 94 additions & 0 deletions pyuvdata/tests/test_uvdata.py
Expand Up @@ -404,6 +404,16 @@ def test_select_blts():

nt.assert_true(np.all(selected_data == uv_object2.data_array))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(blt_inds=blt_inds[np.newaxis, :])
nt.assert_equal(len(blt_inds), uv_object2.Nblts)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific baseline-times using pyuvdata.',
uv_object2.history))
nt.assert_true(np.all(selected_data == uv_object2.data_array))

# check for errors associated with out of bounds indices
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(-10, -5))
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(
Expand All @@ -428,6 +438,22 @@ def test_select_antennas():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(antenna_nums=ants_to_keep)

nt.assert_equal(len(ants_to_keep), uv_object2.Nants_data)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for ant in ants_to_keep:
nt.assert_true(
ant in uv_object2.ant_1_array or ant in uv_object2.ant_2_array)
for ant in np.unique(uv_object2.ant_1_array.tolist() + uv_object2.ant_2_array.tolist()):
nt.assert_true(ant in ants_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific antennas using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(antenna_nums=ants_to_keep[np.newaxis, :])

nt.assert_equal(len(ants_to_keep), uv_object2.Nants_data)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for ant in ants_to_keep:
Expand All @@ -452,6 +478,18 @@ def test_select_antennas():

nt.assert_equal(uv_object2, uv_object3)

# check that it also works with higher dimension array
uv_object3 = copy.deepcopy(uv_object)
ants_to_keep = np.array(sorted(list(ants_to_keep)))
ant_names = []
for a in ants_to_keep:
ind = np.where(uv_object3.antenna_numbers == a)[0][0]
ant_names.append(uv_object3.antenna_names[ind])

uv_object3.select(antenna_names=[ant_names])

nt.assert_equal(uv_object2, uv_object3)

# check for errors associated with antennas not included in data, bad names or providing numbers and names
nt.assert_raises(ValueError, uv_object.select,
antenna_nums=np.max(unique_ants) + np.arange(1, 3))
Expand Down Expand Up @@ -600,6 +638,20 @@ def test_select_times():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(times=times_to_keep)

nt.assert_equal(len(times_to_keep), uv_object2.Ntimes)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for t in times_to_keep:
nt.assert_true(t in uv_object2.time_array)
for t in np.unique(uv_object2.time_array):
nt.assert_true(t in times_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific times using pyuvdata.',
uv_object2.history))
# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(times=times_to_keep[np.newaxis, :])

nt.assert_equal(len(times_to_keep), uv_object2.Ntimes)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for t in times_to_keep:
Expand Down Expand Up @@ -628,6 +680,20 @@ def test_select_frequencies():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(frequencies=freqs_to_keep)

nt.assert_equal(len(freqs_to_keep), uv_object2.Nfreqs)
for f in freqs_to_keep:
nt.assert_true(f in uv_object2.freq_array)
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in freqs_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific frequencies using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(frequencies=freqs_to_keep[np.newaxis, :])

nt.assert_equal(len(freqs_to_keep), uv_object2.Nfreqs)
for f in freqs_to_keep:
nt.assert_true(f in uv_object2.freq_array)
Expand Down Expand Up @@ -682,6 +748,20 @@ def test_select_freq_chans():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(freq_chans=chans_to_keep)

nt.assert_equal(len(chans_to_keep), uv_object2.Nfreqs)
for chan in chans_to_keep:
nt.assert_true(uv_object.freq_array[0, chan] in uv_object2.freq_array)
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in uv_object.freq_array[0, chans_to_keep])

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific frequencies using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(freq_chans=chans_to_keep[np.newaxis, :])

nt.assert_equal(len(chans_to_keep), uv_object2.Nfreqs)
for chan in chans_to_keep:
nt.assert_true(uv_object.freq_array[0, chan] in uv_object2.freq_array)
Expand Down Expand Up @@ -719,6 +799,20 @@ def test_select_polarizations():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(polarizations=pols_to_keep)

nt.assert_equal(len(pols_to_keep), uv_object2.Npols)
for p in pols_to_keep:
nt.assert_true(p in uv_object2.polarization_array)
for p in np.unique(uv_object2.polarization_array):
nt.assert_true(p in pols_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific polarizations using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(polarizations=[pols_to_keep])

nt.assert_equal(len(pols_to_keep), uv_object2.Npols)
for p in pols_to_keep:
nt.assert_true(p in uv_object2.polarization_array)
Expand Down
20 changes: 18 additions & 2 deletions pyuvdata/uvdata.py
Expand Up @@ -1017,6 +1017,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
# test for blt_inds presence before adding inds from antennas & times
if blt_inds is not None:
blt_inds = uvutils.get_iterable(blt_inds)
if np.array(blt_inds).ndim > 1:
blt_inds = np.array(blt_inds).flatten()
history_update_string += 'baseline-times'
n_selects += 1

Expand All @@ -1025,17 +1027,22 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
raise ValueError(
'Only one of antenna_nums and antenna_names can be provided.')

antenna_names = uvutils.get_iterable(antenna_names)
if not isinstance(antenna_names, (list, tuple, np.ndarray)):
antenna_names = (antenna_names,)
if np.array(antenna_names).ndim > 1:
antenna_names = np.array(antenna_names).flatten()
antenna_nums = []
for s in antenna_names:
if s not in self.antenna_names:
raise ValueError(
'Antenna name {a} is not present in the antenna_names array'.format(a=s))
antenna_nums.append(self.antenna_numbers[np.where(
np.array(self.antenna_names) == s)[0]])
np.array(self.antenna_names) == s)][0])

if antenna_nums is not None:
antenna_nums = uvutils.get_iterable(antenna_nums)
if np.array(antenna_nums).ndim > 1:
antenna_nums = np.array(antenna_nums).flatten()
if n_selects > 0:
history_update_string += ', antennas'
else:
Expand Down Expand Up @@ -1128,6 +1135,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if times is not None:
times = uvutils.get_iterable(times)
if np.array(times).ndim > 1:
times = np.array(times).flatten()
if n_selects > 0:
history_update_string += ', times'
else:
Expand Down Expand Up @@ -1159,12 +1168,15 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
raise ValueError(
'blt_inds contains indices that are too large')
if min(blt_inds) < 0:
print(blt_inds)
raise ValueError('blt_inds contains indices that are negative')

blt_inds = list(sorted(set(list(blt_inds))))

if freq_chans is not None:
freq_chans = uvutils.get_iterable(freq_chans)
if np.array(freq_chans).ndim > 1:
freq_chans = np.array(freq_chans).flatten()
if frequencies is None:
frequencies = self.freq_array[0, freq_chans]
else:
Expand All @@ -1174,6 +1186,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if frequencies is not None:
frequencies = uvutils.get_iterable(frequencies)
if np.array(frequencies).ndim > 1:
frequencies = np.array(frequencies).flatten()
if n_selects > 0:
history_update_string += ', frequencies'
else:
Expand Down Expand Up @@ -1208,6 +1222,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if polarizations is not None:
polarizations = uvutils.get_iterable(polarizations)
if np.array(polarizations).ndim > 1:
polarizations = np.array(polarizations).flatten()
if n_selects > 0:
history_update_string += ', polarizations'
else:
Expand Down

0 comments on commit 1db49e6

Please sign in to comment.