# Goal: Calibrate a File as Fast as Possible
The goal of this notebook is to demonstrate techniques for quickly getting as much diagnostic information about a file as possible. This involves running a full suite of calibration, using shortcuts where possible to converge quickly to absolute calibration.

Current calibration shortcuts include:
- using autocorrelations for first-pass relative gain calibration
- using an empirically determined scaling applied to autocorrelations to approximate absolute gain calibration
- using empirically determined RFI station headings to approximate absolute phase calibration
- using DPSS filters to derive RFI flagging from autocorrelations

Important speed-ups come from:
- general I/O speed-ups from hera_cal.io.read_hera_hdf5 (~20x)
- speed-ups to DPSS fitting from hera_filters (~100x)
- speed-up of firstcal by using RFI channels and fewer channels (~10x)
- obtaining sufficient firstcal accuracy to skip logcal (saves 15s)
- capping Omnical iterations at 100 (~20x) relative to 10,000 used in current pipeline. Obtains equivalent $\chi^2$.

With plotting turned off, current notebook runs in ~60s.

In [1]:
PLOT = False
if PLOT:
    %matplotlib notebook
    #%matplotlib inline

import glob
#filenames = glob.glob('/lustre/aoc/projects/hera/aparsons/2459114/zen.2459114.6*.sum.uvh5')[100:101]
#filenames = glob.glob('/lustre/aoc/projects/hera/aparsons/2459639/zen.2459639.49617.sum.uvh5')[:1]
filenames = glob.glob('/lustre/aoc/projects/hera/aparsons/2459743/zen.2459743.55*.sum.uvh5')[:1]

In [2]:
import numpy as np
import hera_cal
from hera_cal.utils import split_bl, join_bl
import uvtools
import matplotlib.pyplot as plt
import time
import hera_filters
import linsolve
from copy import deepcopy
_ = np.seterr(all='ignore')  # get rid of red warnings

In [3]:
class Timer:
    '''Keep track of run-time through various stages and print nicely
    formatted deltas.'''
    order = []
    def clock(self, name):
        self.order.append((name, time.time()))
    def __str__(self):
        t_full = '%5.2f s' % (self.order[-1][-1] - self.order[0][-1])
        s = f'{self.order[0][0]}->{self.order[-1][0]}: {t_full}'
        if len(self.order) <= 2:
            return s
        t_last = '%5.2f s' % (self.order[-1][-1] - self.order[-2][-1])
        return s + f', {self.order[-2][0]}->{self.order[-1][0]}: {t_last}'

In [4]:
timer = Timer()
timer.clock('start')

In [5]:
# Pick an input file and get header information
print('FILE:', filenames)
hc = hera_cal.io.HERADataFastReader(filenames)
_ = hc.read(read_data=False, read_flags=False, read_nsamples=False)

print('NANTS:', len(hc.data_ants))
print('NFREQS:', len(hc.freqs), (hc.freqs[0], hc.freqs[-1]))
print('NTIMES:', len(hc.times), (hc.times[0], hc.times[-1]))
print('LSTS:', (hc.lsts[0], hc.lsts[-1]))
print('NPOLS:', len(hc.pols), hc.pols)

inttime = 24 * 3600 * np.median(np.diff(hc.times))  # XXX get directly
chan_res = np.median(np.diff(hc.freqs))  # XXX get directly
intcnt = int(inttime * chan_res)  # number of samples per integration in correlator

FILE: ['/lustre/aoc/projects/hera/aparsons/2459743/zen.2459743.55144.sum.uvh5']
NANTS: 150
NFREQS: 1536 (46920776.3671875, 234298706.0546875)
NTIMES: 2 (2459743.551380498, 2459743.551492346)
LSTS: (5.2580245360019715, 5.258729222607931)
NPOLS: 4 ['nn', 'ee', 'ne', 'en']


# Check Autocorrelation Levels

In [6]:
# Read autocorrelations from the file, indexing by (antenna, pol) rather than baseline
auto_pols = ['ee', 'nn']
auto_bls = [(i, i, pol) for i in hc.data_ants for pol in hc.pols if pol in auto_pols]
autos = hc.read(bls=auto_bls, read_data=True, read_flags=False, read_nsamples=False)[0]
autos = {split_bl(k)[0]: v for k, v in autos.items()}  # index by ant, not bl
antpos = {k: pos for k, pos in hc.antpos.items() if k in hc.data_ants}

## Sort antennas based on autocorrelation spectra

Cuts are made on:
- absolute power range, dividing out by correlator accumulation to get the 4b real/4b imag RMS levels. Should nominally be ~10 for well-trimmed RF inputs
- spectral slope across band, computed by medians on either side of the center frequency. Deviations from flatness are signs of antennas not seeing sky emission
- RFI occupancy. Positive outliers from a (-0.5, 1, 0.5) convolving kernel are flagged for being above the specified fraction of the mean

In [7]:
def within(val, bounds):
    return bounds[0] <= val <= bounds[1]

class Bounds:
    '''Sort antennas into good/suspect/bad categories based on bounds.'''
    
    def __init__(self, absolute, good):
        self.abs_bound = absolute
        self.good_bound = good
        self.clear()
        
    def classify(self, k, val):
        '''Assign k to internal sets of good/suspect/bad based on value.'''
        if not within(val, self.abs_bound):
            self.bad.add(k)
        elif not within(val, self.good_bound):
            self.suspect.add(k)
        else:
            self.good.add(k)

    def clear(self):
        '''Clear good/suspect/bad sets.'''
        self.bad = set()
        self.suspect = set()
        self.good = set()

def _antenna_str(ants):
    '''Turn a set of (ant, pol) keys into a string.'''
    return ','.join(['%d%s' % (ant[0], ant[1][-1]) for ant in sorted(ants)])

class AntennaClassification:
    '''Injests Bounds to create sets of good/suspect/bad antennas.'''
    
    def __init__(self, *bounds_list):
        self.clear()
        for b in bounds_list:
            self.add_bounds(b)
            
    def clear(self):
        '''Clear good/suspect/bad sets.'''
        self.bad = set()
        self.suspect = set()
        self.good = set()
        
    def add_bounds(self, bound):
        '''Add antennas from Bounds to good/suspect/bad sets and remove
        intersections from superior categories.'''
        self.bad.update(bound.bad)
        self.suspect.update(bound.suspect)
        self.good.update(bound.good)
        self.good.difference_update(self.bad)  # remove bad from good
        self.good.difference_update(self.suspect)  # remove suspect from good
        self.suspect.difference_update(self.bad)  # remove bad from suspect

    def __str__(self):
        s = []
        s.append(f'Good ({len(self.good)}): {_antenna_str(self.good)}')
        s.append(f'Suspect ({len(self.suspect)}): {_antenna_str(self.suspect)}')
        s.append(f'Bad ({len(self.bad)}): {_antenna_str(self.bad)}')
        return '\n\n'.join(s)
    
    def is_good(self, k):
        return k in self.good
    
    def is_bad(self, k):
        return k in self.bad

