In [None]:
import numpy as np
import matplotlib.pyplot as plt
from hera_cal import frf
import glob
import os
from copy import deepcopy
from hera_cal import redcal
from IPython.display import display, HTML
from hera_cal.io import HERAData
from matplotlib.colors import LogNorm
from hera_pspec import utils
%config Completer.use_jedi = False
from scipy.interpolate import interp1d
from hera_pspec.container import PSpecContainer
import hera_pspec.plot as pspecplot
import copy
from matplotlib import cm as cmaps
from hera_pspec import grouping
import gc

import tqdm
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
#os.environ = {
#    'JULIANDATE': 'LST.0.00000',
#    'DATA_PATH': '/lustre/aoc/projects/hera/H4C/postprocessing/lstbin/after_filtering_before_red_average/all-bands/',
#    'LABEL' : 'all-bands-allbls',
#    'SPWS' : "0,1,2,3,4,5,6,7",
#    'LST_FIELDS': "1~3,4.2~6.2",
#    'GRP_SKIP': '10',
#    'BLP_SKIP': '2',
#    'FIELD_LABELS': '1,2',
#    'MAX_PLOTS_PER_ROW': '10',
#    'EXTS' : 'foreground_filled~foreground_res.filled_flags~foreground_model.filled_flags'

#}


In [None]:
# Use environment variables to figure out path to data
JD = os.environ['JULIANDATE']
data_path = os.environ['DATA_PATH']
label = os.environ['LABEL']
spws = os.environ['SPWS'].split(',')
lst_fields = os.environ['LST_FIELDS'].split(',')
for m in range(len(lst_fields)):
    field = lst_fields[m].split('~')
    lst_fields[m] = (float(field[0]), float(field[1]))

exts = os.environ['EXTS'].split('~')
grp_skip = int(os.environ['GRP_SKIP'])
blp_skip = int(os.environ['BLP_SKIP'])
field_labels = os.environ['FIELD_LABELS'].split(',')
max_plots_per_row = int(os.environ['MAX_PLOTS_PER_ROW'])
spws = [int(spw) for spw in spws]
if 'LST' in JD:
    JD = 'LST'
print(f'JD = "{JD}"')
print(f'data_path = "{data_path}"')
print(f'label = "{label}"')
print(f'spws = "{spws}"')
print(f'lst_fields = "{lst_fields}"')
print(f'grp_skip = "{grp_skip}"')
print(f'blp_skip = "{blp_skip}"')
print(f'field_labels = "{field_labels}"')
print(f'max_plots_per_row = "{max_plots_per_row}"')
print(f'exts = "{exts}"')

In [None]:
from astropy.time import Time
try:
    utc = Time(JD, format='jd').datetime
    print(f'Date: {utc.month}-{utc.day}-{utc.year}')
except:
    print("Could not parse JD=${JD}")

In [None]:
psc_files_sum = {ext: None for ext in exts}
psc_files_diff = {ext: None for ext in exts}

In [None]:
for ext in psc_files_sum:
    print('Looking for sum power-spectrum containers in', data_path, 'on JD', JD)
    psc_files_sum[ext] = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.*.sum.{label}.{ext}.xtalk_filtered_pstokes.tavg.pspec.h5')))
    print('Found {} files.'.format(len(psc_files_sum[ext])))
    print('Looking for diff power-spectrum containers in', data_path, 'on JD', JD)
    psc_files_diff[ext]= sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.*.diff.{label}.{ext}.xtalk_filtered_pstokes.tavg.pspec.h5')))
    print('Found {} files.'.format(len(psc_files_diff[ext])))
    diffs = len(psc_files_diff) > 0

In [None]:
psca_files_sum = {ext: None for ext in exts}
psca_files_diff = {ext: None for ext in exts}
for ext in psca_files_sum:
    print('Looking for sum power-spectrum containers in', data_path, 'on JD', JD)
    psca_files_sum[ext] = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.*.sum.{label}.autos.{ext}_pstokes.tavg.pspec.h5')))
    print('Found {} files.'.format(len(psca_files_sum[ext])))
    print('Looking for diff power-spectrum containers in', data_path, 'on JD', JD)
    psca_files_diff[ext]= sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.*.diff.{label}.autos.{ext}_pstokes.tavg.pspec.h5')))
    print('Found {} files.'.format(len(psca_files_diff[ext])))
    diffs = len(psca_files_diff) > 0 & diffs



