In [1]:
import numpy as np
from pyuvdata import UVData, UVCal
from hera_cal import io
from pyuvdata.utils import polnum2str, polstr2num, jnum2str, jstr2num
from hera_cal.datacontainer import DataContainer

In [2]:
import operator
import collections
from collections import OrderedDict as odict
import copy
import os
import warnings

In [3]:
import unittest

# new utils

In [4]:
def split_pol(pol):
    '''Splits visibility polarization string into anntenna polarizations.'''
    if polstr2num(pol) > 0: # this includes Stokes and pseudo-Stokes visibilities 
        raise ValueError('Unable to split Stokes or pseudo-Stokes polarization ' + pol)
    return jnum2str(jstr2num(pol[0])), jnum2str(jstr2num(pol[1]))

def conj_pol(pol):
    '''Given V_ij^(pol), return the polarization of V_ji^(conj pol) such 
    (V_ji^(conj pol))* = V_ij^(pol). This means xy -> yx and yx -> xy, but
    psuedo-Stokes visibilities are unaffected. Case is left unmodified.'''
    if polstr2num(pol) > 0:  # this includes Stokes and pseudo-Stokes visibilities 
        return pol
    else:
        return pol[::-1]

In [5]:
# class Test_Pol_Ops(unittest.TestCase):
    
#     def test_split_pol(self):
#         self.assertEqual(split_pol('xx'),('jxx','jxx'))
#         self.assertEqual(split_pol('xy'),('jxx','jyy'))        
#         self.assertEqual(split_pol('XY'),('jxx','jyy'))
#         with self.assertRaises(ValueError):
#             split_pol('I')
#         with self.assertRaises(ValueError):    
#             split_pol('pV')           
            
#     def test_conj_pol(self):
#         self.assertEqual(conj_pol('xx'),'xx')
#         self.assertEqual(conj_pol('XX'),'XX')        
#         self.assertEqual(conj_pol('XY'),'YX')
#         self.assertEqual(conj_pol('yx'),'xy')
#         self.assertEqual(conj_pol('Q'),'Q')
#         self.assertEqual(conj_pol('pU'),'pU')
        
# if __name__ == '__main__':
#     unittest.main(argv=['first-arg-is-ignored'], exit=False)

# HERACal

In [6]:
class HERACal(UVCal):
    '''HERAData is a subclass of pyuvdata.UVCal meant to serve as an interface between 
    pyuvdata-readable calfits files and dictionaries (the in-memory format for hera_cal)
    that map antennas and polarizations to gains, flags, and qualities. Supports standard
    UVCal functionality, along with read() and update() functionality for going back and
    forth to dictionaires. Upon read(), stores useful metadata internally.
    
    Does not support partial data loading or writing. Assumes a single spectral window.
    '''
    
    def __init__(self, input_cal):
        '''Instantiate a HERACal object. Currently only supports calfits files.
        
        Arguments:
            input_cal: string calfits file path or list of paths
        '''
        super(HERACal, self).__init__()
        
        # parse input_data as filepath(s)
        if isinstance(input_cal, str):
            self.filepaths = [input_cal]
        elif isinstance(input_cal, collections.Iterable):  # List loading
            if np.all([isinstance(i, str) for i in input_cal]):  # List of visibility data paths
                self.filepaths = list(input_cal)
            else:
                raise TypeError('If input_cal is a list, it must be a list of strings.')
        else: 
            raise ValueError('input_cal must be a string or a list of strings.')

    def _extract_metadata(self):
        '''Extract and store useful metadata and array indexing dictionaries.'''
        self.freqs = np.unique(self.freq_array)
        self.times = np.unique(self.time_array)
        self.pols = [jnum2str(j) for j in self.jones_array]
        self._jnum_indices = {jnum: i for i, jnum in enumerate(self.jones_array)}        
        self.ants= [(ant, pol) for ant in self.ant_array for pol in self.pols]
        self._antnum_indices = {ant: i for i, ant in enumerate(self.ant_array)}        
        
    def build_cal_dicts(self):
        '''Turns the calibration information currently loaded into the HERACal object
        into dictionaries that map antenna-pol tuples to calibration waterfalls. Computes
        and stores internally useful metadata in the process.
        
        Returns:
            gains: dict mapping antenna-pol keys to (Nint, Nfreq) complex gains arrays
            flags: dict mapping antenna-pol keys to (Nint, Nfreq) boolean flag arrays
            quals: dict mapping antenna-pol keys to (Nint, Nfreq) float qual arrays
            total_qual: dict mapping polarization to (Nint, Nfreq) float total quality array
        '''
        self._extract_metadata()
        gains, flags, quals, total_qual = odict(), odict(), odict(), odict()
              
        # build dict of gains, flags, and quals
        for (ant, pol) in self.ants:
            i, ip = self._antnum_indices[ant], self._jnum_indices[jstr2num(pol)]
            gains[(ant, pol)] = np.array(self.gain_array[i, 0, :, :, ip].T)
            flags[(ant, pol)] = np.array(self.flag_array[i, 0, :, :, ip].T)
            quals[(ant, pol)] = np.array(self.quality_array[i, 0, :, :, ip].T)
        
        # build dict of total_qual if available
        for pol in self.pols:
            ip = self._jnum_indices[jstr2num(pol)]
            if self.total_quality_array is not None:
                total_qual[pol] = np.array(self.total_quality_array[0, :, :, ip].T)
            else:
                total_qual = None
        
        return gains, flags, quals, total_qual

    def read(self):
        '''Reads calibration information from file, computes useful metadata and returns
        dictionaries that map antenna-pol tuples to calibration waterfalls. 
        
        Returns:
            gains: dict mapping antenna-pol keys to (Nint, Nfreq) complex gains arrays
            flags: dict mapping antenna-pol keys to (Nint, Nfreq) boolean flag arrays
            quals: dict mapping antenna-pol keys to (Nint, Nfreq) float qual arrays
            total_qual: dict mapping polarization to (Nint, Nfreq) float total quality array
        '''
        self.read_calfits(self.filepaths)
        return self.build_cal_dicts()
    
    def update(self, gains=None, flags=None, quals=None, total_qual=None):
        '''Update internal calibrations arrays (data_array, flag_array, and nsample_array)
        using DataContainers (if not left as None) in preparation for writing to disk. 

        Arguments:
            gains: optional dict mapping antenna-pol to complex gains arrays
            flags: optional dict mapping antenna-pol to boolean flag arrays
            quals: optional dict mapping antenna-pol to float qual arrays
            total_qual: optional dict mapping polarization to float total quality array
        '''
        # loop over and update gains, flags, and quals
        data_arrays = [self.gain_array, self.flag_array, self.quality_array]
        for to_update, array in zip([gains, flags, quals], data_arrays):
            if to_update is not None:
                for (ant, pol) in to_update.keys():
                    i, ip = self._antnum_indices[ant], self._jnum_indices[jstr2num(pol)]
                    array[i, 0, :, :, ip] = to_update[(ant, pol)].T
        
        # update total_qual
        if total_qual is not None:
            for pol in total_qual.keys():
                ip = self._jnum_indices[jstr2num(pol)]
                self.total_quality_array[0, :, :, ip] = total_qual[pol].T

