Skip to content

Commit

Permalink
Add methods to add new data, improve validation
Browse files Browse the repository at this point in the history
  • Loading branch information
philbull committed Jan 13, 2018
1 parent 26bcf1e commit 918126d
Showing 1 changed file with 68 additions and 37 deletions.
105 changes: 68 additions & 37 deletions hera_pspec/pspecdata.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,99 @@
import numpy as np
import aipy
import pyuvdata
from .utils import hash, cov
#from utils import hash, cov

class PSpecData(object):

def __init__(self, dsets=None, wgts=None, bls=None): #, lmin=None, lmode=None):
def __init__(self, dsets=[], wgts=[]):
"""
Object to store multiple sets of UVData visibilities and perform
operations such as power spectrum estimation on them.
Parameters
----------
dsets : List of UVData objects
dsets : List of UVData objects, optional
List of UVData objects containing the data that will be used to
compute the power spectrum.
compute the power spectrum. Default: Empty list.
wgts : List of UVData objects
List of UVData objects containing weights for the input data.
bls : list of tuple, optional
List of baselines (antenna pairs) that should be used in the power
spectrum calculation. If not set, all baselines will be used.
Default: None.
wgts : List of UVData objects, optional
List of UVData objects containing weights for the input data.
Default: Empty list.
"""
self.clear_cov_cache() # Covariance matrix cache
self.dsets = []; self.wgts = []
self.Nfreq = None

# Store the input UVData objects if specified
if len(dsets) > 0:
self.add(dsets, wgts)

def add(self, dsets, wgts):
"""
Add a dataset to the collection in this PSpecData object.
Parameters
----------
dsets : UVData or list
UVData object or list of UVData objects containing data to add to
the collection.
wgts : UVData or list
UVData object or list of UVData objects containing weights to add
to the collection. Must be the same length as dsets.
"""
# Convert input args to lists if possible
if isinstance(dsets, pyuvdata.UVData): dsets = [dsets,]
if isinstance(wgts, pyuvdata.UVData): wgts = [wgts,]
if isinstance(dsets, tuple): dsets = list(dsets)
if isinstance(wgts, tuple): wgts = list(wgts)

# Only allow UVData or lists
if not isinstance(dsets, list) or not isinstance(wgts, list):
raise TypeError("dsets and wgts must be UVData or lists of UVData")

# Make sure enough weights were specified
assert(len(dsets) == len(wgts))

# Check that everything is a UVData object
for d, w in zip(dsets, wgts):
if not isinstance(d, pyuvdata.UVData) \
or not isinstance(w, pyuvdata.UVData):
raise TypeError("Only UVData objects can be used as datasets "
"or weights.")

# Append to list
self.dsets += dsets
self.wgts += wgts

# Store no. frequencies
self.Nfreq = self.dsets[0].Nfreq

def validate_datasets(self):
"""
Validate stored datasets and weights to make sure they are consistent
with one another (e.g. have the same shape, baselines etc.).
"""
# Sanity checks on input data
assert len(dsets) > 1
assert len(wgts) > 0
assert len(dsets) == len(wgts)
# FIXME: Should allow one set of weights to be specified for all data
assert len(self.dsets) > 1
assert len(self.dsets) == len(self.wgts)

# Check if data are all the same shape
self.Nfreq = None
if dsets is not None:
nfreqs = [d.Nfreqs for d in dsets]
assert np.all_equal(nfreqs)
self.Nfreq = nfreqs[0]

# Store the UVData
self.dsets = dsets
self.wgts = wgts

def clear_cov_cache(self, keys=None):
"""
Clear stored covariance data (or some subset of it).
Parameters
----------
keys : TODO, optional
TODO. Default: None.
keys : list, optional
List of keys to remove from covariance matrix cache. If 'None', all
keys will be removed. Default: None.
"""
raise NotImplementedError() # FIXME
if keys is None:
self._C, self._Cempirical, self._I, self._iC = {}, {}, {}, {}
self._iCt = {}
Expand Down Expand Up @@ -226,17 +267,6 @@ def iC(self, k):
# FIXME: Is series of dot products quicker?
self.set_iC({k:np.einsum('ij,j,jk', V.T, 1./S, U.T)})
return self._iC[k]

#
# If t is provided, calculate iC for the provided time index, including flagging
# XXX this does not respect manual setting of iC with ds.set_iC
#UserWarning("This does not respect manual setting of iC with ds.set_iC")
#w = self.w[k][:,t:t+1]
#m = hash(w)
#if not self._iCt.has_key(k): self._iCt[k] = {}
#if not self._iCt[k].has_key(m):
# self._iCt[k][m] = np.linalg.pinv(self.C(k,t), rcond=rcond)
#return self._iCt[k][m]

def set_iC(self, d):
"""
Expand Down Expand Up @@ -281,9 +311,6 @@ def q_hat(self, k1, k2, use_identity=True, use_fft=True):
q_hat : array_like
Unnormalized bandpowers
"""
# FIXME: Should perform sanity checks to make sure keys exist in both
# datasets

# Whether to use look-up fn. for identity or inverse covariance matrix
icov_fn = self.I if use_identity else self.iC

Expand Down Expand Up @@ -562,6 +589,10 @@ def pspec(self, keys, weights='none'):
and baselines specified in 'keys'.
"""
#FIXME: Define sensible grouping behaviors.
#FIXME: Check that requested keys exist in all datasets

# Validate the input data to make sure it's sensible
self.validate_datasets()

pvs = []
for k, key1 in enumerate(keys):
Expand Down

0 comments on commit 918126d

Please sign in to comment.