Skip to content

Commit

Permalink
Flatten higher dimensional arrays and remove asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
bhazelton committed Jul 5, 2018
1 parent 9397424 commit adbd138
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 15 deletions.
102 changes: 94 additions & 8 deletions pyuvdata/tests/test_uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,20 @@ def test_select_blts():

nt.assert_true(np.all(selected_data == uv_object2.data_array))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(blt_inds=blt_inds[np.newaxis, :])
nt.assert_equal(len(blt_inds), uv_object2.Nblts)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific baseline-times using pyuvdata.',
uv_object2.history))
nt.assert_true(np.all(selected_data == uv_object2.data_array))

# check for errors associated with out of bounds indices
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(-10, -5))
nt.assert_raises(ValueError, uv_object.select, blt_inds=np.arange(
uv_object.Nblts + 1, uv_object.Nblts + 10))
nt.assert_raises(AssertionError, uv_object.select,
blt_inds=np.reshape(np.arange(-10, -5), (1, 5)))


def test_select_antennas():
Expand All @@ -430,6 +438,22 @@ def test_select_antennas():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(antenna_nums=ants_to_keep)

nt.assert_equal(len(ants_to_keep), uv_object2.Nants_data)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for ant in ants_to_keep:
nt.assert_true(
ant in uv_object2.ant_1_array or ant in uv_object2.ant_2_array)
for ant in np.unique(uv_object2.ant_1_array.tolist() + uv_object2.ant_2_array.tolist()):
nt.assert_true(ant in ants_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific antennas using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(antenna_nums=ants_to_keep[np.newaxis, :])

nt.assert_equal(len(ants_to_keep), uv_object2.Nants_data)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for ant in ants_to_keep:
Expand All @@ -454,13 +478,24 @@ def test_select_antennas():

nt.assert_equal(uv_object2, uv_object3)

# check that it also works with higher dimension array
uv_object3 = copy.deepcopy(uv_object)
ants_to_keep = np.array(sorted(list(ants_to_keep)))
ant_names = []
for a in ants_to_keep:
ind = np.where(uv_object3.antenna_numbers == a)[0][0]
ant_names.append(uv_object3.antenna_names[ind])

uv_object3.select(antenna_names=[ant_names])

nt.assert_equal(uv_object2, uv_object3)

# check for errors associated with antennas not included in data, bad names or providing numbers and names
nt.assert_raises(ValueError, uv_object.select,
antenna_nums=np.max(unique_ants) + np.arange(1, 3))
nt.assert_raises(ValueError, uv_object.select, antenna_names='test1')
nt.assert_raises(ValueError, uv_object.select,
antenna_nums=ants_to_keep, antenna_names=ant_names)
nt.assert_raises(AssertionError, uv_object.select, antenna_names=[['test1']])


def sort_bl(p):
Expand Down Expand Up @@ -603,6 +638,20 @@ def test_select_times():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(times=times_to_keep)

nt.assert_equal(len(times_to_keep), uv_object2.Ntimes)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for t in times_to_keep:
nt.assert_true(t in uv_object2.time_array)
for t in np.unique(uv_object2.time_array):
nt.assert_true(t in times_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific times using pyuvdata.',
uv_object2.history))
# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(times=times_to_keep[np.newaxis, :])

nt.assert_equal(len(times_to_keep), uv_object2.Ntimes)
nt.assert_equal(Nblts_selected, uv_object2.Nblts)
for t in times_to_keep:
Expand All @@ -617,7 +666,6 @@ def test_select_times():
# check for errors associated with times not included in data
nt.assert_raises(ValueError, uv_object.select, times=[
np.min(unique_times) - uv_object.integration_time])
nt.assert_raises(AssertionError, uv_object.select, times=times_to_keep[np.newaxis, :])


def test_select_frequencies():
Expand All @@ -632,6 +680,20 @@ def test_select_frequencies():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(frequencies=freqs_to_keep)

nt.assert_equal(len(freqs_to_keep), uv_object2.Nfreqs)
for f in freqs_to_keep:
nt.assert_true(f in uv_object2.freq_array)
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in freqs_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific frequencies using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(frequencies=freqs_to_keep[np.newaxis, :])

nt.assert_equal(len(freqs_to_keep), uv_object2.Nfreqs)
for f in freqs_to_keep:
nt.assert_true(f in uv_object2.freq_array)
Expand Down Expand Up @@ -672,7 +734,6 @@ def test_select_frequencies():
message='Selected frequencies are not contiguous')
nt.assert_raises(ValueError, uv_object2.write_uvfits, write_file_uvfits)
nt.assert_raises(ValueError, uv_object2.write_miriad, write_file_miriad)
nt.assert_raises(AssertionError, uv_object.select, times=freqs_to_keep[np.newaxis, :])


def test_select_freq_chans():
Expand All @@ -687,6 +748,20 @@ def test_select_freq_chans():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(freq_chans=chans_to_keep)

