In [51]:
%matplotlib notebook
import numpy as np
from hera_cal import omni, utils
reload(utils)
import hera_qm.ant_metrics as ant_metrics
reload(ant_metrics)
from hera_cal.data import DATA_PATH
from hera_cal.redcal import get_pos_reds
import sys
from pyuvdata import UVData

In [52]:
def red_corr_metrics(data, pols, antpols, ants, reds, xants=[],
                     rawMetric=False, crossPol=False):
    """Calculate the modified Z-Score over all redundant groups for each antenna.

    Calculates the extent to which baselines involving an antenna do not correlate
    with others they are nominmally redundant with.
    Arguments:
    data -- data for all polarizations in a format that can support data.get_data(i,j,pol)
    pols -- List of visibility polarizations (e.g. ['xx','xy','yx','yy']).
    antpols -- List of antenna polarizations (e.g. ['x', 'y'])
    ants -- List of all antenna indices.
    reds -- List of lists of tuples of antenna numbers that make up redundant baseline groups.
    xants -- list of antennas in the (ant,antpol) format that should be ignored.
    rawMetric -- return the raw power correlations instead of the modified z-score
    crossPol -- return results only when the two visibility polarizations differ by a single flip

    Returns:
    powerRedMetric -- a dictionary indexed by (ant,antpol) of the modified z-scores of the mean
    power correlations inside redundant baseline groups that the antenna participates in.
    Very small numbers are probably bad antennas.
    """
    # Compute power correlations and assign them to each antenna
    autoPower = compute_median_auto_power_dict(data, pols, reds)
    antCorrs = {(ant, antpol): 0.0 for ant in ants for antpol in antpols if
                (ant, antpol) not in xants}
    antCounts = deepcopy(antCorrs)
    for pol0 in pols:
        for pol1 in pols:
            iscrossed_i = (pol0[0] != pol1[0])
            iscrossed_j = (pol0[1] != pol1[1])
            onlyOnePolCrossed = (iscrossed_i ^ iscrossed_j)
            # This function can instead record correlations for antennas whose counterpart are pol-swapped
            if (not crossPol and (pol0 is pol1)) or (crossPol and onlyOnePolCrossed):
                for bls in reds:
                    data_shape = data.get_data(bls[0][0], bls[0][1], pol0).shape
                    data_array_shape = (len(bls), data_shape[0], data_shape[1])
                    # correlation_array = np.zeros(corr_shape, dtype=np.complex128)
                    data_array = np.zeros(data_array_shape, np.complex128)
                    data_array1 = np.zeros(data_array_shape, np.complex128)
                    antpols1, antopols2 = [], []
                    for n, (ant0_i, ant0_j) in enumerate(bls):
                        data_array[n] = data.get_data(ant0_i, ant0_j, pol0)
                        data_array1[n] = data.get_data(ant0_i, ant0_j, pol1)
                        antpols1.append((ant0_i, pol0[0]))
                        antpols1.append((ant0_j, pol0[1]))
                        antpols2.append((ant0_i, pol1[0]))
                        antpols2.append((ant0_j, pol1[1]))
                    # Take the tensor dot over the times axis, data_arry is (nbls, ntimes, nfreqs)
                    corr_array = np.tensordot(data_array, data_array1.conj(), axes=[[0],[0]]).reshape(0,2,1,3)
                    corr_array = np.median(corr_array, axis=(2,3))
                    autos = np.sqrt(np.diagonal(corr_array, axis1=0, axis2=1).copy())
                    corr_array /= autos[:, None]
                    corr_array /= autos[None, :]