In [7]:
# from hera_cal.data import DATA_PATH
# import os

# class Test_HERACal(unittest.TestCase):
    
#     def setUp(self):
#         self.fname_xx = os.path.join(DATA_PATH, "test_input/zen.2457698.40355.xx.HH.uvc.omni.calfits")
#         self.fname_yy = os.path.join(DATA_PATH, "test_input/zen.2457698.40355.yy.HH.uvc.omni.calfits")
#         self.fname_both = os.path.join(DATA_PATH, "test_input/zen.2457698.40355.HH.uvcA.omni.calfits")
    
#     def test_init(self):
#         hc = HERACal(self.fname_xx)
#         self.assertEqual(hc.filepaths, [self.fname_xx])
#         hc = HERACal([self.fname_xx, self.fname_yy])
#         self.assertEqual(hc.filepaths, [self.fname_xx, self.fname_yy])
#         hc = HERACal((self.fname_xx, self.fname_yy))
#         self.assertEqual(hc.filepaths, [self.fname_xx, self.fname_yy])        
#         with self.assertRaises(TypeError):
#             hc = HERACal([0,1])
#         with self.assertRaises(ValueError):
#             hc = HERACal(None)
            
#     def test_read(self):
#         # test one file with both polarizations and a non-None total quality array
#         hc = HERACal(self.fname_both)
#         gains, flags, quals, total_qual = hc.read()
#         uvc = UVCal()
#         uvc.read_calfits(self.fname_both)
#         np.testing.assert_array_equal(uvc.gain_array[0, 0, :, :, 0].T, gains[9, 'jxx'])
#         np.testing.assert_array_equal(uvc.flag_array[0, 0, :, :, 0].T, flags[9, 'jxx'])        
#         np.testing.assert_array_equal(uvc.quality_array[0, 0, :, :, 0].T, quals[9, 'jxx'])                
#         np.testing.assert_array_equal(uvc.total_quality_array[0, :, :, 0].T, total_qual['jxx'])
#         np.testing.assert_array_equal(np.unique(uvc.freq_array), hc.freqs)
#         np.testing.assert_array_equal(np.unique(uvc.time_array), hc.times)        
#         self.assertEqual(hc.pols, ['jxx', 'jyy'])
#         self.assertEqual(set([ant[0] for ant in hc.ants]), set(uvc.ant_array))
        
#         # test list loading
#         hc = HERACal([self.fname_xx, self.fname_yy])
#         gains, flags, quals, total_qual = hc.read()
#         self.assertEqual(len(gains.keys()), 36)
#         self.assertEqual(len(flags.keys()), 36)
#         self.assertEqual(len(quals.keys()), 36)
#         self.assertEqual(hc.freqs.shape, (1024,))
#         self.assertEqual(hc.times.shape, (3,))
#         self.assertEqual(sorted(hc.pols), ['jxx', 'jyy'])
        
#     def test_write(self):
#         hc = HERACal(self.fname_both)
#         gains, flags, quals, total_qual = hc.read()
#         for key in gains.keys():
#             gains[key] *= 2.0 + 1.0j
#             flags[key] = np.logical_not(flags[key])
#             quals[key] *= 2.0
#         for key in total_qual.keys():
#             total_qual[key] *= 2
#         hc.update(gains=gains, flags=flags, quals=quals, total_qual=total_qual)
#         hc.write_calfits('test.calfits', clobber=True)
        
#         gains_in, flags_in, quals_in, total_qual_in = hc.read()
#         hc2 = HERACal('test.calfits')
#         gains_out, flags_out, quals_out, total_qual_out = hc2.read()
#         for key in gains_in.keys():
#             np.testing.assert_array_equal(gains_in[key] * (2.0 + 1.0j), gains_out[key])
#             np.testing.assert_array_equal(np.logical_not(flags_in[key]), flags_out[key])            
#             np.testing.assert_array_equal(quals_in[key] * (2.0), quals_out[key])
#         for key in total_qual.keys():
#             np.testing.assert_array_equal(total_qual_in[key] * (2.0), total_qual_out[key])        
        
#         os.remove('test.calfits')



if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)


----------------------------------------------------------------------
Ran 0 tests in 0.000s

OK


# HERAData

