<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Imports" data-toc-modified-id="Imports-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Imports</a></span></li><li><span><a href="#Configuration" data-toc-modified-id="Configuration-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Configuration</a></span></li><li><span><a href="#Do-the-binning" data-toc-modified-id="Do-the-binning-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Do the binning</a></span></li><li><span><a href="#LST-bin-the-Autos" data-toc-modified-id="LST-bin-the-Autos-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>LST-bin the Autos</a></span><ul class="toc-item"><li><span><a href="#In-painted-Mode" data-toc-modified-id="In-painted-Mode-4.1"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>In-painted Mode</a></span></li><li><span><a href="#Flagged-Mode" data-toc-modified-id="Flagged-Mode-4.2"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>Flagged-Mode</a></span></li><li><span><a href="#Plot" data-toc-modified-id="Plot-4.3"><span class="toc-item-num">4.3&nbsp;&nbsp;</span>Plot</a></span></li></ul></li><li><span><a href="#Cross-Pairs" data-toc-modified-id="Cross-Pairs-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Cross-Pairs</a></span></li></ul></div>

# LST-Bin

**by Steven Murray**, last updated 27th Mar, 2024.

This notebook performs LST-binning, producing a **single output file**. The input to this notebook consists of two configuration files, and one index:

1. A `fileconf`, which is *produced* by `hera_cal.lstbin_simple.make_lst_bin_config_file()` run over a set of raw files. This file lists all the raw files that correspond to all the particular bins, which makes it quick for this notebook to read them in.
2. A binning configuration file, `config`, that specifies all the parameters to use when performing the binning itself.
3. The file index that corresponds to the LST bins that will be saved to the output file in _this_ notebook.

The notebook then proceeds to do essentially the same thing as `hera_cal.lstbin_simple.lst_bin_files_single_outfile`, but with extra plotting and inspection stops along the way.

## Imports

In [None]:
import os
import sys
import yaml
import toml
import inspect
from pathlib import Path
from functools import partial
from typing import Literal
from datetime import datetime
from time import time as _time
import resource

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from astropy import units
import matplotlib as mpl
import h5py
import attrs
from scipy.stats import chi2
import psutil

from pyuvdata.uvdata import FastUVH5Meta
from pyuvdata import UVData
from hera_cal import lst_stack as lstbin
from hera_cal.lst_stack.config import LSTConfig
from hera_cal import abscal
from hera_cal.red_groups import RedundantGroups
from hera_cal.lst_stack import metrics as lstmet

In [None]:
start_time = _time()

## Configuration

In [None]:
fileconf: str = "/lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/redavg-smoothcal-notebook/file-config.h5"
fileidx: int = 380

# The following are defaults that can be overwritten at execution time (preferably by a YAML file)
make_plots: bool = True
save_lstbin_data: bool = True
save_metric_data: bool = True
plot_n_worst: int = 5

outdir: str = "."
bl_chunk_size: int = 0
rephase: bool = True
vis_units: str = "Jy"
fname_format: str = '{inpaint_mode}/zen.{kind}.{lst:7.5f}.sum.uvh5'
overwrite: bool = True
write_med_mad: bool = False
do_flagged_mode: bool = False
freq_min: float = 0.0
freq_max: float = 0.0
history: str = ""

In [None]:
# Parameter changes for typing
outdir = Path(outdir)
if freq_max <= 0.0:
    freq_max = None
if freq_min <= 0.0:
    freq_min = None
if bl_chunk_size <= 0:
    bl_chunk_size = None

In [None]:
fileconf = Path(fileconf)
assert fileconf.exists() and fileconf.is_file(), "The input file-configuration file is not a file"

In [None]:
stackconf = LSTConfig.from_file(fileconf)

In [None]:
print("The LST grid was configured with these parameters: \n")
for key, val in attrs.asdict(stackconf.config).items():
    if key != 'data_files':
        print(f"  {key:>36}: {val}")

In [None]:
print("The raw files have the following properties: \n")
for key, val in stackconf.properties.items():
    print(f"  {key:>25}: {val}")

In [None]:
stackconf = stackconf.at_single_outfile(fileidx)

In [None]:
print(f"LST bin edges considered in this notebook (file index {fileidx}):")
print(f"  {stackconf.lst_grid_edges}")

In [None]:
print(f"Raw files used in this notebook (for all bins): \n")
for fl in stackconf.matched_files:
    print(fl.name)

In [None]:
print(f"The data has {len(stackconf.autos + stackconf.antpairs)} ant-pairs, and {stackconf.pols} polarizations.")

In [None]:
inpaint_mode =  (not do_flagged_mode and stackconf.inpaint_files is not None)

In [None]:
print(f"We will use {'inpaint' if inpaint_mode else 'flagged'} mode in this notebook.")

In [None]:
outdir = Path(outdir)
if not outdir.exists():
    outdir.mkdir(parents=True, exist_ok=True)

In [None]:
print(f"Writing output files to: \n  {outdir}")

In [None]:
# Split up the baselines into chunks that will be LST-binned together.
# This is just to save on RAM.
if bl_chunk_size is None:
    bl_chunk_size = len(stackconf.antpairs)
else:
    bl_chunk_size = min(bl_chunk_size, len(stackconf.antpairs))

n_bl_chunks = int(np.ceil(len(stackconf.antpairs) / bl_chunk_size))

In [None]:
out_fname = lstbin.format_outfile_name(
    fname_format=fname_format, lst=stackconf.lst_grid_edges[0], inpaint_mode=inpaint_mode,
    pols=stackconf.pols, lst_branch_cut=stackconf.properties["lst_branch_cut"],
)

## Define Stacking/Averaging Functions

Define and initialize the output files that we will write in this notebook:

In [None]:
out_files = {}
kinds = ["LST", "STD"]
if write_med_mad:
    kinds += ["MED", "MAD"]
