From 02db5c8861a361b71e0b8de1a54d9b47a18b14f2 Mon Sep 17 00:00:00 2001 From: Shaunak Modak Date: Fri, 13 Jul 2018 16:29:52 -0600 Subject: [PATCH 1/6] added function to do greedy flagging --- hera_pspec/__init__.py | 2 +- hera_pspec/flags.py | 83 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 hera_pspec/flags.py diff --git a/hera_pspec/__init__.py b/hera_pspec/__init__.py index aef41fa7..c2ef7883 100644 --- a/hera_pspec/__init__.py +++ b/hera_pspec/__init__.py @@ -1,7 +1,7 @@ """ __init__.py file for hera_pspec """ -from hera_pspec import version, conversions, grouping, pspecbeam, plot, pstokes, testing +from hera_pspec import version, conversions, grouping, pspecbeam, plot, pstokes, testing, flags from hera_pspec import uvpspec_utils as uvputils from hera_pspec.uvpspec import UVPSpec diff --git a/hera_pspec/flags.py b/hera_pspec/flags.py new file mode 100644 index 00000000..a3c1e29e --- /dev/null +++ b/hera_pspec/flags.py @@ -0,0 +1,83 @@ +import numpy as np + +mask_generator(nsamples, flags, n_threshold, greedy=False, axis, greedy_threshold, retain_flags=True): + """ + Generates a greedy flags mask from input flags and nsamples arrays + + Parameters + ---------- + nsamples : numpy.ndarray + integer array with number of samples available for each frequency channel at a given LST angle + + flags : numpy.ndarray + binary array with 1 representing flagged, 0 representing unflagged + + n_threshold : int + minimum number of samples needed for a point to remain unflagged + + greedy : bool + greedy flagging is used if true (default is False) + + axis : int + which axis to flag first if greedy=True (1 is row-first, 0 is col-first) + + greedy_threshold : float + if greedy=True, the threshold used to flag rows or columns if axis=1 or 0, respectively + + retain_flags : bool + LST-Bin Flags are left flagged even if thresholds are not met (default is True) + + Returns + ------- + mask : numpy.ndarray + binary array of the new mask where 1 is flagged, 0 is unflagged + +""" + + shape = nsamples.shape + flags_output = np.zeros(shape) + + num_exactly_equal = 0 + + # comparing the number of samples to the threshold + + for i in range(shape[0]): + for j in range(shape[1]): + if nsamples[i, j] < n_threshold: + flags_output[i, j] = 1 + elif nsamples[i, j] > n_threshold: + if retain_flags and flags[i, j] == 1: + flags_output[i, j] = 1 + else: + flags_output[i, j] = 0 + elif nsamples[i, j] == n_threshold: + if retain_flags and flags[i, j] == 1: + flags_output[i, j] = 1 + else: + flags_output[i, j] = 0 + num_exactly_equal += 1 + + # the greedy part + + if axis == 0: + if greedy: + column_flags_counter = 0 + for j in range(shape[1]): + if np.sum(flags_output[:, j])/shape[0] > greedy_threshold: + flags_output[:, j] = np.ones([shape[0]]) + column_flags_counter += 1 + for i in range(shape[0]): + if np.sum(flags_output[i, :]) > column_flags_counter: + flags_output[i, :] = np.ones([shape[1]]) + elif axis == 1: + if greedy: + row_flags_counter = 0 + for i in range(shape[0]): + if np.sum(flags_output[i, :])/shape[1] > greedy_threshold: + flags_output[i, :] = np.ones([shape[1]]) + row_flags_counter += 1 + for j in range(shape[1]): + if np.sum(flags_output[:, j]) > row_flags_counter: + flags_output[:, j] = np.ones([shape[0]]) + + return flags_output From dcd400bf942cb1f7de1f6a9af5d5e0e8007c2a17 Mon Sep 17 00:00:00 2001 From: Shaunak Modak Date: Fri, 3 Aug 2018 14:06:55 -0600 Subject: [PATCH 2/6] additional flagging functions and tests --- hera_pspec/flags.py | 374 ++++++++++++++++++++++++----- hera_pspec/pspecdata.py | 97 +++----- hera_pspec/tests/test_flags.py | 205 ++++++++++++++++ hera_pspec/tests/test_pspecdata.py | 33 ++- 4 files changed, 574 insertions(+), 135 deletions(-) create mode 100644 hera_pspec/tests/test_flags.py diff --git a/hera_pspec/flags.py b/hera_pspec/flags.py index a3c1e29e..c5295482 100644 --- a/hera_pspec/flags.py +++ b/hera_pspec/flags.py @@ -1,83 +1,329 @@ +from __future__ import print_function, division import numpy as np +import matplotlib +#matplotlib.use('Agg') +from matplotlib import gridspec +import matplotlib.pyplot as plt +from pyuvdata import UVData +import copy -mask_generator(nsamples, flags, n_threshold, greedy=False, axis, greedy_threshold, retain_flags=True): + +def uvd_to_array(uvdlist, baseline): """ - Generates a greedy flags mask from input flags and nsamples arrays + Reads UVData objects and stores flags and nsamples arrays in a list + in preparation for stacking Parameters ---------- - nsamples : numpy.ndarray - integer array with number of samples available for each frequency channel at a given LST angle + uvdlist : list + a list of UVData objects + + baseline : tuple + specifying the baseline to look at in the form (ant1, ant2, pol), + for example (65, 66, 'xx') + + Returns + ------- + nsamples_list : list + a list of nsamples arrays from the input files + + flags_list : list + a list of flags arrays from the input files + + """ + if len(uvdlist) == 0: + raise ValueError("uvdlist must contain at least 1 UVData object") + elif not isinstance(uvdlist, list): + raise TypeError("uvdlist takes list inputs (for 1 UVData object, \ + add it to a list of length 1)") + # creating lists of flags and nsamples arrays of input UVData objects + flags_list = [uvd.get_flags(baseline) for uvd in uvdlist] + nsamples_list = [uvd.get_nsamples(baseline) for uvd in uvdlist] + return nsamples_list, flags_list + +def stacked_array(array_list): + """ + Generates a long stacked array for (waterfall plots) from a list of arrays - flags : numpy.ndarray - binary array with 1 representing flagged, 0 representing unflagged + Parameters + ---------- + array_list : list + list of numpy.ndarray objects to be stacked + + Returns + ------- + array_total : numpy.ndarray + array of all arrays in array_list stacked in list index order + """ + counter = 0 + if len(array_list) == 0: + raise ValueError("input array list cannot be empty") + # looping through all the arrays and stacking them up + for i in range(len(array_list)): + array_new = np.zeros(array_list[i].shape) + if counter == 0: + array_total = array_list[i] + elif counter != 0: + array_new = array_list[i] + array_total = np.vstack((array_total, array_new)) + counter += 1 + return array_total + +def construct_factorizable_mask(uvdlist, spw_ranges=[(0, 1024)], first='col', greedy_threshold=0.3, n_threshold = 1, + retain_flags=True, unflag=False, greedy=True, inplace=False): + """ + Generates a factorizable mask using a greedy flagging algorithm given a list + of UVData objects. First, flags are added to the mask based on the number of + samples available for the pixel. Next, in greedy flagging, based on the + "first" param, full columns (or rows) exceeding the greedy threshold are + flagged, & then any remaining flags have their full rows (or columns) + flagged. Unflagging the entire array is also an option. + + Parameters + ---------- + uvdlist : list + list of UVData objects to operate on + + spw_ranges : list + list of tuples of the form (min_channel, max_channel) defining which + spectral window (channel range) to flag - min_channel is inclusive, + but max_channel is exclusive + + first : str + either 'col' or 'row', defines which axis is flagged first based on + the greedy_threshold - default is 'col' - n_threshold : int - minimum number of samples needed for a point to remain unflagged + greedy_threshold : float + the flag fraction beyond which a given row or column is flagged in the + first stage of greedy flagging - greedy : bool - greedy flagging is used if true (default is False) + n_threshold : int + the number of samples needed for a pixel to remain unflagged + + retain_flags : bool + if True, then pixels flagged in the file will always remain flagged, even + if they meet the n_threshold (default is True) - axis : int - which axis to flag first if greedy=True (1 is row-first, 0 is col-first) + unflag : bool + if True, the entire mask is unflagged. default is False - greedy_threshold : float - if greedy=True, the threshold used to flag rows or columns if axis=1 or 0, respectively + greedy : bool + if True, greedy flagging takes place, & if False, only n_threshold flagging + is used (resulting mask will not be factorizable). default is True - retain_flags : bool - LST-Bin Flags are left flagged even if thresholds are not met (default is True) + inplace : bool + if True, then the input UVData objects' flag arrays are modified, and if + False, new UVData objects identical to the inputs but with updated flags + are created and returned Returns ------- - mask : numpy.ndarray - binary array of the new mask where 1 is flagged, 0 is unflagged - -""" - - shape = nsamples.shape - flags_output = np.zeros(shape) - - num_exactly_equal = 0 - - # comparing the number of samples to the threshold + uvdlist_updated : list + if inplace=False, a new list of UVData objects with updated flags + """ + # initialize a list to place output UVData objects in if inplace=False + uvdlist_updated = [] - for i in range(shape[0]): - for j in range(shape[1]): - if nsamples[i, j] < n_threshold: - flags_output[i, j] = 1 - elif nsamples[i, j] > n_threshold: - if retain_flags and flags[i, j] == 1: - flags_output[i, j] = 1 - else: - flags_output[i, j] = 0 - elif nsamples[i, j] == n_threshold: - if retain_flags and flags[i, j] == 1: - flags_output[i, j] = 1 - else: - flags_output[i, j] = 0 - num_exactly_equal += 1 + # iterate over datasets + for dset in uvdlist: + if not isinstance(dset, UVData): raise TypeError("uvdlist must be a list of UVData objects") + if not inplace: uvd_updated_i = copy.deepcopy(dset) + # iterate over spectral windows + for spw in spw_ranges: + if not isinstance(spw, tuple): raise TypeError("spw_ranges must be a list of tuples") + if unflag: + #unflag everything if unflag = True + if inplace: + dset.flag_array[:, :, spw[0]:spw[1], :] = False + continue + elif not inplace: + uvd_updated_i.flag_array[:, :, spw[0]:spw[1], :] = False + uvdlist_updated.append(uvd_updated_i) + continue + # conduct flagging: + # iterate over polarizations + for n in range(dset.Npols): + # iterate over unique baselines + ubl = np.unique(dset.baseline_array) + for bl in ubl: + # get baseline-times indices + bl_inds = np.where(np.in1d(dset.baseline_array, bl))[0] + # create a new array of flags with only those indices + flags = dset.flag_array[bl_inds, 0, :, n].copy() + nsamples = dset.nsample_array[bl_inds, 0, :, n].copy() + Ntimes = int(flags.shape[0]) + Nfreqs = int(flags.shape[1]) + narrower_flags_window = flags[:, spw[0]:spw[1]] + narrower_nsamples_window = nsamples[:, spw[0]:spw[1]] + flags_output = np.zeros(narrower_flags_window.shape) + if not (isinstance(greedy_threshold, float) or isinstance(n_threshold, int)): + raise TypeError("greedy_threshold must be a float, and n_threshold must be an int") + if greedy_threshold >= 1 or greedy_threshold <= 0: + raise ValueError("greedy_threshold must be between 0 & 1, exclusive") + # if retaining flags, an extra condition is added to the threshold filter + if retain_flags: + flags_output[(narrower_nsamples_window >= n_threshold) & (narrower_flags_window == False)] = False + flags_output[(narrower_nsamples_window < n_threshold) | (narrower_flags_window == True)] = True + else: + flags_output[(narrower_nsamples_window >= n_threshold)] = False + flags_output[(narrower_nsamples_window < n_threshold)] = True + # conducting the greedy flagging + if greedy: + if first != 'col' and first != 'row': + raise ValueError("first must be either 'row' or 'col'") + if first == 'col': + # flagging all columns that exceed the greedy_threshold + col_indices = np.where(np.sum(flags_output, axis = 0)/Ntimes > greedy_threshold) + flags_output[:, col_indices] = True + # flagging all remaining rows + remaining_rows = np.where(np.sum(flags_output, axis = 1) > len(list(col_indices[0]))) + flags_output[remaining_rows, :] = True + elif first == 'row': + # flagging all rows that exceed the greedy_threshold + row_indices = np.where(np.sum(flags_output, axis = 1)/(spw[1]-spw[0]) > greedy_threshold) + flags_output[row_indices, :] = True + # flagging all remaining columns + remaining_cols = np.where(np.sum(flags_output, axis = 0) > len(list(row_indices[0]))) + flags_output[:, remaining_cols] = True + # updating the UVData object's flag_array if inplace, or creating a new object if not + if inplace: + dset.flag_array[bl_inds, 0, spw[0]:spw[1], n] = flags_output + elif not inplace: + uvd_updated_i.flag_array[bl_inds, 0, spw[0]:spw[1], n] = flags_output + if not inplace: uvdlist_updated.append(uvd_updated_i) + # returning an updated list of UVData objects if not inplace + if not inplace: + return uvdlist_updated - # the greedy part +def long_waterfall(array_list, title, cmap='gray', starting_lst=[]): + """ + Generates a waterfall plot of flags or nsamples with axis sums from an + input array - if axis == 0: - if greedy: - column_flags_counter = 0 - for j in range(shape[1]): - if np.sum(flags_output[:, j])/shape[0] > greedy_threshold: - flags_output[:, j] = np.ones([shape[0]]) - column_flags_counter += 1 - for i in range(shape[0]): - if np.sum(flags_output[i, :]) > column_flags_counter: - flags_output[i, :] = np.ones([shape[1]]) - elif axis == 1: - if greedy: - row_flags_counter = 0 - for i in range(shape[0]): - if np.sum(flags_output[i, :])/shape[1] > greedy_threshold: - flags_output[i, :] = np.ones([shape[1]]) - row_flags_counter += 1 - for j in range(shape[1]): - if np.sum(flags_output[:, j]) > row_flags_counter: - flags_output[:, j] = np.ones([shape[0]]) + Parameters + ---------- + array_list : list + list of arrays to be stacked and displayed + + title : str + title of the plot + + cmap : str, optional + cmap parameter for the waterfall plot (default is 'gray') + + starting_lst : list, optional + list of starting lst to display in the plot + + Returns + ------- + main_waterfall : matplotlib.axes + Matplotlib Axes instance of the main plot + + freq_histogram : matplotlib.axes + Matplotlib Axes instance of the sum across times + + time_histogram : matplotlib.axes + Matplotlib Axes instance of the sum across freqs + + data : numpy.ndarray + A copy of the stacked_array output that is being displayed + """ + # creating the array to be displayed using stacked_array() + data = stacked_array(array_list) + # setting up the figure and grid + fig = plt.figure() + fig.suptitle(title, fontsize=30, horizontalalignment='center') + grid = gridspec.GridSpec(ncols=10, nrows=15) + main_waterfall = fig.add_subplot(grid[0:14, 0:8]) + freq_histogram = fig.add_subplot(grid[14:15, 0:8], sharex=main_waterfall) + time_histogram = fig.add_subplot(grid[0:14, 8:10], sharey=main_waterfall) + fig.set_size_inches(20, 80) + grid.tight_layout(fig) + counter = data.shape[0] // 60 + # waterfall plot + main_waterfall.imshow(data, aspect='auto', cmap=cmap, + interpolation='none') + main_waterfall.set_ylabel('Integration Number') + main_waterfall.set_yticks(np.arange(0, counter*60 + 1, 30)) + main_waterfall.set_ylim(60*(counter+1), 0) + #red lines separating files + for i in range(counter+1): + main_waterfall.plot(np.arange(data.shape[1]), + 60*i*np.ones(data.shape[1]), '-r') + for i in range(len(starting_lst)): + if not isinstance(starting_lst[i], str): + raise TypeError("starting_lst must be a list of strings") + # adding text of filenames + if len(starting_lst) > 0: + for i in range(counter): + short_name = 'first\nintegration LST:\n'+starting_lst[i] + plt.text(-20, 26 + i*60, short_name, rotation=-90, size='small', + horizontalalignment='center') + main_waterfall.set_xlim(0, 1024) + # frequency sum plot + counts_freq = np.sum(data, axis=0) + max_counts_freq = max(np.amax(counts_freq), data.shape[0]) + normalized_freq = 100 * counts_freq/max_counts_freq + freq_histogram.set_xticks(np.arange(0, 1024, 50)) + freq_histogram.set_yticks(np.arange(0, 101, 5)) + freq_histogram.set_xlabel('Channel Number (Frequency)') + freq_histogram.set_ylabel('Occupancy %') + freq_histogram.grid() + freq_histogram.plot(np.arange(0, 1024), normalized_freq, 'r-') + # time sum plot + counts_times = np.sum(data, axis=1) + max_counts_times = max(np.amax(counts_times), data.shape[1]) + normalized_times = 100 * counts_times/max_counts_times + time_histogram.plot(normalized_times, np.arange(data.shape[0]), 'k-', + label='all channels') + time_histogram.set_xticks(np.arange(0, 101, 10)) + time_histogram.set_xlabel('Flag %') + time_histogram.autoscale(False) + time_histogram.grid() + # returning the axes + return main_waterfall, freq_histogram, time_histogram, data - return flags_output +def flag_channels(uvdlist, spw_ranges, inplace=False): + """ + Flags a given range of channels entirely for a list of UVData objects + + Parameters + ---------- + uvdlist : list + list of UVData objects to be flagged + + spw_ranges : list + list of tuples of the form (min_channel, max_channel) defining which + channels to flag + + inplace : bool, optional + if True, then the input UVData objects' flag arrays are modified, + and if False, new UVData objects identical to the inputs but with + updated flags are created and returned (default is False) + + Returns: + ------- + uvdlist_updated : list + list of updated UVData objects + """ + uvdlist_updated = [] + for uvd in uvdlist: + if not isinstance(uvd, UVData): + raise TypeError("uvdlist must be a list of UVData objects") + if not inplace: + uvd_updated_i = copy.deepcopy(uvd) + for spw in spw_ranges: + if not isinstance(spw, tuple): + raise TypeError("spw_ranges must be a list of tuples") + for pol in range(uvd.Npols): + ubl = np.unique(uvd.baseline_array) + for bl in ubl: + bl_inds = np.where(np.in1d(uvd.baseline_array, bl))[0] + fully_flagged = np.ones(uvd.flag_array[bl_inds, 0, spw[0]:spw[1], pol].shape, dtype=bool) + if inplace: + uvd.flag_array[bl_inds, 0, spw[0]:spw[1], pol] = fully_flagged + elif not inplace: + uvd_updated_i.flag_array[bl_inds, 0, spw[0]:spw[1], pol] = fully_flagged + uvdlist_updated.append(uvd_updated_i) + if not inplace: + return uvdlist_updated \ No newline at end of file diff --git a/hera_pspec/pspecdata.py b/hera_pspec/pspecdata.py index 8246ace3..64312aae 100644 --- a/hera_pspec/pspecdata.py +++ b/hera_pspec/pspecdata.py @@ -5,6 +5,7 @@ from collections import OrderedDict as odict import hera_cal as hc from hera_pspec import uvpspec, utils, version, pspecbeam, container +from hera_pspec.flags import construct_factorizable_mask from hera_pspec import uvpspec_utils as uvputils from pyuvdata import utils as uvutils import datetime @@ -1454,39 +1455,44 @@ def cov_p_hat(self,M,q_cov): p_cov[tnum]=np.einsum('ab,cd,bd->ac',M,M,q_cov[tnum]) return p_cov - def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2, unflag=False): - """ - For each dataset in self.dset, update the flag_array such that - the flagging patterns are time-independent for each baseline given - a selection for spectral windows. + def broadcast_dset_flags(self, spw_ranges=None, first='col', greedy_threshold=0.3, n_threshold = 1, + retain_flags=True, unflag=False, greedy=True): - For each frequency pixel in a selected spw, if the fraction of flagged - times exceeds time_thresh, then all times are flagged. If it does not, - the specific integrations which hold flags in the spw are flagged across - all frequencies in the spw. + """ + Using the greedy flagging function construct_factorizable_mask from + flags.py, update the flags arrays of the datasets in the object + with masks that are independent in time and frequency. - Additionally, one can also unflag the flag_array entirely if desired. + Parameters + ---------- + spw_ranges : list + list of tuples of the form (min_channel, max_channel) defining which + spectral window (channel range) to flag - min_channel is inclusive, + but max_channel is exclusive - Note: although technically allowed, this function may give unexpected results - if multiple spectral windows in spw_ranges have frequency overlap. + first : str + either 'col' or 'row', defines which axis is flagged first based on + the greedy_threshold - default is 'col' - Note: it is generally not recommended to set time_thresh > 0.5, which - could lead to substantial amounts of data being flagged. + greedy_threshold : float + the flag fraction beyond which a given row or column is flagged in the + first stage of greedy flagging - Parameters - ---------- - spw_ranges : list of tuples - list of len-2 spectral window tuples, specifying the start (inclusive) - and stop (exclusive) index of the frequency array for each spw. - Default is to use the whole band + n_threshold : int + the number of samples needed for a pixel to remain unflagged - time_thresh : float - Fractional threshold of flagged pixels across time needed to flag all times - per freq channel. It is not recommend to set this greater than 0.5 + retain_flags : bool + if True, then pixels flagged in the file will always remain flagged, even + if they meet the n_threshold (default is True) unflag : bool - If True, unflag all data in the spectral window. + if True, the entire mask is unflagged. default is False + + greedy : bool + if True, greedy flagging takes place, & if False, only n_threshold flagging + is used (resulting mask will not be factorizable). default is True """ + # validate datasets self.validate_datasets() @@ -1494,45 +1500,14 @@ def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2, unflag=False): self.clear_cache() # spw type check - if spw_ranges is None: + if spw_ranges == None: spw_ranges = [(0, self.Nfreqs)] - assert isinstance(spw_ranges, list), "spw_ranges must be fed as a list of tuples" - - # iterate over datasets - for dset in self.dsets: - # iterate over spw ranges - for spw in spw_ranges: - self.set_spw(spw) - # unflag - if unflag: - # unflag for all times - dset.flag_array[:, :, self.spw_range[0]:self.spw_range[1], :] = False - continue - # enact time threshold on flag waterfalls - # iterate over polarizations - for i in range(dset.Npols): - # iterate over unique baselines - ubl = np.unique(dset.baseline_array) - for bl in ubl: - # get baseline-times indices - bl_inds = np.where(np.in1d(dset.baseline_array, bl))[0] - # get flag waterfall - flags = dset.flag_array[bl_inds, 0, :, i].copy() - Ntimes = float(flags.shape[0]) - Nfreqs = float(flags.shape[1]) - # get time- and freq-continguous flags - freq_contig_flgs = np.sum(flags, axis=1) / Nfreqs > 0.999999 - Ntimes_noncontig = np.sum(~freq_contig_flgs, dtype=np.float) - # get freq channels where non-contiguous flags exceed threshold - exceeds_thresh = np.sum(flags[~freq_contig_flgs], axis=0, dtype=np.float) / Ntimes_noncontig > time_thresh - # flag channels for all times that exceed time_thresh - dset.flag_array[bl_inds, :, np.where(exceeds_thresh)[0][:, None], i] = True - # for pixels that have flags but didn't meet broadcasting limit - # flag the integration within the spw - flags[:, np.where(exceeds_thresh)[0]] = False - flag_ints = np.max(flags[:, self.spw_range[0]:self.spw_range[1]], axis=1) - dset.flag_array[bl_inds[flag_ints], :, self.spw_range[0]:self.spw_range[1], i] = True + # using the construct_factorizable_mask function from flags.py to conduct the flagging and update the objects + construct_factorizable_mask(uvdlist=self.dsets, spw_ranges=spw_ranges, first=first, greedy_threshold=greedy_threshold, + n_threshold = n_threshold, retain_flags=retain_flags, unflag=unflag, greedy=greedy, + inplace=True) + def units(self, little_h=True): """ Return the units of the power spectrum. These are inferred from the diff --git a/hera_pspec/tests/test_flags.py b/hera_pspec/tests/test_flags.py new file mode 100644 index 00000000..8dff9c9b --- /dev/null +++ b/hera_pspec/tests/test_flags.py @@ -0,0 +1,205 @@ +from __future__ import print_function, division +import unittest +import nose.tools as nt +import numpy as np +from pyuvdata import UVData +import os +import sys +from hera_pspec.data import DATA_PATH +from hera_pspec.flags import uvd_to_array, stacked_array, construct_factorizable_mask, long_waterfall, flag_channels + +dfiles = ['zen.even.xx.LST.1.28828.uvOCRSA', 'zen.odd.xx.LST.1.28828.uvOCRSA'] +baseline = (38, 68, 'xx') + +class Test_Flags(unittest.TestCase): + + def setUp(self): + + # Load datafiles into UVData objects + self.d = [] + for dfile in dfiles: + _d = UVData() + _d.read_miriad(os.path.join(DATA_PATH, dfile)) + self.d.append(_d) + # data to use when testing the plotting function + self.data_list = [self.d[0].get_flags(38, 68, 'xx'), self.d[1].get_flags(38, 68, 'xx')] + + def tearDown(self): + pass + + def runTest(self): + pass + + def test_uvd_to_array(self): + """ + testing the uvd to array function + """ + nsamples, flags = uvd_to_array(self.d, baseline) + # making sure that length of lists is always equal + nt.assert_equal(len(nsamples), len(flags)) + # making sure an error comes up if len(uvdlist) = 0 + nt.assert_raises(ValueError, uvd_to_array, [], baseline) + # error if a UVData object is input instead of a list + nt.assert_raises(TypeError, uvd_to_array, self.d[0], baseline) + + def test_stacked_array(self): + """ + testing the array stacking function + """ + flags_list = uvd_to_array(self.d, baseline)[1] + long_array_flags = stacked_array(flags_list) + + # make sure # rows in output = sum of # rows in each input array + nt.assert_equal(long_array_flags.shape[0], sum([flag_array.shape[0] \ + for flag_array in flags_list])) + for flag_array in flags_list: + # ensuring that the number of columns is unchanged + nt.assert_equal(long_array_flags.shape[1], flag_array.shape[1]) + # ensuring that arrays are stacked in order as expected + nt.assert_true(np.array_equal( \ + long_array_flags[0 : flags_list[0].shape[0], :], flags_list[0])) + nt.assert_true(np.array_equal( \ + long_array_flags[ flags_list[0].shape[0] : flags_list[0].shape[0] + \ + flags_list[1].shape[0], :], flags_list[1])) + + def test_construct_factorizable_mask(self): + """ + testing mask generator function + """ + # testing unflagging + unflagged_uvdlist = construct_factorizable_mask(self.d, unflag=True, \ + inplace=False) + for uvd in unflagged_uvdlist: + unflagged_mask = uvd.get_flags((38, 68, 'xx')) + nt.assert_equal(np.sum(unflagged_mask), 0) + # ensuring that greedy flagging works as expected in extreme cases + allflagged_uvdlist = construct_factorizable_mask( \ + self.d, greedy_threshold=0.0001, first='row', inplace=False) + for uvd in allflagged_uvdlist: + flagged_mask = uvd.get_flags((38, 68, 'xx')) + # everything flagged since the greedy threshold is extremely low + nt.assert_equal(np.sum(flagged_mask), \ + np.sum(np.ones(flagged_mask.shape))) + # ensuring that n_threshold parameter works as expected in extreme cases + allflagged_uvdlist2 = construct_factorizable_mask( \ + self.d, n_threshold=35, first='row', inplace=False) + for uvd in allflagged_uvdlist2: + flagged_mask = uvd.get_flags((38, 68, 'xx')) + nt.assert_equal(np.sum(flagged_mask), \ + np.sum(np.ones(flagged_mask.shape))) + # ensuring that greedy flagging is occurring within the intended spw: + greedily_flagged_uvdlist = construct_factorizable_mask( \ + self.d, n_threshold = 6, greedy_threshold = 0.35, first='col', \ + spw_ranges=[(0, 300), (500, 700)], inplace=False) + for i in range(len(self.d)): + # checking that outside the spw range, flags are all equal + nt.assert_true(np.array_equal( \ + greedily_flagged_uvdlist[i].get_flags((38, 68, 'xx'))[:, 300:500], \ + self.d[i].get_flags((38, 68, 'xx'))[:, 300:500])) + nt.assert_true(np.array_equal( \ + greedily_flagged_uvdlist[i].get_flags((38, 68, 'xx'))[:, 700:], \ + self.d[i].get_flags((38, 68, 'xx'))[:, 700:])) + # flags are actually retained + original_flags_ind = np.where(self.d[i].get_flags((38, 68, 'xx')) == True) + new_flags = greedily_flagged_uvdlist[i].get_flags((38, 68, 'xx')) + old_flags = self.d[i].get_flags((38, 68, 'xx')) + nt.assert_true(np.array_equal( \ + new_flags[original_flags_ind], old_flags[original_flags_ind])) + # checking that inplace objects match in important areas + nt.assert_true(np.array_equal( \ + greedily_flagged_uvdlist[i].get_data((38, 68, 'xx')), \ + self.d[i].get_data((38, 68, 'xx')))) + nt.assert_true(np.array_equal( \ + greedily_flagged_uvdlist[i].get_nsamples((38, 68, 'xx')), \ + self.d[i].get_nsamples((38, 68, 'xx')))) + # making sure flags are actually independent in each spw + masks = [new_flags[:, 0:300], new_flags[:, 500:700]] + for mask in masks: + Nfreqs = mask.shape[1] + Ntimes = mask.shape[0] + N_flagged_rows = np.sum( \ + 1*(np.sum(mask, axis=1)/Nfreqs > 0.999999999)) + N_flagged_cols = np.sum( \ + 1*(np.sum(mask, axis=0)/Ntimes > 0.999999999)) + nt.assert_true(int(np.sum( \ + mask[np.where(np.sum(mask, axis=1)/Nfreqs < 0.99999999)]) \ + /(Ntimes-N_flagged_rows)) == N_flagged_cols) + + # copied from test_plot.py for testing the long_waterfall plotting function + def axes_contains(self, ax, obj_list): + """ + Check that a matplotlib.Axes instance contains certain elements. + + Parameters + ---------- + ax : matplotlib.Axes + Axes instance. + + obj_list : list of tuples + List of tuples, one for each type of object to look for. The tuple + should be of the form (matplotlib.object, int), where int is the + number of instances of that object that are expected. + """ + # Get plot elements + elems = ax.get_children() + # Loop over list of objects that should be in the plot + contains_all = False + for obj in obj_list: + objtype, num_expected = obj + num = 0 + for elem in elems: + if isinstance(elem, objtype): num += 1 + if num != num_expected: + return False + # Return True if no problems found + return True + + def test_long_waterfall(self): + """ + testing the long waterfall plotting function + """ + main_waterfall, freq_histogram, time_histogram, data = long_waterfall( \ + self.data_list, title='Flags Waterfall') + # making sure the main waterfall has the right number of dividing lines + main_waterfall_elems = [(matplotlib.lines.Line2D, \ + round(data.shape[0]/60, 0))] + nt.assert_true(self.axes_contains(main_waterfall, main_waterfall_elems)) + # making sure the time graph has the appropriate line element + time_elems = [(matplotlib.lines.Line2D, 1)] + nt.assert_true(self.axes_contains(time_histogram, time_elems)) + # making sure the freq graph has the appropriate line element + freq_elems = [(matplotlib.lines.Line2D, 1)] + nt.assert_true(self.axes_contains(freq_histogram, freq_elems)) + + def test_flag_channels(self): + """ + testing the channel-flagging function + """ + # ensuring that flagging is occurring: + column_flagged_uvdlist = flag_channels( \ + self.d, [(200, 451), (680, 881)], inplace=False) + for i in range(len(self.d)): + # checking that outside the spw ranges, flags are all equal + nt.assert_true(np.array_equal( \ + column_flagged_uvdlist[i].get_flags((38, 68, 'xx'))[:, :200], \ + self.d[i].get_flags((38, 68, 'xx'))[:, :200])) + nt.assert_true(np.array_equal( \ + column_flagged_uvdlist[i].get_flags((38, 68, 'xx'))[:, 451:680], \ + self.d[i].get_flags((38, 68, 'xx'))[:, 451:680])) + nt.assert_true(np.array_equal( \ + column_flagged_uvdlist[i].get_flags((38, 68, 'xx'))[:, 881:], \ + self.d[i].get_flags((38, 68, 'xx'))[:, 881:])) + # checking that inside the ranges, everything is flagged + nt.assert_true(np.array_equal( \ + column_flagged_uvdlist[i].get_flags((38, 68, 'xx'))[:, 200:451], \ + self.d[i].get_flags((38, 68, 'xx'))[:, 200:451])) + nt.assert_true(np.array_equal( \ + column_flagged_uvdlist[i].get_flags((38, 68, 'xx'))[:, 680:881], \ + self.d[i].get_flags((38, 68, 'xx'))[:, 680:881])) + # checking that inplace objects match in important areas + nt.assert_true(np.array_equal( \ + column_flagged_uvdlist[i].get_data((38, 68, 'xx')), \ + self.d[i].get_data((38, 68, 'xx')))) + nt.assert_true(np.array_equal( \ + column_flagged_uvdlist[i].get_nsamples((38, 68, 'xx')), \ + self.d[i].get_nsamples((38, 68, 'xx')))) \ No newline at end of file diff --git a/hera_pspec/tests/test_pspecdata.py b/hera_pspec/tests/test_pspecdata.py index 8bcfc2ed..1cd2532d 100644 --- a/hera_pspec/tests/test_pspecdata.py +++ b/hera_pspec/tests/test_pspecdata.py @@ -5,6 +5,7 @@ import os, copy, sys from scipy.integrate import simps, trapz from hera_pspec import pspecdata, pspecbeam, conversions, container, utils +from hera_pspec.flags import construct_factorizable_mask from hera_pspec.data import DATA_PATH from pyuvdata import UVData from hera_cal import redcal @@ -1094,38 +1095,50 @@ def test_broadcast_dset_flags(self): # test basic execution w/ a spw selection ds = pspecdata.PSpecData(dsets=[copy.deepcopy(uvd), copy.deepcopy(uvd)], wgts=[None, None]) - ds.broadcast_dset_flags(spw_ranges=[(400, 800)], time_thresh=0.2) - nt.assert_false(ds.dsets[0].get_flags(24, 25)[:, 550:650].any()) + ds.broadcast_dset_flags(spw_ranges=[(400, 600)], greedy_threshold=0.2) + nt.assert_false(ds.dsets[0].get_flags(24, 25)[5:7, 700:800].any()) # test w/ no spw selection ds = pspecdata.PSpecData(dsets=[copy.deepcopy(uvd), copy.deepcopy(uvd)], wgts=[None, None]) - ds.broadcast_dset_flags(spw_ranges=None, time_thresh=0.2) - nt.assert_true(ds.dsets[0].get_flags(24, 25)[:, 550:650].any()) - + ds.broadcast_dset_flags(spw_ranges=None, greedy_threshold=0.2) + nt.assert_true(ds.dsets[0].get_flags(24, 25)[5:7, 700:800].any()) + # test unflagging ds = pspecdata.PSpecData(dsets=[copy.deepcopy(uvd), copy.deepcopy(uvd)], wgts=[None, None]) - ds.broadcast_dset_flags(spw_ranges=None, time_thresh=0.2, unflag=True) + ds.broadcast_dset_flags(spw_ranges=None, greedy_threshold=0.2, unflag=True) nt.assert_false(ds.dsets[0].get_flags(24, 25)[:, :].any()) - + + # test retained flags + ds = pspecdata.PSpecData(dsets=[copy.deepcopy(uvd), copy.deepcopy(uvd)], wgts=[None, None]) + original_flags_ind = np.where(ds.dsets[0].get_flags(24, 25, 'xx') == True) + old_flags = ds.dsets[0].get_flags((24, 25, 'xx')) + ds.broadcast_dset_flags(spw_ranges=None, greedy_threshold=0.2, retain_flags=True) + new_flags = ds.dsets[0].get_flags((24, 25, 'xx')) + nt.assert_true(np.array_equal(new_flags[original_flags_ind], old_flags[original_flags_ind])) + # test single integration being flagged within spw ds = pspecdata.PSpecData(dsets=[copy.deepcopy(uvd), copy.deepcopy(uvd)], wgts=[None, None]) ds.dsets[0].flag_array[ds.dsets[0].antpair2ind(24, 25)[3], 0, 600, 0] = True - ds.broadcast_dset_flags(spw_ranges=[(400, 800)], time_thresh=0.25, unflag=False) + ds.broadcast_dset_flags(spw_ranges=[(400, 800)], greedy_threshold=0.25, unflag=False) nt.assert_true(ds.dsets[0].get_flags(24, 25)[3, 400:800].all()) nt.assert_false(ds.dsets[0].get_flags(24, 25)[3, :].all()) - + # test pspec run sets flagged integration to have zero weight uvd.flag_array[uvd.antpair2ind(24, 25)[3], 0, 400, :] = True ds = pspecdata.PSpecData(dsets=[copy.deepcopy(uvd), copy.deepcopy(uvd)], wgts=[None, None]) - ds.broadcast_dset_flags(spw_ranges=[(400, 450)], time_thresh=0.25) + ds.broadcast_dset_flags(spw_ranges=[(400, 450)], greedy_threshold=0.25) uvp = ds.pspec([(24, 25), (37, 38), (38, 39)], [(24, 25), (37, 38), (38, 39)], (0, 1), ('xx', 'xx'), spw_ranges=[(400, 450)], verbose=False) + # assert flag broadcast above hits weight arrays in uvp nt.assert_true(np.all(np.isclose(uvp.get_wgts(0, ((24, 25), (24, 25)), 'xx')[3], 0.0))) + # assert flag broadcast above hits integration arrays nt.assert_true(np.isclose(uvp.get_integrations(0, ((24, 25), (24, 25)), 'xx')[3], 0.0)) + # average spectra avg_uvp = uvp.average_spectra(blpair_groups=[sorted(np.unique(uvp.blpair_array))], time_avg=True, inplace=False) + # repeat but change data in flagged portion ds.dsets[0].data_array[uvd.antpair2ind(24, 25)[3], 0, 400:450, :] *= 100 uvp2 = ds.pspec([(24, 25), (37, 38), (38, 39)], [(24, 25), (37, 38), (38, 39)], (0, 1), ('xx', 'xx'), From 424c366a25ae09944e9fa5dc41e0981db424051f Mon Sep 17 00:00:00 2001 From: Shaunak Modak Date: Fri, 3 Aug 2018 14:11:11 -0600 Subject: [PATCH 3/6] travis and plot minor modifications --- .travis.yml | 2 +- hera_pspec/plot.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 8e878c8a..335fcda4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ install: before_script: - "export DISPLAY=:99.0" - "sh -e /etc/init.d/xvfb start" - - "export MPLBACKEND=agg" + - export MPLBACKEND=agg - sleep 3 script: nosetests hera_pspec --with-coverage --cover-package=hera_pspec --verbose diff --git a/hera_pspec/plot.py b/hera_pspec/plot.py index 4cc3fb97..d8d4133e 100644 --- a/hera_pspec/plot.py +++ b/hera_pspec/plot.py @@ -1,6 +1,7 @@ import numpy as np import pyuvdata from hera_pspec import conversions +import matplotlib import matplotlib.pyplot as plt import copy @@ -144,5 +145,4 @@ def delay_spectrum(uvp, blpairs, spw, pol, average_blpairs=False, ax.set_ylabel("$P(k_\parallel)$ $[%s]$" % psunits, fontsize=16) # Return Axes - return ax - + return ax \ No newline at end of file From e0fab29c688b7a581ebcb74bbd3e6feb7c204b57 Mon Sep 17 00:00:00 2001 From: Shaunak Modak Date: Mon, 6 Aug 2018 12:48:53 -0600 Subject: [PATCH 4/6] fixed some errors to pass more tests --- hera_pspec/flags.py | 1 - hera_pspec/pspecdata.py | 12 ++++++------ hera_pspec/tests/test_pspecdata.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/hera_pspec/flags.py b/hera_pspec/flags.py index c5295482..79de6eff 100644 --- a/hera_pspec/flags.py +++ b/hera_pspec/flags.py @@ -1,7 +1,6 @@ from __future__ import print_function, division import numpy as np import matplotlib -#matplotlib.use('Agg') from matplotlib import gridspec import matplotlib.pyplot as plt from pyuvdata import UVData diff --git a/hera_pspec/pspecdata.py b/hera_pspec/pspecdata.py index 4453ff32..016e651e 100644 --- a/hera_pspec/pspecdata.py +++ b/hera_pspec/pspecdata.py @@ -2447,7 +2447,7 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, exclude_auto_bls=False, exclude_permutations=True, Nblps_per_group=None, bl_len_range=(0, 1e10), bl_deg_range=(0, 180), bl_error_tol=1.0, beam=None, cosmo=None, rephase_to_dset=None, trim_dset_lsts=False, broadcast_dset_flags=True, - time_thresh=0.2, Jy2mK=False, overwrite=True, verbose=True, store_cov=False, history=''): + greedy_thresh=0.2, Jy2mK=False, overwrite=True, verbose=True, store_cov=False, history=''): """ Create a PSpecData object, run OQE delay spectrum estimation and write results to a PSpecContainer object. @@ -2566,9 +2566,9 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, non-overlapping LSTs. broadcast_dset_flags : boolean - If True, broadcast dset flags across time using fractional time_thresh. + If True, broadcast dset flags across time using fractional greedy_thresh. - time_thresh : float + greedy_thresh : float Fractional flagging threshold, above which a broadcast of flags across time is triggered (if broadcast_dset_flags == True). This is done independently for each baseline's visibility waterfall. @@ -2701,7 +2701,7 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, # broadcast flags if broadcast_dset_flags: - ds.broadcast_dset_flags(time_thresh=time_thresh) + ds.broadcast_dset_flags(greedy_thresh=greedy_thresh) # perform Jy to mK conversion if desired if Jy2mK: @@ -2817,8 +2817,8 @@ def list_of_tuple_tuples(v): a.add_argument("--cosmo", default=None, nargs='+', type=float, help="List of float values for [Om_L, Om_b, Om_c, H0, Om_M, Om_k].") a.add_argument("--rephase_to_dset", default=None, type=int, help="dset integer index to phase all other dsets to. Default is no rephasing.") a.add_argument("--trim_dset_lsts", default=False, action='store_true', help="Trim non-overlapping dset LSTs.") - a.add_argument("--broadcast_dset_flags", default=False, action='store_true', help="Broadcast dataset flags across time according to time_thresh.") - a.add_argument("--time_thresh", default=0.2, type=float, help="Fractional flagging threshold across time to trigger flag broadcast if broadcast_dset_flags is True") + a.add_argument("--broadcast_dset_flags", default=False, action='store_true', help="Broadcast dataset flags across time according to greedy_thresh.") + a.add_argument("--greedy_thresh", default=0.2, type=float, help="Fractional flagging threshold across time to trigger flag broadcast if broadcast_dset_flags is True") a.add_argument("--Jy2mK", default=False, action='store_true', help="Convert datasets from Jy to mK if a beam model is provided.") a.add_argument("--exclude_auto_bls", default=False, action='store_true', help='If blpairs is not provided, exclude all baselines paired with itself.') a.add_argument("--exclude_permutations", default=False, action='store_true', help='If blpairs is not provided, exclude a basline-pair permutations. Ex: if (A, B) exists, exclude (B, A).') diff --git a/hera_pspec/tests/test_pspecdata.py b/hera_pspec/tests/test_pspecdata.py index 17417baf..1c4eec49 100644 --- a/hera_pspec/tests/test_pspecdata.py +++ b/hera_pspec/tests/test_pspecdata.py @@ -1298,7 +1298,7 @@ def test_pspec_run(): rephase_to_dset=0, blpairs=[((37, 38), (37, 38)), ((37, 38), (52, 53))], pol_pairs=[('xx', 'xx'), ('xx', 'xx')], dset_labels=["foo", "bar"], dset_pairs=[(0, 0), (0, 1)], spw_ranges=[(50, 75), (120, 140)], - cosmo=cosmo, trim_dset_lsts=True, broadcast_dset_flags=True, time_thresh=0.1, + cosmo=cosmo, trim_dset_lsts=True, broadcast_dset_flags=True, greedy_thresh=0.1, store_cov=True) nt.assert_true("foo_bar" in psc.groups()) nt.assert_equal(psc.spectra('foo_bar'), [u'foo_x_bar', u'foo_x_foo']) From e176c5ad82f4ca5e0d1d7ed9dc3145f450404b3e Mon Sep 17 00:00:00 2001 From: Shaunak Modak Date: Tue, 7 Aug 2018 14:14:06 -0600 Subject: [PATCH 5/6] flags & pspecdata changes now pass all relevant tests --- hera_pspec/flags.py | 6 +++--- hera_pspec/pspecdata.py | 14 +++++++------- hera_pspec/tests/test_flags.py | 8 ++++++-- hera_pspec/tests/test_pspecdata.py | 2 +- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/hera_pspec/flags.py b/hera_pspec/flags.py index 79de6eff..00fe904f 100644 --- a/hera_pspec/flags.py +++ b/hera_pspec/flags.py @@ -258,17 +258,17 @@ def long_waterfall(array_list, title, cmap='gray', starting_lst=[]): short_name = 'first\nintegration LST:\n'+starting_lst[i] plt.text(-20, 26 + i*60, short_name, rotation=-90, size='small', horizontalalignment='center') - main_waterfall.set_xlim(0, 1024) + main_waterfall.set_xlim(0, data.shape[1]) # frequency sum plot counts_freq = np.sum(data, axis=0) max_counts_freq = max(np.amax(counts_freq), data.shape[0]) normalized_freq = 100 * counts_freq/max_counts_freq - freq_histogram.set_xticks(np.arange(0, 1024, 50)) + freq_histogram.set_xticks(np.arange(0, data.shape[1], 50)) freq_histogram.set_yticks(np.arange(0, 101, 5)) freq_histogram.set_xlabel('Channel Number (Frequency)') freq_histogram.set_ylabel('Occupancy %') freq_histogram.grid() - freq_histogram.plot(np.arange(0, 1024), normalized_freq, 'r-') + freq_histogram.plot(np.arange(0, data.shape[1]), normalized_freq, 'r-') # time sum plot counts_times = np.sum(data, axis=1) max_counts_times = max(np.amax(counts_times), data.shape[1]) diff --git a/hera_pspec/pspecdata.py b/hera_pspec/pspecdata.py index 016e651e..7df09c01 100644 --- a/hera_pspec/pspecdata.py +++ b/hera_pspec/pspecdata.py @@ -2350,7 +2350,7 @@ def rephase_to_dset(self, dset_index=0, inplace=True): # get blts indices of basline indices = dset.antpair2ind(*k[:2]) # get index in polarization_array for this polarization - polind = pol_list.index(hc.io.polstr2num[k[-1]]) + polind = pol_list.index(hc.io.polstr2num(k[-1])) # insert into dset dset.data_array[indices, 0, :, polind] = data[k] @@ -2447,7 +2447,7 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, exclude_auto_bls=False, exclude_permutations=True, Nblps_per_group=None, bl_len_range=(0, 1e10), bl_deg_range=(0, 180), bl_error_tol=1.0, beam=None, cosmo=None, rephase_to_dset=None, trim_dset_lsts=False, broadcast_dset_flags=True, - greedy_thresh=0.2, Jy2mK=False, overwrite=True, verbose=True, store_cov=False, history=''): + greedy_threshold=0.2, Jy2mK=False, overwrite=True, verbose=True, store_cov=False, history=''): """ Create a PSpecData object, run OQE delay spectrum estimation and write results to a PSpecContainer object. @@ -2566,9 +2566,9 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, non-overlapping LSTs. broadcast_dset_flags : boolean - If True, broadcast dset flags across time using fractional greedy_thresh. + If True, broadcast dset flags across time using fractional greedy_threshold. - greedy_thresh : float + greedy_threshold : float Fractional flagging threshold, above which a broadcast of flags across time is triggered (if broadcast_dset_flags == True). This is done independently for each baseline's visibility waterfall. @@ -2701,7 +2701,7 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, # broadcast flags if broadcast_dset_flags: - ds.broadcast_dset_flags(greedy_thresh=greedy_thresh) + ds.broadcast_dset_flags(greedy_threshold=greedy_threshold) # perform Jy to mK conversion if desired if Jy2mK: @@ -2817,8 +2817,8 @@ def list_of_tuple_tuples(v): a.add_argument("--cosmo", default=None, nargs='+', type=float, help="List of float values for [Om_L, Om_b, Om_c, H0, Om_M, Om_k].") a.add_argument("--rephase_to_dset", default=None, type=int, help="dset integer index to phase all other dsets to. Default is no rephasing.") a.add_argument("--trim_dset_lsts", default=False, action='store_true', help="Trim non-overlapping dset LSTs.") - a.add_argument("--broadcast_dset_flags", default=False, action='store_true', help="Broadcast dataset flags across time according to greedy_thresh.") - a.add_argument("--greedy_thresh", default=0.2, type=float, help="Fractional flagging threshold across time to trigger flag broadcast if broadcast_dset_flags is True") + a.add_argument("--broadcast_dset_flags", default=False, action='store_true', help="Broadcast dataset flags across time according to greedy_threshold.") + a.add_argument("--greedy_threshold", default=0.2, type=float, help="Fractional flagging threshold across time to trigger flag broadcast if broadcast_dset_flags is True") a.add_argument("--Jy2mK", default=False, action='store_true', help="Convert datasets from Jy to mK if a beam model is provided.") a.add_argument("--exclude_auto_bls", default=False, action='store_true', help='If blpairs is not provided, exclude all baselines paired with itself.') a.add_argument("--exclude_permutations", default=False, action='store_true', help='If blpairs is not provided, exclude a basline-pair permutations. Ex: if (A, B) exists, exclude (B, A).') diff --git a/hera_pspec/tests/test_flags.py b/hera_pspec/tests/test_flags.py index 8dff9c9b..6915f30f 100644 --- a/hera_pspec/tests/test_flags.py +++ b/hera_pspec/tests/test_flags.py @@ -3,6 +3,7 @@ import nose.tools as nt import numpy as np from pyuvdata import UVData +import matplotlib import os import sys from hera_pspec.data import DATA_PATH @@ -161,7 +162,10 @@ def test_long_waterfall(self): main_waterfall, freq_histogram, time_histogram, data = long_waterfall( \ self.data_list, title='Flags Waterfall') # making sure the main waterfall has the right number of dividing lines - main_waterfall_elems = [(matplotlib.lines.Line2D, \ + if round(data.shape[0]/60, 0) == 0: + main_waterfall_elems = [(matplotlib.lines.Line2D, 1)] + else: + main_waterfall_elems = [(matplotlib.lines.Line2D, \ round(data.shape[0]/60, 0))] nt.assert_true(self.axes_contains(main_waterfall, main_waterfall_elems)) # making sure the time graph has the appropriate line element @@ -170,7 +174,7 @@ def test_long_waterfall(self): # making sure the freq graph has the appropriate line element freq_elems = [(matplotlib.lines.Line2D, 1)] nt.assert_true(self.axes_contains(freq_histogram, freq_elems)) - + def test_flag_channels(self): """ testing the channel-flagging function diff --git a/hera_pspec/tests/test_pspecdata.py b/hera_pspec/tests/test_pspecdata.py index 1c4eec49..b8fa7642 100644 --- a/hera_pspec/tests/test_pspecdata.py +++ b/hera_pspec/tests/test_pspecdata.py @@ -1298,7 +1298,7 @@ def test_pspec_run(): rephase_to_dset=0, blpairs=[((37, 38), (37, 38)), ((37, 38), (52, 53))], pol_pairs=[('xx', 'xx'), ('xx', 'xx')], dset_labels=["foo", "bar"], dset_pairs=[(0, 0), (0, 1)], spw_ranges=[(50, 75), (120, 140)], - cosmo=cosmo, trim_dset_lsts=True, broadcast_dset_flags=True, greedy_thresh=0.1, + cosmo=cosmo, trim_dset_lsts=True, broadcast_dset_flags=True, greedy_threshold=0.1, store_cov=True) nt.assert_true("foo_bar" in psc.groups()) nt.assert_equal(psc.spectra('foo_bar'), [u'foo_x_bar', u'foo_x_foo']) From eedee0f55d390a30de21e401cf6e05b46d06f33b Mon Sep 17 00:00:00 2001 From: Paul La Plante Date: Tue, 25 Sep 2018 15:18:59 -0700 Subject: [PATCH 6/6] Make tests pass --- hera_pspec/pspecdata.py | 2 +- hera_pspec/tests/test_pspecdata.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hera_pspec/pspecdata.py b/hera_pspec/pspecdata.py index 4d3bfc99..6ab30fb1 100644 --- a/hera_pspec/pspecdata.py +++ b/hera_pspec/pspecdata.py @@ -2721,7 +2721,7 @@ def pspec_run(dsets, filename, dsets_std=None, groupname=None, dset_labels=None, # broadcast flags if broadcast_dset_flags: - ds.broadcast_dset_flags(greedy_threshold=greedy_threshold) + ds.broadcast_dset_flags(spw_ranges=spw_ranges, greedy_threshold=greedy_threshold) # perform Jy to mK conversion if desired if Jy2mK: diff --git a/hera_pspec/tests/test_pspecdata.py b/hera_pspec/tests/test_pspecdata.py index b378015c..a91517ed 100644 --- a/hera_pspec/tests/test_pspecdata.py +++ b/hera_pspec/tests/test_pspecdata.py @@ -1316,7 +1316,7 @@ def test_pspec_run(): # assert dset labeling propagated nt.assert_equal(set(uvp.labels), set(['bar', 'foo'])) # assert spw_ranges and n_dlys specification worked - np.testing.assert_array_equal(uvp.get_spw_ranges(), [(163476562.5, 165917968.75, 25, 20), (170312500.0, 172265625.0, 20, 20)]) + np.testing.assert_array_equal(uvp.get_spw_ranges(), [(163476562.5, 165917968.75, 25, 25), (170312500.0, 172265625.0, 20, 20)]) # get shifted UVDatas and test rephasing, flag broadcasting uvd = UVData() @@ -1331,7 +1331,7 @@ def test_pspec_run(): psc, ds = pspecdata.pspec_run([copy.deepcopy(uvd1), copy.deepcopy(uvd2)], "./out2.h5", blpairs=[((37, 38), (37, 38)), ((37, 38), (52, 53))], verbose=False, overwrite=True, spw_ranges=[(50, 100)], rephase_to_dset=0, - broadcast_dset_flags=True, time_thresh=0.3) + broadcast_dset_flags=True, greedy_threshold=0.3) # assert first integration flagged across entire spw nt.assert_true(ds.dsets[0].get_flags(37, 38)[0, 50:100].all()) # assert first integration flagged *ONLY* across spw