In [13]:
class HERAData(UVData):
    '''HERAData is a subclass of pyuvdata.UVData meant to serve as an interface between 
    pyuvdata-compatible data formats on disk (especially uvh5) and DataContainers,
    the in-memory format for visibilities used in hera_cal. In addition to standard 
    UVData functionality, HERAData supports read() and update() functions that interface
    between internal UVData data storage and DataContainers, which contain visibility
    data in a dictionary-like format, along with some useful metadata. read() supports
    partial data loading, though only the most useful subset of selection modes from 
    pyuvdata (and not all modes for all data types).
    
    When using uvh5, HERAData supports additional useful functionality:
    * Upon __init__(), the most useful metadata describing the entire file is loaded into
      the object (everything in HERAData_metas; see get_metadata_dict() for details).
    * Partial writing using partial_write(), which will initialize a new file with the
      same metadata and write to disk using DataContainers by assuming that the user is
      writing to the same part of the data as the most recent read().
    * Generators that enable iterating over baseline, frequency, or time in chunks (see 
      iterate_over_bls(), iterate_over_freqs(), and iterate_over_times() for details).
      
    Assumes a single spectral window. Assumes that data for a given baseline is regularly
    spaced in the underlying data_array.
    '''
    
    # static list of useful metadata to calculate and save
    HERAData_metas = ['ants', 'antpos', 'freqs', 'times', 'lsts', 'pols', 
                      'antpairs', 'bls', 'times_by_bl', 'lsts_by_bl']

    def __init__(self, input_data, filetype='uvh5'):
        '''Instantiate a HERAData object. If the filetype is uvh5, read in and store 
        useful metadata (see get_metadata_dict()), either as object attributes or, 
        if input_data is a list, as dictionaries mapping string paths to metadata.
        
        Arguments:
            input_data: string data file path or list of string data file paths
            filetype: supports 'uvh5' (defualt), 'miriad', 'uvfits'
        '''
        # initialize as empty UVData object
        super(HERAData, self).__init__()
        
        # parse input_data as filepath(s)
        if isinstance(input_data, str):
            self.filepaths = [input_data]
        elif isinstance(input_data, collections.Iterable):  # List loading
            if np.all([isinstance(i, str) for i in input_data]):  # List of visibility data paths
                self.filepaths = list(input_data)
            else:
                raise TypeError('If input_data is a list, it must be a list of strings.')
        else: 
            raise ValueError('input_data must be a string or a list of strings.')
        for f in self.filepaths:
            if not os.path.exists(f):
                raise IOError('Cannot find file ' + f)
        
        # load metadata from file
        self.filetype = filetype        
        if self.filetype is 'uvh5':
            # read all UVData metadata from first file
            temp_paths = copy.deepcopy(self.filepaths)
            self.filepaths = self.filepaths[0]
            self.read(read_data=False)
            self.filepaths = temp_paths
            
            if len(self.filepaths) > 1:  # save HERAData_metas in dicts
                for meta in self.HERAData_metas:
                    setattr(self, meta, {})
                for path in self.filepaths:
                    hc = HERAData(path, filetype='uvh5')
                    meta_dict = self.get_metadata_dict()
                    for meta in self.HERAData_metas:
                        getattr(self, meta)[path] = meta_dict[meta]
            else:  # save HERAData_metas as attributes
                self.writers = {}
                for key, value in self.get_metadata_dict().items():
                    setattr(self, key, value)
                
        elif self.filetype in ['miriad', 'uvfits']:
            for meta in self.HERAData_metas:
                setattr(self, meta, None)  # no pre-loading of metadata
        else:
            raise NotImplementedError('Filetype ' + self.filetype + ' has not been implemented.')
    
    def clear(self):
        '''Resets all standard UVData attributes.'''
        super(HERAData, self).__init__()
    
    def get_metadata_dict(self):
        ''' Produces a dictionary of the most useful metadata. Used as object
        attributes and as metadata to store in DataContainers.
        
        Returns:
            metadata_dict: dictionary of all items in self.HERAData_metas
        '''
        antpos, ants = self.get_ENU_antpos()
        ants = sorted(ants)
        antpos = dict(zip(ants, antpos))
        
        freqs = np.unique(self.freq_array)
        times = np.unique(self.time_array)
        lst_indices = np.unique(self.lst_array.ravel(), return_index=True)[1]
        lsts = self.lst_array.ravel()[np.sort(lst_indices)]
        pols = [polnum2str(polnum) for polnum in self.polarization_array]
        antpairs = self.get_antpairs()
        bls = [antpair + (pol,) for antpair in antpairs for pol in pols]
        
        times_by_bl = {antpair: np.array(self.time_array[self._blt_slices[antpair]]) 
                                          for antpair in antpairs}
        lsts_by_bl = {antpair: np.array(self.lst_array[self._blt_slices[antpair]]) 
                                         for antpair in antpairs}

        locs = locals()
        return {meta: eval(meta, {}, locs) for meta in self.HERAData_metas}

    def _determine_blt_slicing(self):
        '''Determine the mapping between antenna pairs and
        slices of the blt axis of the data_array.'''
        self._blt_slices = {}
        for ant1, ant2 in self.get_antpairs():
            indices = self.antpair2ind(ant1, ant2)
            if len(indices) == 1:
                self._blt_slices[(ant1, ant2)] = slice(indices[0], indices[0] + 1, self.Nblts)
            elif not (len(set(np.ediff1d(indices))) == 1):
                raise NotImplementedError('UVData objects with non-regular spacing of ' +
                                          'baselines in its baseline-times are not supported.')
            else:
                self._blt_slices[(ant1, ant2)] = slice(indices[0], indices[-1] + 1, 
                                                       indices[1] - indices[0])

            
    def _determine_pol_indexing(self):
        '''Determine the mapping between polnums and indices
        in the polarization axis of the data_array.'''
        self._polnum_indices = {}
        for i, polnum in enumerate(self.polarization_array):
            self._polnum_indices[polnum] = i   

            
    def _get_slice(self, data_array, key):
        '''Return a copy of the Nint by Nfreq waterfall or waterfalls for a given key. Abstracts 
        away both baseline ordering (by applying complex conjugation) and polarization capitalization.
        
        Arguments:
            data_array: numpy array of shape (Nblts, 1, Nfreq, Npol) 
            key: if of the form (0,1,'xx'), return anumpy array.
                 if of the form (0,1), return a dict mapping pol strings to waterfalls.
                 if of of the form 'xx', return a dict mapping ant-pair tuples to waterfalls.
        '''        
        if isinstance(key, tuple) and len(key) == 3:  # asking for bl-pol
            try:
                return np.array(np.squeeze(data_array[self._blt_slices[key[0:2]], 0, :, 
                                self._polnum_indices[polstr2num(key[2])]]))        
            except KeyError:
                return np.conj(np.squeeze(data_array[self._blt_slices[key[1::-1]], 0, :, 
                               self._polnum_indices[polstr2num(conj_pol(key[2]))]]))

        elif isinstance(key, tuple) and len(key) == 2:  # asking for antpair
            pols = np.array([polnum2str(polnum) for polnum in self.polarization_array])
            return {pol: self._get_slice(data_array, key + (pol,)) for pol in pols}
        elif isinstance(key, str):  # asking for a pol
            return {antpair: self._get_slice(data_array, antpair + (key,)) for antpair in self.get_antpairs()}
        else:
            raise KeyError('Unrecognized key type for slicing data.')

    def _set_slice(self, data_array, key, value):
        '''Update data_array with Nint by Nfreq waterfall(s). Abstracts away both baseline 
        ordering (by applying complex conjugation) and polarization capitalization.
        
        Arguments:
            data_array: numpy array of shape (Nblts, 1, Nfreq, Npol) 
            key: baseline (e.g. (0,1,'xx)), ant-pair tuple (e.g. (0,1)), or pol str (e.g. 'xx')
            value: if key is a baseline, must be an (Nint, Nfreq) numpy array;
                   if key is an ant-pair tuple, must be a dict mapping pol strings to waterfalls;
                   if key is a pol str, must be a dict mapping ant-pair tuples to waterfalls
        '''
        if isinstance(key, tuple) and len(key) == 3:  # providing bl-pol
            try:
                data_array[self._blt_slices[key[0:2]], 0, : , 
                           self._polnum_indices[polstr2num(key[2])]] = value
            except(KeyError):
                data_array[self._blt_slices[key[1::-1]], 0, : , 
                           self._polnum_indices[polstr2num(conj_pol(key[2]))]] = np.conj(value)
        elif isinstance(key, tuple) and len(key) == 2:  # providing antpair with all pols
            for pol in value.keys():
                self._set_slice(data_array, (key + (pol,)), value[pol])
        elif isinstance(key, str):  # providing pol with all antpairs
            for antpair in value.keys():
                self._set_slice(data_array, (antpair + (key,)), value[antpair])
        else:
            raise KeyError('Unrecognized key type for slicing data.')            
     
    def build_datacontainers(self):
        '''Turns the data currently loaded into the HERAData object into DataContainers.
        Returned DataContainers include useful metadata specific to the data actually
        in the DataContainers (which may be a subset of the total data). This includes
        antenna positions, frequencies, all times, all lsts, and times and lsts by baseline.
        
        Returns:
            data: DataContainer mapping baseline keys to complex visibility waterfalls
            flags: DataContainer mapping baseline keys to boolean flag waterfalls
            nsamples: DataContainer mapping baseline keys to interger Nsamples waterfalls
        '''
        # build up DataContainers
        data, flags, nsamples = odict(), odict(), odict()
        meta = self.get_metadata_dict()
        for bl in meta['bls']:
            data[bl] = self._get_slice(self.data_array, bl)
            flags[bl] = self._get_slice(self.flag_array, bl)
            nsamples[bl] = self._get_slice(self.nsample_array, bl)
        data = DataContainer(data)
        flags = DataContainer(flags)
        nsamples = DataContainer(nsamples)
        
        # store useful metadata inside the DataContainers
        for dc in [data, flags, nsamples]:
            for attr in ['antpos', 'freqs', 'times', 'lsts', 'times_by_bl', 'lsts_by_bl']:
                setattr(dc, attr, meta[attr])
            
        return data, flags, nsamples
    
    def read(self, bls=None,  polarizations=None, times=None,
             frequencies=None, freq_chans=None, read_data=True):
        '''Reads data from file. Supports partial data loading. Default: read all data in file.
      
        Arguments:
            bls: A list of antenna number tuples (e.g. [(0,1), (3,2)]) or a list of
                baseline 3-tuples (e.g. [(0,1,'xx'), (2,3,'yy')]) specifying baselines
                to keep in the object. For length-2 tuples, the  ordering of the numbers
                within the tuple does not matter. For length-3 tuples, the polarization
                string is in the order of the two antennas. If length-3 tuples are provided,
                the polarizations argument below must be None. Ignored if read_data is False.
            polarizations: The polarizations to include when reading data into
                the object.  Ignored if read_data is False.
            times: The times to include when reading data into the object.
                Ignored if read_data is False. Miriad will load then select on this axis.
            frequencies: The frequencies to include when reading data. Ignored if read_data 
                is False. Miriad will load then select on this axis.
            freq_chans: The frequency channel numbers to include when reading data. Ignored 
                if read_data is False. Miriad will load then select on this axis.
            read_data: Read in the visibility and flag data. If set to false, only the 
                basic metadata will be read in and nothing will be returned. Results in an
                incompletely defined object (check will not pass). Default True.
        
        Returns:
            data: DataContainer mapping baseline keys to complex visibility waterfalls
            flags: DataContainer mapping baseline keys to boolean flag waterfalls
            nsamples: DataContainer mapping baseline keys to interger Nsamples waterfalls
        '''
        # save last read parameters
        locs = locals()
        partials = ['bls', 'polarizations', 'times', 'frequencies', 'freq_chans']
        self.last_read_kwargs = {p: eval(p, {}, locs) for p in partials}
        
        # load data
        if self.filetype is 'uvh5':
            self.read_uvh5(self.filepaths, bls=bls, polarizations=polarizations, times=times,
                           frequencies=frequencies, freq_chans=freq_chans, read_data=read_data)
        else:
            if not read_data:
                raise NotImplementedError('reading only metadata is not implemented for' + self.filetype)
            if self.filetype is 'miriad':
                self.read_miriad(self.filepaths, bls=bls, polarizations=polarizations)
                if any([times is not None, frequencies is not None, freq_chans is not None]):
                    warnings.warn('miriad does not support partial loading for times and frequencies. '
                                  'Loading the file first and then performing select.')
                self.select(times=times, frequencies=frequencies, freq_chans=freq_chans)
            if self.filetype is 'uvfits':
                self.read_uvfits(self.filepaths, bls=bls, polarizations=polarizations, 
                                 times=times, frequencies=frequencies, freq_chans=freq_chans)
        
        # process data into DataContainers
        if read_data or self.filetype is 'uvh5':
            self._determine_blt_slicing()
            self._determine_pol_indexing()
        if read_data:
            return self.build_datacontainers()

    def __getitem__(self, key):
        '''Shortcut for reading a single visibility waterfall given a baseline tuple.'''
        return self.read(bls=key)[0][key]

    def update(self, data=None, flags=None, nsamples=None):
        '''Update internal data arrays (data_array, flag_array, and nsample_array)
        using DataContainers (if not left as None) in preparation for writing to disk. 

        Arguments:
            data: Optional DataContainer mapping baselines to complex visibility waterfalls
            flags: Optional DataContainer mapping baselines to boolean flag waterfalls
            nsamples: Optional DataContainer mapping baselines to interger Nsamples waterfalls
        '''
        if data is not None:
            for bl in data.keys():
                self._set_slice(self.data_array, bl, data[bl])
        if flags is not None:
            for bl in flags.keys():
                self._set_slice(self.flag_array, bl, flags[bl])
        if nsamples is not None:
            for bl in nsamples.keys():
                self._set_slice(self.nsample_array, bl, nsamples[bl])
                
    def partial_write(self, output_path, data=None, flags=None, nsamples=None, clobber=False):
        '''Writes part of a uvh5 file using DataContainers whose shape matches the most recent
        call to HERAData.read() in this object. Does not work for other filetypes or when the
        HERAData object is initialized with a list of files. 
        
        Arguments:
            output_path: path to file to write uvh5 file to
            data: Optional DataContainer mapping baselines to complex visibility waterfalls
            flags: Optional DataContainer mapping baselines to boolean flag waterfalls
            nsamples: Optional DataContainer mapping baselines to interger Nsamples waterfalls
            clobber: if True, overwrites existing file at output_path
        '''
        # Type verifications
        if self.filetype is not 'uvh5':
            raise NotImplementedError('Partial writing for filetype ' + self.filetype + ' has not been implemented.')
        if len(self.filepaths) > 1:
            raise NotImplementedError('Partial writing for list-loaded HERAData objects has not been implemented.')
        if not isinstance(output_path, str):
            raise ValueError('output_path must be a string path file to write.')
        
        # get writer or initialize new writer if necessary
        if output_path in self.writers:
            hd_writer = self.writers[output_path]
        else:
            hd_writer = HERAData(self.filepaths[0])
            hd_writer.initialize_uvh5_file(output_path, clobber=clobber)
            self.writers[output_path] = hd_writer
        
        # make a copy of this object and then update the relevant arrays using DataContainers
        this = copy.deepcopy(self)
        this.update(data=data, flags=flags, nsamples=nsamples)
        hd_writer.write_uvh5_part(output_path, this.data_array, this.flag_array,
                                  this.nsample_array, **self.last_read_kwargs)
        
    def iterate_over_bls(self, Nbls=1, bls=None):
        '''Produces a generator that iteratively yields successive calls to 
        HERAData.read() by baseline or group of baselines.
        
        Arguments:
            Nbls: number of baselines to load at once.
            bls: optional user-provided list of baselines to iterate over.
                Default: use self.bls (which only works for uvh5).

        Yields:
            data, flags, nsamples: DataContainers (see HERAData.read() for more info).    
        '''
        if bls is None:
            if self.filetype is not 'uvh5':
                raise NotImplementedError('Baseline iteration without explicitly setting bls for filetype ' + self.filetype +
                                          '  without setting bls has not been implemented.')
            bls = self.bls
            if isinstance(bls, dict):  # multiple files
                bls = list(set([bl for bls in bls.values() for bl in bls]))
            bls = sorted(bls)
        for i in range(0, len(bls), Nbls):
            yield self.read(bls=bls[i:i + Nbls])

            
    def iterate_over_freqs(self, Nchans=1, freqs=None):
        '''Produces a generator that iteratively yields successive calls to 
        HERAData.read() by frequency channel or group of contiguous channels.
        
        Arguments:
            Nchans: number of frequencies to load at once. 
            freqs: optional user-provided list of frequencies to iterate over.
                Default: use self.freqs (which only works for uvh5).

        Yields:
            data, flags, nsamples: DataContainers (see HERAData.read() for more info).    
        '''
        if freqs is None:          
            if self.filetype is not 'uvh5':
                raise NotImplementedError('Frequency iteration for filetype ' + self.filetype +
                                          '  without setting freqs has not been implemented.')
            freqs = self.freqs
            if isinstance(self.freqs, dict):  # multiple files
                freqs = np.unique(self.freqs.values())
        for i in range(0, len(freqs), Nchans):
            yield self.read(frequencies=freqs[i:i + Nchans])

    def iterate_over_times(self, Nints=1, times=None):
        '''Produces a generator that iteratively yields successive calls to 
        HERAData.read() by time or group of contiguous times.
        
        Arguments:
            Nints: number of integrations to load at once. 
            times: optional user-provided list of times to iterate over.
                Default: use self.times (which only works for uvh5).

        Yields:
            data, flags, nsamples: DataContainers (see HERAData.read() for more info).
        '''
        if times is None:
            if self.filetype is not 'uvh5':
                raise NotImplementedError('Time iteration for filetype ' + self.filetype +
                                          '  without setting times has not been implemented.')
            times = self.times
            if isinstance(times, dict):  # multiple files
                times = np.unique(times.values())
        for i in range(0, len(times), Nints):
            yield self.read(times=times[i:i + Nints])