In [8]:
# First-pass antenna classification based on auto levels

CEN_FREQ = 136e6  # Hz
RFI_THRESH = 1e-2  # fraction of mean

pwr_bound = Bounds(absolute=(1, 80), good=(5, 30))
slope_bound = Bounds(absolute=(-0.2, 0.2), good=(-0.12, 0.12))
rfi_bound = Bounds(absolute=(0, 0.15), good=(0, 0.1))

for k, v in autos.items():
    mean = np.mean(v, axis=0) / intcnt
    hi_pwr = np.median(mean[hc.freqs > CEN_FREQ])
    lo_pwr = np.median(mean[hc.freqs <= CEN_FREQ])
    pwr = 0.5 * (hi_pwr + lo_pwr)
    slope = 0
    if pwr != 0:
        slope = (hi_pwr - lo_pwr) / pwr 
    rfi = np.abs(mean[1:-1] - 0.5 * (mean[:-2] + mean[2:])) / mean[1:-1]
    rfi_frac = np.mean(np.where(rfi > RFI_THRESH, 1, 0))
    #print(k, pwr, slope, rfi_frac)
    pwr_bound.classify(k, pwr)
    slope_bound.classify(k, slope)
    rfi_bound.classify(k, rfi_frac)

ant_class = AntennaClassification(pwr_bound, slope_bound, rfi_bound)
print(ant_class)

Good (24): 10e,20e,40e,40n,55e,56n,57e,70n,73n,99e,99n,142e,144e,144n,145e,161e,162e,178e,181e,181n,183e,183n,191n,329e

Suspect (94): 7e,7n,8e,8n,9e,9n,10n,19e,19n,20n,21e,21n,31e,31n,33e,33n,41e,41n,42e,42n,45e,46e,46n,54e,54n,55n,56e,57n,69e,69n,70e,71e,71n,72e,72n,73e,81e,82e,82n,119n,135e,135n,136e,136n,138e,138n,140e,140n,141e,143e,143n,145n,160e,162n,163e,163n,164e,164n,165e,166e,168e,168n,169e,169n,170e,170n,176e,176n,177e,177n,178n,179e,179n,182e,182n,184e,184n,185e,185n,186e,186n,187e,187n,189e,189n,191e,321e,321n,323n,324e,324n,329n,333e,333n

Bad (182): 0e,0n,1e,1n,2e,2n,3e,3n,4e,4n,5e,5n,11e,11n,12e,12n,13e,13n,14e,14n,15e,15n,16e,16n,17e,17n,18e,18n,23e,23n,24e,24n,25e,25n,26e,26n,27e,27n,28e,28n,29e,29n,30e,30n,32e,32n,36e,36n,37e,37n,38e,38n,39e,39n,45n,50e,50n,51e,51n,52e,52n,53e,53n,65e,65n,66e,66n,67e,67n,68e,68n,81n,83e,83n,84e,84n,85e,85n,86e,86n,87e,87n,88e,88n,89e,89n,90e,90n,91e,91n,92e,92n,93e,93n,94e,94n,98e,98n,100e,100n,101e,101n,102e,102n,103e,103n,104e,104

In [9]:
# Plot shape of autos versus frequency
if PLOT:
    plot_pols = ['Jee', 'Jnn']
    fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=len(plot_pols), sharex=True)

    for k, v in autos.items():
        if ant_class.is_bad(k):
            continue
        ax = axes[plot_pols.index(k[-1])]
        ax.semilogy(hc.freqs / 1e6, np.mean(v / intcnt, axis=0), label=str(k[0]))

    for cnt, pol in enumerate(plot_pols):
        ax = axes[plot_pols.index(pol)]
        ax.set_title(f'Polarization: {pol}')
        ax.set_ylabel('Power')
        ax.grid()
        #ax.legend(ncol=3)
    _ = axes[-1].set_xlabel('Frequency [MHz]')    

# Determine RFI Flagging

In [10]:
# First-pass min RFI flagging done with channel differencing

SIG_THRESH = 10
ARRAY_FLAG_THRESH = 0.05
        
rfi1_wgts = {}

for k, v in autos.items():
    if ant_class.is_bad(k):
        continue
    sig = v / np.sqrt(intcnt / 2)  # factor of 2 for autos
    w = np.ones(v.shape)

    # A priori RFI flags
    w[:, np.logical_and(222e6 < hc.freqs, hc.freqs < 224e6)] = 0
    #w[:, np.logical_and(88e6 < hc.freqs, hc.freqs < 108e6)] = 0  # manually flag FM

    # First pass: difference with average on either side and flag positive outliers
    for width in (1, 2, 4):
        ker = np.ones(2 * width + 1)
        wv = v * w
        f_res = np.zeros_like(v)
        for t in range(v.shape[0]):
            d1 = np.convolve(wv[t], ker, mode='valid') - wv[t, width:-width]
            w1 = np.convolve( w[t], ker, mode='valid') -  w[t, width:-width]
            f_res[t, width:-width] = wv[t, width:-width] - d1 / w1.clip(1, np.Inf)
            ker_std = np.sqrt(1**2 + 1 / (2 * width))
            w[t, width:-width] = np.where(f_res[t, width:-width] > sig[t, width:-width] * ker_std * SIG_THRESH, 0, w[t, width:-width])

    rfi1_wgts[k] = w

flags1 = sum([1 - v for k, v in rfi1_wgts.items()]) / len(rfi1_wgts)
data1_wgts = np.where(flags1 > ARRAY_FLAG_THRESH, 0, 1)

In [11]:
# # Plot shape of rfi_wgt autos versus frequency
# if PLOT:
#     plot_pols = ['Jee', 'Jnn']
#     fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=len(plot_pols), sharex=True)
#     mask = np.where(data1_wgts[0])  # for an array-averaged mask
    
#     for k, v in autos.items():
#         if ant_class.is_bad(k):
#             continue
#         ax = axes[plot_pols.index(k[-1])]
#         #mask = np.where(rfi1_wgts[k][0])  # for the individual mask of this antenna
#         ax.semilogy(hc.freqs[mask] / 1e6, v[0][mask] / np.median(v[0][mask]), label=str(k[0]))

#     for cnt, pol in enumerate(plot_pols):
#         ax = axes[plot_pols.index(pol)]
#         ax.set_title(f'Polarization: {pol}')
#         ax.set_ylabel('Power')
#         ax.grid()
#     _ = axes[-1].set_xlabel('Frequency [MHz]')    