for kind in kinds:
    # Create the files we'll write to
    out_files[kind] = lstbin.io.create_lstbin_output_file(
        fname=out_fname,
        kind=kind,
        lsts=stackconf.lst_grid,
        outdir=outdir,
        file_list=stackconf.matched_metas,
        antpairs=stackconf.autos + stackconf.antpairs,
        start_jd=stackconf.properties['first_jd'],
        lst_branch_cut=stackconf.properties["lst_branch_cut"],
        freq_min=freq_min,
        freq_max=freq_max, 
        vis_units=vis_units, 
        history=history,
        overwrite=overwrite
    )

Now, define a function that uses the configuration we've established and performs LST-binning for a subset of baselines.

In [None]:
def write_baseline_chunk(rdc: dict, nbls_so_far: int):
    nbls_in_chunk = rdc['data'].shape[1]
    
    slc = slice(nbls_so_far, nbls_so_far + nbls_in_chunk)

    lstbin.write_baseline_slc_to_file(
        fl=out_files["LST"],
        slc=slc,
        data=rdc["data"],
        flags=rdc["flags"],
        nsamples=rdc["nsamples"],
    )

    lstbin.write_baseline_slc_to_file(
        fl=out_files["STD"],
        slc=slc,
        data=rdc["std"],
        flags=rdc["flags"],
        nsamples=rdc["days_binned"],
    )

    if write_med_mad:
        lstbin.write_baseline_slc_to_file(
            fl=out_files["MED"],
            slc=slc,
            data=rdc["median"],
            flags=rdc["flags"],
            nsamples=rdc["nsamples"],
        )
        lstbin.write_baseline_slc_to_file(
            fl=out_files["MAD"],
            slc=slc,
            data=rdc["mad"],
            flags=rdc["flags"],
            nsamples=rdc["days_binned"],
        )

In [None]:
def stack_blchunk(
    bl_chunk: int | str,
    nbls_so_far: int,
):
    """Process a single chunk of baselines."""
    sig = inspect.signature(lstbin.binning.lst_bin_files_from_config)
    kw = {k: v for k, v in globals().items() if k in sig.parameters}
    
    stacks: list[UVData] = lstbin.binning.lst_bin_files_from_config(
        stackconf,
        bl_chunk_to_load=bl_chunk,
        nbl_chunks=n_bl_chunks,
        
    )
    
    chunk_size = stacks[0].Nbls
    slc = slice(nbls_so_far, nbls_so_far + chunk_size)

    dshape = (chunk_size, stacks[0].Nfreqs, stacks[0].Npols)

    rdc = lstbin.averaging.reduce_lst_bins(
        [uvd.data_array.reshape((uvd.Ntimes,) + dshape) for uvd in stacks],
        [uvd.flag_array.reshape((uvd.Ntimes,) + dshape) for uvd in stacks],
        [uvd.nsample_array.reshape((uvd.Ntimes,) + dshape) for uvd in stacks],
        inpainted_mode=inpaint_mode,
        get_mad=True,
    )
    write_baseline_chunk(rdc, nbls_so_far)
        
    return stacks, rdc, chunk_size

## Plotting Style Setup

In [None]:
data_jd_ints = sorted({int(meta.times[0]) for meta in stackconf.matched_metas})

In [None]:
styles = {}