In [23]:
from hera_cal.data import DATA_PATH
import os

class Test_HERAData(unittest.TestCase):
     
    def setUp(self):
        self.uvh5_1 = os.path.join(DATA_PATH, "zen.2458116.61019.xx.HH.h5XRS_downselected")
        self.uvh5_2 = os.path.join(DATA_PATH, "zen.2458116.61765.xx.HH.h5XRS_downselected")
        self.miriad_1 = os.path.join(DATA_PATH, "zen.2458043.12552.xx.HH.uvORA")
        self.miriad_2 = os.path.join(DATA_PATH, "zen.2458043.13298.xx.HH.uvORA")
        self.uvfits = os.path.join(DATA_PATH, 'zen.2458043.12552.xx.HH.uvA.vis.uvfits')
        self.four_pol = [os.path.join(DATA_PATH, 'zen.2457698.40355.{}.HH.uvcA'.format(pol)) 
                         for pol in ['xx','yy','xy','yx']]
    
    def test_init(self):
        # single uvh5 file
        hd = HERAData(self.uvh5_1)
        self.assertEqual(hd.filepaths, [self.uvh5_1])
        for meta in hd.HERAData_metas:
            self.assertIsNotNone(getattr(hd, meta))
        self.assertEqual(len(hd.freqs), 1024)
        self.assertEqual(len(hd.bls), 3)
        self.assertEqual(len(hd.times), 60)
        self.assertEqual(len(hd.lsts), 60)
        self.assertEqual(hd.writers, {})
        
        # multiple uvh5 files
        files = [self.uvh5_1, self.uvh5_2]
        hd = HERAData(files)
        self.assertEqual(hd.filepaths, files)
        for meta in hd.HERAData_metas:
            self.assertIsNotNone(getattr(hd, meta))
        for f in files:
            self.assertEqual(len(hd.freqs[f]), 1024)
            self.assertEqual(len(hd.bls[f]), 3)
            self.assertEqual(len(hd.times[f]), 60)
            self.assertEqual(len(hd.lsts[f]), 60)      
        self.assertFalse(hasattr(hd, 'writers'))

        # miriad
        hd = HERAData(self.miriad_1, filetype='miriad')
        self.assertEqual(hd.filepaths, [self.miriad_1])
        for meta in hd.HERAData_metas:
            self.assertIsNone(getattr(hd, meta))

        # uvfits
        hd = HERAData(self.uvfits, filetype='uvfits')
        self.assertEqual(hd.filepaths, [self.uvfits])
        for meta in hd.HERAData_metas:
            self.assertIsNone(getattr(hd, meta))
        
        # test errors
        with self.assertRaises(TypeError):
            hd = HERAData([1,2])
        with self.assertRaises(ValueError):
            hd = HERAData(None)
        with self.assertRaises(NotImplementedError):
            hd = HERAData(self.uvh5_1, 'not a real type')
        with self.assertRaises(IOError):
            hd = HERAData('fake path')

    def test_clear(self):
        hd = HERAData(self.uvh5_1)
        hd.read()
        hd.clear()
        self.assertIsNone(hd.data_array)
        self.assertIsNone(hd.flag_array)
        self.assertIsNone(hd.nsample_array)
        self.assertEqual(hd.filepaths, [self.uvh5_1])
        for meta in hd.HERAData_metas:
            self.assertIsNotNone(getattr(hd, meta))
        self.assertEqual(len(hd.freqs), 1024)
        self.assertEqual(len(hd.bls), 3)
        self.assertEqual(len(hd.times), 60)
        self.assertEqual(len(hd.lsts), 60)
        self.assertEqual(hd.writers, {})
        
        

            
    def test_get_metadata_dict(self):
        hd = HERAData(self.uvh5_1)
        metas = hd.get_metadata_dict()
        for meta in hd.HERAData_metas:
            self.assertTrue(meta in metas)
        self.assertEqual(len(metas['freqs']), 1024)
        self.assertEqual(len(metas['bls']), 3)
        self.assertEqual(len(metas['times']), 60)
        self.assertEqual(len(metas['lsts']), 60)
        np.testing.assert_array_equal(metas['times'], np.unique(list(metas['times_by_bl'].values())))
        np.testing.assert_array_equal(metas['lsts'], np.unique(list(metas['lsts_by_bl'].values())))
        
    def test_determine_blt_slicing(self):
        hd = HERAData(self.uvh5_1)
        for s in hd._blt_slices.values():
            self.assertIsInstance(s, slice)
        for bl, s in hd._blt_slices.items():
            np.testing.assert_array_equal(np.arange(180)[np.logical_and(hd.ant_1_array == bl[0], 
                                          hd.ant_2_array == bl[1])], np.arange(180)[s])
        # test check for non-regular spacing
        hd.ant_1_array = hd.ant_2_array
        with self.assertRaises(NotImplementedError):
            hd._determine_blt_slicing()
    
    def test_determine_pol_indexing(self):
        hd = HERAData(self.uvh5_1)
        self.assertEqual(hd._polnum_indices, {-5:0})
        hd = HERAData(self.four_pol, filetype='miriad')
        hd.read(bls=[(53,53)])
        self.assertEqual(hd._polnum_indices, {-8: 3, -7: 2, -6: 1, -5: 0})
        
    def test_get_slice(self):
        hd = HERAData(self.uvh5_1)
        hd.read()
        for bl in hd.bls:
            np.testing.assert_array_equal(hd._get_slice(hd.data_array, bl), hd.get_data(bl))                                      
        np.testing.assert_array_equal(hd._get_slice(hd.data_array, (54, 53, 'XX')),
                                      hd.get_data((54, 53, 'XX')))        
        np.testing.assert_array_equal(hd._get_slice(hd.data_array, (53, 54))['XX'],
                                      hd.get_data((53, 54, 'XX')))
        np.testing.assert_array_equal(hd._get_slice(hd.data_array, 'XX')[(53,54)],
                                      hd.get_data((53, 54, 'XX')))
        with self.assertRaises(KeyError):
            hd._get_slice(hd.data_array, None)
            
        hd = HERAData(self.four_pol, filetype='miriad')
        d, f, n = hd.read(bls=[(80, 81)])
        for p in d.pols():
            np.testing.assert_array_almost_equal(hd._get_slice(hd.data_array, (80, 81, p)),
                                                 hd.get_data((80, 81, p)).flatten())
            try:
                np.testing.assert_array_almost_equal(hd._get_slice(hd.data_array, (81, 80, p)),
                                                     hd.get_data((81, 80, p)).flatten())
            except: # this is only here until pyuvdata fixes issue #398
                np.testing.assert_array_almost_equal(hd._get_slice(hd.data_array, (81, 80, p)),
                                                     hd.get_data((81, 80, p[::-1])).flatten())
        
    def test_set_slice(self):
        hd = HERAData(self.uvh5_1)
        hd.read()
        np.random.seed(21)
        
        for bl in hd.bls:
            new_vis = np.random.randn(60,1024) + np.random.randn(60,1024)*1.0j
            hd._set_slice(hd.data_array, bl, new_vis)
            np.testing.assert_array_almost_equal(new_vis, hd.get_data(bl))
        
        new_vis = np.random.randn(60,1024) + np.random.randn(60,1024)*1.0j
        hd._set_slice(hd.data_array, (54, 53, 'xx'), new_vis)
        np.testing.assert_array_almost_equal(np.conj(new_vis), hd.get_data((53, 54, 'xx')))
        
        new_vis = np.random.randn(60,1024) + np.random.randn(60,1024)*1.0j
        hd._set_slice(hd.data_array, (53, 54), {'xx': new_vis})
        np.testing.assert_array_almost_equal(new_vis, hd.get_data((53, 54, 'xx')))

        new_vis = np.random.randn(60,1024) + np.random.randn(60,1024)*1.0j
        to_set = {(53, 54): new_vis, (54, 54): 2*new_vis, (53, 53): 3*new_vis}
        hd._set_slice(hd.data_array, 'XX', to_set)
        np.testing.assert_array_almost_equal(new_vis, hd.get_data((53, 54, 'xx')))
    
        with self.assertRaises(KeyError):
            hd._set_slice(hd.data_array, None, None)
    
    def test_build_datacontainers(self):
        hd = HERAData(self.uvh5_1)
        d, f, n = hd.read()
        for bl in hd.bls:
            np.testing.assert_array_almost_equal(d[bl], hd.get_data(bl))
            np.testing.assert_array_almost_equal(f[bl], hd.get_flags(bl))
            np.testing.assert_array_almost_equal(n[bl], hd.get_nsamples(bl))
        for dc in [d, f, n]:
            self.assertIsInstance(dc, DataContainer)
            for k in dc.antpos.keys():
                self.assertTrue(np.all(dc.antpos[k] == hd.antpos[k]))
            self.assertTrue(np.all(dc.freqs == hd.freqs))
            self.assertTrue(np.all(dc.times == hd.times))
            self.assertTrue(np.all(dc.lsts == hd.lsts)) 
            for k in dc.times_by_bl.keys():
                self.assertTrue(np.all(dc.times_by_bl[k] == hd.times_by_bl[k]))
                self.assertTrue(np.all(dc.lsts_by_bl[k] == hd.lsts_by_bl[k]))

    def test_read(self):
        # uvh5
        hd = HERAData(self.uvh5_1)
        d, f, n = hd.read(bls=(53, 54, 'xx'), frequencies=hd.freqs[0:100], times=hd.times[0:10])
        self.assertEqual(hd.last_read_kwargs['bls'], (53, 54, 'xx'))
        self.assertEqual(hd.last_read_kwargs['polarizations'], None)
        for dc in [d, f, n]:
            self.assertEqual(len(dc), 1)
            self.assertEqual(dc.keys(), [(53, 54, 'XX')])
            self.assertEqual(dc[53, 54, 'xx'].shape, (10, 100))
        with self.assertRaises(ValueError):
            d, f, n = hd.read(polarizations=['xy'])
        
        # miriad
        hd = HERAData(self.miriad_1, filetype='miriad')
        d, f, n = hd.read()
        hd = HERAData(self.miriad_1, filetype='miriad')
        with warnings.catch_warnings(record=True) as w:
            d, f, n = hd.read(bls=(52, 53), polarizations=['XX'], frequencies=d.freqs[0:30], times=d.times[0:10])
            self.assertEqual(len(w), 1)        
        self.assertEqual(hd.last_read_kwargs['polarizations'], ['XX'])
        for dc in [d, f, n]:
            self.assertEqual(len(dc), 1)
            self.assertEqual(dc.keys(), [(52, 53, 'XX')])
            self.assertEqual(dc[52, 53, 'xx'].shape, (10, 30))
        with self.assertRaises(NotImplementedError):
            d, f, n = hd.read(read_data=False)

        # uvfits
        hd = HERAData(self.uvfits, filetype='uvfits')
        d, f, n = hd.read(bls=(0, 1, 'xx'), freq_chans=range(10))
        self.assertEqual(hd.last_read_kwargs['freq_chans'], range(10))
        for dc in [d, f, n]:
            self.assertEqual(len(dc), 1)
            self.assertEqual(dc.keys(), [(0, 1, 'XX')])
            self.assertEqual(dc[0, 1, 'xx'].shape, (60, 10))
        with self.assertRaises(NotImplementedError):
            d, f, n = hd.read(read_data=False)

    def test_getitem(self):
        hd = HERAData(self.uvh5_1)
        hd.read()
        for bl in hd.bls:
            np.testing.assert_array_almost_equal(hd[bl], hd.get_data(bl))

    def test_update(self):
        hd = HERAData(self.uvh5_1)
        d, f, n = hd.read()
        for bl in hd.bls:
            d[bl] *= (2.0 + 1.0j)
            f[bl] = np.logical_not(f[bl])
            n[bl] += 1
        hd.update(data=d, flags=f, nsamples=n)
        d2, f2, n2 = hd.build_datacontainers()
        for bl in hd.bls:
            np.testing.assert_array_almost_equal(d[bl], d2[bl])
            np.testing.assert_array_equal(f[bl], f2[bl])
            np.testing.assert_array_equal(n[bl], n2[bl])            
    
    def test_partial_write(self):
        hd = HERAData(self.uvh5_1)
        self.assertEqual(hd.writers, {})
        d, f, n = hd.read(bls=hd.bls[0])
        self.assertEqual(hd.last_read_kwargs['bls'], (53, 53, 'XX'))
        d[(53, 53, 'XX')] *= (2.0 + 1.0j)
        hd.partial_write('out.h5', data=d, clobber=True)
        self.assertTrue('out.h5' in hd.writers)
        self.assertIsInstance(hd.writers['out.h5'], HERAData)
        for meta in hd.HERAData_metas:
            try:
                np.testing.assert_array_equal(getattr(hd, meta), 
                                              getattr(hd.writers['out.h5'], meta))
            except:
                for k in getattr(hd, meta).keys():
                    np.testing.assert_array_equal(getattr(hd, meta)[k], 
                                                  getattr(hd.writers['out.h5'], meta)[k])
        
        d, f, n = hd.read(bls=hd.bls[1])
        self.assertEqual(hd.last_read_kwargs['bls'], (53, 54, 'XX'))
        d[(53, 54, 'XX')] *= (2.0 + 1.0j)
        hd.partial_write('out.h5', data=d, clobber=True)
        
        d, f, n = hd.read(bls=hd.bls[2])
        self.assertEqual(hd.last_read_kwargs['bls'], (54, 54, 'XX'))
        d[(54, 54, 'XX')] *= (2.0 + 1.0j)
        hd.partial_write('out.h5', data=d, clobber=True)

        hd = HERAData(self.uvh5_1)
        d, f, n = hd.read()
        hd2 = HERAData('out.h5')
        d2, f2, n2 = hd2.read()
        for bl in hd.bls:
            np.testing.assert_array_almost_equal(d[bl] * (2.0 + 1.0j), d2[bl])
            np.testing.assert_array_equal(f[bl], f2[bl])
            np.testing.assert_array_equal(n[bl], n2[bl])            
        os.remove('out.h5')
            
        # test errors
        hd = HERAData(self.miriad_1, filetype='miriad')
        with self.assertRaises(NotImplementedError):
            hd.partial_write('out.uv')
        hd = HERAData([self.uvh5_1, self.uvh5_2])
        with self.assertRaises(NotImplementedError):
            hd.partial_write('out.h5')
        hd = HERAData(self.uvh5_1)
        with self.assertRaises(ValueError):
            hd.partial_write(None)
    
    def test_iterate_over_bls(self):
        hd = HERAData(self.uvh5_1)
        for (d, f, n) in hd.iterate_over_bls(Nbls=2):
            for dc in (d, f, n):
                self.assertTrue(len(dc.keys()) == 1 or len(dc.keys()) == 2)
                self.assertEqual(list(dc.values())[0].shape, (60, 1024))
        
        hd = HERAData([self.uvh5_1, self.uvh5_2])
        for (d, f, n) in hd.iterate_over_bls():
            for dc in (d, f, n):
                self.assertEqual(len(d.keys()), 1)
                self.assertEqual(list(d.values())[0].shape, (120, 1024))
        
        hd = HERAData(self.miriad_1, filetype='miriad')
        d, f, n = next(hd.iterate_over_bls(bls=[(52, 53, 'xx')]))
        self.assertEqual(d.keys(), [(52, 53, 'XX')])
        with self.assertRaises(NotImplementedError):
            next(hd.iterate_over_bls())

    def test_iterate_over_freqs(self):
        hd = HERAData(self.uvh5_1)
        for (d, f, n) in hd.iterate_over_freqs(Nchans=256):
            for dc in (d, f, n):
                self.assertEqual(len(dc.keys()), 3)
                self.assertEqual(list(dc.values())[0].shape, (60, 256))
        
        hd = HERAData([self.uvh5_1, self.uvh5_2])
        for (d, f, n) in hd.iterate_over_freqs(Nchans=512):
            for dc in (d, f, n):
                self.assertEqual(len(dc.keys()), 3)
                self.assertEqual(list(dc.values())[0].shape, (120, 512))
                
        hd = HERAData(self.uvfits, filetype='uvfits')
        d, f, n = hd.read()
        d, f, n = next(hd.iterate_over_freqs(Nchans=2, freqs=d.freqs[0:2]))
        for value in d.values():
            self.assertEqual(value.shape, (60, 2))
        with self.assertRaises(NotImplementedError):
            next(hd.iterate_over_bls())

                
    def test_iterate_over_times(self):
        hd = HERAData(self.uvh5_1)
        for (d, f, n) in hd.iterate_over_times(Nints=20):
            for dc in (d, f, n):
                self.assertEqual(len(dc.keys()), 3)
                self.assertEqual(list(dc.values())[0].shape, (20, 1024))
        
        hd.read(frequencies=hd.freqs[0:512])
        hd.write_uvh5('out1.h5', clobber=True)
        hd.read(frequencies=hd.freqs[512:])
        hd.write_uvh5('out2.h5', clobber=True)
        hd = HERAData(['out1.h5', 'out2.h5'])
        for (d, f, n) in hd.iterate_over_times(Nints=30):
            for dc in (d, f, n):
                self.assertEqual(len(dc.keys()), 3)
                self.assertEqual(list(dc.values())[0].shape, (30, 1024))
        os.remove('out1.h5')
        os.remove('out2.h5')
        
        hd = HERAData(self.uvfits, filetype='uvfits')
        d, f, n = hd.read()
        d, f, n = next(hd.iterate_over_times(Nints=2, times=d.times[0:2]))
        for value in d.values():
            self.assertEqual(value.shape, (2, 64))
        with self.assertRaises(NotImplementedError):
            next(hd.iterate_over_times())


        
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

...............
----------------------------------------------------------------------
Ran 15 tests in 5.460s

OK