In [12]:
%%time
# Second-pass RFI flagging done with DPSS filters

def dpss_filter(y, amat, fmat):
    '''Apply the provided DPSS filter matrices to data.'''
    model = np.array([amat @ (fmat[i] @ y[i]) for i in range(y.shape[0])])
    return model.real

SIG_THRESH = 4
#FILTER_WIDTH = 60e-9
#FILTER_WIDTH = 250e-9
FILTER_WIDTH = 200e-9
CENTERS = [0, 2700e-9, -2700e-9]

filter_kwargs = {'filter_centers': CENTERS,
                 'filter_half_widths': [FILTER_WIDTH] * len(CENTERS),
                 'eigenval_cutoff': [1e-9] * len(CENTERS)} 

amat, _ = hera_filters.dspec.dpss_operator(hc.freqs, **filter_kwargs)
# XXX this doesn't scale well for many times
fmat = np.array([hera_filters.dspec.fit_solution_matrix(np.diag(w), amat) for w in data1_wgts])

rfi2_wgts = {}

for k, v in autos.items():
    if ant_class.is_bad(k):
        continue
    mdl = dpss_filter(v * data1_wgts, amat, fmat)
    #fmat = np.array([hera_filters.dspec.fit_solution_matrix(np.diag(w), amat) for w in rfi_wgts[k]])
    #mdl = dpss_filter(v * rfi_wgts[k], amat, fmat)
    sig = mdl / np.sqrt(intcnt / 2)
    rfi2_wgts[k] = np.where(v - mdl > sig * SIG_THRESH, 0, 1)

flags2 = sum([1 - v for k, v in rfi2_wgts.items()]) / len(rfi2_wgts)

# Array-wide RFI weights
data2_wgts = np.where(flags2 > ARRAY_FLAG_THRESH, 0, 1)

CPU times: user 2.37 s, sys: 229 ms, total: 2.59 s
Wall time: 2.6 s


In [13]:
# Plot flagging fraction versus frequency
if PLOT:
    t = 0
    plt.figure()
    plt.bar(hc.freqs / 1e6, flags1[t])
    plt.bar(hc.freqs / 1e6, 1-data1_wgts[t], alpha=0.5)
    plt.bar(hc.freqs / 1e6, flags2[t], alpha=0.5)
    plt.bar(hc.freqs / 1e6, 1-data2_wgts[t], alpha=0.5)
    plt.plot(hc.freqs / 1e6, np.ones_like(hc.freqs) * ARRAY_FLAG_THRESH, 'k:')
    plt.yscale('log')
    plt.xlabel('Frequency [MHz]')
    plt.ylabel('Flag Fraction')
    plt.title('RFI Pass 1')

In [14]:
# Plot shape of rfi_wgt autos versus frequency
if PLOT:
    plot_pols = ['Jee', 'Jnn']
    fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=len(plot_pols), sharex=True)
    #mask = np.where(data_wgts2[0])  # for an array-averaged mask
    
    for k, v in autos.items():
        if ant_class.is_bad(k):
            continue
        ax = axes[plot_pols.index(k[-1])]
        mask = np.where(rfi2_wgts[k][0])  # for the individual mask of this antenna
        ax.semilogy(hc.freqs[mask] / 1e6, v[0][mask] / np.median(v[0][mask]), label=str(k[0]))

    for cnt, pol in enumerate(plot_pols):
        ax = axes[plot_pols.index(pol)]
        ax.set_title(f'Polarization: {pol}')
        ax.set_ylabel('Power')
        ax.grid()
    _ = axes[-1].set_xlabel('Frequency [MHz]')    

In [15]:
# # Plot RFI flags versus frequency
# if PLOT:
#     plt.figure(figsize=(8,4))
#     plt.plot(hc.freqs / 1e6, flags[0], 'c.')
#     plt.plot(hc.freqs / 1e6, flags[1], 'm.')
#     plt.fill_between(hc.freqs / 1e6, 1-data2_wgts[0], color='c', alpha=0.5)
#     plt.fill_between(hc.freqs / 1e6, 1-data2_wgts[1], color='m', alpha=0.5)

#     plt.plot(hc.freqs / 1e6, np.ones(hc.freqs.size) * ARRAY_FLAG_THRESH, 'k:')
#     plt.grid()
#     plt.ylabel('Antenna Fraction')
#     plt.title('RFI Flags')
#     _ = plt.xlabel('Frequency [MHz]')

In [16]:
# Second-pass antenna classification based on RFI

RFI2_SIG_THRESH = 3  # sigma threshold for flagging RFI
rfi2_bound = Bounds(absolute=(-np.Inf, 5), good=(-np.Inf, 1))

ant_class = AntennaClassification(pwr_bound, slope_bound, rfi_bound)

amat, _ = hera_filters.dspec.dpss_operator(hc.freqs, **filter_kwargs)
fmat = np.array([hera_filters.dspec.fit_solution_matrix(np.diag(w), amat) for w in data2_wgts])

smooth_mdl = {}
rfi_wgts = {}

for k, v in autos.items():
    if ant_class.is_bad(k):
        continue
    smooth_mdl[k] = dpss_filter(v * data2_wgts, amat, fmat)
    sig = smooth_mdl[k] / np.sqrt(intcnt / 2)
    rfi_wgts[k] = np.where(v - smooth_mdl[k] > sig * RFI2_SIG_THRESH, 0, 1)

flag_frac = {k: np.sum((1 - v) * data2_wgts) / np.sum(data2_wgts) for k, v in rfi_wgts.items()}
ff = np.array(list(flag_frac.values()))
ff_median = np.median(ff)
ff_std = np.median(np.abs(ff - ff_median)) / 0.675
#print(ff_median, ff_std)
    
for k, v in autos.items():
    if ant_class.is_bad(k):
        rfi2_bound.bad.add(k)
    else:
        zscore = (flag_frac[k] - ff_median) / ff_std
        #print(k, zscore, flag_frac[k])
        rfi2_bound.classify(k, zscore)

ant_class = AntennaClassification(pwr_bound, slope_bound, rfi_bound, rfi2_bound)
print(ant_class)

Good (13): 20e,40n,55e,56n,73n,144e,144n,161e,178e,181e,181n,183e,191n

Suspect (77): 7n,8n,9e,9n,10n,19e,19n,20n,21e,21n,31e,31n,40e,41e,41n,42n,45e,46e,54e,54n,55n,56e,57n,69e,69n,70e,70n,71e,71n,72e,72n,73e,82e,82n,119n,135e,135n,136e,138e,138n,140e,140n,141e,143e,143n,145n,162e,162n,163e,163n,164e,164n,165e,168e,168n,169n,170n,177e,177n,178n,179e,179n,182e,183n,184e,184n,185e,185n,186e,186n,187n,189e,189n,191e,321n,324n,333n

