diff --git a/hera_pspec/pspecdata.py b/hera_pspec/pspecdata.py index ebc42e7d..fcbdc421 100644 --- a/hera_pspec/pspecdata.py +++ b/hera_pspec/pspecdata.py @@ -1,47 +1,88 @@ import numpy as np import aipy +import pyuvdata from .utils import hash, cov #from utils import hash, cov class PSpecData(object): - def __init__(self, dsets=None, wgts=None, bls=None): #, lmin=None, lmode=None): + def __init__(self, dsets=[], wgts=[]): """ Object to store multiple sets of UVData visibilities and perform operations such as power spectrum estimation on them. Parameters ---------- - dsets : List of UVData objects + dsets : List of UVData objects, optional List of UVData objects containing the data that will be used to - compute the power spectrum. + compute the power spectrum. Default: Empty list. - wgts : List of UVData objects - List of UVData objects containing weights for the input data. - - bls : list of tuple, optional - List of baselines (antenna pairs) that should be used in the power - spectrum calculation. If not set, all baselines will be used. - Default: None. + wgts : List of UVData objects, optional + List of UVData objects containing weights for the input data. + Default: Empty list. """ self.clear_cov_cache() # Covariance matrix cache + self.dsets = []; self.wgts = [] + self.Nfreq = None + + # Store the input UVData objects if specified + if len(dsets) > 0: + self.add(dsets, wgts) + + def add(self, dsets, wgts): + """ + Add a dataset to the collection in this PSpecData object. + Parameters + ---------- + dsets : UVData or list + UVData object or list of UVData objects containing data to add to + the collection. + + wgts : UVData or list + UVData object or list of UVData objects containing weights to add + to the collection. Must be the same length as dsets. + """ + # Convert input args to lists if possible + if isinstance(dsets, pyuvdata.UVData): dsets = [dsets,] + if isinstance(wgts, pyuvdata.UVData): wgts = [wgts,] + if isinstance(dsets, tuple): dsets = list(dsets) + if isinstance(wgts, tuple): wgts = list(wgts) + + # Only allow UVData or lists + if not isinstance(dsets, list) or not isinstance(wgts, list): + raise TypeError("dsets and wgts must be UVData or lists of UVData") + + # Make sure enough weights were specified + assert(len(dsets) == len(wgts)) + + # Check that everything is a UVData object + for d, w in zip(dsets, wgts): + if not isinstance(d, pyuvdata.UVData) \ + or not isinstance(w, pyuvdata.UVData): + raise TypeError("Only UVData objects can be used as datasets " + "or weights.") + + # Append to list + self.dsets += dsets + self.wgts += wgts + + # Store no. frequencies + self.Nfreq = self.dsets[0].Nfreq + + def validate_datasets(self): + """ + Validate stored datasets and weights to make sure they are consistent + with one another (e.g. have the same shape, baselines etc.). + """ # Sanity checks on input data - assert len(dsets) > 1 - assert len(wgts) > 0 - assert len(dsets) == len(wgts) - # FIXME: Should allow one set of weights to be specified for all data + assert len(self.dsets) > 1 + assert len(self.dsets) == len(self.wgts) # Check if data are all the same shape - self.Nfreq = None if dsets is not None: nfreqs = [d.Nfreqs for d in dsets] assert np.all_equal(nfreqs) - self.Nfreq = nfreqs[0] - - # Store the UVData - self.dsets = dsets - self.wgts = wgts def clear_cov_cache(self, keys=None): """ @@ -49,10 +90,10 @@ def clear_cov_cache(self, keys=None): Parameters ---------- - keys : TODO, optional - TODO. Default: None. + keys : list, optional + List of keys to remove from covariance matrix cache. If 'None', all + keys will be removed. Default: None. """ - raise NotImplementedError() # FIXME if keys is None: self._C, self._Cempirical, self._I, self._iC = {}, {}, {}, {} self._iCt = {} @@ -226,17 +267,6 @@ def iC(self, k): # FIXME: Is series of dot products quicker? self.set_iC({k:np.einsum('ij,j,jk', V.T, 1./S, U.T)}) return self._iC[k] - - # - # If t is provided, calculate iC for the provided time index, including flagging - # XXX this does not respect manual setting of iC with ds.set_iC - #UserWarning("This does not respect manual setting of iC with ds.set_iC") - #w = self.w[k][:,t:t+1] - #m = hash(w) - #if not self._iCt.has_key(k): self._iCt[k] = {} - #if not self._iCt[k].has_key(m): - # self._iCt[k][m] = np.linalg.pinv(self.C(k,t), rcond=rcond) - #return self._iCt[k][m] def set_iC(self, d): """ @@ -281,9 +311,6 @@ def q_hat(self, k1, k2, use_identity=True, use_fft=True): q_hat : array_like Unnormalized bandpowers """ - # FIXME: Should perform sanity checks to make sure keys exist in both - # datasets - # Whether to use look-up fn. for identity or inverse covariance matrix icov_fn = self.I if use_identity else self.iC @@ -562,6 +589,10 @@ def pspec(self, keys, weights='none'): and baselines specified in 'keys'. """ #FIXME: Define sensible grouping behaviors. + #FIXME: Check that requested keys exist in all datasets + + # Validate the input data to make sure it's sensible + self.validate_datasets() pvs = [] for k, key1 in enumerate(keys):