for i, jdint in enumerate(data_jd_ints):
    styles[jdint] = {'color': f"C{i%10}", 'ls': ['-', '--', ':', '-.'][i//10]}

## Define Subsets of Data to Consider

### Bands

In [None]:
bands_considered = [
    (0, 200), (200, 400), (400, 600), (600, 800), (800, 1000), (1000, 1200), (1200, 1400), (1400, 1536),
    (0, 450),    # low band
    (450, 1536), # high band
    (0, 1536),   # full band
]

### Baselines

In [None]:
def get_all_antenna_sectors():
    antpos = stackconf.config.datameta.antenna_positions
    zero_pos = np.mean([antpos[165], antpos[166], antpos[145]], axis=0)
    
    sectors = {}
    for ant, pos in enumerate(antpos):
        rec = pos - zero_pos
        theta = np.arctan2(rec[1], rec[0])
        bllen = np.sqrt(rec[0]**2 + rec[1]**2)
        if bllen > 200:
            sectors[ant] = 4  # outrigger
        elif -np.pi / 3 <= theta < np.pi / 3:
            sectors[ant] = 1
        elif np.pi / 3 <= theta < np.pi:
            sectors[ant] = 2
        elif -np.pi <= theta < -np.pi/3:
            sectors[ant] = 3
    return sectors

sectors = get_all_antenna_sectors()

In [None]:
def getblvec(a, b):
    return auto_stacks[0].antenna_positions[a] - auto_stacks[0].antenna_positions[b]
def getbllen(a,b):
    return np.sqrt(np.sum(np.square(getblvec(a,b))))

In [None]:
all_ee = lambda bl: bl[2] == 'ee'
all_nn = lambda bl: bl[2] == 'nn'
short_bls = lambda bl: getbllen(bl[0], bl[1])<=60.0
long_bls = lambda bl: getbllen(bl[0], bl[1])>60.0
intersector_bls = lambda bl: sectors[bl[0]] != sectors[bl[1]]
intrasector_bls = lambda bl: sectors[bl[0]] == sectors[bl[1]]

subsets = {
    'all': lambda bl: True,
    'ee-only': all_ee,
    'nn-only': all_nn,
    'Short (<60 m) baselines': short_bls,
    'Long (>60 m) baselines': long_bls,
    'Inter-sector baselines': intersector_bls,
    "Intra-sector baselines": intrasector_bls,
}

## LST-bin the Autos

In [None]:
def make_auto_plot(auto_stacks: list[UVData], lstbin: dict):
    
    fig, ax = plt.subplots(
        len(stackconf.autos)*len(stackconf.pols), len(auto_stacks), 
        sharex=True, sharey=True, squeeze=False, constrained_layout=True,
        figsize=(12, 6)
    )

    for i, stack in enumerate(auto_stacks):
        for j, autopair in enumerate(stackconf.autos):
            for p, pol in enumerate(stackconf.pols):
                axx = ax[j*len(stackconf.pols) + p, i]
                
                for k, t in enumerate(stack.time_array[::stack.Nbls]):
                    flg = stack.get_flags(autopair + (pol,))[k]
                    d = stack.get_data(autopair+(pol,))[k]
                    
                    axx.plot(
                        stack.freq_array / 1e6,
                        np.where(flg, np.nan, d.real),
                        label=f"{int(t)}" if not p else None,
                        **styles[int(t)]
                    )
                    axx.set_yscale('log')
                    axx.set_title(f"Pair {autopair}, pol={pol}, LST {stackconf.lst_grid[i]*12/np.pi:.3f} hr")

                # plot the mean
                axx.plot(
                    stack.freq_array / 1e6,
                    np.where(lstbin['flags'][i, j, :, p], np.nan, lstbin['data'][i, j, :, p].real),
                    label='LSTBIN',
                    color='k', lw=2
                )
                
    ax[0,0].legend(ncols=3)

In [None]:
auto_stacks, autos_lstavg, nbls_so_far = stack_blchunk('autos', 0)

### Plot

In [None]:
if make_plots:
    make_auto_plot(auto_stacks, autos_lstavg);

## Cross-Pairs

In [None]:
cross_stacks, cross_lstavg, nbls_so_far = stack_blchunk(0, nbls_so_far)

### Calculate Metrics

In [None]:
dt = np.median(cross_stacks[0].integration_time) * units.s
df = np.median(np.diff(cross_stacks[0].freq_array)) * units.Hz

In [None]:
reds_with_pols = RedundantGroups.from_antpos(antpos={i: pos for i, pos in enumerate(stackconf.config.datameta.antpos_enu)}, pols=stackconf.pols)

In [None]:
auto_stats = lstmet.LSTBinStats.from_reduced_data(stackconf.autos, stackconf.pols, autos_lstavg, reds=reds_with_pols)

In [None]:
cross_stats = lstmet.LSTBinStats.from_reduced_data(stackconf.antpairs, stackconf.pols, cross_lstavg, reds=reds_with_pols)

In [None]:
zscores = lstmet.get_zscores(auto_stats, cross_stats, dt, df, cross_stacks)

In [None]:
zdist_pred = lstmet.zsquare_predicted_dist()

### Distributions of $Z^2$

#### Simple Histogram

In [None]:
if make_plots:
    fig, ax = plt.subplots(1, 2, figsize=(14, 5))

    x = np.logspace(-5, 7, 200)

    xc = 10**((np.log10(x[1:]) + np.log10(x[:-1]))/2)
    dndzsq = zdist_pred.pdf(xc)

    ax[0].hist(zscores[0].data_array.real.flatten(), bins=x, label='First LST Bin', density=True, histtype='step')
    ax[0].hist(zscores[1].data_array.real.flatten(), bins=x, label="Second LST Bin", density=True, histtype='step')
    ax[0].plot(xc, dndzsq, label='Predicted')
    ax[0].set_xscale('log')
    ax[0].set_yscale('log')
    ax[0].set_ylim(1e-12, 1e4)
    ax[0].legend()
    ax[0].set_xlabel(r"Log10 $Z^2$")
    ax[0].set_title("PDF of $Z^2$")

    # Plot the CDF
    x = np.linspace(0, 100, 100)

    size0 = np.sum(~np.isnan(zscores[0].data_array))
    size1 = np.sum(~np.isnan(zscores[1].data_array))

    cdf_data_0 = [np.sum(zscores[0].data_array.real < c)/size0 for c in x]
    cdf_data_1 = [np.sum(zscores[1].data_array.real < c)/size1 for c in x]

    ax[1].plot(x, cdf_data_0)
    ax[1].plot(x, cdf_data_1)
    ax[1].plot(x, zdist_pred.cdf(x))
    ax[1].set_xlabel(r"$Z^2$")
    ax[1].set_title("CDF of $Z^2$")

#### Get list of bads

In [None]:
def consecutive(data: np.ndarray, stepsize: int=1) -> list[tuple[int, int]]:
    """From https://stackoverflow.com/a/46606745/1467820"""
    sequences = np.split(data, np.where(np.diff(data) != stepsize)[0]+1)
    
    l = []
    for s in sequences:
        if len(s) > 1:
            l.append((s[0], s[-1]))
        else:
            l.append((s[0], s[0]+1))
            
    return l

In [None]:
allbad = {}
inpainted_regions = {}


for lstbin, zuv in enumerate(zscores):
    for bl, zsq in zuv.antpairpol_iter():
        a, b, pol = bl
        nsamps = zuv.get_nsamples((a, b, pol))
        
        pol = {'ee': 0, 'nn': 1}[pol]
        
        
        for night, zsqn in enumerate(zsq):
            nsampsn = nsamps[night]
            
            jdint = zuv.time_array[::zuv.Nbls][night]
            
            badfreqs = np.nonzero((zsqn.real > 9))[0]
            if len(badfreqs) > 0:
                ranges = consecutive(badfreqs)

                for rng in ranges:
                    allbad[(lstbin, a, b, pol, jdint, rng[0], rng[1])] = zsqn[rng[0]:rng[1]]

            paintedfreqs = np.nonzero((nsampsn < 0))[0]
            if len(paintedfreqs) > 0:
                ranges = consecutive(paintedfreqs)
                for rng in ranges:
                    inpainted_regions[(lstbin, a, b, pol, jdint, rng[0], rng[1])] = nsampsn[rng[0]:rng[1]]

In [None]:
chunk_lengths = [b - a for _, _, _, _, _, a, b in allbad.keys()]
print("Biggest Frequency Chunk With |Z|>3: ", np.max(chunk_lengths))

In [None]:
if save_metric_data:
    # Write out the "bad" data
    fname = out_fname.format(kind="HIGHZ")


    with h5py.File(outdir / fname, 'w') as fl:
        fl['indices'] = np.array(list(allbad.keys()))  # integer array
        fl['zsq'] = np.concatenate(tuple(allbad.values()))


#### Histogram of freq-chunk size

In [None]:
if make_plots:
    plt.hist(chunk_lengths, bins=np.arange(np.min(chunk_lengths), np.max(chunk_lengths)+1))
    plt.yscale('log')
    plt.xlabel("Channel-Chunk Length with |Z|>3")
    plt.ylabel("Number of Occurences");

#### Histogram of Inpainted Region size

In [None]:
if make_plots:
    inpsize = [k[-1] - k[-2] for k in inpainted_regions]
    plt.hist(inpsize, bins=np.arange(np.max(inpsize)+1)-0.5)
    plt.xlabel("Inpainted-Chunk Length")
    plt.ylabel("Number of chunks")
    plt.yscale('log')

#### BoxPlots of Z^2 across axis chunks

In [None]:
def _set_boxplot_ax_props(nboxes: int, ax):
    ax.axhline(zdist_pred.ppf(0.5), ls='-', color='gray')
    ax.fill_between([-0.5, nboxes-0.5], [zdist_pred.ppf(0.25)]*2, [zdist_pred.ppf(0.75)]*2, color='gray', alpha=0.3)
    ax.axhline(1, ls='--', color='C3', lw=1)
    ax.set_ylim(1e-1, None)
    ax.set_xlim(-0.5, nboxes-0.5)
    
    ax.set_yscale('log')
    ax.set_ylabel(r"$Z^2$")

In [None]:
def box_plot_all_groups(zscores):
    fig, axx = plt.subplots(len(subsets), 1, sharex=True, figsize=(12, 3*len(subsets)), layout='constrained')

    allbls = [(a, b, p) for a, b in stackconf.antpairs for p in stackconf.pols]

    for j, (name, selector) in enumerate(subsets.items()):
        ax = axx[j]
            
        for i, band in enumerate(bands_considered):
            for n, night in enumerate(data_jd_ints):
                allz = lstmet.get_data_subset(zscores, band=band, nights=night, selector=selector)
                
                bplot = ax.boxplot(
                    allz, positions = [i-0.3 + 0.05*n], 
                    showfliers=False, whis=(0,100), showmeans=True,
                    labels=[f"chs {band[0]}-{band[1]}" if (n==len(data_jd_ints)//2 and j==(len(subsets)-1)) else ""], 
                )
                bplot['boxes'][0].set_color(styles[night]['color'])
                bplot['boxes'][0].set_linestyle(styles[night]['ls'])
                bplot['whiskers'][0].set_color(styles[night]['color'])
                bplot['whiskers'][0].set_linestyle(styles[night]['ls'])
                bplot['whiskers'][1].set_color(styles[night]['color'])
                bplot['whiskers'][1].set_linestyle(styles[night]['ls'])
                
                bplot['caps'][0].set_color(styles[night]['color'])
                bplot['caps'][1].set_color(styles[night]['color'])
                
                bplot['means'][0].set_marker("*")
                bplot['means'][0].set_markerfacecolor(styles[night]['color'])
                bplot['means'][0].set_markeredgecolor(styles[night]['color'])
                bplot['means'][0].set_markersize(10)
                
                if i==0 and j==0:
                    # Dummy lines for legend
                    ax.plot([1,2], [np.nan, np.nan], **styles[night], label=str(night))
                    
        _set_boxplot_ax_props(len(bands_considered), ax)
        ax.set_ylabel(name.replace(" baselines", ""))
    
    axx[0].legend(ncols=3)
    
    return axx


In [None]:
if make_plots:
    box_plot_all_groups(zscores);

### Mean Z^2 Over Different Axes

In [None]:
metrics = {}

In [None]:
metrics['band_reduced_mean'] = {}
for band in bands_considered:
    metrics['band_reduced_mean'][band] = lstmet.get_data_subset_mean(
        zscores, band=band, mean_over='band'
    )

In [None]:
metrics['bl_reduced_mean'] = {}
allbls = [(a,b,p) for a,b in stackconf.antpairs for p in stackconf.pols]

for j, (name, selector) in enumerate(subsets.items()):
    metrics['bl_reduced_mean'][name] = lstmet.get_data_subset_mean(
        zscores, selector=selector, mean_over='bls'
    )

In [None]:
metrics['night_reduced_mean'] = lstmet.get_data_subset_mean(
    zscores, mean_over='nights'
)

In [None]:
metrics['night_and_bl_reduced_mean'] = {}

for j, (name, selector) in enumerate(subsets.items()):    
    metrics['night_and_bl_reduced_mean'][name] = lstmet.get_data_subset_mean(
        zscores, selector=selector, mean_over=('nights', 'bls')
    )

In [None]:
metrics['night_and_band_reduced_mean'] = {}

for band in bands_considered:
    metrics['night_and_band_reduced_mean'][band] = lstmet.get_data_subset_mean(
        zscores, band=band, mean_over=('nights', 'band')
    )

In [None]:
metrics['bl_and_band_reduced_mean'] = {}

for j, (name, selector) in enumerate(subsets.items()):
    for band in bands_considered:
        metrics['bl_and_band_reduced_mean'][(band, name)] = lstmet.get_data_subset_mean(
            zscores, band=band, selector=selector, mean_over=('bls', 'band')
        )

In [None]:
metrics['all_reduced_mean'] = {}

for j, (name, selector) in enumerate(subsets.items()):    
    for band in bands_considered:
        metrics['all_reduced_mean'][(band, name)] = lstmet.get_data_subset_mean(
            zscores, band=band, selector=selector, mean_over=('bls', 'band', 'nights')
        )

#### Plot Totally Reduced

In [None]:
subset_styles = {name: {'color': f"C{i%len(subsets)}", 'ls': ['-', '--', ':', '-.'][i//4]} for i, name in enumerate(subsets.keys())}

In [None]:
if make_plots:
    done = set()
    for (band, subset_name), means in metrics['all_reduced_mean'].items():
        mid = np.mean(band)
        size=0 if band[1]-band[0]==200 else (1 if band[1]-band[0] < 1500 else 2)

        plt.errorbar([np.mean(band)], means[0], xerr=[[mid - band[0]]], marker='ox*'[size], markersize=8, **subset_styles[subset_name], label=subset_name.replace("baselines", "") if subset_name not in done else None)
        done.add(subset_name)
    plt.legend(ncols=2)
    plt.yscale('log')

#### Plot Reduced over Nights + Bands

In [None]:
def make_baseline_zsq_plot():
    # TODO: need a better cmap to easily see what's "good" and "bad"
    
    fig, axx = plt.subplots(len(bands_considered)-3, 2, sharey=True, figsize=(24, 5*(len(bands_considered)-3)), layout='constrained')
    
    cmap = mpl.colors.ListedColormap(["C0", f"C1", f"C3"])
    for i, band in enumerate(bands_considered):
        if band[1] - band[0] > 200:
            continue

        ax = axx[i]
        
        mean_zsq = metrics['night_and_band_reduced_mean'][band][0]
    
        uvws = zscores[0].uvw_array[:zscores[0].Nbls][:, :2]
        uvws[uvws[:, 1] < 0] *= -1

        ax[0].scatter(uvws[:, 0], uvws[:, 1], c=mean_zsq[:, 0].real, norm=mpl.colors.LogNorm( vmin=1, vmax=1000), marker='H', s=60, cmap=cmap)
        ax[0].set_title(stackconf.pols[0])
        ax[0].set_aspect("equal", 'datalim')
        ax[0].set_xlim(-200, 200)

        cbar = ax[1].scatter(uvws[:, 0], uvws[:, 1], c=mean_zsq[:, 1].real, norm=mpl.colors.LogNorm( vmin=1, vmax=1000), marker='H', s=60, cmap=cmap)
        ax[1].set_title(stackconf.pols[1])
        ax[1].set_aspect("equal", 'datalim')
        ax[1].set_xlim(-200, 200)
        ax[0].grid(True)
        ax[1].grid(True)
        
        ax[0].set_ylabel(str(band))
        
        plt.colorbar(cbar, ax=ax)

In [None]:
if make_plots:
    make_baseline_zsq_plot()

#### Plot Reduced over Nights and bls

In [None]:
def plot_excess_variance_wrt_freq():
    for subset, zsq in metrics['night_and_bl_reduced_mean'].items():
        # do the mean over the two LST bins here...
        zsq = np.nanmean(zsq, axis=0)
        
        plt.plot(stackconf.config.datameta.freq_array / 1e6, zsq, label=subset.replace("baselines", ""), **subset_styles[subset])
        
    plt.xlabel("Freq [MHz]")
    plt.ylabel(r"Mean $Z^2$ across Nights, LSTs and Baselines")
    plt.legend(ncols=2)
    plt.ylim(7e-1, 100)
    plt.yscale('log')

In [None]:
if make_plots:
    plot_excess_variance_wrt_freq()

#### Plot Reduced over Bls

In [None]:
def plot_reduced_over_bls():
    images = {}

    for subset, zsqs in metrics['bl_reduced_mean'].items():

        jdidx = [data_jd_ints.index(jd) for jd in zscores[0].time_array[::zscores[0].Nbls].astype(int)]

        images[subset] = zsqs[0][jdidx]

    nrows = int(np.ceil(len(subsets)/3))

    fig, ax = plt.subplots(nrows, 3, sharex=True, sharey=True, layout='constrained', figsize=(14, 3*nrows))

    cmap = mpl.colors.ListedColormap(["C0", f"C1", f"C3"])

    for i, (key, img) in enumerate(images.items()):
        axx = ax.flatten()[i]
        plt.sca(axx)

        cbar = plt.imshow(
            img, norm=mpl.colors.LogNorm( vmin=1, vmax=1000),
            origin='lower',
            extent=(
                zscores[0].freq_array.min()/1e6, 
                zscores[0].freq_array.max()/1e6,
                0,
                len(data_jd_ints)
            ),
            cmap=cmap, aspect='auto',
            interpolation='none',
        )

        #axx.xaxis.set_ticks(np.arange(img.shape[1]))
        axx.yaxis.set_ticks(np.arange(img.shape[0]) +0.5)

        #axx.xaxis.set_ticklabels([f"<{b[1]}" for b in bands_considered[:img.shape[1]]])
        axx.yaxis.set_ticklabels(data_jd_ints)

        axx.set_title(key.replace("baselines", ""), pad=-3)

        if i < 3:
            axx.tick_params('x', labeltop=True, labelbottom=False, top=True)

    for j in range(i+1, ax.size):
        ax.flatten()[j].axis('off')

    cbar = plt.colorbar(cbar, ax = ax)
    cbar.set_label(r"Mean $Z^2$ over bl subset")

In [None]:
if make_plots:
    plot_reduced_over_bls()

### Plot Selection of the Worst Visibilities

In [None]:
def plot_visibilities_per_type(
    lstbin_blpols: list[tuple[int, tuple[int, int, str]]], 
    stacks: list[UVData],
    stats: lstmet.LSTBinStats,
    comments: list[str],
    freq_range=None | tuple[float, float] | list[tuple[int, int]], 
    label=None, 
    yrange=None,
    alpha=0.5,
):
    all_figs = []
    
    lststyle = dict(color='k', lw=3, zorder=-1)
    meta = stackconf.config.datameta
    
    if isinstance(freq_range, tuple):
        mask = (meta.freq_array >= freq_range[0]) & (meta.freq_array < freq_range[1])
        freqs=meta.freq_array[mask]/1e6
    else:
        mask = slice(None)
        freqs = meta.freq_array/1e6

    handles = []
    for jdint, style in styles.items():
        handles.append(mpl.lines.Line2D([0], [0], label=str(jdint), alpha=alpha, **style))

            
    for i, (comment, (lstidx, blpol)) in enumerate(zip(comments, lstbin_blpols)):
        if isinstance(freq_range, list):
            this_range = freq_range[i]
            
            # pad the range a bit
            this_range = (max(this_range[0] - 100, 0), min(this_range[1]+100, 1536))
            mask = slice(this_range[0], this_range[1])
            freqs = meta.freq_array[mask]/1e6
            
        stack = stacks[lstidx]
        zscore = zscores[lstidx]
        
        rawd = stack.get_data(blpol)[:, mask]        
        rawf = stack.get_flags(blpol)[:, mask]
        rawn = stack.get_nsamples(blpol)[:, mask]
        inp = rawn < 0
        
        lstf = stats.flags[blpol][lstidx, mask]
        lstd = stats.mean[blpol][lstidx, mask]
        lstn = stats.nsamples[blpol][lstidx, mask]
        
        lstmed = stats.mean[blpol][lstidx, mask]  # actually mean
        lstmad = stats.mad[blpol][lstidx, mask]
        
        zsq = zscore.get_data(blpol)[:, mask].real
        
        if np.all(lstf):
            print("ALL FLAGGED")
            continue
            
        fig, ax = plt.subplots(
            4, 2, 
            sharex=True, figsize=(15, 8), 
            constrained_layout=True, gridspec_kw={'height_ratios': (2,1,2,1)}
        )
        
        nights_in = set()

        mag = np.where(rawf, np.nan, np.abs(rawd))
#        phs = np.where(rawf, np.nan, np.angle(rawd))
        rl = np.where(rawf, np.nan, rawd.real)
        im = np.where(rawf, np.nan, rawd.imag)
        
        maglstbin = np.where(lstf, np.nan, np.abs(lstd))
#        phslstbin = np.where(lstf, np.nan, np.angle(lstd))
        rllstbin = np.where(lstf, np.nan, lstd.real)
        imlstbin = np.where(lstf, np.nan, lstd.imag)
        
        rllstbin_med = np.where(lstf, np.nan, lstmed.real)
        imlstbin_med = np.where(lstf, np.nan, lstmed.imag)
        rllstbin_mad = np.where(lstf, np.nan, lstmad.real)
        imlstbin_mad = np.where(lstf, np.nan, lstmad.imag)
                
        ax[0, 0].plot(freqs, maglstbin, **lststyle)
        ax[0, 1].plot(freqs, rllstbin, **lststyle)                
        ax[2, 1].plot(freqs, imlstbin, **lststyle)
        
        for jdidx, jd in enumerate(stack.time_array[::stack.Nbls]):
            jdint = int(jd)
            
            style = styles[jdint]

            if np.all(rawf[jdidx]):
                continue

            thisinp = inp[jdidx]
            inp_ranges = consecutive(np.nonzero(thisinp)[0])
            
            # Amplitude and Phase
            ax[0, 0].plot(freqs, mag[jdidx], **style)
            for rng in inp_ranges:
                ax[0, 0].fill_between(freqs[rng[0]:rng[1]], mag[jdidx, rng[0]:rng[1]], maglstbin[rng[0]:rng[1]], color=style['color'], alpha=0.2)
                
            ax[1, 0].plot(freqs, mag[jdidx] - maglstbin, **style)
            for rng in inp_ranges:
                ax[1, 0].fill_between(freqs[rng[0]:rng[1]], mag[jdidx, rng[0]:rng[1]] - maglstbin[rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)
            
            ax[2, 0].plot(freqs, zsq[jdidx], **style)
            for rng in inp_ranges:
                ax[2, 0].fill_between(freqs[rng[0]:rng[1]], zsq[jdidx, rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)

            # Real / Imag
            ax[0, 1].plot(freqs, rl[jdidx], **style)
            for rng in inp_ranges:
                ax[0, 1].fill_between(freqs[rng[0]:rng[1]], rl[jdidx, rng[0]:rng[1]], rllstbin[rng[0]:rng[1]], color=style['color'], alpha=0.2)
            
            rldiff = (rl[jdidx] - rllstbin_med)/rllstbin_mad
            ax[1, 1].plot(freqs, rldiff, **style)
            for rng in inp_ranges:
                ax[1,1].fill_between(freqs[rng[0]:rng[1]], rldiff[rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)
            
            ax[2, 1].plot(freqs, im[jdidx], **style)
            for rng in inp_ranges:
                ax[2, 1].fill_between(freqs[rng[0]:rng[1]], im[jdidx, rng[0]:rng[1]], imlstbin[rng[0]:rng[1]], color=style['color'], alpha=0.2)
            
            imdiff = (im[jdidx] - imlstbin_med)/imlstbin_mad
            ax[3, 1].plot(freqs, (im[jdidx] - imlstbin_med)/imlstbin_mad, **style)
            for rng in inp_ranges:
                ax[3,1].fill_between(freqs[rng[0]:rng[1]], imdiff[rng[0]:rng[1]], 0, color=style['color'], alpha=0.2)

            if yrange:
                ax[0, 0].set_ylim(yrange)

        ax[1,1].axhline(4, color='gray', ls='--')
        ax[1,1].axhline(-4, color='gray', ls='--')

        ax[3,1].axhline(4, color='gray', ls='--')
        ax[3,1].axhline(-4, color='gray', ls='--')
            
        bl_coords = stackconf.config.datameta.antpos_enu[blpol[0]] - stackconf.config.datameta.antpos_enu[blpol[1]]
        
        fig.suptitle(
            f"Baseline: {blpol} [{bl_coords[0]:.1f}-EW, {bl_coords[1]:.1f}-NS]. "
            f"LST = {stackconf.lst_grid[0]*12/np.pi:.5f} hr."
        )
        ax[-1, 0].set_xlabel("Frequency [MHz]")
        ax[-1, 1].set_xlabel("Frequency [MHz]")
        
        ax[0, 0].set_ylabel("Magnitude")
        ax[0, 1].set_ylabel("Real Part")
        
        ax[1, 0].set_ylabel("Magnitude Diff")
        ax[1, 1].set_ylabel("Real Z-score")
        ax[1, 1].set_ylim(-7, 7)
        
        ax[2, 0].set_ylabel(r"$Z^2$")
        ax[2, 0].set_yscale('log')
        ax[2, 0].set_ylim(1e-1,)
        
        ax[2, 1].set_ylabel("Imag Part")
        
        #ax[3, 0].set_ylabel("Phase Diff")
        ax[3, 1].set_ylabel("Imag Z-score")
        ax[3, 1].set_ylim(-7, 7)
        ax[0, 0].legend(handles=handles, ncols=5)

        ax[0,1].text(0.95, 0.95, comment, transform=ax[0,1].transAxes, ha='right', va='top')

        for axx in ax.flatten():
            for line in range(0, 1536, 200):
                axx.axvline(meta.freq_array[line]/1e6, color='gray', alpha=0.4)
            axx.set_xlim(freqs[0], freqs[-1])
            
        all_figs.append(fig)
        
    return all_figs 

In [None]:
def get_worst_mean_over_each_band(n=1):
    bad_fellas = {}

    nights0 = [data_jd_ints.index(jd) for jd in zscores[0].time_array[::zscores[0].Nbls].astype(int)]
    nights1 = [data_jd_ints.index(jd) for jd in zscores[1].time_array[::zscores[1].Nbls].astype(int)]
        
    newmeans = {band: np.ones((len(zscores), len(data_jd_ints), len(stackconf.antpairs), len(stackconf.pols)))*np.nan for band in metrics['band_reduced_mean']}
    
    for band, zsqs in metrics['band_reduced_mean'].items():
        # zsqs is length(lstbins), where each is an array of shape (nights, antpairs, pols)
        # however, the number of nights for each lstbin could be different, so make them the same here....
        newmeans[band][0, nights0] = zsqs[0]
        newmeans[band][1, nights1] = zsqs[1]
        
    lst_night_bl_pols = [(lst, jd, bl + (pol,)) for lst in range(len(zscores)) for jd in data_jd_ints for bl in stackconf.antpairs for pol in stackconf.pols]
    
    for band, zsq in newmeans.items():
        zsq = np.where(np.isnan(zsq.flatten()), -1, zsq.flatten())
        
        worst_idx = np.argpartition(zsq, -n)[-n:]
        worst_zsq = zsq[worst_idx]
        worst_idx = worst_idx[np.argsort(-worst_zsq)]
        
        for idx, z in zip(worst_idx, worst_zsq):
            lst, jd, bl = lst_night_bl_pols[idx]
            
            if (lst, bl) not in bad_fellas:
                bad_fellas[(lst, bl)] = []
                
            bad_fellas[(lst, bl)].append((jd, z, fr"Worst $Z^2$ in band {band[0]}-{band[1]}", band))

    return bad_fellas


In [None]:
def get_worst_mean_for_continuously_bad_stuff(n=1):
    
    bad_fellas = {}
    
    chsizes = [(1, 2), (2, 10), (10, 20), (20, 50), (50, 100), (100, 1536)]
    sized = {ch: {} for ch in chsizes}
    for k, v in allbad.items():
        s = k[-1] - k[-2]  # size of chunk
        if s == 1:
            continue
        
        for i, ch in enumerate(chsizes):
            if ch[0] <= s < ch[1]:
                sized[ch][k] = v

    for chsize, thesebads in sized.items():
        
        keys = list(thesebads.keys())
        meanz = np.array([np.nanmean(v) for v in thesebads.values()])
        
        worst_idx = np.argpartition(meanz, -n)[-n:]
        worst_zsq = meanz[worst_idx]
        worst_idx = worst_idx[np.argsort(-worst_zsq)]

        for idx, z in zip(worst_idx, worst_zsq):
            lst, a, b, pol, jdint, low, high = keys[idx]
            bl = (a, b, stackconf.pols[pol])
            
            if (lst, bl) not in bad_fellas:
            
                bad_fellas[(lst, bl)] = []

            bad_fellas[(lst, bl)].append((int(jdint), z, fr"Worst $Z^2$ over {chsize[0]}-{chsize[1]} channels",(low, high)))
    return bad_fellas

In [None]:
def get_worst_continuous_bad_zscore(n=1):
    bad_fellas = {}
    nights0 = [data_jd_ints.index(jd) for jd in zscores[0].time_array[::zscores[0].Nbls].astype(int)]
    nights1 = [data_jd_ints.index(jd) for jd in zscores[1].time_array[::zscores[1].Nbls].astype(int)]
    
    smallbands = [b for b in bands_considered if b[1] - b[0] <= 200]
    
    newmeans = np.ones(
        (len(smallbands), len(zscores), len(data_jd_ints), len(stackconf.antpairs), len(stackconf.pols))
    )*np.nan
    
    for i, band in enumerate(smallbands):
        zsqs = metrics['band_reduced_mean'][band]
        # zsqs is length(lstbins), where each is an array of shape (nights, antpairs, pols)
        # however, the number of nights for each lstbin could be different, so make them the same here....
        newmeans[i, 0, nights0] = zsqs[0]
        newmeans[i, 1, nights1] = zsqs[1]

    lst_night_bl_pols = [(lst, jd, bl + (pol,)) for lst in range(len(zscores)) for jd in data_jd_ints for bl in stackconf.antpairs for pol in stackconf.pols]

    zsq = np.nanmin(newmeans, axis=0)
    
    zsq = np.where(np.isnan(zsq).flatten(), -1, zsq.flatten())
    
    worst_idx = np.argpartition(zsq, -n)[-n:]
    worst_zsq = zsq[worst_idx]
    worst_idx = worst_idx[np.argsort(-worst_zsq)]

    for idx, z in zip(worst_idx, worst_zsq):
        lst, jd, bl = lst_night_bl_pols[idx]

        if (lst, bl) not in bad_fellas:
            bad_fellas[(lst, bl)] = []

        bad_fellas[(lst, bl)].append((jd, z, fr"Worst min($Z^2$) over entire band", (0, 1536)))

    return bad_fellas


In [None]:
def get_bad_inpaints(n=1):
    
    bad_fellas = {}
    
    nights = [
        zscores[0].time_array[::zscores[0].Nbls].astype(int).tolist(),
        zscores[1].time_array[::zscores[1].Nbls].astype(int).tolist()
    ]

    chsizes = [(2, 5), (5, 10), (10, 20)]    
    sized = {ch: {} for ch in chsizes}
    for k, v in inpainted_regions.items():
        s = k[-1] - k[-2]  # size of chunk
        if s == 1:
            continue
        
        for i, ch in enumerate(chsizes):
            if ch[0] <= s < ch[1]:
                sized[ch][k] = v
                
    for chsize, bads in sized.items():
        
        
        keys = list(bads.keys())
        
        meanz = np.array([
            np.nanmean(zscores[lst].get_data((a, b, stackconf.pols[pol]))[nights[lst].index(int(jdint)), low:high]) 
            for (lst, a, b, pol, jdint, low, high) in bads.keys()
        ])
        
        worst_idx = np.argpartition(meanz, -n)[-n:]
        worst_zsq = meanz[worst_idx]
        worst_idx = worst_idx[np.argsort(-worst_zsq)]

        for idx, z in zip(worst_idx, worst_zsq):
            lst, a, b, pol, jdint, low, high = keys[idx]
            bl = (a, b, stackconf.pols[pol])
            
            if (lst, bl) not in bad_fellas:
            
                bad_fellas[(lst, bl)] = []

            bad_fellas[(lst, bl)].append((int(jdint), z, fr"Worst inpainted $Z^2$ for {chsize[0]}-{chsize[1]} chans", (low, high)))
    return bad_fellas

In [None]:
if make_plots:
    worst_mean_over_each_band = get_worst_mean_over_each_band(n=plot_n_worst)

In [None]:
if make_plots:
    worst_mean_for_continously_bad = get_worst_mean_for_continuously_bad_stuff(n=plot_n_worst)

In [None]:
if make_plots:
    worst_minimum_zscores_over_bands = get_worst_continuous_bad_zscore(n=plot_n_worst)

In [None]:
if make_plots:
    worst_inpainted_regions = get_bad_inpaints(n=plot_n_worst)

In [None]:
if make_plots:
    # Merge all the things that we want to take a closer look at
    badstuff = {}

    for dct in (worst_mean_over_each_band, worst_mean_for_continously_bad, worst_minimum_zscores_over_bands, worst_inpainted_regions):
        for k, v in dct.items():
            if k not in badstuff:
                badstuff[k] = []

            badstuff[k].extend(v)

In [None]:
if make_plots:
    freq_ranges = [sum((vv[-1] for vv in v), start=()) for v in badstuff.values()]
    freq_ranges = [(min(v), max(v)) for v in freq_ranges]

    plot_visibilities_per_type(
        lstbin_blpols= list(badstuff.keys()), 
        stacks= cross_stacks,
        stats= cross_stats,
        comments=["\n".join([f"{vv[-2]}: {vv[0]}" for vv in v]) for v in badstuff.values()],
        freq_range=freq_ranges,
        alpha=0.5,
    );

### Write Out Metrics

In [None]:
# Write out the "bad" data
fname = out_fname.format(kind='LSTBIN-METRICS')

In [None]:
if save_metric_data:
    def write_metric(grp, metric: dict[str, list[np.ndarray]]):
        for key in metric:
            _grp = grp.create_group(str(key))

            for i, lstbin in enumerate(metric[key]):
                _grp[f'zsqmean-{i}'] = lstbin


    with h5py.File(outdir / fname, 'w') as fl:

        meta = fl.create_group("meta")
        meta['pols'] = stackconf.pols
        meta['ant1'] = np.array([a for a, b in stackconf.antpairs])
        meta['ant2'] = np.array([b for a, b in stackconf.antpairs])
        meta['freqs'] = stackconf.config.datameta.freq_array
        nights = meta.create_group("nights")
        for i, nght in enumerate(zscores):
            nights[str(i)] = nght.time_array[::nght.Nbls].astype(int)

        mgrp = fl.create_group("metrics")

        for name, mtrc in metrics.items():
            if name=='night_reduced_mean':
                # night reduced mean is different -- simply an array
                nrm = mgrp.create_group("night_reduced_mean")
                for i, val in enumerate(mtrc):
                    nrm[f'zsqmean-{i}'] = val
            else:
                write_metric(mgrp.create_group(name), mtrc)
        
        

## Notebook Metadata and Software Versions

In [None]:
for repo in ['numpy', 'scipy', 'astropy', 'hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata']:
    exec(f'from {repo} import __version__')
    print(f'{repo:>25}: {__version__}')

In [None]:
print("Run by: ")
os.system("whoami");

In [None]:
print(f"Run on {datetime.now()}")

In [None]:
print(f"Execution of notebook took: {(_time() - start_time)/60.0:.2f} minutes")

In [None]:
print(f"Peak memory in this notebook run: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024**2:.2f} GB")