Bad (210): 0e,0n,1e,1n,2e,2n,3e,3n,4e,4n,5e,5n,7e,8e,10e,11e,11n,12e,12n,13e,13n,14e,14n,15e,15n,16e,16n,17e,17n,18e,18n,23e,23n,24e,24n,25e,25n,26e,26n,27e,27n,28e,28n,29e,29n,30e,30n,32e,32n,33e,33n,36e,36n,37e,37n,38e,38n,39e,39n,42e,45n,46n,50e,50n,51e,51n,52e,52n,53e,53n,57e,65e,65n,66e,66n,67e,67n,68e,68n,81e,81n,83e,83n,84e,84n,85e,85n,86e,86n,87e,87n,88e,88n,89e,89n,90e,90n,91e,91n,92e,92n,93e,93n,94e,94n,98e,98n,99e,99n,100e,100n,101e,101n,102e,102n,103e,103n,104e,104n,105e,105n,106e,106n,107e,107n,108e,108n,109e,109n,110e,110n,111e,111n,112e,112n,116

In [17]:
# Plot RMS residuals of good high-passed autos versus frequency
if PLOT:
    plot_pols = ['Jee', 'Jnn']
    fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=len(plot_pols), sharex=True)

    for k, v in autos.items():
        if ant_class.is_bad(k):
            continue
        ax = axes[plot_pols.index(k[-1])]
        mdl = smooth_mdl[k][0]
        sig = mdl / np.sqrt(intcnt / 2)
        residual = (v[0] - mdl) / sig
        mask = np.where(data2_wgts[0])
        ax.plot(hc.freqs[mask] / 1e6, residual[mask], 'k', label=str(k[0]), alpha=0.1)

    for cnt, pol in enumerate(plot_pols):
        ax = axes[plot_pols.index(pol)]
        ax.set_title(f'Good Antennas, Polarization: {pol}')
        ax.set_ylabel('Z Score')
        ax.set_ylim(-5, 5)
    _ = axes[-1].set_xlabel('Frequency [MHz]')

In [18]:
# Plot antenna positions with good and bad antennas highlighted
if PLOT:
    fig, axes = plt.subplots(figsize=(7,14), ncols=1, nrows=2)
    for cnt, pol in enumerate(('Jee', 'Jnn')):
        plt.sca(axes[cnt])
        ex_ants = [k[0] for k in ant_class.bad if k[-1] == pol]
        hl_ants = [k[0] for k in ant_class.good if k[-1] == pol]
        uvtools.plot.plot_antpos(antpos, ex_ants=ex_ants, hl_ants=hl_ants)
        plt.title(f'Polarization: {pol}')
    for ax in axes:
        ax.set_xlim(-200, 100)
        ax.set_ylim(-200, 100)

# Estimate Absolute Amplitude from Autocorrelations

In [19]:
# First estimate of antenna gains from autos

# polynomial fit to x=log10(freq) y=log10(abscal_gain / auto) for H4C
abscal_loglog_poly = np.array([159.60511509346617, -6411.680706063783, 102993.61331972879, -826937.7537351248, 3318643.7541476665, -5325564.542530925])
auto_scalar = 10**np.polyval(abscal_loglog_poly, np.log10(hc.freqs))
inpainted_autos = {k: data2_wgts * v + (1 - data2_wgts) * smooth_mdl[k] for k, v in smooth_mdl.items()}
noise_mdl = {k: v / np.sqrt(intcnt / 2) for k, v in smooth_mdl.items()}
#mean_pwr = np.mean([v for k, v in inpainted_autos.items() if k in good_ants], axis=0)
#auto_gains = {k: np.sqrt(auto_scalar * v / mean_pwr) for k, v in inpainted_autos.items()}
auto_gains = {k: np.sqrt(auto_scalar * v) for k, v in inpainted_autos.items()}

In [20]:
# Plot shape of in-painted autos versus frequency
if PLOT:
    plot_pols = ['Jee', 'Jnn']
    fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=len(plot_pols), sharex=True)

    for k, v in inpainted_autos.items():
        if ant_class.is_bad(k):
            continue
        ax = axes[plot_pols.index(k[-1])]
        ax.plot(hc.freqs / 1e6, np.mean(v, axis=0) / np.mean(v), label=str(k[0]))

    for cnt, pol in enumerate(plot_pols):
        ax = axes[plot_pols.index(pol)]
        ax.set_title(f'Polarization: {pol}')
        ax.set_ylabel('Calibrated Power')
        ax.grid()
    _ = axes[-1].set_xlabel('Frequency [MHz]')

In [21]:
timer.clock('auto_flags')
print(timer)

start->auto_flags:  8.29 s


# Firstcal Delays from Stable RFI Transmitters

In [22]:
data = hc.read(read_data=True, read_flags=False, read_nsamples=False)[0]

In [23]:
# Dict of FM Radio Headings, ch: (fq, ang, chisq)
phs_sol = {
     #359: ( 90744018.5546875, 0.785398, 23.39),
     360: ( 90866088.8671875, 0.785398, 10.85),
     369: ( 91964721.6796875, 0.106814, 34.81),
     386: ( 94039916.9921875, 0.785398, 18.12),
     #391: ( 94650268.5546875, 3.581415, 47.47),
     392: ( 94772338.8671875, 3.587698, 40.78),
     #399: ( 95626831.0546875, 6.063273, 36.57),
     400: ( 95748901.3671875, 6.063273, 24.07),
     441: (100753784.1796875, 0.785398, 21.72),
     447: (101486206.0546875, 3.587698, 43.82),
     #455: (102462768.5546875, 6.063273, 18.87),
     456: (102584838.8671875, 6.063273, 8.811),
     471: (104415893.5546875, 0.785398, 13.39),
     477: (105148315.4296875, 3.587698, 19.82),
     485: (106124877.9296875, 6.063273, 4.041),
    1182: (191207885.7421875, 0.785398, 27.06),
#    1444: (223190307.6171875, 1.426283, 54.68),
#    1445: (223312377.9296875, 2.607521, 52.55),
#    1494: (229293823.2421875, 5.560618, 51.34),
}

chs = np.array(sorted(list(phs_sol.keys())))
ch_wgts = np.where(hc.freqs[chs] > 150e6, 10, 1)  # upwgt high-band station to offset FM overrepresentation
sum_ch_wgts = np.sum(ch_wgts)
lams = 3e8 / hc.freqs[chs]
_angs = np.array([phs_sol[ch][1] for ch in chs])
rfi_headings = np.array([np.cos(_angs), np.sin(_angs), np.zeros_like(_angs)])