#                         for (ant1_i, ant1_j) in bls[n + 1:]:
#                             data1 = data.get_data(ant1_i, ant1_j, pol1)
#                             corr = np.median(np.abs(np.mean(data0 * data1.conj(),
#                                                             axis=0)))
#                             corr /= np.sqrt(autoPower[ant0_i, ant0_j, pol0] *
#                                             autoPower[ant1_i, ant1_j, pol1])
#                             antsInvolved = [(ant0_i, pol0[0]), (ant0_j, pol0[1]),
#                                             (ant1_i, pol1[0]), (ant1_j, pol1[1])]
#                             if not np.any([(ant, antpol) in xants for ant, antpol
#                                            in antsInvolved]):
#                                 # Only record the crossed antenna if i or j is crossed
#                                 if crossPol and iscrossed_i:
#                                     antsInvolved = [(ant0_i, pol0[0]),
#                                                     (ant1_i, pol1[0])]
#                                 elif crossPol and iscrossed_j:
#                                     antsInvolved = [(ant0_j, pol0[1]), (ant1_j, pol1[1])]
#                                 for ant, antpol in antsInvolved:
#                                     antCorrs[(ant, antpol)] += corr
#                                     antCounts[(ant, antpol)] += 1

    # Compute average and return
    for key, count in antCounts.items():
        if count > 0:
            antCorrs[key] /= count
        else:
            # Was not found in reds, should not have a valid metric.
            antCorrs[key] = np.NaN

In [53]:
verbose = True
pols = ['xx','xy','yx','yy']
JD = '2457757.47316'
dataFileList = [DATA_PATH + '/zen.2457698.40355.xx.HH.uvcA',
                DATA_PATH + '/zen.2457698.40355.yy.HH.uvcA',
                DATA_PATH + '/zen.2457698.40355.xy.HH.uvcA',
                DATA_PATH + '/zen.2457698.40355.yx.HH.uvcA']
freqs = np.arange(.1,.2,.1/1024)
sys.path.append(DATA_PATH)

uvd = UVData()
uvd.read_miriad(dataFileList[0])
aa = utils.get_aa_from_uv(uvd)
info = omni.aa_to_info(aa, pols=[pols[-1][0]], crosspols=[pols[-1]])
reds = info.get_reds()
metricsJSONFilename = JD+'.metrics.json'

In [54]:
am = ant_metrics.AntennaMetrics(dataFileList, reds, fileformat='miriad')

In [55]:
%prun red_corr = am.red_corr_metrics(rawMetric=True)

 

In [60]:
am.data.get_data(reds[0][0][0], reds[0][0][1], 'xx').shape

(1, 1024)

In [61]:
am.data.data_array.shape

(190, 1, 1024, 4)

In [None]:
pol0='xx'
pol1='xx'
for bls in [reds[0]]:
    data_shape = am.data.get_data(bls[0][0], bls[0][1], pol0).shape
    data_array_shape = (len(bls), data_shape[0], data_shape[1])
    # correlation_array = np.zeros(corr_shape, dtype=np.complex128)
    data_array = np.zeros(data_array_shape, np.complex128)
    data_array1 = np.zeros(data_array_shape, np.complex128)
    antpols1, antpols2 = [], []
    for n, (ant0_i, ant0_j) in enumerate(bls):
        data_array[n] = am.data.get_data(ant0_i, ant0_j, pol0)
        data_array1[n] = am.data.get_data(ant0_i, ant0_j, pol1)
        antpols1.append((ant0_i, pol0[0]))
        antpols1.append((ant0_j, pol0[1]))
        antpols2.append((ant0_i, pol1[0]))
        antpols2.append((ant0_j, pol1[1]))
    corr_array = np.tensordot(data_array, data_array1.conj(), axes=[[1],[1]]).transpose([0,2,1,3])
    corr_array = np.median(corr_array, axis=(2,3))
    autos = np.sqrt(np.diagonal(corr_array, axis1=0, axis2=1).copy())
    corr_array /= autos[:, None]
    corr_array /= autos[None, :]

divide by zero encountered in divide
invalid value encountered in divide
divide by zero encountered in divide
invalid value encountered in divide


In [89]:
c

(9, 1, 1024)