# H6C LST-Binning Inspection Notebook

**Steven Murray & Josh Dillon**

This notebook provides a sense-check for H6C LST-binned data. It can operate in two different modes: either on redundantly-averaged or non-redundantly-averaged data. Some plots will be included/omitted dependending on which of these modes is being used (the mode is auto-detected).

## Table of Figures

* [Auto-Correlation Plot](#Figure:-Auto-Correlation-Plot)
* [Mean Excess Variance as a Function of Frequency](#Figure:-Mean-Excess-Variance-as-Function-of-Frequency)
* Distribution of Excess Variance:
  * [As function of Days Binned](#Figure:-Distribution-of-Excess-Variance-as-function-of-Days-Binned)
  * [Across Baseline Subsets and LSTs for Low- and High-Band](#Figure:-Distribution-of-Excess-Variance-Across-Baseline-Subsets-and-LSTs-for-Low--and-High-Band)
  * [Across LSTs and Bands for All Baselines](#Figure:-Distribution-of-Excess-Variance-across-LSTs-and-Bands-for-All-Baselines)
  * [As Function of Baseline Lenth and LST at 160 MHz](#Figure:-Distribution-of-Excess-Variance-with-Baseline-Length-and-LST-at-160-MHz)
  * [Between NS and EW baselines and N/E pols](#Figure:-Distribution-of-Excess-Variance-Between-NS-and-EW-baselines-and-pols)
  * [Across Redundant Group Size](#Figure:-Distribution-of-Excess-Variance-Across-Redundant-Group-Size)
* [Raw Visibilities over Nights for the Worst Cases of Excess Variance](#Figure:-Visibilities-Over-Nights)
* Distribution of Predicted Z-Scores:
  * [For a single Frequency/LST/Night](#Figure:-Histogram-of-Baseline-Z-Scores-at-single-Frequency-/-LST-/-Night)
  * [Per-night at 138 MHz](#Figure:-Box-Plot-of-Z-Scores-at-138-MHz)
  * [Per-night at 169 MHz](#Figure:-Box-Plot-of-Z-Scores-at-169-MHz)
* Sigma-Clipping
  * [Sigma-Clipped Fraction as a Function of Threshold, LST and Night](#Figure:-Sigma-Clipped-Fraction-As-Function-of-Threshold,-LST-and-Night)
  * [List of Most Sigma-Clipped LSTs, Nights and Baselines](#List-of-most-sigma-clipped-LSTs,-Nights-and-Baselines)
  * [List of Most-Sigma-Clipped Antennas](#List-of-Most-Sigma-Clipped-Antennas)
  * [List of Most-Sigma-Clipped Antenna-Nights](#List-of-Most-Sigma-Clipped-Antenna-Nights)
  * [Counts of Contiguous Flagged Region Sizes](#Figure:-Counts-of-Contiguous-Flagged-Region-Sizes)

## Configuration and Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from hera_cal.io import HERAData
from hera_cal.datacontainer import DataContainer
import glob
from hera_cal import utils, noise, redcal, lstbin
from hera_cal.lstbin_simple import lst_average
from hera_cal.abscal import match_times
from copy import deepcopy
import os
from IPython.display import display, HTML
import warnings
from pathlib import Path
from astropy.time import Time
from astropy import units as un
import matplotlib as mpl
from hera_cal import io, apply_cal
import toml
from hera_cal.io import HERADataFastReader
from collections import defaultdict
from matplotlib import patches
from hera_opm.mf_tools import get_lstbin_datafiles
from hera_cal.red_groups import RedundantGroups
from hera_cal.datacontainer import RedDataContainer
import yaml
from pyuvdata.uvdata import FastUVH5Meta
import re
import json
from functools import partial
import copy
import attrs
from functools import cached_property
from scipy.stats import gamma, norm
from collections import Counter

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [None]:
# Update the path below if running this notebook interactively
lstbin_path = Path(os.environ.get(
    "LSTBIN_PATH", 
    "/lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/nonavg-smoothcal/"
))

if not lstbin_path.exists():
    raise IOError(f"{lstbin_path} does not exist!")

In [None]:
# Get the baseline groups we want to look at
bl_groups_to_view = os.environ.get(
    "BLGROUPS",  # should be in given in form '[[1, 2,"ee"], [0,0,"nn"]]'
    [
        (4, 7, 'ee'),
        (4, 7, 'nn'),
        (10, 22, 'nn'), 
        (20, 22, 'nn'), 
        (10, 47, 'nn'), 
        (81, 155, 'ee'), 
        (8, 61, 'ee'),
        (4, 6, 'ee'),
    ]
)

if isinstance(bl_groups_to_view, str):
    bl_groups_to_view=[tuple(x) for x in json.loads(bl_groups_to_view)]
    

### Setup Data

In [None]:
config = toml.load(lstbin_path / "lstbin-config.toml")

In [None]:
with open(lstbin_path / 'file-config.yaml', 'r') as fl:
    lstbin_file_config = yaml.safe_load(fl)

In [None]:
dlst = lstbin_file_config['config_params']['dlst']

In [None]:
# Get a list of every file that goes into the LST-bin products
all_data_files = sum(sum(lstbin_file_config['matched_files'], start=[]), start=[])

In [None]:
# Get the JDs of the input data files from their filenames (i.e. the first time in each file)
data_jds = [float(re.findall(lstbin_file_config['config_params']['jd_regex'], fl)[0]) for fl in all_data_files]
data_jd_ints = sorted({int(jd) for jd in data_jds})

In [None]:
# Get a simple increasing list of jd-ints that cover the full observation (including missing days)
JDs = np.arange(int(min(data_jds)), int(max(data_jds)) + 1)

In [None]:
print(f"The dataset spans a range of days {JDs.min()} -- {JDs.max()}")

In [None]:
# Get the list of output files containing the full dataset at a given LST, for ease of comparison.
GOLDEN_LSTs = [float(x) for x in config['LSTBIN_OPTS']['golden_lsts'].split(",")]

In [None]:
golden_files = sorted(lstbin_path.glob('*.GOLDEN.*'))
golden_hds = [io.HERADataFastReader(fl) for fl in golden_files]

In [None]:
meta = FastUVH5Meta(all_data_files[0], blts_are_rectangular=True)
dt = (meta.times[1] - meta.times[0])*3600*24
df = meta.freq_array[1] - meta.freq_array[0]
dlst = (meta.lsts[1] - meta.lsts[0])%(2*np.pi)

In [None]:
# Ensure that for each golden file, all the LSTs are really aligned in a single bin
for ghd in golden_hds:
    assert ghd.lsts.min() + dlst >= (ghd.lsts.max() - 1e-7), f"Got range of {ghd.lsts.max()  - ghd.lsts.min()} > {dlst}"

In [None]:
# Get the index of the LST-bin output file that matches the "Golden" LST, so we can do more easy comparison.
# Get the file index and time index that the golden LST corresponds to
lst_grid = lstbin_file_config['lst_grid']
lst_grid_flat = np.array(sum(lst_grid, start=[]))
dlst = lst_grid_flat[1] - lst_grid_flat[0]

lst_edges = [(lst - dlst/2, lst + dlst/2) for lst in lst_grid_flat]

# Note that the lst_edges might not all be contiguous, because some LSTs
# might not be observed in a dataset
            
print(f"Golden LSTs investigated in this notebook come from the file indices (and time indices within that file):")

golden_file_indices = {}
golden_time_indices = {}
new_golden_lsts = []
for j, glst in enumerate(GOLDEN_LSTs):
    for k, edges in enumerate(lst_edges):
        if glst < edges[0]:
            glst += 2*np.pi
            
        if glst < edges[1]:
            break
    else:
        continue
        
    lstbin_index = k
    glst = lst_grid_flat[lstbin_index]  # this is more exact than the golden LST gotten from file name

    for i, lsts in enumerate(lst_grid):
        if glst in lsts:
            idx = lsts.index(glst)
            golden_file_indices[glst] = i
            golden_time_indices[glst] = idx
            print(f"LST {glst*12/np.pi:6.3f} hr: file={i:>04}, time-idx={idx}")
            break

    new_golden_lsts.append(glst)
GOLDEN_LSTs = new_golden_lsts

In [None]:
# Put our files and hds into the same kind of dicts
new_golden_files = {}
new_golden_hds = {}
for fl, hd in zip(golden_files, golden_hds):
    lstmean = np.mean(hd.lsts)
    idx = np.argmin(np.abs(np.array(GOLDEN_LSTs) - lstmean))
    new_golden_files[GOLDEN_LSTs[idx]] = fl
    new_golden_hds[GOLDEN_LSTs[idx]] = hd
    
golden_files = new_golden_files
golden_hds = new_golden_hds

In [None]:
golden_meta =next(iter(golden_hds.values()))

In [None]:
reds = RedundantGroups.from_antpos(golden_meta.antpos, pols=golden_meta.pols, include_autos=True)

In [None]:
# Determine if this is a redundant dataset or not
blgroups = {reds.get_ubl_key(bl) for bl in golden_meta.bls}
RED_DATA = len(blgroups) == len(golden_meta.bls)

In [None]:
if RED_DATA:
    print("This dataset is redundantly averaged")
else:
    print("This dataset is NOT redundantly averaged")

In [None]:
def read_redundant(hd, reds=None, bls=None, **kw):
    if reds is None:
        reds = RedundantGroups.from_antpos(hd.antpos, hd.pols, include_autos=True)
    keyed = reds.keyed_on_bls(hd.bls)
    if bls is not None:
        bls = [keyed.get_ubl_key(bl) for bl in bls]
    d, f , n = hd.read(bls=bls, **kw)
    d = RedDataContainer(d, reds=reds)
    f = RedDataContainer(f, reds=reds)
    n = RedDataContainer(n, reds=reds)
    return d, f, n

In [None]:
golden_data = {}
golden_flags = {}
golden_nsamples = {}

for glst, hd in golden_hds.items():
    if RED_DATA:
        # We read all the baselines, because we need the golden data to get the expected variance for each baseline.
        golden_data[glst], golden_flags[glst], golden_nsamples[glst] = read_redundant(hd, reds=reds)
    else:
        # Read only autos for now
        golden_data[glst], golden_flags[glst], golden_nsamples[glst] = hd.read(bls=[bl for bl in hd.bls if (bl[0] == bl[1] and bl[2][0] == bl[2][1])])

In [None]:
# load LST-binned data for the bins that match the "GOLDEN" LSTs
lst_bin_files = {}
lstbin_hds = {}
for glst, fl_idx in golden_file_indices.items():
    lst_edge = lst_grid[fl_idx][0] - dlst/2
    fname = lstbin_path / "zen.{kind}.{lst:7.5f}.sum.uvh5".format(kind='LST', lst=lst_edge)
    lst_bin_files[glst] = fname
    lstbin_hds[glst] = HERADataFastReader(fname)

# We read all the baselines for the lst-binned data, because we want to do averages.
lstbin_data = {}
lstbin_flags = {}
lstbin_nsamples = {}
for glst in GOLDEN_LSTs:
    hd = lstbin_hds[glst]
    idx = golden_time_indices[glst]
    if RED_DATA:
        lstbin_data[glst], lstbin_flags[glst], lstbin_nsamples[glst] = read_redundant(hd, reds=reds, times=[hd.times[idx]])
    else:
        # here we're careful only to read data that is not fully flagged, otherwise RAM goes through the roof.
        _, lstbin_flags[glst], _ = hd.read(read_data=False, read_flags=True, pols=['ee', 'nn'])
        flagged =  [bl for bl in lstbin_flags[glst].bls() if np.all(lstbin_flags[glst][bl][idx])]
        del lstbin_flags[glst][flagged]
        lstbin_flags[glst].select_or_expand_times([hd.times[idx]], skip_bda_check=True)
        
        lstbin_data[glst], _, lstbin_nsamples[glst] = hd.read(list(lstbin_flags[glst].bls()), read_nsamples=True)
        
        lstbin_data[glst].select_or_expand_times([hd.times[idx]], skip_bda_check=True)
        lstbin_nsamples[glst].select_or_expand_times([hd.times[idx]], skip_bda_check=True)
        

In [None]:
# load night-to-night standard deviations
std_files = {glst: fl.parent / fl.name.replace('.LST.', '.STD.') for glst, fl, in lst_bin_files.items()}
std_hds = {glst: HERADataFastReader(fl) for glst, fl in std_files.items()}

std_data = {}
std_flags = {}
std_nsamples = {}
for glst, hd in std_hds.items():
    idx = golden_time_indices[glst]
    if RED_DATA:
        std_data[glst], std_flags[glst], std_nsamples[glst] = read_redundant(hd, reds=reds, times=[hd.times[idx]])
    else:
        std_data[glst], _, _ = hd.read(bls=list(lstbin_data[glst].bls()))
        std_data[glst].select_or_expand_times([hd.times[idx]], skip_bda_check=True)

In [None]:
# Here, we just make double-sure that data/flags/nsamples are consistent
for glst in GOLDEN_LSTs:
    for bl in lstbin_data[glst].bls():
        lstf = lstbin_flags[glst][bl]
        lstn = lstbin_nsamples[glst][bl]
        lstd = lstbin_data[glst][bl]
        
        lstf |= (lstn==0)
        lstn[lstf] == 0
        lstd[lstf] *= np.nan  # multiply by nan instead of setting to nan, to get the imaginary cmp.
        std_data[glst][bl][lstf] *= np.nan

In [None]:
# Now we re-sort the "golden lsts" such that we get them in one contiguous chunk
if len(GOLDEN_LSTs) < 24:
    firstidx = np.argmax(np.diff(sorted(GOLDEN_LSTs))) + 1
    GOLDEN_LSTs = np.roll(GOLDEN_LSTs, len(GOLDEN_LSTs) - firstidx)

In [None]:
# Assign colors / line styles for days for the entire notebook, so we have consistency.
styles = {}
for i, jdint in enumerate(data_jd_ints):
    styles[jdint] = {'color': f"C{i%10}", 'ls': ['-', '--', ':', '-.'][i//10]}

### Define Baseline Subsets

Here we define some functions for obtaining different subsets of baselines (eg. long vs. short, EW vs. NS, different pols, intra- vs. inter-sector).

In [None]:
def get_all_antenna_sectors():
    antpos = next(iter(lstbin_data.values())).antpos
    zero_pos = np.mean([antpos[165], antpos[166], antpos[145]], axis=0)
    
    sectors = {}
    for ant, pos in antpos.items():
        rec = pos - zero_pos
        theta = np.arctan2(rec[1], rec[0])
        bllen = np.sqrt(rec[0]**2 + rec[1]**2)
        if bllen > 200:
            sectors[ant] = 4  # outrigger
        elif -np.pi / 3 <= theta < np.pi / 3:
            sectors[ant] = 1
        elif np.pi / 3 <= theta < np.pi:
            sectors[ant] = 2
        elif -np.pi <= theta < -np.pi/3:
            sectors[ant] = 3
    return sectors

In [None]:
sectors = get_all_antenna_sectors()

In [None]:
def getbllen(a,b):
    return np.sqrt(np.sum(np.square(golden_meta.antpos[a] - golden_meta.antpos[b])))

In [None]:
all_ee = lambda bl: bl[2] == 'ee'
all_nn = lambda bl: bl[2] == 'nn'
short_bls = lambda bl: getbllen(bl[0], bl[1])<=60.0
long_bls = lambda bl: getbllen(bl[0], bl[1])>60.0
intersector_bls = lambda bl: sectors[bl[0]] != sectors[bl[1]]
intrasector_bls = lambda bl: sectors[bl[0]] == sectors[bl[1]]

subsets = {
    'all': lambda bl: True,
    'ee-only': all_ee,
    'nn-only': all_nn,
    'Short (<60 m) baselines': short_bls,
    'Long (>60 m) baselines': long_bls,
    'Inter-sector baselines': intersector_bls,
    "Intra-sector baselines": intrasector_bls,
}

In [None]:
def get_selected_bls(bls, days_binned, selectors=None, min_days: int=7):
    mindaysel = lambda bl: bl in days_binned and (np.median(days_binned[bl]) >= min_days)
    crossbl = lambda bl: bl[0] != bl[1] and bl[2][0] == bl[2][1]
    
    if selectors is None:
        selectors = [mindaysel, crossbl]
    elif callable(selectors):
        selectors = [mindaysel, crossbl, selectors]
    else:
        selectors.extend([crossbl, mindaysel])
        
    select = lambda bl: all(sel(bl) for sel in selectors)
    
    return [bl for bl in bls if select(bl)]

## Autocorrelations

In [None]:
def plot_autos(
    data: DataContainer, flags: DataContainer, tidx=slice(None), fig=None, ax=None, 
    xlabel: bool=True, ylabel: bool = True, title: bool=True, legend=True, color=None,
    freq_step: int=1
):
    if fig is None:
        fig, ax = plt.subplots(
            1, 2, sharex=True, sharey=True, figsize=(15,7), 
            gridspec_kw={'wspace': 0.0}, constrained_layout=True
        )
    else:
        assert len(ax) == 2
    
    handles = []
    for i, pol in enumerate(data.pols()):
        if pol[0] != pol[1]:
            continue  # skip non-autos
        
        plt.sca(ax[i])    
        
        bls = [bl for bl in reds[(0,0,pol)] if bl in data]

        for j, bl in enumerate(bls):
            thisd = np.where(flags[bl][tidx], np.nan, np.abs(data[bl][tidx]))
                
            for k, spec in enumerate(thisd):
                jdint = int(data.times[tidx][k])
                
                if i==0 and j==0:
                    handles.append(mpl.lines.Line2D([0], [0], label=str(jdint), **styles[jdint]))
                    
                if np.all(np.isnan(spec)):
                    continue 
                    
                if color:
                    plt.plot(
                        data.freqs[::freq_step] / 1e6, 
                        spec[::freq_step], 
                        color=color
                    )
                else:
                    plt.plot(
                        data.freqs[::freq_step] / 1e6, 
                        spec[::freq_step], 
                        **styles[jdint]
                    )
                
        plt.yscale('log')
        if title:
            plt.title(f"{pol} Pol")

        if xlabel:
            plt.xlabel("Frequency [MHz]")
    if ylabel:
        ax[0].set_ylabel("|V| [Jy]")
    if legend:
        ax[0].legend(loc='lower left', ncols=3, handles=handles)

            
def plot_autos_multi_lst(datas: dict[float, DataContainer], flags: dict[float, DataContainer], 
                         tidx: int=slice(None), freq_step: int = 1):
    fig, ax = plt.subplots(
        len(datas), 2, figsize=(10, max(1.5*len(datas), 6)), 
        sharex=True, sharey=True, 
        gridspec_kw={'hspace': 0, 'wspace': 0}, constrained_layout=True
    )
    
    fig.suptitle("All Auto-Correlations")

    handles = []
    for jdint in data_jd_ints:
        handles.append(mpl.lines.Line2D([0], [0], label=str(jdint), **styles[jdint]))
    handles.append(mpl.lines.Line2D([0], [0], label="LST-average", color='k'))
    ax[0,0].legend(handles=handles, ncols=3)
    for i , key in enumerate(GOLDEN_LSTs):
        plot_autos(
            datas[key], flags[key], tidx=tidx, fig=fig, ax=ax[i], 
            xlabel=i==(len(GOLDEN_LSTs)-1), 
            title=i==0,
            legend=False,
            freq_step=freq_step
        )
        plot_autos(
            lstbin_data[key], 
            lstbin_flags[key], fig=fig, ax=ax[i], 
            xlabel=False, title=False, 
            color='k',
            freq_step=freq_step,
            legend=False
        )
        ax[i, 1].text(
            0.8, 0.8, 
            f"{key*12/np.pi:5.2f} hr", size=14, transform=ax[i,1].transAxes
        )
    return fig, ax

### Figure: Auto-Correlation Plot

In [None]:
plot_autos_multi_lst(golden_data, golden_flags, freq_step=10);

**Figure 1:** All autocorrelations in the night-to-night data going into the LST averages explored in this notebook. Each row is an LST bin, columns are polarizations. Colored lines represent different nights, and the black line represents the LST average.

## Distributions of Excess Variance

In this section, we explore the distribution of the night-to-night variance and variance of LST-averaged data. We derive theoretical distributions for these quantities in Memo #XXXX. We look at how these quantities vary with frequency, LST, various properties of the baselines, and things like autocorrelation magnitude and baseline group size (where applicable). 

We first set up some classes to deal with the theoretical and observed variance distributions.

In [None]:
from scipy.stats import rv_continuous

class MixtureModel(rv_continuous):
    """A distribution model from mixing multiple models.
    
    Taken from https://stackoverflow.com/a/72315113/1467820
    """
    def __init__(self, submodels, *args, weights = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.submodels = submodels
        if weights is None:
            weights = [1 for _ in submodels]
        if len(weights) != len(submodels):
            raise(ValueError(f'There are {len(submodels)} submodels and {len(weights)} weights, but they must be equal.'))
        self.weights = [w / sum(weights) for w in weights]
        
    def _pdf(self, x):
        pdf = self.submodels[0].pdf(x) * self.weights[0]
        for submodel, weight in zip(self.submodels[1:], self.weights[1:]):
            pdf += submodel.pdf(x)  * weight
        return pdf
            
    def _sf(self, x):
        sf = self.submodels[0].sf(x) * self.weights[0]
        for submodel, weight in zip(self.submodels[1:], self.weights[1:]):
            sf += submodel.sf(x)  * weight
        return sf

    def _cdf(self, x):
        cdf = self.submodels[0].cdf(x) * self.weights[0]
        for submodel, weight in zip(self.submodels[1:], self.weights[1:]):
            cdf += submodel.cdf(x)  * weight
        return cdf

    def rvs(self, size):
        submodel_choices = np.random.choice(len(self.submodels), size=size, p = self.weights)
        submodel_samples = [submodel.rvs(size=size) for submodel in self.submodels]
        rvs = np.choose(submodel_choices, submodel_samples)
        return rvs

@attrs.define(slots=False)
class LSTBinStats:
    days_binned: DataContainer  = attrs.field()
    n2n_var_obs: DataContainer = attrs.field()
    lstavg_var_obs: DataContainer = attrs.field()
    lstavg_var_pred: DataContainer = attrs.field()
    per_night_var_pred: DataContainer = attrs.field()
        
    @classmethod
    def from_data(cls, *, 
        lstbin_data: DataContainer, lstbin_nsamples: DataContainer, lstbin_flags: DataContainer, 
        std_data: DataContainer, 
        data: DataContainer=None, nsamples: DataContainer=None, flags: DataContainer=None, 
        **kwargs
    ):
        """Get the observed and predicted variance metrics from observations in a particular LST bin."""
        days_binned = {}
        all_obs_var = {}
        all_predicted_var = {}
        all_interleaved_var = {}
        all_predicted_binned_var = {}
        excess_binned_var = {}
        excess_interleaved_var = {}
        per_night_var_pred = {}
        
        # Make sure we output correct types
        dcls = lstbin_data.__class__ # Either DataContainer or RedDataContainer
        REDAVG = dcls == RedDataContainer
        
        if REDAVG:
            dcls = partial(dcls, reds=data.reds)
                
        if REDAVG and (data is None or nsamples is None or flags is None):
            raise ValueError("If data is redundantly-averaged, you must provide data, nsamples and flags")
            
        for bl in lstbin_data.bls():
            lstd = lstbin_data[bl][0]
            lstn = lstbin_nsamples[bl][0]
            lstf = lstbin_flags[bl][0]
            stdd = std_data[bl][0]
            
            if np.all(lstf):
                continue

            splbl = utils.split_bl(bl)
            if splbl[0] == splbl[1]:  # don't use autos
                continue
            
            # Observed variances.
            all_obs_var[bl] = np.abs(np.where(lstf, np.nan, stdd**2))
            all_interleaved_var[bl] = noise.interleaved_noise_variance_estimate(
                np.atleast_2d(np.where(lstf, np.nan, lstd)), kernel=[[1, -2, 1]]
            )[0]
            # Set first and last frequency to NaN
            all_interleaved_var[bl][[0,-1]] = np.nan

            if REDAVG:
                # In the redundantly-averaged case we need to know the
                # nsamples (and autos) on each night, because they all have
                # different nsamples.
                
                # Ensure flagged data has zero samples
                gd = data[bl]
                gn = nsamples[bl].copy()
                gf = flags[bl]

                gn[gf] = 0

                # This might be slighly wrong because it gets a different variance
                # each night not just from the Nsamples but also the autos. In the 
                # sample variance calculation that goes in to the STD files, we
                # use only the nsamples.
                per_day_expected_var = noise.predict_noise_variance_from_autos(
                    bl, data, dt=dt, df=df, nsamples=nsamples
                )
                per_day_expected_var[gf] = np.inf
                per_night_var_pred[bl] = per_day_expected_var
            
                wgts_arr = np.where(gf, 0, per_day_expected_var**-1) 

                # compute ancillary statistics, see math above
                days_binned[bl] = np.sum(gn > 0, axis=0)
            
            
                all_predicted_binned_var[bl] = np.sum(wgts_arr, axis=0)**-1
                all_predicted_var[bl] = (days_binned[bl] - 1) * all_predicted_binned_var[bl]
            else:
                # Although the above code WOULD work for non-redundantly-averaged
                # data, it is highly inefficient, because we don't need to know 
                # the nsamples every night (since we know they're all uniform).
                expected_var = noise.predict_noise_variance_from_autos(
                    bl, lstbin_data, dt=dt, df=df,
                )[0]
                expected_var[lstf] = np.inf
                days_binned[bl] = lstn
                all_predicted_binned_var[bl] = expected_var / lstn
                all_predicted_var[bl] = all_predicted_binned_var[bl] * (lstn - 1)
                per_night_var_pred[bl] = expected_var[None, :]
                
            excess_binned_var[bl] = all_obs_var[bl] / all_predicted_var[bl]
            excess_interleaved_var[bl] = all_interleaved_var[bl] / all_predicted_binned_var[bl]

        return cls(
            days_binned=dcls(days_binned),
            n2n_var_obs = dcls(all_obs_var),
            lstavg_var_obs= dcls(all_interleaved_var),
            lstavg_var_pred= dcls(all_predicted_binned_var),
            per_night_var_pred = dcls(per_night_var_pred),
        )
    
    @cached_property
    def _cls(self):
        if isinstance(self.days_binned, RedDataContainer):
            return partial(RedDataContainer, reds=self.days_binned.reds)
        else:
            return DataContainer
        
    @cached_property
    def n2n_var_pred(self) -> DataContainer:
        return self._cls({bl: self.lstavg_var_pred[bl] * (self.days_binned[bl] - 1) for bl in self.bls()})
    
    @cached_property
    def n2n_excess_var(self) -> DataContainer:
        return self._cls({bl: self.n2n_var_obs[bl] / self.n2n_var_pred[bl] for bl in self.bls()})

    @cached_property
    def lstavg_excess_var(self) -> DataContainer:
        return self._cls({bl: self.lstavg_var_obs[bl] / self.lstavg_var_pred[bl] for bl in self.bls()})
    
    @classmethod
    def n2n_excess_var_distribution(cls, ndays_binned: int):
        return gamma(a=(ndays_binned-1)/2, scale=2/(ndays_binned-1))
    
    def n2n_excess_var_pred_dist(self, bls, freq_inds=slice(None), min_n: int = 1) -> rv_continuous:
        """Get a scipy distribution representing the theoretical distribution of excess variance.
        
        This will return a MixtureModel -- i.e. it will be the expected distribution of all frequencies
        and baselines asked for (not their average).
        
        """
        if not hasattr(bls[0], "__len__"):
            bls = [bls]
            
        all_ns = np.concatenate(tuple(self.days_binned[bl][freq_inds] for bl in bls))
        unique_days_binned, counts = np.unique(all_ns, return_counts=True)
        indx = np.argwhere(unique_days_binned >= min_n)[:, 0]
        unique_days_binned = unique_days_binned[indx]
        counts= counts[indx]
        
        return MixtureModel([self.n2n_excess_var_distribution(nn) for nn in unique_days_binned], weights=counts)

    def n2n_excess_var_avg_pred_dist(self, bls, freq_inds=slice(None), min_n: int = 1):
        """Get a scipy distribution representing the theoretical distribution of averaged excess variance.
        
        This will return the expected distribution of the averaged excess variance for the
        requested baselines and frequencies. Note this is NOT the excess averaged variance (i.e.
        we're averaging the mean-one excess over the baselines/frequencies, rather than averaging
        the observed variance and dividing by the averaged expected variance).
        
        This is exact for non-redundantly averaged data, and an approximation for red-avg data.
        Gotten from https://stats.stackexchange.com/a/191912/81338
        """
        if not hasattr(bls[0], "__len__"):
            bls = [bls]

        ndays_binned = (np.concatenate(tuple(self.days_binned[bl][freq_inds]for bl in bls)))
        ndays_binned = ndays_binned[ndays_binned >= min_n]
        
        M = len(ndays_binned)
        ksum = np.sum(M**2 / 2 / np.sum(1/(ndays_binned - 1)))
        thetasum = 1 / ksum
        
        return gamma(a=ksum, scale=thetasum)
        
    def bls(self):
        return self.days_binned.bls()
    
    def getmean(
        self,
        rdc: str | RedDataContainer | DataContainer, 
        bls = None,
        min_days: int = 7
    ):
        if isinstance(rdc, str):
            rdc = getattr(self, rdc)
        if bls is None:
            bls = self.bls()
            
        return np.nanmean([np.where(self.days_binned[bl] >= min_days, rdc[bl], np.nan) for bl in bls], axis=0)

In [None]:
stats = {}
for glst in GOLDEN_LSTs:
    stats[glst] = LSTBinStats.from_data(
        data=golden_data[glst], nsamples=golden_nsamples[glst], flags=golden_flags[glst], 
        lstbin_data=lstbin_data[glst], lstbin_nsamples=lstbin_nsamples[glst], lstbin_flags=lstbin_flags[glst], 
        std_data=std_data[glst]
    )

In [None]:
def noise_comparison(glst, subsets: dict[str, callable], mean_of_ratios: bool = False, log: bool=False, min_days: int = 7):
    
    lstbin_hd = lstbin_hds[glst]
    
    stat = stats[glst]
    meanvar = stat.getmean('n2n_var_pred', min_days=min_days)
    
    if np.all(np.isnan(meanvar)):
        return
    
    fig, ax = plt.subplots(2,2, figsize=(16,8), sharex='col', gridspec_kw={'height_ratios': [2, 1]})
    plt.subplots_adjust(hspace=.0)
    ax=ax.flatten()

    ax[0].plot(
        lstbin_hd.freqs/1e6, 
        meanvar, 
        lw=2, 
        label='Predicted Variance from LST-Binned\nAutocorrelations', 
        color='k'
    )
    ax[0].set_ylabel('Nightly Visibility Variance (Jy$^2$) ')
    ax[0].set_title(
        f'Visibility Variance Across Nights at {glst*12/np.pi:5.3f} Hours LST'
        '\n(Mean Over Unflagged Times and Baselines)'
    )
    if log:
        ax[0].set_yscale('log')
    else:
        ax[0].set_ylim(-100, 6000)

    ax[1].plot(
        lstbin_hd.freqs/1e6, 
        stat.getmean('lstavg_var_pred', min_days=min_days), 
        lw=2,
        label='Predicted Variance from LST-Binned\nAutocorrelations and N$_{samples}$',
        color='k'
    )
    ax[1].set_ylabel('LST-Binned Visibility Variance (Jy$^2$)')
    ax[1].set_title(
        f'Variance of LST-Binned Visibilities at {glst*12/np.pi:5.3f} Hours LST'
        '\n(Mean Over Unflagged Times and Baselines, measured by frequency-interleaving)'
    )
    
    if log:
        ax[1].set_yscale('log')
    else:
        ax[1].set_ylim(-10, 200)
    
    for i, (name, subset) in enumerate(subsets.items()):
        bls = get_selected_bls(stat.bls(), days_binned=stat.days_binned, selectors=subset, min_days=min_days)
        if not bls:
            continue
            
        mean_obs_var = stat.getmean('n2n_var_obs', bls, min_days=min_days)
        mean_interleaved_var = stat.getmean('lstavg_var_obs', bls, min_days=min_days)
        
        if mean_of_ratios:
            mean_excess_var = stat.getmean('n2n_excess_var', bls, min_days=min_days)
            mean_excess_lstavg_var = stat.getmean('lstavg_excess_var', bls, min_days=min_days)
        else:
            mean_excess_var = mean_obs_var / stat.getmean("n2n_var_pred", bls, min_days=min_days)
            mean_excess_lstavg_var = mean_interleaved_var / stat.getmean("lstavg_var_pred", bls, min_days=min_days)
                        
        if i == 0:
            ax[0].plot(lstbin_hd.freqs/1e6, mean_obs_var, lw=1, label=name, color=f'C{i}')
            ax[1].plot(lstbin_hd.freqs/1e6, mean_interleaved_var, lw=1, color=f'C{i}')
        else:
            # Dummy plot to get a legend
            ax[0].plot(lstbin_hd.freqs/1e6, np.nan*np.ones(len(mean_obs_var)), lw=1, label=name, color=f'C{i}')
            
        ax[2].plot(lstbin_hd.freqs/1e6, mean_excess_var, color=f'C{i}', lw=1)
        favg_rat = np.nanmean(mean_excess_var)
        ax[2].plot(lstbin_hd.freqs/1e6, np.ones_like(lstbin_hd.freqs) * favg_rat, '--', color=f'C{i}', label=f'{favg_rat:.3f}')
    
        ax[3].plot(lstbin_hd.freqs/1e6, mean_excess_lstavg_var, color=f'C{i}', lw=1)
        favg_rat = np.nanmean(mean_excess_lstavg_var)
        ax[3].plot(lstbin_hd.freqs/1e6, np.ones_like(lstbin_hd.freqs) * favg_rat, '--', color=f'C{i}', label=f'{favg_rat:.3f}')

    ax[0].legend()
    ax[1].legend()
    ax[2].set_xlabel('Frequency (MHz)')
    ax[2].set_xlim([40,250])
    ax[2].set_ylim([.9, 1.5 * favg_rat])
    ax[2].set_ylabel('Observed / Predicted')
    ax[2].legend(loc='upper right', title='Freq-Mean Ratios', ncols=3)

    ax[3].set_xlabel('Frequency (MHz)')
    ax[3].set_ylim([.9, 1.5 * favg_rat])
    ax[3].set_xlim([40,250])
    ax[3].set_ylabel('Observed / Predicted')
    ax[3].legend(loc='upper right', title="Freq-Mean Ratios", ncols=3)

    plt.tight_layout()

### Figure: Mean Excess Variance as Function of Frequency

Comparison of noise predicted from autocorrelations (and $N_{samples}$) to the noise measured either from the standard deviation across nights or from frequency-interleaving.

Based on [Validation Test 4.0.0b](https://github.com/HERA-Team/hera-validation/blob/master/test-series/4/test-4.0.0b.ipynb) and [Aguirre et al. (2021) Figure 12](https://www.overleaf.com/project/5e7cdde364f7d40001749218) (the H1C IDR2 Validation paper).

In [None]:
for glst in GOLDEN_LSTs:
    noise_comparison(glst, subsets, log=True, min_days=7)

### Figure: Distribution of Excess Variance as function of Days Binned

In principle, the excess variance has a distribution that is dependent only on the number of days being binned and no other variable (eg. variance of any particular night). The following plot shows these distributions vs. the theoretical prediction.

In [None]:
stat = stats[GOLDEN_LSTs[5]]
all_nds = np.concatenate([stat.days_binned[bl] for bl in stat.bls()])
unique_nds = np.sort(np.unique(all_nds))
unique_nds = unique_nds[unique_nds > 2]

fig, ax = plt.subplots(len(unique_nds), 2, sharex=True, sharey=True,
                       gridspec_kw={"hspace": 0.0, "wspace": 0}, figsize=(15, len(unique_nds)*2))

x = np.linspace(0, 5, 200)
for i, nd in enumerate(unique_nds):
    for j, (name, select) in enumerate(subsets.items()):
        bls = get_selected_bls(list(stat.bls()), stat.days_binned, selectors=select, min_days=0)

        excess_low = np.concatenate([stat.n2n_excess_var[bl][(stat.days_binned[bl] == nd) & (golden_meta.freqs < 90e6)] for bl in bls])
        excess_high = np.concatenate([stat.n2n_excess_var[bl][(stat.days_binned[bl] == nd) & (golden_meta.freqs > 110e6)] for bl in bls])

        ax[i, 0].hist(excess_low, bins=np.linspace(0, 5, 101), histtype='step', density=True, color=f'C{j}', label=name)
        ax[i, 1].hist(excess_high, bins=np.linspace(0, 5, 101), histtype='step', density=True, color=f'C{j}', label=name)
        
    ax[i, 0].plot(x, gamma(a=(nd-1)/2, scale=2/(nd-1)).pdf(x), color='k')
    ax[i, 1].plot(x, gamma(a=(nd-1)/2, scale=2/(nd-1)).pdf(x), color='k')
    
    ax[i, 0].set_ylabel(f"{int(nd)} days binned")
ax[0,0].legend(ncols=3)
ax[0,0].set_title("Low Band (<90 MHz)")
ax[0,1].set_title("High Band (>110 MHz)");


In [None]:
def make_violin_plot(stat, bl_lists, coords, freq_masks=(slice(None),), min_days=7, xlabel=None, fig=None, ax=None, ylabel=True, topticks=False):
    if fig is None:
        fig, ax = plt.sublots(1, 1, figsize=(12,8))
      
    plt.sca(ax)
    
    # Expand freq_masks out to the length of bl_lists
    if len(freq_masks) == 1 and len(bl_lists) > 1:
        freq_masks = freq_masks * len(bl_lists)
    elif len(bl_lists) == 1 and len(freq_masks) > 1:
        bl_lists = bl_lists * len(freq_masks)
    elif len(bl_lists) != len(freq_masks):
        raise ValueError("bl_lists and freq_masks must be of the same length")
        
    evs = []
    dists = []
    for bls, freq in zip(bl_lists, freq_masks):
        if not bls:
            evs.append(None)
            dists.append(None)
            continue
        evs.append(np.concatenate([stat.n2n_excess_var[bl][stat.days_binned[bl] >= min_days][freq] for bl in bls]))
        try:
            dists.append(stat.n2n_excess_var_pred_dist(bls=bls, freq_inds=freq, min_n=min_days).rvs(size=5000))
        except ValueError:
            # If you get no samples, dist won't work here.
            dists.append(None)
            evs[-1] = None
            
    # Create labels for categorical-type coords based on the original number of
    # coords, even if they end up not being shown because they have no baselines.
    labels = None
    if isinstance(coords[0], str):
        labels = coords
        coords = np.arange(len(coords))
        plt.gca().set_xticks(np.arange(len(coords)), labels=labels)

        
    if topticks:
        ax.xaxis.set_tick_params(labeltop=True)
        
    # remove all the coordinates that have no values at all
    coords = [c for c, ev in zip(coords, evs) if ev is not None and not np.all(np.isnan(ev))]
    dists = [d for d, ev in zip(dists, evs) if d is not None and not np.all(np.isnan(ev))]
    evs = [ev for ev in evs if ev is not None and not np.all(np.isnan(ev))]
    
    if not evs:
        return evs
    
    # Remove huge outliers from evs because otherwise the KDE struggles...
    evs = [ev[ev < 10] for ev in evs]
        
    widths=[0.5 * (y-x) for x,y in zip(coords, coords[1:])] + [0.5 * (coords[-1] - coords[-2])]
    parts = plt.violinplot(
        dists, coords, showextrema=False, 
        widths=widths
    )
    plt.axhline(1, color='k', ls='--')
    
    for pc in parts['bodies']:
        pc.set_facecolor('black')
        pc.set_alpha(0.25)
        
    plt.violinplot(evs, coords, showextrema=False, showmeans=True, widths=widths)
    
    plt.legend(
        handles=[
            plt.Rectangle( (0, 0), 1, 1, facecolor='C1'), 
            plt.Rectangle((0,0), 1,1, facecolor='black', alpha=0.25)
        ],
        labels=['Theory', 'Data'],
        ncols=2
    )
    
    plt.ylim(0, 5)
    if ylabel:
        plt.ylabel("Excess Variance")
    
    if xlabel and labels is not None:
        plt.xlabel(xlabel)
        
    return evs

In [None]:
def make_full_violin_plot(selectors, fig=None, ax=None, min_days=7, suptitle=None, ylabel=True):
    if fig is None:
        fig, ax = plt.subplots(len(stats), 1, sharex=True, sharey=True, figsize=(15, 2*len(stats)), constrained_layout=True)

    for i, (glst, stat) in enumerate(stats.items()):
        bls = [
            get_selected_bls(list(stat.bls()), stat.days_binned, selectors=selector[0], min_days=min_days) for selector in selectors.values()
        ]
        freqmask = [selector[1] for selector in selectors.values()]
        make_violin_plot(
            stat, bls, 
            list(selectors.keys()), 
            fig=fig, ax=ax[i], 
            freq_masks=freqmask, 
            ylabel=False, 
            topticks=i==0
        )
        if ylabel:
            ax[i].set_ylabel(f"{glst*12/np.pi:.2f} hr")

    if suptitle:
        fig.suptitle(suptitle)
    return fig, ax

### Figure: Distribution of Excess Variance Across Baseline Subsets and LSTs for Low- and High-Band

In the following plot, the orange violins represent the observed distribution of excess variance (and the horizontal orange line is the mean of each), while the gray represents the theoretical distribution for that category. The data in each violin come from all baselines within the baseline subset, and all frequencies within the specified band (there is no averaging being done, we're just taking each Baseline/LST/freq as its own datum). Only data that has at least 7 contributing days in the LST-average are counted. Low/High band refer to below and above FM respectively.

In [None]:
fig, ax = plt.subplots(len(stats), 2, sharex=True, sharey=True, figsize=(15, 2*len(stats)), constrained_layout=True)

make_full_violin_plot(
    selectors = {name.replace("baselines", ""): (sel, slice(None, 850)) for name, sel in subsets.items()},
    fig=fig, ax = ax[:, 0],
    suptitle="Distribution of Excess Variance across Subsets, LSTs and Bands"
)
make_full_violin_plot(
    selectors = {name.replace("baselines", ""): (sel, slice(850, None)) for name, sel in subsets.items()},
    fig=fig, ax = ax[:, 1], ylabel=False
)

ax[0,0].set_title("Low Band (< 90 MHz)")
ax[0,1].set_title("High Band (> 110 MHz)")


### Figure: Distribution of Excess Variance across LSTs and Bands for All Baselines

In the following plot, the orange violins represent the observed distribution of excess variance (and the horizontal orange line is the mean of each), while the gray represents the theoretical distribution for that category. The data in each violin come from all baselines, and all frequencies within the specified band of 200 channels each (there is no averaging being done, we're just taking each Baseline/LST/freq as its own datum). Only data that has at least 7 contributing days in the LST-average are counted. 

In [None]:
fig, ax = make_full_violin_plot(
    selectors = {golden_meta.freqs[ind] / 1e6: (lambda bl: True, slice(ind-100, ind+100)) for ind in range(100, 1535, 200)},
    suptitle="Distribution of Excess Variance across LSTs and Bands"
)
ax[-1].set_xlabel("Freq [MHz]")


### Figure: Distribution of Excess Variance with Baseline Length and LST at 160 MHz

In the following plot, the orange violins represent the observed distribution of excess variance (and the horizontal orange line is the mean of each), while the gray represents the theoretical distribution for that category. The data in each violin come from baselines within a given range of lengths (each bin is 14.6 m wide), and all frequencies within a 200-channel frequency band centered around 160 MHz. The choice of frequency range is intended to capture the best quality data in the spectrum. No averaging is done, we're just taking each Baseline/LST/freq as its own datum. Only data that has at least 7 contributing days in the LST-average are counted.

In [None]:
bllen_grid = [(start, start + 14.6) for start in np.arange(7.0, 180.0, 14.6)]

fig, ax = make_full_violin_plot(
    selectors = {(edge[0] + edge[1])/2: (lambda bl, edge=edge: edge[0] <= getbllen(bl[0], bl[1]) < edge[1], slice(800, 1000)) for edge in bllen_grid},
    suptitle="Distribution of Excess Variance across Baseline Lengths at ~160 MHz"
)
ax[-1].set_xlabel("Baseline Length [m]")

### Figure: Distribution of Excess Variance Between NS and EW baselines and pols

In the following plot, the orange violins represent the observed distribution of excess variance (and the horizontal orange line is the mean of each), while the gray represents the theoretical distribution for that category. 
The data in each violin come from baselines that are North-South or East-West oriented (within 6 degrees), and further subdivided by their polarization. Data in each category is taken from all frequencies within a 200-channel frequency band centered around 160 MHz. The choice of frequency range is intended to capture the best quality data in the spectrum. No averaging is done, we're just taking each Baseline/LST/freq as its own datum. Only data that has at least 7 contributing days in the LST-average are counted.

In [None]:
bllen_grid = [(start, start + 14.6) for start in np.arange(7.0, 180.0, 14.6)]

fig, ax = make_full_violin_plot(
    selectors = {
        "EW (ee)": (lambda bl: np.abs(bl[1]/bl[0]) < 1/10. and bl[2]=='ee', slice(800,1000)),
        "EW (nn)": (lambda bl: np.abs(bl[1]/bl[0]) < 1/10. and bl[2]=='nn', slice(800,1000)),
        "NS (ee)": (lambda bl: np.abs(bl[0]/bl[1]) < 1/10. and bl[2]=='ee', slice(800,1000)),
        "NS (nn)": (lambda bl: np.abs(bl[0]/bl[1]) < 1/10. and bl[2]=='nn', slice(800,1000)),
    },
    suptitle="Distribution of Excess Variance for EW vs NS Baselines at ~160 MHz"
)

### Figure: Distribution of Excess Variance Across Redundant Group Size

In the following plot, the orange violins represent the observed distribution of excess variance (and the horizontal orange line is the mean of each), while the gray represents the theoretical distribution for that category. 
The data in each violin come from baselines that have redundant groups within the specified size range. Data in each category is taken from all frequencies within a 200-channel frequency band centered around 160 MHz. The choice of frequency range is intended to capture the best quality data in the spectrum. No averaging is done, we're just taking each Baseline/LST/freq as its own datum. Only data that has at least 7 contributing days in the LST-average are counted. **Note:** redgroup size is highly correlated with baseline length.

In [None]:
group_size_bins = [(1, 10), (10, 20), (20, 40), (40, 80), (80, 200), (200, 500), (500, 1000), (1000, np.inf)]

fig, ax = make_full_violin_plot(
    selectors = {
        f"{g[0]}-{g[1]}": (lambda bl, g=g: (g[0] <= len(reds[bl]) < g[1]), slice(800, 1000)) for g in group_size_bins
    },
    suptitle="Distribution of Excess Variance across Redundant Group Size"
)

## Raw Visibilities

In this section, we plot some raw (calibrated) data in comparison to the LST-binned data, using the "Golden Data" output by the LST-binner. We focus on the baselines with the highest excess variance, so that we can more easily identify issues.

In [None]:
def mad(x, axis=0):
    med = np.nanmedian(x, axis=axis)
    return np.nanmedian(np.abs(x - med), axis=axis)*1.4826

In [None]:
def get_bl_coords(bl):
    sep_unit = np.abs(golden_meta.antpos[1][0] - golden_meta.antpos[0][0])
    return (np.abs(golden_meta.antpos[bl[0]] - golden_meta.antpos[bl[1]]) / sep_unit)[:2]

In [None]:
def plot_visibilities_per_type(
    types, 
    glst,
    freq_range=None, 
    label=None, yrange=None,
    alpha=0.5,
):
    all_figs = []
    
    lstbin_hd = lstbin_hds[glst]
    lststyle = dict(color='k', lw=3, alpha=0.2, zorder=10000)
    excess = stats[glst].n2n_excess_var
    
    for bltype in types:
        fig, ax = plt.subplots(
            4, 2, 
            sharex=True, figsize=(15, 8), 
            constrained_layout=True, gridspec_kw={'height_ratios': (2,1,2,1)}
        )
        
        nights_in = set()
        if freq_range is not None:
            mask = (lstbin_hd.freqs >= freq_range[0]) & (lstbin_hd.freqs < freq_range[1])
            freqs=lstbin_hd.freqs[mask]/1e6
        else:
            mask = slice(None, None, None)
            freqs = lstbin_hd.freqs/1e6
            
        
        bls = [bl for bl in reds[bltype] if bl in golden_data[glst].bls()]
        
        handles = []
        for jdint in data_jd_ints:
            handles.append(mpl.lines.Line2D([0], [0], label=str(jdint), alpha=alpha, **styles[jdint]))
        
        
        for j, bl in enumerate(bls):
            flgs = golden_flags[glst][bl][:, mask]
            datas = golden_data[glst][bl][:, mask]
            
            mag = np.where(flgs, np.nan, np.abs(datas))
            phs = np.where(flgs, np.nan, np.angle(datas))
            rl = np.where(flgs, np.nan, datas.real)
            im = np.where(flgs, np.nan, datas.imag)

            lstflg = lstbin_flags[glst][bl][0, mask]
            lstdata = lstbin_data[glst][bl][0, mask]

            rlmad = mad(rl)
            immad = mad(im)

            maglstbin = np.where(lstflg, np.nan, np.abs(lstdata))
            phslstbin = np.where(lstflg, np.nan, np.angle(lstdata))
            rllstbin = np.where(lstflg, np.nan, lstdata.real)
            imlstbin = np.where(lstflg, np.nan, lstdata.imag)
            
            for night in range(len(golden_data[glst].times)):
                jdint = int(golden_data[glst].times[night])
                style = copy.deepcopy(styles[jdint])
                style['alpha'] = alpha

                if np.all(flgs[night]):
                    continue
                    
                # Amplitude and Phase
                ax[0, 0].plot(freqs, mag[night], **style)
                ax[0, 0].plot(freqs, maglstbin, **lststyle)                
                ax[1, 0].plot(freqs, mag[night] - maglstbin, **style)
                
                ax[2, 0].plot(freqs, phs[night], **style)
                ax[2, 0].plot(freqs, phslstbin, **lststyle)
                phsdiff = phs[night] - phslstbin
                phsdiff[phsdiff < -np.pi] += 2*np.pi
                phsdiff[phsdiff > np.pi] -= 2*np.pi
                ax[3, 0].plot(freqs, phsdiff, **style)
                        
                # Real / Imag
                ax[0, 1].plot(freqs, rl[night], **style)
                ax[0, 1].plot(freqs, rllstbin, **lststyle)                
                ax[1, 1].plot(freqs, (rl[night] - rllstbin)/rlmad, **style)
                
                ax[2, 1].plot(freqs, im[night], **style)
                ax[2, 1].plot(freqs, imlstbin, **lststyle)
                ax[3, 1].plot(freqs, (im[night] - imlstbin)/immad, **style)
                
                if yrange:
                    ax[0, 0].set_ylim(yrange)
                    
            ax[1,1].axhline(4, color='gray', ls='--')
            ax[1,1].axhline(-4, color='gray', ls='--')
            
            ax[3,1].axhline(4, color='gray', ls='--')
            ax[3,1].axhline(-4, color='gray', ls='--')
            
        bl_coords = get_bl_coords(bltype)
        
        fig.suptitle(
            f"Baseline Type: {bltype} [{bl_coords[0]:.1f}-EW, {bl_coords[1]:.1f}-NS]. Size {len(bls)}. "
            f"LST = {glst*12/np.pi:5.3} hr. Median Excess Var = {np.nanmedian(excess[bltype][mask]):.2f}"
        )
        ax[-1, 0].set_xlabel("Frequency [MHz]")
        ax[-1, 1].set_xlabel("Frequency [MHz]")
        
        ax[0, 0].set_ylabel("Magnitude")
        ax[0, 1].set_ylabel("Real Part")
        
        ax[1, 0].set_ylabel("Magnitude Diff")
        ax[1, 1].set_ylabel("Real Z-score")
        ax[1, 1].set_ylim(-7, 7)
        
        ax[2, 0].set_ylabel("Phase")
        ax[2, 1].set_ylabel("Imag Part")
        
        ax[3, 0].set_ylabel("Phase Diff")
        ax[3, 1].set_ylabel("Imag Z-score")
        ax[3, 1].set_ylim(-7, 7)
        ax[0, 0].legend(handles=handles, ncols=5)

        all_figs.append(fig)
        
    return all_figs 

### Figure: Visibilities Over Nights

In [None]:
def get_sorted_keys(stat, min_days=7):
    excess = []
    mask = (golden_meta.freqs>125e6) & (golden_meta.freqs<=230e6)
    bls = []
    for bl in stat.bls():
        if np.mean(stat.days_binned[bl][mask]) < min_days:
            continue
            
        median = np.nanmedian(stat.n2n_excess_var[bl][mask])
        if not np.isnan(median):
            excess.append(median)
            bls.append(bl)
    srt = [k for k, v in sorted(zip(bls, excess), key=lambda item: item[1])]
    return srt

In [None]:
for glst in GOLDEN_LSTs:
    # Sort keys from best to worst
    keys = get_sorted_keys(stats[glst])
    if keys:
        # Three worst, and single best.
        if RED_DATA:
            keys = keys[-1:-4:-1] + keys[:1]
        
        else:
            # We want keys from DIFFERENT red groups
            use_bls = []
            use_keys = []
            for key in reversed(keys):
                if key not in use_bls:
                    use_keys.append(key)
                    use_bls.extend(reds[key])
                if len(use_keys) == 4:
                    break
            use_bls.extend(reds[keys[0]])
            use_keys.append(keys[0])
            
            hd = golden_hds[glst]
            use_bls = [k for k in use_bls if (k in hd.bls or utils.reverse_bl(k) in hd.bls) and (k in lstbin_flags[glst])]
            use_bls = [k if k in hd.bls else utils.reverse_bl(k) for k in use_bls]
            
            # Read these particular baselines into golden_data.
            golden_data[glst], golden_flags[glst], _ = hd.read(bls=use_bls, read_flags=True)
            keys = use_keys
            
        figs = plot_visibilities_per_type(
            keys, glst, freq_range=(125e6, 230e6), alpha=0.75,
        );

## Distribution of Predicted Z-Scores

In [None]:
def get_baseline_zscores(freq_inds, stat, golden_data, golden_nsamples, golden_flags, lstbin_data):
    zscores = {}
            
    if len(golden_data.freqs) == len(freq_inds):
        golden_freq_inds = list(range(len(golden_data.freqs)))
    else:
        golden_freq_inds = freq_inds
        
    for bl in golden_data.bls():
        gd = golden_data[bl]
        gf = golden_flags[bl]
        
        if bl[0] == bl[1] or bl[2][0] != bl[2][1]:
            # skip autos
            continue

        if bl not in stat.per_night_var_pred:
            continue
            
        pred_var = stat.per_night_var_pred[bl][:, freq_inds]
        
        zscores[bl] = np.sqrt(2) * (gd[:, golden_freq_inds] - lstbin_data[bl][0, freq_inds]) / np.sqrt(pred_var)
        zscores[bl][gf[:, golden_freq_inds]] *= np.nan
        
    return RedDataContainer(zscores, reds=reds) if RED_DATA else DataContainer(zscores)

In [None]:
rfreq_hds = {glst: HERADataFastReader(str(lstbin_hds[glst].filepaths[0]).replace(".LST.", ".REDUCEDCHAN.")) for glst in GOLDEN_LSTs}
view_indices = [int(x) for x in config['LSTBIN_OPTS']['save_channels'].split(",")]

if not RED_DATA:
    # Replace the GOLDEN data with the reduced-chan data for now. 
    for glst, hd in rfreq_hds.items():
        golden_data[glst], golden_flags[glst], _ = hd.read(read_flags=True)
    


In [None]:
zscores_pred = {
    glst: get_baseline_zscores(
        view_indices, stats[glst], golden_data[glst], golden_nsamples[glst], golden_flags[glst], lstbin_data[glst]
    ) for glst in GOLDEN_LSTs
}

In [None]:
def plot_zscore_histogram(freq_index, glst, fig=None, ax=None, xlabel=True, legend=True):
    if fig is None:
        fig, ax = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True, constrained_layout=True)
    
    bins=np.linspace(-10, 10, 101)
    
    for i, jd in enumerate(golden_data[glst].times):
        this = np.array([zscores_pred[glst][bl][i, freq_index] for bl in zscores_pred[glst].bls()])
        ax[0].hist(this.real, bins=bins, label=str(int(jd)), histtype='step', density=True, **styles[int(jd)])
        ax[1].hist(this.imag, bins=bins, label=str(int(jd)), histtype='step', density=True, **styles[int(jd)])
        
    if xlabel:
        ax[0].set_xlabel(f"Real Z-Score at {golden_meta.freqs[freq_index]/1e6:.1f} MHz")
        ax[1].set_xlabel(f"Imag Z-Score at {golden_meta.freqs[freq_index]/1e6:.1f} MHz")
    
    x = np.linspace(-10, 10, 1000)
    y = np.exp(-(x**2)/2) / np.sqrt(2*np.pi)
    ax[0].plot(x, y, color='k', lw=3)
    ax[1].plot(x, y, color='k', lw=3)
    
    if legend:
        ax[0].legend(ncol=3, title='Night (JD)');
    ax[0].set_yscale('log')
    ax[1].set_yscale('log')
    ax[0].set_ylim(8e-4, 1)
    ax[1].set_ylim(8e-4, 1)
    
def plot_all_zscore_histograms(freq_index):
    fig, ax = plt.subplots(len(GOLDEN_LSTs), 2, figsize=(15, 2*len(GOLDEN_LSTs)), sharex=True, sharey=True, constrained_layout=True)
    fig.suptitle(f"Z-Scores (from predicted variance) across baselines per-night at {golden_meta.freqs[view_indices[freq_index]]/1e6:.1f} MHz ")
    
    ax[0,0].set_title("Real Part")
    ax[0,1].set_title("Imag Part")
    
    handles = {}
    for i, glst in enumerate(GOLDEN_LSTs):
        plot_zscore_histogram(freq_index, glst, fig=fig, ax=ax[i], xlabel=i==len(GOLDEN_LSTs)-1, legend=False)
        
        # Keep track of legend stuff.
        h, l = ax[i, 0].get_legend_handles_labels()
        handles.update(dict(zip(l, h)))
                
        ax[i, 0].text(0.8, 0.8, f"LST {glst*12/np.pi:5.3f} hr", transform=ax[i,0].transAxes, fontweight='bold')
       
    handles = [h for l, h in sorted(handles.items())]
    fig.legend(loc='upper left', handles=handles, ncols=3)

### Figure: Histogram of Baseline Z-Scores at single Frequency / LST / Night

This plot is showing the the distribution of Z-scores (with respect to the _predicted_ variance) over baselines for each night. Now, remember that the value for a *particular baseline* over nights is by definition mean-zero here (since the z-score for a baseline is defined as the visibility minus the mean over nights for that vis, divided by the std over nights), so the full distribution of everything in this plot should be mean zero, as we can see it is. However, any particular night is free to have a non-zero mean -- all baselines could have been "bad" on that night together. 

In [None]:
for i in range(len(view_indices)):
    plot_all_zscore_histograms(i)

In [None]:
def make_box_plot(freq_index, glst):
    
    plt.figure(figsize=(12, 5))
    z = zscores_pred[glst]
    zscore_arr = [np.concatenate([z[bl][i, freq_index][~np.isnan(z[bl][i, freq_index])].real for bl in z.bls()]) for i in range(z.shape[0])]
    jdsubs = [int(jd) - 2459800 for jd in golden_data[glst].times]
    plt.boxplot(
        zscore_arr, 
        positions=jdsubs
    )
    mean_var = np.array([np.nanmean(zscore**2) for zscore in zscore_arr])
    plt.scatter(jdsubs, mean_var, marker='*', s=75, color='r', zorder=5, 
                label=r'$\langle z^2 \rangle \approx$  Excess Variance ')
    plt.ylim(-12, 12)
    plt.xlabel("Night (JD)")
    plt.ylabel("Predicted Z-Score")
    
    # Put on lines where box-plot markers should be if it were Gaussian
    plt.axhline(0.6754, ls='--', color='gray')
    plt.axhline(-0.6754, ls='--', color='gray')
    plt.axhline(2.698, ls='--', color='gray', alpha=0.5)
    plt.axhline(-2.698, ls='--', color='gray', alpha=0.5)
    plt.axhline(1, ls='-', color='r', lw=1)
    
    plt.title(f"Z-Scores Per-Night Across Baselines at {golden_meta.freqs[view_indices[freq_index]]/1e6:.1f} MHz and LST {glst*12/np.pi:5.3f} hr")
    plt.legend()


### Figure: Box-Plot of Z-Scores at 138 MHz

The following plot shows the distribution of Z-scores at 138 MHz, grouped by night, to highlight which nights (if any) are behaving poorly at each LST. Gray lines show the theoretical expectation for the box and whiskers respectively of the box plots. The red line is at unity and red stars indicate the estimate of the contribution to excess variance from that night (determined by the average of the squared z-score over baselines for that night). 

In [None]:
for glst in GOLDEN_LSTs:
    make_box_plot(1, glst)

### Figure: Box-Plot of Z-Scores at 200 MHz

The following plot shows the distribution of Z-scores at 169 MHz (nominally a well-behaved frequency), grouped by night, to highlight which nights (if any) are behaving poorly at each LST. Gray lines show the theoretical expectation for the box and whiskers respectively of the box plots. The red line is at unity and red stars indicate the estimate of the contribution to excess variance from that night (determined by the average of the squared z-score over baselines for that night). 

In [None]:
for glst in GOLDEN_LSTs:
    make_box_plot(2, glst)

## Exploration of Sigma-Clipping

In this section, we attempt to understand the impact of sigma-clipping on the data. We form robust Z-scores from the GOLDEN data using median absolute deviation, just as in the pipeline code itself, then threshold at different thresholds to inform us of the impact of sigma-clipping.

In [None]:
def get_observed_zscores(glst, chunk_size: int=None, chunk: int = None, min_N: int=4):
    if chunk_size:
        # Read in a bit of data at a time. 
        gd, gf, _ = golden_hds[glst].read(bls=list(lstbin_data[glst].bls())[chunk*chunk_size:(chunk+1)*chunk_size], read_flags=True)
    else:
        gd = golden_data[glst]
        gf = golden_flags[glst]
    
    # We do this in a baseline-loop, which may be slower than it could be
    zscores = {}
    flags = {}
            
    for bl in gf.bls():
        # Ignore autos
        if bl[0] == bl[1] or bl[2][0] != bl[2][1]:
            continue
            
        flg = gf[bl].copy()
        
        this = np.zeros(flg.shape, dtype=complex)
        for part in ['real', 'imag']:
        
            d = getattr(gd[bl], part).copy()

            flg[np.isnan(d) | np.isinf(d)] = True

            d[flg] *= np.nan
            location = np.nanmedian(d, axis=0)
            mad = np.nanmedian(np.abs(d - location), axis=0) * 1.482579

            if part == "real":
                this += (d - location)/mad
            else:
                this += 1j * (d - location)/mad
            
        # Apply min_N criterion
        # the point of "flagging" here is just to be able to exclude 
        # these values from showing up in the computed fractions
        # that are flagged specifically because of their Z-score.
        ndays_binned = np.sum(~flg, axis=0)
        flg[:, ndays_binned < min_N] = True

        zscores[bl] = this
        flags[bl] = flg

    if RED_DATA:
        return RedDataContainer(zscores, reds=reds), RedDataContainer(flags, reds=reds),
    else:
        return DataContainer(zscores), DataContainer(flags)

In [None]:
def sigma_clip_fraction_plot(thresholds, sigma_clip_fracs):
    nglst = len(sigma_clip_fracs)
    ncols = 3
    nrows = nglst // ncols + 1
    fig, axx = plt.subplots(nrows, 3, sharey=True, sharex=True, constrained_layout=True, figsize=(15, 2*nrows))

    cdf = 2* norm().cdf(-np.array(thresholds))

    ax = axx.flatten()

    # get legend entries
    for jdint, style in styles.items():
        ax[0].plot([0], [np.nan], label=str(jdint), **style)
            
    for j, (glst, fracs) in enumerate(sigma_clip_fracs.items()):
        ax[j].plot(thresholds, cdf, color='k', label='Gaussian Theory' if j==0 else None)
        
        ax[j].text(0.5, 0.85, f"LST {glst*12/np.pi:.2f}", fontsize=14, transform=ax[j].transAxes)

        for i, jd in enumerate(golden_data[glst].times):
            for part in ['real', 'imag']:
                frac_cut = np.mean([fracs[bl][i] for bl in fracs.bls()], axis=0)
                intjd = int(jd)
                ax[j].plot(thresholds, frac_cut, alpha=0.5 if part=='imag' else 1.0, **styles[intjd])
    for axxx in ax[j+1:]:
        axxx.axis('off')
        
    fig.supxlabel("sigma clip threshold")

    fig.legend(loc='center', ncols=3, bbox_to_anchor=(0.85, 0.18), frameon=False)
    
    fig.supylabel("Fraction of Samples Clipped")


In [None]:
def get_true_chunk_sizes(x):
    return np.diff(np.where(np.concatenate(([x[0]], x[:-1] != x[1:], [True])))[0])[::2]

In [None]:
def get_sigma_clip_stats(min_N: int = 4, thresholds=(4,), ret_zscores=True):
    if not RED_DATA:
        chunk_size=None  # read in all baselines at once, but only one file at a time.
        nchunks = 1
    else:
        chunk_size=None
        nchunks = 1
    
    flag_fracs = {}
    contig_sizes = {}
    all_zscores = {}
    
    for glst in GOLDEN_LSTs:
        flag_frac_glst = {}
        contig_size_glst = {}
        
        if chunk_size:
            nchunks = len(lstbin_data[glst].bls()) // chunk_size + 1
    
        all_zscores[glst] = RedDataContainer({}, reds=reds) if RED_DATA else DataContainer({})
        
        for chunk in range(nchunks):
            print(f"Getting MAD Z-Scores for Baseline Chunk {chunk + 1} of {nchunks}")
            zscores, pre_flags = get_observed_zscores(glst, chunk_size=chunk_size, chunk=chunk, min_N=min_N)

            for bl in zscores.bls():
                flag_frac_glst[bl] = []
                contig_size_glst[bl] = []
                
                for i, (z, flg) in enumerate(zip(zscores[bl], pre_flags[bl])):
                    
                    fracs = []
                    contig_size = []
                    for thresh in thresholds:
                        clip_flags = (np.abs(z.real) > thresh) | (np.abs(z.imag) > thresh)
                        fracs.append(np.sum(clip_flags & (~flg)) / clip_flags.size)
                        ch = get_true_chunk_sizes(clip_flags)
                        if np.any(np.isnan(~ch)):
                            contig_size.append(ch.max())
                        else:
                            contig_size.append(0)
                    
                    flag_frac_glst[bl].append(fracs)
                    contig_size_glst[bl].append(contig_size)
                    
                flag_frac_glst[bl] = np.array(flag_frac_glst[bl])
                contig_size_glst[bl] = np.array(contig_size_glst[bl])
                
            if ret_zscores:
                all_zscores[glst] += zscores
                
        flag_fracs[glst] = RedDataContainer(flag_frac_glst, reds=reds) if RED_DATA else DataContainer(flag_frac_glst)
        contig_sizes[glst] = RedDataContainer(contig_size_glst, reds=reds) if RED_DATA else DataContainer(contig_size_glst)
        
    return flag_fracs, contig_sizes, all_zscores

In [None]:
thresholds = (3.5, 4, 4.5, 5., 5.5, 6)

In [None]:
flag_fracs, contig_sizes, zscores_obs = get_sigma_clip_stats(
    min_N = config['LSTBIN_OPTS'].get("sigma_clip_min_N", 4),
    thresholds=thresholds, 
    ret_zscores=RED_DATA
)

### Figure: Sigma-Clipped Fraction As Function of Threshold, LST and Night

If the underlying data is Gaussian, the fraction sigma-clipped can be predicted by the CDF of the Gaussian function as a function of threshold. In reality, we expect the data to have more outliers than an actual Gaussian. In the plot below, we show the fraction of samples (across baselines and frequencies) that are flagged specifically due to sigma-clipping, as a function of the sigma-clipping threshold. We split the plots between different LST bins and different nights. The real part is shown as the full-colour lines, while the imaginary part is shown as the 50% transparent lines of the same style. The black line is the theoretical expectation, given an underlying Gaussian distribution across nights.

In [None]:
sigma_clip_fraction_plot(thresholds, flag_fracs)

### List of most sigma-clipped LSTs, Nights and Baselines

One concern is that we might be flagging out large fractions of particular antennas with sigma-clipping. In this case, it would be ideal to identify the actual issue in a previous step, rather than arbitrarily sigma-clipping them at the LST-binning step. Here, we find the most sigma-clipped baselines for any night/LST, and print all those that are sigma-clipped more than 30% (at 4-sigma). This is purely _sigma-clipped_ flags, where pre-flagged data is not counted.

In [None]:
bad_ones = {}
threshidx = thresholds.index(4)
for glst, dc in flag_fracs.items():
    for bl in dc.bls():
        if bl[2][0] != bl[2][1]:
            continue
        for i, jd in enumerate(golden_data[glst].times):
            if dc[bl][i, threshidx] > 0.3:
                bad_ones[(glst*12/np.pi, int(jd), bl)] = dc[bl][i, threshidx]

In [None]:
for (glst, jd, bl), frac in sorted(bad_ones.items(), key=lambda item: item[1])[-1:-100:-1]:
    print(f"LST {glst:5.2f} hr on night {jd} for bl {bl} had {frac*100:.1f}% sigma-clip flags")


In [None]:
nclipped_ee = [len([x for key, x in bad_ones.items() if key[2][2]=='ee' and x > thresh]) for thresh in np.arange(0.3, 1, 0.1)]
nclipped_nn = [len([x for key, x in bad_ones.items() if key[2][2]=='nn' and x > thresh]) for thresh in np.arange(0.3, 1, 0.1)]

plt.bar(np.arange(0.3, 1., 0.1), nclipped_ee, width=0.03, color='C0', label='ee', align='edge')
plt.bar(np.arange(0.3, 1., 0.1), nclipped_nn, width=-0.03, color='C1', label='nn', align='edge')
plt.legend()
plt.yscale('log')
plt.title("Number of (LST, night, bl) combos flagged \n more than a certain fraction of channels")
plt.xlabel("Flag Fraction (over frequency)")
plt.ylabel("Number of (LST, night, bl) combos")

In [None]:
# antennas flagged
if not RED_DATA:
    antflags = {}
    threshes = np.arange(0.3, 1, 0.1)
    for (lst, jd, bl), frac in bad_ones.items():
        a, b = utils.split_bl(bl)
        if a not in antflags:
            antflags[a] = np.zeros_like(threshes)
        if b not in antflags:
            antflags[b] = np.zeros_like(threshes)

        antflags[a] += (frac >= threshes).astype(int)
        antflags[b] += (frac >= threshes).astype(int)
    

### List of Most-Sigma-Clipped Antennas

In [None]:
if not RED_DATA:
    print(f"Number of baseline-LST-night combos clipped more than given % across frequency, for given ant")
    print("==============================================================================================")
    print("Chans clipped: 30%     40%  50%  60%  70%  80%  90%")
    print("---------------------------------------------------")
    
    for ant, fracs in sorted(antflags.items(), key=lambda item: item[1][0])[-1:-25:-1]:
        print(f"{ant[0]:>3}{ant[1][-1]}:          {fracs.astype(int)}")

In [None]:
# ant-nights
if not RED_DATA:
    antnightflags = {}
    threshes = np.arange(0.3, 1, 0.1)
    for (lst, jd, bl), frac in bad_ones.items():
        a, b = utils.split_bl(bl)
        if (a, jd) not in antnightflags:
            antnightflags[(a, jd)] = np.zeros_like(threshes)
        if (b, jd) not in antnightflags:
            antnightflags[(b, jd)] = np.zeros_like(threshes)

        antnightflags[(a, jd)] += (frac >= threshes).astype(int)
        antnightflags[(b, jd)] += (frac >= threshes).astype(int)
    

### List of Most-Sigma-Clipped Antenna-Nights

In [None]:
if not RED_DATA:
    print(f"Number of baseline-LST combos clipped more than given % across frequency, for given night-ant")
    print("==============================================================================================")
    print("Chans clipped: 30%  40% 50% 60% 70% 80% 90%")
    print("-------------------------------------------------")
    
    for (ant, jd), fracs in sorted(antnightflags.items(), key=lambda item: item[1][0])[-1:-25:-1]:
        print(f"{jd}-{ant[0]:>3}{ant[1][-1]}: {fracs.astype(int)}")
    

### Figure: Counts of Contiguous Flagged Region Sizes

We might also worry about when large contiguous chunks of frequency are flagged for a particular antenna. Here we plot the number of contiguous regions of a given size that are flagged:

In [None]:
all_contig_sizes = [ctg[bl][0] for ctg in contig_sizes.values() for bl in ctg.bls()]

plt.hist(all_contig_sizes, bins=np.arange(10, 1535, 10))

# plt.scatter(cnt.keys(), cnt.values(), label='Sigma-clip Flags')
# plt.scatter(cnt_nopreflags.keys(), cnt_nopreflags.values(), label="Not counting pre-flagged", marker='x')

# plt.legend()
plt.yscale('log')
plt.xscale('log')
plt.xlabel("Number of contiguous flags (in frequency)")
plt.ylabel("Number of occurences");

## Metadata

In [None]:
import hera_cal
import pyuvdata
print('hera_cal version: ', hera_cal.__version__)
print('pyuvdata version: ', pyuvdata.__version__)