In [24]:
# Build redundancy lists from antenna position and filter out bad antennas
reds = hera_cal.redcal.get_reds(antpos, pols=['ee','nn'], pol_mode='2pol')
freds = hera_cal.redcal.filter_reds(reds, ex_ants=ant_class.bad)

In [25]:
# Attempt to geometrically phase baselines to RFI channels, and toss out
# baselines that don't phase (a sign of broken cross-correlation)

ant_class = AntennaClassification(pwr_bound, slope_bound, rfi_bound, rfi2_bound)

_phs_bl_bound = Bounds(absolute=(-np.Inf, 0.4), good=(-np.Inf, 0.15))
phs_bound = Bounds(absolute=(-np.Inf, 0.4), good=(-np.Inf, 0.15))
                       
# Because freq sampling is sparse, a brute-force search for best delay is
# both faster and more robust
DLY0_RNG = 150  # maximum delay to try, in ns
dlys_try = np.linspace(-DLY0_RNG, DLY0_RNG, 4 * DLY0_RNG + 1) * 1e-9
dlys_try.shape = (-1, 1)
fqs = hc.freqs[chs]
fqs.shape = (1, -1)

phasor = np.exp(-2j * np.pi * fqs * dlys_try)  # brute force RFI phasors by delay

ant_cnt = {}  # counts how many times an antenna appears in baselines
bl_dly = {}  # stores best-fit delay for each baseline

for cnt, grp in enumerate(freds):
    bl = grp[0]
    # generate predicted geometric phases for each RFI station
    bl_xyz = antpos[bl[1]] - antpos[bl[0]]
    rfi_phs = np.exp(-2j * np.pi * np.dot(bl_xyz, rfi_headings) / lams)

    for bl in grp:
        a_i, a_j = split_bl(bl)
        if ant_class.is_bad(a_i) or ant_class.is_bad(a_j):
            continue
        d_phs = np.sum(data[bl][:,chs], axis=0) * rfi_phs
        ant_cnt[a_i] = ant_cnt.get(a_i, 0) + 1
        ant_cnt[a_j] = ant_cnt.get(a_j, 0) + 1
        d_phs /= np.abs(d_phs).clip(1, np.Inf)
        d_phs = d_phs * phasor
        _chi    = np.sum(ch_wgts * np.abs(d_phs - 1)**2, axis=1) / sum_ch_wgts
        _chi180 = np.sum(ch_wgts * np.abs(d_phs + 1)**2, axis=1) / sum_ch_wgts # 180-deg phase offset
        i = np.argmin(_chi)
        j = np.argmin(_chi180)
        #print(bl, _chi[i], _chi180[j])
        if _chi180[j] < 0.5 * _chi[i]:
            # if dipoles are swapped, 180-deg phasor is best fit
            i, _chi = j, _chi180
        _phs_bl_bound.classify(bl, _chi[i])  # classify baselines based on best chisq
        bl_dly[bl] = dlys_try[i, 0]

for _cnt in range(2):
    # Calculate fraction with bad chisq on phasing
    phs_bound.clear()
    phs_bound.bad = ant_class.bad.copy()
    bad_cnt = {}
    for bl in _phs_bl_bound.bad:
        a_i, a_j = split_bl(bl)
        if ant_class.is_bad(a_i) or ant_class.is_bad(a_j):
            continue
        bad_cnt[a_i] = bad_cnt.get(a_i, 0) + 1
        bad_cnt[a_j] = bad_cnt.get(a_j, 0) + 1

    # Classify antennas based on fraction of baselines that have bad phasing
    for ant, cnt in bad_cnt.items():
        bad_frac = cnt / ant_cnt[ant]
        #print(ant, bad_frac)
        phs_bound.classify(ant, bad_frac)

    ant_class = AntennaClassification(pwr_bound, slope_bound, rfi_bound, rfi2_bound, phs_bound)
print(ant_class)

Good (12): 20e,40n,56n,73n,144e,144n,161e,178e,181e,181n,183e,191n

Suspect (69): 7n,8n,9e,9n,10n,19e,19n,20n,21e,21n,31e,31n,40e,41e,41n,42n,45e,46e,54e,54n,56e,57n,69e,69n,72e,72n,73e,82e,82n,119n,135e,135n,136e,138e,138n,140e,140n,141e,143e,143n,145n,162e,162n,163e,163n,164e,164n,165e,168e,168n,169n,170n,177e,177n,178n,179e,179n,182e,183n,184e,184n,185e,185n,186e,186n,187n,189e,189n,191e

Bad (219): 0e,0n,1e,1n,2e,2n,3e,3n,4e,4n,5e,5n,7e,8e,10e,11e,11n,12e,12n,13e,13n,14e,14n,15e,15n,16e,16n,17e,17n,18e,18n,23e,23n,24e,24n,25e,25n,26e,26n,27e,27n,28e,28n,29e,29n,30e,30n,32e,32n,33e,33n,36e,36n,37e,37n,38e,38n,39e,39n,42e,45n,46n,50e,50n,51e,51n,52e,52n,53e,53n,55e,55n,57e,65e,65n,66e,66n,67e,67n,68e,68n,70e,70n,71e,71n,81e,81n,83e,83n,84e,84n,85e,85n,86e,86n,87e,87n,88e,88n,89e,89n,90e,90n,91e,91n,92e,92n,93e,93n,94e,94n,98e,98n,99e,99n,100e,100n,101e,101n,102e,102n,103e,103n,104e,104n,105e,105n,106e,106n,107e,107n,108e,108n,109e,109n,110e,110n,111e,111n,112e,112n,116e,116n,119e,120

In [26]:
# %%time
# # 2min 35s
# freds = hera_cal.redcal.filter_reds(reds, ex_ants=bad_ants)
# info = hera_cal.redcal.RedundantCalibrator(freds)
# meta0, sol0 = info.firstcal(data, hc.freqs)

In [27]:
# Solve firstcal delays for non-bad antennas

freds = hera_cal.redcal.filter_reds(reds, ex_ants=ant_class.bad)

# Encode equations for non-bad antennas
eqs1 = {}
for bl, dly in bl_dly.items():
    a_i, a_j = split_bl(bl)
    if ant_class.is_bad(a_i) or ant_class.is_bad(a_j):
        continue
    eqs1['dly_%d_%s - dly_%d_%s' % (a_i + a_j)] = dly

ls = linsolve.LinearSolver(eqs1)
_sol1 = ls.solve()
dlys = {ai: _sol1['dly_%d_%s' % ai] for ai in ant_cnt.keys() if not ant_class.is_bad(ai)}