nt.assert_equal(len(chans_to_keep), uv_object2.Nfreqs)
for chan in chans_to_keep:
nt.assert_true(uv_object.freq_array[0, chan] in uv_object2.freq_array)
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in uv_object.freq_array[0, chans_to_keep])

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific frequencies using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(freq_chans=chans_to_keep[np.newaxis, :])

nt.assert_equal(len(chans_to_keep), uv_object2.Nfreqs)
for chan in chans_to_keep:
nt.assert_true(uv_object.freq_array[0, chan] in uv_object2.freq_array)
Expand All @@ -711,8 +786,6 @@ def test_select_freq_chans():
for f in np.unique(uv_object2.freq_array):
nt.assert_true(f in uv_object.freq_array[0, all_chans_to_keep])

nt.assert_raises(AssertionError, uv_object.select, freq_chans=chans_to_keep[np.newaxis, :])


def test_select_polarizations():
uv_object = UVData()
Expand All @@ -726,6 +799,20 @@ def test_select_polarizations():
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(polarizations=pols_to_keep)

nt.assert_equal(len(pols_to_keep), uv_object2.Npols)
for p in pols_to_keep:
nt.assert_true(p in uv_object2.polarization_array)
for p in np.unique(uv_object2.polarization_array):
nt.assert_true(p in pols_to_keep)

nt.assert_true(uvutils.check_histories(old_history + ' Downselected to '
'specific polarizations using pyuvdata.',
uv_object2.history))

# check that it also works with higher dimension array
uv_object2 = copy.deepcopy(uv_object)
uv_object2.select(polarizations=[pols_to_keep])

nt.assert_equal(len(pols_to_keep), uv_object2.Npols)
for p in pols_to_keep:
nt.assert_true(p in uv_object2.polarization_array)
Expand All @@ -744,7 +831,6 @@ def test_select_polarizations():
message='Selected polarization values are not evenly spaced')
write_file_uvfits = os.path.join(DATA_PATH, 'test/select_test.uvfits')
nt.assert_raises(ValueError, uv_object.write_uvfits, write_file_uvfits)
nt.assert_raises(AssertionError, uv_object.select, polarizations=[pols_to_keep])


def test_select():
Expand Down
22 changes: 15 additions & 7 deletions pyuvdata/uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
# test for blt_inds presence before adding inds from antennas & times
if blt_inds is not None:
blt_inds = uvutils.get_iterable(blt_inds)
assert(np.array(blt_inds).ndim == 1)
if np.array(blt_inds).ndim > 1:
blt_inds = np.array(blt_inds).flatten()
history_update_string += 'baseline-times'
n_selects += 1

Expand All @@ -1028,7 +1029,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if not isinstance(antenna_names, (list, tuple, np.ndarray)):
antenna_names = (antenna_names,)
assert(np.array(antenna_names).ndim == 1)
if np.array(antenna_names).ndim > 1:
antenna_names = np.array(antenna_names).flatten()
antenna_nums = []
for s in antenna_names:
if s not in self.antenna_names:
Expand All @@ -1039,7 +1041,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if antenna_nums is not None:
antenna_nums = uvutils.get_iterable(antenna_nums)
assert(np.array(antenna_nums).ndim == 1)
if np.array(antenna_nums).ndim > 1:
antenna_nums = np.array(antenna_nums).flatten()
if n_selects > 0:
history_update_string += ', antennas'
else:
Expand Down Expand Up @@ -1132,7 +1135,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if times is not None:
times = uvutils.get_iterable(times)
assert(np.array(times).ndim == 1)
if np.array(times).ndim > 1:
times = np.array(times).flatten()
if n_selects > 0:
history_update_string += ', times'
else:
Expand Down Expand Up @@ -1164,13 +1168,15 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,
raise ValueError(
'blt_inds contains indices that are too large')
if min(blt_inds) < 0:
print(blt_inds)
raise ValueError('blt_inds contains indices that are negative')

blt_inds = list(sorted(set(list(blt_inds))))

if freq_chans is not None:
freq_chans = uvutils.get_iterable(freq_chans)
assert(np.array(freq_chans).ndim == 1)
if np.array(freq_chans).ndim > 1:
freq_chans = np.array(freq_chans).flatten()
if frequencies is None:
frequencies = self.freq_array[0, freq_chans]
else:
Expand All @@ -1180,7 +1186,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if frequencies is not None:
frequencies = uvutils.get_iterable(frequencies)
assert(np.array(frequencies).ndim == 1)
if np.array(frequencies).ndim > 1:
frequencies = np.array(frequencies).flatten()
if n_selects > 0:
history_update_string += ', frequencies'
else:
Expand Down Expand Up @@ -1215,7 +1222,8 @@ def _select_preprocess(self, antenna_nums, antenna_names, ant_str, bls,

if polarizations is not None:
polarizations = uvutils.get_iterable(polarizations)
assert(np.array(polarizations).ndim == 1)
if np.array(polarizations).ndim > 1:
polarizations = np.array(polarizations).flatten()
if n_selects > 0:
history_update_string += ', polarizations'
else:
Expand Down

0 comments on commit adbd138

Please sign in to comment.