In [None]:
def get_bls_to_keep(uvpt_list):
    # only keep baselines if they have some unflagged data in one of the files
    bls_to_keep = []
    for blp in uvpt_list[0].get_blpairs():
        integrations = []
        for uvpt in uvpt_list:
            blpslice = uvpt.blpair_to_indices(blp)
            integrations.append(uvpt.integration_array[0][blpslice, 0])
        if not np.allclose(integrations, 0.0):
            bls_to_keep.append(blp)
        else:
            print(f'{blp} has no integrations! Discarding...')
    return bls_to_keep
            
def get_files_to_keep(uvpt_list):
    files_to_keep = []
    files_to_delete = []
    findex = 0
    for uvpt in tqdm.tqdm_notebook(uvpt_list):
        if not np.any(np.isclose(uvpt.integration_array[0][:, 0], 0.0)):
            finite_stat = True
            if hasattr(uvpt, 'stats_array'):
                for stat in uvpt.stats_array:
                    for spw in uvpt.spw_array:
                        finite_stat = finite_stat and np.all(np.isfinite(uvpt.stats_array[stat][spw]))
            else:
                for spw in uvpt.data_array:
                    finite_stat = finite_stat and np.all(np.isfinite(uvpt.data_array[spw]))
            if finite_stat:
                files_to_keep.append(findex)
    return files_to_keep




In [None]:
def discard_flagged_blpairs(uvpt_list):
    # only keep baselines if they have some unflagged data in one of the files
    bls_to_keep = []
    for blp in uvpt_list[0].get_blpairs():
        integrations = []
        for uvpt in uvpt_list:
            blpslice = uvpt.blpair_to_indices(blp)
            integrations.append(uvpt.integration_array[0][blpslice, 0])
        if not np.allclose(integrations, 0.0):
            bls_to_keep.append(blp)
        else:
            print(f'{blp} has no integrations! Discarding...')
    for i in tqdm.tqdm_notebook(range(len(uvpt_list))):
        uvpt_list[i].select(blpairs=bls_to_keep, inplace=True)
    return uvpt_list

def discard_flagged_times(uvpt_list):
    files_to_keep = []
    files_to_delete = []
    findex = 0
    for uvpt in tqdm.tqdm_notebook(uvpt_list):
        if not np.any(np.isclose(uvpt.integration_array[0][:, 0], 0.0)):
            finite_stat = True
            if hasattr(uvpt, 'stats_array'):
                for stat in uvpt.stats_array:
                    for spw in uvpt.spw_array:
                        finite_stat = finite_stat and np.all(np.isfinite(uvpt.stats_array[stat][spw]))
            else:
                for spw in uvpt.data_array:
                    finite_stat = finite_stat and np.all(np.isfinite(uvpt.data_array[spw]))
            if finite_stat:
                files_to_keep.append(findex)
            else:
                files_to_delete.append(findex)
        else:
            files_to_delete.append(findex)
        findex += 1
    for i in tqdm.tqdm_notebook(sorted(files_to_delete, reverse=True)):
        del uvpt_list[i]
    return uvpt_list
        
    