In [28]:
# # Fine-tune FM delays after first round of per-ant solutions
# # XXX probably unnecessary

# DLY1_RNG = 15  # # maximum delay to try, in ns

# # fine-tuning loop. multiple passes appear unnecessary
# dlys_try = np.linspace(-DLY1_RNG, DLY1_RNG, 20 * DLY1_RNG + 1) * 1e-9
# dlys_try.shape = (-1, 1)
# fqs = hc.freqs[chs]
# fqs.shape = (1, -1)
# gains = {ai: np.exp(2j * np.pi * fqs * dly) for ai, dly in dlys.items()}
# phasor = np.exp(-2j * np.pi * fqs * dlys_try)  # brute force RFI phasors by delay

# _eqs2 = {}
# for grp in freds:
#     bl = grp[0]
#     # generate predicted geometric phases for each RFI station
#     bl_xyz = antpos[bl[1]] - antpos[bl[0]]
#     rfi_phs = np.exp(-2j * np.pi * np.dot(bl_xyz, rfi_headings) / lams)

#     for bl in grp:
#         a_i, a_j = split_bl(bl)
#         #g_ij = gains[a_i][:,chs] * gains[a_j][:,chs].conj()
#         g_ij = gains[a_i] * gains[a_j].conj()
#         d_phs = np.sum(data[bl][:,chs] / g_ij, axis=0) * rfi_phs
#         d_phs /= np.abs(d_phs).clip(1, np.Inf)
#         d_phs = d_phs * phasor
#         _chi    = np.sum(ch_wgts * np.abs(d_phs - 1)**2, axis=1) / sum_ch_wgts
#         _chi180 = np.sum(ch_wgts * np.abs(d_phs + 1)**2, axis=1) / sum_ch_wgts # 180-deg phase offset
#         i = np.argmin(_chi)
#         j = np.argmin(_chi180)
#         #print(bl, _chi[i], _chi180[j])
#         if _chi180[j] < 0.5 * _chi[i]:
#             # if dipoles are swapped, 180-deg phasor is best fit
#             i, _chi = j, _chi180
#         _eqs2['dly_%d_%s - dly_%d_%s' % (a_i + a_j)] = dlys_try[i, 0]

# ls = linsolve.LinearSolver(_eqs2)
# _sol2 = ls.solve()
# dlys = {ai: dly + _sol2['dly_%d_%s' % ai] for ai, dly in dlys.items()}

# Finalize Firstcal Delays from Sky

In [29]:
# Final polish on delays using full-band (non-RFI) data

DLY2_RNG = 15  # maximum delay to try, in ns
CH_STEP = 100  # every nth channel to include in fit
_chs = np.arange(0, hc.freqs.size, CH_STEP)
freqs = hc.freqs.copy()
freqs.shape = (1, -1)
fqs = freqs[:,_chs]
bl_swapped = {}

# fine-tuning loop. multiple passes appear unnecessary
dlys_try = np.linspace(-DLY2_RNG, DLY2_RNG, 20 * DLY2_RNG + 1) * 1e-9
dlys_try.shape = (-1, 1)

phasor = np.exp(-2j * np.pi * fqs * dlys_try)  # brute force phasors by delay
gains = {ai: np.exp(2j * np.pi * freqs * dly) for ai, dly in dlys.items()}

_eqs = {}
for grp in freds:
    _chisq = np.Inf
    bl = grp[0]
    # generate predicted geometric phases for each RFI station
    bl_xyz = antpos[bl[1]] - antpos[bl[0]]
    rfi_phs = np.exp(-2j * np.pi * np.dot(bl_xyz, rfi_headings) / lams)
    # pick a representative baseline for a redundant group based on flatness of phase,
    # of rfi indicating it is close to ideal geometric phase solution
    for bl in grp:
        a_i, a_j = split_bl(bl)
        g_ij = gains[a_i][:, chs] * gains[a_j][:, chs].conj()
        d_phs = np.sum(data[bl][:, chs] / g_ij, axis=0) * rfi_phs
        #g_ij = gains[a_i][:,_chs] * gains[a_j][:,_chs].conj()
        #d_phs = np.sum(data[bl][:,_chs] / g_ij, axis=0) * rfi_phs
        d_phs /= np.abs(d_phs).clip(1, np.Inf)
        _chi = np.mean(np.abs(d_phs - 1)**2)
        if _chi < _chisq:
            min_bl = bl
            #phs = d_phs.conj()
            phs = data[bl][:,_chs] / (gains[a_i][:,_chs] * gains[a_j][:,_chs].conj())
            phs = phs.conj()
            _chisq = _chi
    # compute phase relative to representative baseline and encode 4-point phase equation
    ma_i, ma_j = split_bl(min_bl)
    for bl in grp:
        a_i, a_j = split_bl(bl)
        g_ij = gains[a_i][:,_chs] * gains[a_j][:,_chs].conj()
        d_phs = np.sum(data[bl][:,_chs] * phs / g_ij, axis=0)
        d_phs /= np.abs(d_phs).clip(1, np.Inf)
        d_phs = d_phs * phasor
        _chi    = np.sum(np.abs(d_phs - 1)**2, axis=1)
        _chi180 = np.sum(np.abs(d_phs + 1)**2, axis=1) # 180-deg phase offset
        i = np.argmin(_chi)
        j = np.argmin(_chi180)
        #print(bl, _chi[i], _chi180[j])
        bl_swapped[bl] = 0
        if _chi180[j] < 0.5 * _chi[i]:
            # if dipoles are swapped, 180-deg phasor is best fit
            bl_swapped[bl] = 1
            i, _chi = j, _chi180
        _eqs['dly_%d_%s - dly_%d_%s - dly_%d_%s + dly_%d_%s' % (a_i + a_j + ma_i + ma_j)] = dlys_try[i, 0]

ls = linsolve.LinearSolver(_eqs)
_sol = ls.solve()
dlys = {ai: dly + _sol['dly_%d_%s' % ai] for ai, dly in dlys.items()}
#print(len([bl for bl, swapped in bl_swapped.items() if swapped]))

In [30]:
# Identify antennas that have 180-deg dipole rotations; they will be corrected
_swap_bound = Bounds(absolute=(-np.Inf, 0.25), good=(-np.Inf, 0.25))
# Calculate fraction with swapped dipoles
ant_cnt = {}
swap_cnt = {}
for bl, swapped in bl_swapped.items():
    a_i, a_j = split_bl(bl)
    if ant_class.is_bad(a_i) or ant_class.is_bad(a_j):
        continue
    for ant in (a_i, a_j):
        ant_cnt[ant] = ant_cnt.get(ant, 0) + 1
        swap_cnt[ant] = swap_cnt.get(ant, 0) + swapped