def merge_ps_list(ps_list):
    # do a logarithmic merge into single power spectrum.
    # this saves quiete a bit of time by reducing the number of merge calls.
    if len(ps_list) == 1:
        return ps_list[0]
    else:
        return merge_ps_list(ps_list[:len(ps_list)//2])\
        + merge_ps_list(ps_list[len(ps_list)//2:])
        

In [None]:
import pdb
from joblib import Parallel, delayed
import multiprocessing
ncpu = multiprocessing.cpu_count()

def perform_average(uvpt, error_weights):
    blp_groups, _, _, _ = utils.get_blvec_reds(uvpt, bl_error_tol=1.0)
    # exclude auto baselines. We can do a comparison between including autos and not including autos.
    # later.
    blp_groups = [[blp for blp in blpgrp if blp[0] != blp[1]] for blpgrp in blp_groups]
    blp_groups = [blp_grp for blp_grp in blp_groups if len(blp_grp) > 0]
    # discard baselines not chunked in averaging groups.
    blps_to_keep = []
    for grp in blp_groups:
        blps_to_keep.extend(grp)
    uvpt.select(blpairs=blps_to_keep, inplace=True)
    return uvpt.average_spectra(blpair_groups=blp_groups, inplace=False, error_weights=error_weights,
                                          error_field=[stat for stat in uvpt.stats_array])
        
def load_select_and_average(fname, blps_to_keep, error_weights):
    psc = PSpecContainer(fname, keep_open=False)
    uvpt = psc.get_pspec(group='dset0', psname='dset0_x_dset0')
    uvpt.select(blpairs=blps_to_keep, inplace=True)
    for spw in uvpt.spw_array:
            assert np.all(np.isfinite(uvpt.data_array[spw]))
            if hasattr(uvpt, 'stats_array'):
                for stat in uvpt.stats_array:
                    assert np.all(np.isfinite(uvpt.stats_array[stat][spw]))
    uvpt = perform_average(uvpt, error_weights=error_weights)
    return uvpt


def extract_spectra(filelist, avg_baselines=True, error_weights="P_N", ncpus_per_job=2):
    uvp_list = []
    unum = 0
    print('extracting spectra...')
    for pscf in tqdm.tqdm_notebook(filelist):
        psc = PSpecContainer(pscf, keep_open=False)
        uvpt = psc.get_pspec(group='dset0', psname='dset0_x_dset0')
        uvp_list.append(uvpt)
    print('discarding flagged blpairs')
    uvp_list = discard_flagged_blpairs(uvp_list)
    print('discarding flagged times')
    uvp_list = discard_flagged_times(uvp_list)
    print('checking for NaNs')
    for uvpt in tqdm.tqdm_notebook(uvp_list):
        for spw in uvpt.spw_array:
            assert np.all(np.isfinite(uvpt.data_array[spw]))
            if hasattr(uvpt, 'stats_array'):
                for stat in uvpt.stats_array:
                    assert np.all(np.isfinite(uvpt.stats_array[stat][spw]))
    print('averaging baselines...')
    if avg_baselines:
        for i in range(len(uvp_list)):
            uvp_list[i] = perform_average(uvp_list[i], error_weights)
    #if avg_baselines:
    #    uvp_list = Parallel(n_jobs=ncpu // ncpus_per_job)(delayed(perform_average)(uvp_list[i], error_weights) for i in range(len(uvp_list)))
    print('merging spectra...')
    uvp = merge_ps_list(uvp_list)
    del uvp_list
    gc.collect()
    return uvp



In [None]:
exts = list(psc_files_sum.keys())
uvpa = Parallel(n_jobs=2)(delayed(extract_spectra)(psca_files_sum[exts[i]], avg_baselines=False, error_weights=None, ncpus_per_job=2) for i in range(len(exts))) 
uvpa = {ext: uvpa for (ext, uvpa) in zip(exts, uvpa)}
gc.collect()

In [None]:
exts = list(psc_files_sum.keys())
uvpda = Parallel(n_jobs=2)(delayed(extract_spectra)(psca_files_diff[exts[i]], avg_baselines=False, error_weights=None, ncpus_per_job=2) for i in range(len(exts))) 
uvpda = {ext: uvpda for (ext, uvpda) in zip(exts, uvpda)}
gc.collect()

In [None]:
exts = list(psc_files_sum.keys())

In [None]:
from hera_pspec.uvpspec import UVPSpec

In [None]:

fnames = [f'{data_path}/pspec_12_days_sum_{label}_{ext}.hdf5' for ext in exts]
if np.all([os.path.exists(fn) for fn in fnames]):
    uvp_avg = {ext: UVPSpec() for ext in exts}
    for ext, fn in zip(exts, fnames):
        uvp_avg[ext].read_hdf5(fn)
else:
    uvp_avg = {}
    uvp_avg = Parallel(n_jobs=2)(delayed(extract_spectra)(psc_files_sum[exts[i]], avg_baselines=True, error_weights='P_N', ncpus_per_job=2) for i in range(len(exts))) 
    uvp_avg = {ext: uvpa for (ext, uvpa) in zip(exts, uvp_avg)}
    gc.collect()
    for ext, fname in zip(exts, uvp_avg):
        uvp_avg[ext].write_hdf5(fname)

In [None]:
fnames = [f'{data_path}/pspec_12_days_diff_{label}_{ext}.hdf5' for ext in exts]
if np.all([os.path.exists(fn) for fn in fnames]):
    uvpd_avg = {ext: UVPSpec() for ext in exts}
    for ext, fn in zip(exts, fnames):
        uvpd_avg[ext].read_hdf5(fn)
else:
    uvpd_avg = {}
    uvpd_avg = Parallel(n_jobs=2)(delayed(extract_spectra)(psc_files_diff[exts[i]], avg_baselines=True, error_weights='P_N', ncpus_per_job=2) for i in range(len(exts))) 
    uvpd_avg = {ext: uvpda for (ext, uvpda) in zip(exts, uvpd_avg)}
    gc.collect()
    for ext, fname in zip(exts, fnames):
        uvpd_avg[ext].write_hdf5(fname)

In [None]:
# plot waterfalls of autos
def plot_auto_power_spectra_waterfalls(uvpt, spw=0, nblp_per_row=8, dynamic_range=1e9):
    lst_func = interp1d(uvpt.time_avg_array, uvpt.lst_avg_array * 12 / np.pi)
    vmax = 10 ** np.round(np.log10(np.max(np.abs(uvpt.data_array[spw]))))
    vmin = vmax / dynamic_range
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    extent_tdelay = [uvpt.get_dlys(spw).min() * 1e9, uvpt.get_dlys(spw).max() * 1e9, uvpt.time_avg_array.max(), uvpt.time_avg_array.min()]
    for blpc in range(0, uvpt.Nblpairs, nblp_per_row):
        fig, axarr = plt.subplots(1, nblp_per_row)
        fig.set_size_inches(36, 6)
        cbax = fig.add_axes([0.915, 0.1, 0.005, 0.8])
        for i in range(nblp_per_row):
            blpind = blpc + i
            if blpind < uvpt.Nblpairs:
                blp = uvpt.get_blpairs()[blpind]
                k = (spw, blp, ('pI', 'pI'))
                data = uvpt.get_data(k)
                plt.sca(axarr[i])
                cm = plt.imshow(data.real, norm=LogNorm(vmin, vmax), cmap='inferno', interpolation='nearest', aspect='auto', extent=extent_tdelay)
                plt.gca().set_yticks([t for t in plt.gca().get_yticks() if float(t) >= extent_tdelay[-1] and float(t) <= extent_tdelay[-2]])
                plt.xlim(-2000, 2000)
                plt.gca().tick_params(labelsize=12)
                if i == 0:
                    plt.ylabel('LST [hours]', fontsize=14)
                    plt.gca().set_yticklabels([f'{lst_func(t):.1f}' for t in plt.gca().get_yticks()])
                    cbar = fig.colorbar(cm, orientation='vertical', cax=cbax)
                    cbax.yaxis.set_ticks_position('right')
                    cbar.ax.set_ylabel(f'P($\\tau$) [${uvpt.units}$]', rotation=90, fontsize=14)
                    cbax.tick_params(labelsize=16)
                else:
                    plt.gca().set_yticklabels(['' for t in plt.gca().get_yticks()])
                plt.xlabel('$\\tau$ [ns]', fontsize=14)
                plt.title(f'{blp} \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n (pI, pI)', fontsize=16)

In [None]:
for spw in spws:
    plot_auto_power_spectra_waterfalls(uvpa['foreground_filled'], spw=spw, dynamic_range=1e10, nblp_per_row=max_plots_per_row)

In [None]:
def plot_auto_power_spectra_tavg(uvpt, uvptd=None, spw=0, normalize=True, dynamic_range=1e11, label_outliers=True, outlier_delay=700., outlier_sigma=5.):
    vmax = 10 ** np.round(np.log10(np.max(np.abs(uvpt.data_array[spw]))))
    vmin = vmax / dynamic_range
    uvpt = copy.deepcopy(uvpt)
    uvpt.average_spectra(time_avg=True)
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    if uvptd is not None:
        uvptd.average_spectra(time_avg=True)
    fig, axarr = plt.subplots(1,uvpt.Npols)
    fig.set_size_inches(12 * uvpt.Npols, 12)
    for p, pp in enumerate(uvpt.get_polpairs()):
        if uvpt.Npols > 1:
            plt.sca(axarr[p])
        distribution = []
        blpkeys = []
        colors=[]
        for blp in uvpt.get_blpairs():
            k = (spw, blp, pp)
            data = uvpt.get_data(k).real[0]
            if normalize:
                norm_factor = data[np.argmin(np.abs(uvpt.get_dlys(spw)))]
            else:
                norm_factor = 1.0
            data /= norm_factor
            l0 = plt.plot(uvpt.get_dlys(spw) * 1e9, data)[0]
            colors.append(l0.get_color())
            blpkeys.append(k)
            distribution.append(np.abs(data[np.argmin(np.abs(uvpt.get_dlys(spw) * 1e9 - outlier_delay))]))
            if uvptd is not None:
                plt.plot(uvptd.get_dlys(spw) * 1e9, uvptd.get_data(k).real[0] / norm_factor, ls='--', color=l0.get_color())
        # label outliers from distribution.
        distribution = np.asarray(distribution)
        if label_outliers:
            mad = np.sqrt(np.median(np.abs(distribution - np.median(distribution)) ** 2.))
            zscore = (distribution - np.median(distribution)) / mad
            to_label = np.where(zscore >= outlier_sigma)[0]
            for i, label_ind in enumerate(to_label):
                plt.text(outlier_delay + 100 * i, distribution[label_ind], blpkeys[label_ind][1][0], color=colors[label_ind], fontsize=12)
        plt.ylim(vmin / norm_factor, vmax / norm_factor)
        plt.xlim(-100, 2000)
        plt.yscale('log')
        plt.grid()
        plt.gca().tick_params(labelsize=16)
        plt.xlabel('$\\tau$ [ns]', fontsize=18)
        if not normalize:
            plt.ylabel(f'$P(\\tau) [{uvpt.units}]$', fontsize=18)
        else:
            plt.ylabel(f'$P(\\tau) / P(0)$ [unitless]', fontsize=18)

        plt.title(f'{freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz {pp}', fontsize=16)

Compare auto power spectra with noise. Note that the autos are not cross-multiplied in baseline and time so they have a noise bias.

In [None]:
for spw in spws:
    plot_auto_power_spectra_tavg(uvpt=uvpa['foreground_filled'], uvptd=uvpda['foreground_filled'], spw=spw, outlier_delay=700, outlier_sigma=6)

Time for some wedge plots. Plot each field, average of all fields and all LSTs.

In [None]:
def select_field(uvpt, fields, time_average=True):
    time_selection = (uvpt.lst_avg_array * 12 / np.pi >= fields[0][0]) & (uvpt.lst_avg_array * 12 / np.pi <= fields[0][1])
    if len(fields) > 0:
        for field in fields[1:]:
            time_selection = time_selection | ((uvpt.lst_avg_array * 12 / np.pi >= field[0]) & (uvpt.lst_avg_array * 12 / np.pi <= field[1]))

    output = uvpt.select(times=uvpt.time_avg_array[time_selection], inplace=False)
    if time_average:
        output = output.average_spectra(time_avg=time_average, inplace=False, error_weights='P_N', error_field=list(uvpt.stats_array.keys()))
    return output

In [None]:
import copy
def plot_wedges(uvpt, polpair=('pI', 'pI'), spw=0, dynamic_range=1e10, max_bl_angle=75., label=''):
    vmax = np.round(np.log10(np.max(np.abs(uvpt.data_array[spw])))) - 1
    vmin = vmax - np.log10(dynamic_range)
    fig, axarr = plt.subplots(1, len(lst_fields) + 2)
    fig.set_size_inches(36, 9)
    cmap = copy.copy(cmaps.inferno)
    cmap.set_bad(color='k')
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    blp_groups, _, red_bl_ang, _ = utils.get_blvec_reds(uvpt, bl_error_tol=1.0)
    blps_to_keep = []
    for blg, ang in zip(blp_groups, red_bl_ang):
        if ang <= max_bl_angle:
            blps_to_keep.extend(blg)
    for fieldnum, (fieldlabel, field) in enumerate(zip(field_labels, lst_fields)):
        title = f'{label}: Field {fieldlabel} \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n {polpair}'
        uvpt_select = select_field(uvpt, [field])
        # only select baselines with angle less then max_bl_angle.
        uvpt_select.select(blpairs=blps_to_keep, inplace=True)
        plt.sca(axarr[fieldnum])
        pspecplot.delay_wedge(uvp=uvpt_select, ax=plt.gca(), pol=('pI', 'pI'), spw=spw, component='real', vmin=vmin, vmax=vmax, error_weights='P_N', 
                              fold=True, rotate=True, cmap=cmap, log10=True, colorbar=True, title=title, horizon_lines=True)
        plt.ylim(0, 2000)
    # now select the union of all fields
    plt.sca(axarr[len(lst_fields)])
    uvpt_select = select_field(uvpt, lst_fields)
    title = f'{label}: All Fields \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n {polpair}'
    # only select baselines with angle less then max_bl_angle.
    uvpt_select.select(blpairs=blps_to_keep, inplace=True)
    pspecplot.delay_wedge(uvp=uvpt_select, ax=plt.gca(), pol=('pI', 'pI'), spw=spw, component='real', vmin=vmin, vmax=vmax, error_weights='P_N', 
                          fold=True, rotate=True, cmap=cmap, log10=True, colorbar=True, title=title, horizon_lines=True)
    plt.ylim(0, 2000)
    
    
    
    plt.sca(axarr[len(lst_fields) + 1])
    # Now do all LSTs
    title = f'{label}: All LSTs \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n {polpair}'
    uvpt_select = uvpt.average_spectra(time_avg=True, inplace=False, error_weights='P_N')
    uvpt_select.select(blpairs=blps_to_keep, inplace=True)
    pspecplot.delay_wedge(uvp=uvpt_select, ax=plt.gca(), pol=('pI', 'pI'), spw=spw, component='real', vmin=vmin, vmax=vmax, error_weights='P_N', 
                      fold=True, rotate=True, cmap=cmap, log10=True, colorbar=True, title=title, horizon_lines=True)
    plt.ylim(0, 2000)

In [None]:
for spw in spws:
    for ext in uvp_avg:
        plot_wedges(uvp_avg[ext], spw=spw, label=ext)

In [None]:
def plot_spherical_power_spectra(uvpt, uvptd=None, polpair=('pI', 'pI'), spw=0, dynamic_range=1e10, max_bl_angle=75., k_bin_multiplier=2, delta_sq=False, axarr=None, 
                                 fig=None, scolor='k', dcolor='grey', label='', legend=False):
    vmax = 10 ** (np.round(np.log10(np.max(np.abs(uvpt.data_array[spw])))) - 1)
    vmin = vmax / dynamic_range
    if fig is None or axarr is None:
        fig, axarr = plt.subplots(1, len(lst_fields) + 2)
        fig.set_size_inches(36, 9)
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    blp_groups, _, red_bl_ang, _ = utils.get_blvec_reds(uvpt, bl_error_tol=1.0)
    blps_to_keep = []
    # SET K BINS
    dlys = uvpt.get_dlys(spw)
    kparas = uvpt.get_kparas(spw)
    cfactor = np.nanmean(dlys / kparas) * 1e9
    nk = len(kparas)
    kbins = uvpt.get_kparas(spw)[nk//2::k_bin_multiplier]
    kbin_widths = np.asarray([np.mean(np.diff(kbins)) / 2 for k in range(len(kbins))])
    for blg, ang in zip(blp_groups, red_bl_ang):
        if ang <= max_bl_angle:
            blps_to_keep.extend(blg)
    # dsq multplier
    if delta_sq:
        dsq = kbins ** 3 / (2 * np.pi ** 2)
        vmax *= dsq[dsq > 0.0].min()
        vmin *= dsq[dsq > 0.0].min()
    else:
        dsq = 1.0
    # noise multiplier
    dsqn = dsq * 2
    def single_plot(uvpt, uvptd, title):
        # only select baselines with angle less then max_bl_angle.
        uvpt.select(blpairs=blps_to_keep, inplace=True)
        # set wedge to have infinite variance.
        uvpt.set_stats_slice('P_N', m=1/0.299, b=200, above=False, val=1e40)
        # spherical average
        uvpt = grouping.spherical_average(uvpt, kbins=kbins, bin_widths=kbin_widths, 
                                          error_weights='P_N')
        to_plot = [uvpt]
        if uvptd is not None:
            uvptd.select(blpairs=blps_to_keep, inplace=True)
            uvptd.set_stats_slice('P_N', m=1/0.299, b=200, above=False, val=1e40)
            uvptd = grouping.spherical_average(uvptd, kbins=kbins, bin_widths=kbin_widths, 
                                               error_weights='P_N')
            to_plot += [uvptd]
        offsets = [0., np.mean(np.diff(kbins) / 10.)]
        for color, uvpplot, offset in zip([scolor, dcolor], to_plot, offsets):
            k=(spw, uvpplot.get_blpairs()[0], polpair)
            d = np.real(uvpplot.get_data(k)[0])
            ltz = d <= 0.
            gtz = d > 0.
            plt.errorbar(kbins[gtz] + offset, np.abs(d*dsq)[gtz], (dsqn*uvpplot.get_stats('P_SN', k)[0])[gtz], 
                 ls='none', marker='o', color=color, label=f'{label}: sum', capsize=2)
            plt.errorbar(kbins[ltz] + offset, np.abs(dsq*d)[ltz], (dsqn*uvpplot.get_stats('P_SN', k)[0])[ltz], 
                         ls='none', marker='o', markerfacecolor='none', color=color, capsize=2)
            plt.plot(kbins, (dsqn*uvpplot.get_stats('P_SN', k)[0]), color='grey', ls='--')#, label='P_SN sum')
        plt.yscale('log')
        plt.grid()
        plt.ylim(vmin , vmax)
        plt.title(title, fontsize=18)
        plt.xlabel('k [$h$Mpc$^{-1}$]', fontsize=16)
        if delta_sq:
            plt.ylabel('$\\Delta^2$ [mK$^2$]', fontsize=16)
        else:
            plt.ylabel(f'${uvpt.units}$')
        plt.gca().tick_params(labelsize=14)
        plt.xlim(0, 1.4)
        if legend:
            plt.legend()
        ax1 = plt.gca()
        ax2 = ax1.twiny()
        dticks = ax1.get_xticks()
        dticklabels = [f'{int(cfactor * k)}' for k in dticks]
        ax2.set_xticks(dticks)
        ax2.set_xticklabels(dticklabels)
        ax2.set_xlabel('$\\tau$ [ns]')

    for fieldnum, (fieldlabel, field) in enumerate(zip(field_labels, lst_fields)):
        # select fields
        plt.sca(axarr[fieldnum])
        uvpt_select = select_field(uvpt, [field])
        if uvptd is not None:
            uvptd_select = select_field(uvptd, [field])
        else:
            uvptd_select = None
        title = (f'SPW {spw}, {polpair}, z={.5*(1420e6/freq_range[1] - 1) + .5*(1420e6/freq_range[0]-1):.2f}'
                 f'({1420e6/freq_range[1] - 1:.2f}-{1420e6/freq_range[0]-1:.2f})\n $\\nu$ = {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n Field {fieldlabel}')
        single_plot(uvpt_select, uvptd_select, title)

    # now select the union of all fields
    plt.sca(axarr[len(lst_fields)])
    uvpt_select = select_field(uvpt, lst_fields)
    if uvptd is not None:
        uvptd_select = select_field(uvptd, lst_fields)
    else:
        uvptd_select = None

    title = (f'SPW {spw}, {polpair}, z={.5*(1420e6/freq_range[1] - 1) + .5*(1420e6/freq_range[0]-1):.2f}'
             f'({1420e6/freq_range[1] - 1:.2f}-{1420e6/freq_range[0]-1:.2f})\n $\\nu$ = {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n  All Fields')
    single_plot(uvpt_select, uvptd_select, title)
    
    # All LSTs.

    plt.sca(axarr[len(lst_fields) + 1])
    # Now do all LSTs
    title = (f'SPW {spw}, {polpair}, z={.5*(1420e6/freq_range[1] - 1) + .5*(1420e6/freq_range[0]-1):.2f}'
             f'({1420e6/freq_range[1] - 1:.2f}-{1420e6/freq_range[0]-1:.2f})\n $\\nu$ = {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n  All LSTs')    
    uvpt_select = select_field(uvpt, [(uvpt.lst_avg_array.min() * 12 / np.pi, uvpt.lst_avg_array.max() * 12 / np.pi)])
    if uvptd is not None:
        uvptd_select = select_field(uvptd, [(uvpt.lst_avg_array.min() * 12 / np.pi, uvpt.lst_avg_array.max() * 12 / np.pi)])
    else:
        uvptd_select = None
    single_plot(uvpt_select, uvptd_select, title)
    return fig, axarr

In [None]:
if 'foreground_filled' in exts:
    ext_labels = {'filled': 'foreground_filled',
                  'res': 'foreground_res.filled_flags',
                  'model': 'foreground_model.filled_flags'}
else:
    ext_labels = {'filled': 'foreground_filled.res_flags.filled', 
                      'res':'foreground_res.filled', 
                      'model': 'foreground_model.res_flags.filled'}

In [None]:
for spw in spws:
    plot_spherical_power_spectra(uvp_avg[ext_labels['filled']], uvpd_avg[ext_labels['filled']], spw=spw)

In [None]:
for spw in spws:
    plot_spherical_power_spectra(uvp_avg[ext_labels['filled']], uvpd_avg[ext_labels['filled']], spw=spw, delta_sq=True)

In [None]:

for spw in spws:
    fig, axarr =  plot_spherical_power_spectra(uvp_avg[ext_labels['filled']], spw=spw, label='Interpolated')
    fig, axarr = plot_spherical_power_spectra(uvp_avg[ext_labels['res']], spw=spw, fig=fig, axarr=axarr, scolor='r', label='Foreground Resid')
    fig, axarr = plot_spherical_power_spectra(uvp_avg[ext_labels['model']], spw=spw, fig=fig, axarr=axarr, scolor='orange', label='Foreground Model', legend=True)
    #for ax in axarr:
    #    ax.set_ylim(1e8, 1e13)

In [None]:
for spw in spws:
    fig, axarr =  plot_spherical_power_spectra(uvp_avg[ext_labels['filled']], spw=spw, label='Interpolated', delta_sq=True)
    plot_spherical_power_spectra(uvp_avg[ext_labels['res']], spw=spw, fig=fig, axarr=axarr, scolor='r', label='Foreground Resid', delta_sq=True)
    plot_spherical_power_spectra(uvp_avg[ext_labels['model']], spw=spw, fig=fig, axarr=axarr, scolor='orange', label='Foreground Model', legend=True, delta_sq=True)
    #for ax in axarr:
    #    ax.set_ylim(1e3, 1e11)