for ant, cnt in swap_cnt.items():
    swap_frac = cnt / ant_cnt[ant]
    _swap_bound.classify(ant, swap_frac)

# restrict swap antennas to where both polarizations show it
print(f'Maybe Reversed?: {_antenna_str(_swap_bound.bad)}')
print(f'Maybe Not Reversed?: {_antenna_str(_swap_bound.good)}')
swap_ants = set(k[0] for k in _swap_bound.bad if k[-1] == 'Jee')
swap_ants.intersection_update([k[0] for k in _swap_bound.bad if k[-1] == 'Jnn'])
swap_ants = set((k, pol) for k in swap_ants for pol in ('Jee', 'Jnn'))
swap_ants.difference_update(ant_class.bad)
#swap_ants.clear()
print(f'Reversed: {_antenna_str(swap_ants)}')

Maybe Reversed?: 72e,143e,178n
Maybe Not Reversed?: 7n,8n,9e,9n,10n,19e,19n,20e,20n,21e,21n,31e,31n,40e,40n,41e,41n,42n,45e,46e,54e,54n,56e,56n,57n,69e,69n,72n,73e,73n,82e,82n,119n,135e,135n,136e,138e,138n,140e,140n,141e,143n,144e,144n,145n,161e,162e,162n,163e,163n,164e,164n,165e,168e,168n,169n,170n,177e,177n,178e,179e,179n,181e,181n,182e,183e,183n,184e,184n,185e,185n,186e,186n,187n,189e,189n,191e,191n
Reversed: 


In [31]:
# Given final firstcal gain solutions, compute unique baseline solutions
sol0 = {ai: auto_gains[ai] * np.exp(2j * np.pi * freqs * dly) for ai, dly in dlys.items()}
for ai in swap_ants:
    sol0[ai] *= -1
all_bls = set(hera_cal.utils.join_bl(ai, aj) for ai in sol0.keys() for aj in sol0.keys())
info = hera_cal.redcal.RedundantCalibrator(freds)
#sol0.update(info.compute_ubls(data, sol0))  # takes 2 s
# below is much faster
ubl_sols = {} 
for grp in info.reds:
    ubl_sum = 0
    for bl in grp:
        ai, aj = split_bl(bl)
        ubl_sum += data[bl] / (sol0[ai] * sol0[aj].conj())
    ubl_sols[grp[0]] = ubl_sum / len(grp)
    
sol0.update(ubl_sols)

In [32]:
# Plot firstcal calibrated visibilities for a redundant group

def plot_red_gp(data, sol, gp, t=0, title=None):
    '''Plot all calibrated visibility data in a redundant group given redcal solutions.'''
    fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=2, sharex=True)
    ubl = [bl for bl in gp if bl in sol][0]  # find how this group is indexed in solutions
    mask = np.where(data2_wgts[t])
    u = sol[ubl][t]
    for cnt, bl in enumerate(gp):
        a_i, a_j = split_bl(bl)
        g_ij = sol[a_i][t] * sol[a_j][t].conj()
        _dat = data[bl][t]
        axes[0].plot(hc.freqs / 1e6, np.angle(_dat / g_ij), label=str(bl))
        axes[1].semilogy(hc.freqs[mask] / 1e6, np.abs(_dat / g_ij)[mask], label=str(bl))
    axes[0].plot(hc.freqs / 1e6, np.angle(u), 'k', linewidth=3, label=str(ubl))
    axes[1].semilogy(hc.freqs[mask] / 1e6, np.abs(u)[mask], 'k', linewidth=3, label=str(ubl))
    if title is None:
        title = str(ubl)
    axes[0].set_title(title)
    axes[0].set_ylabel('Phase')
    axes[1].set_ylabel('Amplitude')
    axes[1].set_xlabel('Frequency [MHz]')
    axes[1].grid()

if PLOT:
    plot_red_gp(data, sol0, freds[0], title='Firstcal')

In [33]:
timer.clock('firstcal')
print(timer)

start->firstcal: 14.06 s, auto_flags->firstcal:  5.77 s


In [34]:
# Skipping logcal as an unnecessary step
# %%time
# roughly 15s
#meta1, sol1 = info.logcal(data, {k: v for k, v in sol0.items() if len(k) == 2})
#sol1 = info.remove_degen(sol1, degen_sol=sol0)

In [35]:
#if PLOT:
#   plot_red_gp(data, sol1, freds[0], title='Logcal')

# Omnical

In [36]:
# Establish inverse variance weighting estimated from smoothed autocorrelations

wgts = {}
for bl in all_bls:
    a_i, a_j = split_bl(bl)
    noise = np.sqrt(noise_mdl[a_i] * noise_mdl[a_j])
    wgts[bl] = 1 / (noise / np.sqrt(2))**2 # crosses have 1/2 variance of autos
wgts = hera_cal.io.DataContainer(wgts)

In [37]:
%%time
NITER = 100
# Run Omnical

#def wgt_func(abs2):
#    return np.where(abs2 > 0, 5 * np.tanh(abs2 / 5) / abs2, 1)
#meta2, sol2 = info.omnical(use_data, deepcopy(sol1), wgts=use_wgts, conv_crit=1e-10, gain=.4, maxiter=10000,
#                         check_after=500, check_every=100) # standard pipeline, takes 15.5 min
# wgt func reduces sensitivity to outliers; unclear what impact is for preflagged antennas
#meta2, sol2 = info.omnical(use_data, deepcopy(sol1), wgts=use_wgts, conv_crit=1e-5, gain=.4, maxiter=100,
#                         check_after=50, check_every=10, wgt_func=wgt_func)
meta2, sol2 = info.omnical(data, deepcopy(sol0), wgts=wgts, conv_crit=1e-5, gain=.4, maxiter=NITER,
                         check_after=NITER, check_every=10) # hardcoded to run NITER iterations w/o checking

CPU times: user 20.2 s, sys: 831 ms, total: 21 s
Wall time: 21.2 s


In [38]:
# Replace degeneracies in omnical solutions with firstcal degeneracies, which
# inherited a nominal absolute calibration from H4C

sol2 = info.remove_degen(sol2, degen_sol=sol0)
# Slow
#vis_sols = {k: v for k, v in sol2.items() if len(k) == 3}
#gain_sols = {k: v for k, v in sol2.items() if len(k) == 2}
#chisq2_pol, chisq2_per_ant = hera_cal.redcal.normalized_chisq(use_data, use_wgts, freds, vis_sols, gain_sols)
#chisq2 = 0.5 * (chisq2_pol['Jee'] + chisq2_pol['Jnn'])

In [39]:
# Plot omnical calibrated visibilities for a redundant group
if PLOT:
    plot_red_gp(data, sol2, freds[0], title='Omnical')

In [40]:
# # Examine how redundant solution for a group changed between firstcal & omnical
# if PLOT:
#     fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=2, sharex=True)
#     gp = freds[0]
#     ubl = [bl for bl in gp if bl in sol0][0]
#     mask = np.where(data_wgts[0])
#     for cnt, sol in enumerate((sol0, sol2)):
#         u = sol[ubl][0]
#         axes[0].plot(hc.freqs / 1e6, np.angle(u), label=f'sol{2*cnt}')
#         axes[1].semilogy(hc.freqs[mask] / 1e6, np.abs(u[mask]), label=f'sol{2*cnt}')
#     title = str(ubl)
#     axes[0].set_title(title)
#     axes[0].set_ylabel('Phase')
#     axes[0].legend()
#     axes[1].set_ylabel('Amplitude')
#     axes[1].set_xlabel('Frequency [MHz]')
#     axes[1].legend()
#     axes[1].grid()

In [41]:
# Plot chisq and # of iterations from omnical

def calc_chisq(data, wgts, reds, sol):
    '''Calculate reduced chi-square overall and per-antenna.'''
    chisq_sum, chisq_wgt = 0, 0
    chisq_sum_ant, chisq_wgt_ant = {}, {}
    for gp in reds:
        ubl = [bl for bl in gp if bl in sol][0]  # find how this group is indexed in solutions
        u = sol[ubl]
        for bl in gp:
            a_i, a_j = split_bl(bl)
            g_ij = sol[a_i] * sol[a_j].conj()
            _chi = np.abs((data[bl] - g_ij * u))**2 * wgts[bl]
            chisq_sum += _chi
            chisq_wgt += 1
            chisq_sum_ant[a_i] = chisq_sum_ant.get(a_i, 0) + _chi
            chisq_wgt_ant[a_i] = chisq_wgt_ant.get(a_i, 0) + 1
            chisq_sum_ant[a_j] = chisq_sum_ant.get(a_j, 0) + _chi
            chisq_wgt_ant[a_j] = chisq_wgt_ant.get(a_j, 0) + 1
    chisq = chisq_sum / chisq_wgt
    chisq_ant = {k: v / chisq_wgt_ant[k] for k, v in chisq_sum_ant.items()}
    return chisq, chisq_ant

if PLOT:
    plt.figure()
    chisq, _ = calc_chisq(data, wgts, freds, sol2)
    plt.semilogy(hc.freqs / 1e6, chisq[0], label='Omnical')
    chisq, _ = calc_chisq(data, wgts, freds, sol0)
    plt.semilogy(hc.freqs / 1e6, chisq[0], label='Firstcal')
    plt.semilogy(hc.freqs / 1e6, meta2['iter'][0], label='Iterations')
    plt.legend()
    plt.xlabel('Frequency [MHz]')
    plt.ylabel('$\chi_r^2$')
    plt.ylim(3e-1, 1e3)
    plt.grid()

In [42]:
if PLOT:
    plt.figure()
    chisq, chisq_per_ant = calc_chisq(data, wgts, freds, sol2)
    for k, _chi in chisq_per_ant.items():
        mask = np.where(data2_wgts[0])
        #print(k, np.median(_chi[0][mask]))
        plt.semilogy(hc.freqs[mask] / 1e6, _chi[0][mask], alpha=0.2)
    plt.xlabel('Frequency [MHz]')
    plt.ylabel('$\chi^2_r$')
    plt.title('$\chi^2_r$ per Antenna')
    plt.grid()

In [43]:
timer.clock('omnical')
print(timer)

start->omnical: 36.18 s, firstcal->omnical: 22.12 s


In [44]:
# if PLOT:
#     plt.figure()
#     #hist, bins = np.histogram([np.where(_chi > 2)[0].size / _chi.size for _chi in chisq_per_ant.values()])
#     hist, bins = np.histogram([np.median(_chi) for _chi in chisq_per_ant.values()], bins=20)
#     plt.plot(0.5 * (bins[1:] + bins[:-1]), hist)

In [45]:
# Tack on solutions to "bad" antennas given existing redundant bl solution

gsum = {}
gwgt = {}
sol3 = deepcopy(sol2)

for grp in reds:
    try:
        ubl = [bl for bl in grp if bl in sol3][0]
    except(IndexError):
        continue
    u = sol3[ubl]
    for bl in grp:
        a_i, a_j = split_bl(bl)
        noise = np.sqrt(autos[a_i] * autos[a_j])
        wgt = 1 / (noise / np.sqrt(2))**2 # crosses have 1/2 variance of autos
        if a_i not in sol3:
            if a_j not in sol3:
                continue
            gsum[a_i] = gsum.get(a_i, 0) + data[bl] * (u * sol3[a_j].conj()).conj() * wgt
            gwgt[a_i] = gwgt.get(a_i, 0) + np.abs(u)**2 * np.abs(sol3[a_j])**2 * wgt
        elif a_j not in sol3:
            gsum[a_j] = gsum.get(a_j, 0) + data[bl].conj() * (u.conj() * sol3[a_i].conj()).conj() * wgt
            gwgt[a_j] = gwgt.get(a_j, 0) + np.abs(u)**2 * np.abs(sol3[a_i])**2 * wgt
sol3.update({k: np.nan_to_num(gsum[k] / gwgt[k]) for k in gsum.keys()})

In [46]:
timer.clock('badcal')
print(timer)

start->badcal: 37.49 s, omnical->badcal:  1.31 s


In [47]:
# Plot omnical gains versus first-guess gains from auto-correlations
if PLOT:
    plt.figure()
    mask = np.where(data2_wgts[0])
    for k, v in auto_gains.items():
        if k not in sol3:
            continue
            
        plt.plot(v[0][mask], np.abs(sol3[k][0][mask]), ',', alpha=0.2)

    plt.grid()
    plt.xlabel('Auto Gain')
    plt.ylabel('Omnical Gain')

In [48]:
# Plot all gains (good and bad)
if PLOT:
    fig, axes = plt.subplots(figsize=(8,6), ncols=1, nrows=2, sharex=True)
    mask = np.where(data2_wgts[0])
    for k, gain in sol3.items():
        if len(k) == 3:
            continue
        if k[-1] == 'Jee':
            ax = axes[0]
        else:
            ax = axes[1]
        ax.semilogy(hc.freqs[mask] / 1e6, np.abs(gain[0][mask]), label=str(k))
    axes[0].grid()
    axes[0].set_ylabel('Gain')
    axes[0].set_title('Omnical Gain Solutions')
    axes[1].grid()
    axes[1].set_ylabel('Gain')
    plt.xlabel('Frequency [MHz]')