In [None]:
import numpy as np
import matplotlib.pyplot as plt
from hera_cal import frf
import glob
import os
from copy import deepcopy
from hera_cal import redcal
from IPython.display import display, HTML
from hera_cal.io import HERAData
from matplotlib.colors import LogNorm
from hera_pspec import utils
%config Completer.use_jedi = False
from scipy.interpolate import interp1d
from hera_pspec.container import PSpecContainer
import hera_pspec.plot as pspecplot
import copy
from matplotlib import cm as cmaps
from hera_pspec import grouping

import tqdm
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
# Use environment variables to figure out path to data
JD = os.environ['JULIANDATE']
data_path = os.environ['DATA_PATH']
label = os.environ['LABEL']
spws = os.environ['SPWS'].split(',')
lst_fields = os.environ['LST_FIELDS'].split(',')
for m in range(len(lst_fields)):
    field = lst_fields[m].split('~')
    lst_fields[m] = (float(field[0]), float(field[1]))
grp_skip = int(os.environ['GRP_SKIP'])
blp_skip = int(os.environ['BLP_SKIP'])
field_labels = os.environ['FIELD_LABELS'].split(',')
max_plots_per_row = int(os.environ['MAX_PLOTS_PER_ROW'])
spws = [int(spw) for spw in spws]
print(f'JD = "{JD}"')
print(f'data_path = "{data_path}"')
print(f'label = "{label}"')
print(f'spws = "{spws}"')
print(f'lst_fields = "{lst_fields}"')
print(f'grp_skip = "{grp_skip}"')
print(f'blp_skip = "{blp_skip}"')
print(f'field_labels = "{field_labels}"')
print(f'max_plots_per_row = "{max_plots_per_row}"')

In [None]:
from astropy.time import Time
utc = Time(JD, format='jd').datetime
print(f'Date: {utc.month}-{utc.day}-{utc.year}')

In [None]:
print('Looking for sum power-spectrum containers in', data_path, 'on JD', JD)
psc_files_sum = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.sum.{label}.xtalk_filtered_pstokes.tavg.pspec.h5')))
print('Found {} files.'.format(len(psc_files_sum)))
print('Looking for diff power-spectrum containers in', data_path, 'on JD', JD)
psc_files_diff= sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.diff.{label}.xtalk_filtered_pstokes.tavg.pspec.h5')))
print('Found {} files.'.format(len(psc_files_diff)))
diffs = len(psc_files_diff) > 0

In [None]:
print('Looking for sum power-spectrum containers in', data_path, 'on JD', JD)
psca_files_sum = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.sum.{label}.autos.foreground_filled_pstokes.tavg.pspec.h5')))
print('Found {} files.'.format(len(psc_files_sum)))
print('Looking for diff power-spectrum containers in', data_path, 'on JD', JD)
psca_files_diff= sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.diff.{label}.autos.foreground_filled_pstokes.tavg.pspec.h5')))
print('Found {} files.'.format(len(psca_files_diff)))
diffs = len(psca_files_diff) > 0 & diffs



In [None]:
uvp_list = []
uvpd_list = []
uvpa_list = []
uvpda_list = []

unum = 0
for pscf in tqdm.tqdm(psc_files_sum):
    psc = PSpecContainer(pscf, keep_open=False)
    uvpt = psc.get_pspec(group='dset0', psname='dset0_x_dset0')
    uvp_list.append(uvpt)
    
    psc = PSpecContainer(psca_files_sum[unum], keep_open=False)
    uvpt = psc.get_pspec(group='dset0', psname='dset0_x_dset0')
    uvpa_list.append(uvpt)

    if diffs:
        pscd = PSpecContainer(psc_files_diff[unum], keep_open=False)
        uvptd = pscd.get_pspec(group='dset0', psname='dset0_x_dset0')
        uvpd_list.append(uvptd)      
        pscd = PSpecContainer(psca_files_diff[unum], keep_open=False)
        uvptd = pscd.get_pspec(group='dset0', psname='dset0_x_dset0')
        uvpda_list.append(uvptd)
    
    unum += 1

In [None]:
def discard_flagged_blpairs(uvpt_list):
    # only keep baselines if they have some unflagged data in one of the files
    bls_to_keep = []
    for blp in uvpt_list[0].get_blpairs():
        integrations = []
        for uvpt in uvpt_list:
            blpslice = uvpt.blpair_to_indices(blp)
            integrations.append(uvpt.integration_array[0][blpslice, 0])
        if not np.allclose(integrations, 0.0):
            bls_to_keep.append(blp)
        else:
            print(f'{blp} has no integrations! Discarding...')
    return [uvpt.select(blpairs=bls_to_keep, inplace=False) for uvpt in uvpt_list]
    

def discard_flagged_times(uvpt_list):
    files_to_keep = []
    for findex,uvpt in enumerate(uvpt_list):
        if not np.any(np.isclose(uvpt.integration_array[0][:, 0], 0.0)):
            finite_stat = True
            if hasattr(uvpt, 'stats_array'):
                for stat in uvpt.stats_array:
                    for spw in uvpt.spw_array:
                        finite_stat = finite_stat and np.all(np.isfinite(uvpt.stats_array[stat][spw]))
            else:
                for spw in uvpt.data_array:
                    finite_stat = finite_stat and np.all(np.isfinite(uvpt.data_array[spw]))
            if finite_stat:
                files_to_keep.append(findex)
    return [uvpt_list[i] for i in files_to_keep]


uvp_list = discard_flagged_times(discard_flagged_blpairs(uvp_list))
if diffs:
    uvpd_list = discard_flagged_times(discard_flagged_blpairs(uvpd_list))


uvpa_list = discard_flagged_times(discard_flagged_blpairs(uvpa_list))
if diffs:
    uvpda_list = discard_flagged_times(discard_flagged_blpairs(uvpda_list))

In [None]:
for listnum, uvplist_t in enumerate([uvp_list, uvpd_list, uvpa_list, uvpda_list]):
    for uvpt in uvplist_t:
        for spw in uvpt.spw_array:
            assert np.all(np.isfinite(uvpt.data_array[spw]))
            if hasattr(uvpt, 'stats_array'):
                for stat in uvpt.stats_array:
                    assert np.all(np.isfinite(uvpt.stats_array[stat][spw]))

In [None]:
uvp_avg_list = []
uvpd_avg_list = []

for uvplist_t, uvpnewlist_t in zip([uvp_list, uvpd_list], 
                                   [uvp_avg_list, uvpd_avg_list]):
    for uvpt in tqdm.tqdm(uvplist_t):
        blp_groups, _, _, _ = utils.get_blvec_reds(uvpt, bl_error_tol=1.0)
        # exclude auto baselines. We can do a comparison between including autos and not including autos.
        # later.
        blp_groups = [[blp for blp in blpgrp if blp[0] != blp[1]] for blpgrp in blp_groups]
        blp_groups = [blp_grp for blp_grp in blp_groups if len(blp_grp) > 0]
        # discard baselines not chunked in averaging groups.
        blps_to_keep = []
        for grp in blp_groups:
            blps_to_keep.extend(grp)
        uvpt = uvpt.select(blpairs=blps_to_keep, inplace=False)
        uvpnewlist_t.append(uvpt.average_spectra(blpair_groups=blp_groups, inplace=False, error_weights='P_N',
                                                  error_field=[stat for stat in uvpt.stats_array]))

In [None]:
def merge_ps_list(ps_list):
    # do a logarithmic merge into single power spectrum.
    # this saves quiete a bit of time by reducing the number of merge calls.
    if len(ps_list) == 1:
        return ps_list[0]
    else:
        return merge_ps_list(ps_list[:len(ps_list)//2])\
        + merge_ps_list(ps_list[len(ps_list)//2:])
        
        
uvp_avg = merge_ps_list(uvp_avg_list)
uvp = merge_ps_list(uvp_list)
uvpa_avg = merge_ps_list(uvpa_list)

if diffs:
    uvpd_avg = merge_ps_list(uvpd_avg_list)
    uvpda_avg = merge_ps_list(uvpda_list)
    uvpd = merge_ps_list(uvpd_list)
else:
    uvpd_avg = None
    uvpda_avg = None
    uvpd = None


del uvp_avg_list
del uvpd_avg_list
del uvpa_list
del uvpda_list
del uvp_list
del uvpd_list

In [None]:
# plot waterfalls of autos
def plot_auto_power_spectra_waterfalls(uvpt, spw=0, nblp_per_row=8, dynamic_range=1e9):
    lst_func = interp1d(uvpt.time_avg_array, uvpt.lst_avg_array * 12 / np.pi)
    vmax = 10 ** np.round(np.log10(np.max(np.abs(uvpt.data_array[spw]))))
    vmin = vmax / dynamic_range
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    extent_tdelay = [uvpt.get_dlys(spw).min() * 1e9, uvpt.get_dlys(spw).max() * 1e9, uvpt.time_avg_array.max(), uvpt.time_avg_array.min()]
    for blpc in range(0, uvpt.Nblpairs, nblp_per_row):
        fig, axarr = plt.subplots(1, nblp_per_row)
        fig.set_size_inches(36, 6)
        cbax = fig.add_axes([0.915, 0.1, 0.005, 0.8])
        for i in range(nblp_per_row):
            blpind = blpc + i
            if blpind < uvpt.Nblpairs:
                blp = uvpt.get_blpairs()[blpind]
                k = (spw, blp, ('pI', 'pI'))
                data = uvpt.get_data(k)
                plt.sca(axarr[i])
                cm = plt.imshow(data.real, norm=LogNorm(vmin, vmax), cmap='inferno', interpolation='nearest', aspect='auto', extent=extent_tdelay)
                plt.gca().set_yticks([t for t in plt.gca().get_yticks() if float(t) >= extent_tdelay[-1] and float(t) <= extent_tdelay[-2]])
                plt.xlim(-2000, 2000)
                plt.gca().tick_params(labelsize=12)
                if i == 0:
                    plt.ylabel('LST [hours]', fontsize=14)
                    plt.gca().set_yticklabels([f'{lst_func(t):.1f}' for t in plt.gca().get_yticks()])
                    cbar = fig.colorbar(cm, orientation='vertical', cax=cbax)
                    cbax.yaxis.set_ticks_position('right')
                    cbar.ax.set_ylabel(f'P($\\tau$) [${uvpt.units}$]', rotation=90, fontsize=14)
                    cbax.tick_params(labelsize=16)
                else:
                    plt.gca().set_yticklabels(['' for t in plt.gca().get_yticks()])
                plt.xlabel('$\\tau$ [ns]', fontsize=14)
                plt.title(f'{blp} \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n (pI, pI)', fontsize=16)

In [None]:
for spw in spws:
    plot_auto_power_spectra_waterfalls(uvpa_avg, spw=spw, dynamic_range=1e10, nblp_per_row=max_plots_per_row)

In [None]:
def plot_auto_power_spectra_tavg(uvpt, uvptd=None, spw=0, normalize=True, dynamic_range=1e11, label_outliers=True, outlier_delay=700., outlier_sigma=5.):
    vmax = 10 ** np.round(np.log10(np.max(np.abs(uvpt.data_array[spw]))))
    vmin = vmax / dynamic_range
    uvpt = copy.deepcopy(uvpt)
    uvpt.average_spectra(time_avg=True)
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    if uvptd is not None:
        uvptd.average_spectra(time_avg=True)
    fig, axarr = plt.subplots(1,uvpt.Npols)
    fig.set_size_inches(12 * uvpt.Npols, 12)
    for p, pp in enumerate(uvpt.get_polpairs()):
        if uvpt.Npols > 1:
            plt.sca(axarr[p])
        distribution = []
        blpkeys = []
        colors=[]
        for blp in uvpt.get_blpairs():
            k = (spw, blp, pp)
            data = uvpt.get_data(k).real[0]
            if normalize:
                norm_factor = data[np.argmin(np.abs(uvpt.get_dlys(spw)))]
            else:
                norm_factor = 1.0
            data /= norm_factor
            l0 = plt.plot(uvpt.get_dlys(spw) * 1e9, data)[0]
            colors.append(l0.get_color())
            blpkeys.append(k)
            distribution.append(np.abs(data[np.argmin(np.abs(uvpt.get_dlys(spw) * 1e9 - outlier_delay))]))
            if uvptd is not None:
                plt.plot(uvptd.get_dlys(spw) * 1e9, uvptd.get_data(k).real[0] / norm_factor, ls='--', color=l0.get_color())
        # label outliers from distribution.
        distribution = np.asarray(distribution)
        if label_outliers:
            mad = np.sqrt(np.median(np.abs(distribution - np.median(distribution)) ** 2.))
            zscore = (distribution - np.median(distribution)) / mad
            to_label = np.where(zscore >= outlier_sigma)[0]
            for i, label_ind in enumerate(to_label):
                plt.text(outlier_delay + 100 * i, distribution[label_ind], blpkeys[label_ind][1][0], color=colors[label_ind], fontsize=12)
        plt.ylim(vmin / norm_factor, vmax / norm_factor)
        plt.xlim(-100, 2000)
        plt.yscale('log')
        plt.grid()
        plt.gca().tick_params(labelsize=16)
        plt.xlabel('$\\tau$ [ns]', fontsize=18)
        if not normalize:
            plt.ylabel(f'$P(\\tau) [{uvpt.units}]$', fontsize=18)
        else:
            plt.ylabel(f'$P(\\tau) / P(0)$ [unitless]', fontsize=18)

        plt.title(f'{freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz {pp}', fontsize=16)

Compare auto power spectra with noise. Note that the autos are not cross-multiplied in baseline and time so they have a noise bias.

In [None]:
for spw in spws:
    plot_auto_power_spectra_tavg(uvpt=uvpa_avg, uvptd=uvpda_avg, spw=spw, outlier_delay=700, outlier_sigma=6)

In [None]:
# plot waterfalls of autos
def plot_cross_power_spectra_waterfalls(uvpt, spw=0, nblp_per_row=8, dynamic_range=1e9, blp_skip=1, polpair=('pI', 'pI')):
    lst_func = interp1d(uvpt.time_avg_array, uvpt.lst_avg_array * 12 / np.pi)
    vmax = 10 ** np.round(np.log10(np.max(np.abs(uvpt.data_array[spw]))))
    vmin = vmax / dynamic_range
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    extent_tdelay = [uvpt.get_dlys(spw).min() * 1e9, uvpt.get_dlys(spw).max() * 1e9, uvpt.time_avg_array.max(), uvpt.time_avg_array.min()]
    cmap = cmaps.inferno
    cmap.set_bad(color = 'k')
    for blpc in range(0, uvpt.Nblpairs // blp_skip, nblp_per_row):
        fig, axarr = plt.subplots(1, nblp_per_row)
        fig.set_size_inches(36, 6)
        cbax = fig.add_axes([0.915, 0.1, 0.005, 0.8])
        for i in range(nblp_per_row):
            blpind = blpc + i
            if blpind < uvpt.Nblpairs // blp_skip:
                blp = uvpt.get_blpairs()[blpind * blp_skip]
                bl_vec = uvpt.get_blpair_blvecs()[blpind * blp_skip]
                bl_dly = np.linalg.norm(bl_vec) / .3
                k = (spw, blp, polpair)
                data = uvpt.get_data(k)
                plt.sca(axarr[i])
                cm = plt.imshow(data.real, norm=LogNorm(vmin, vmax), interpolation='nearest', aspect='auto', extent=extent_tdelay, cmap=cmap)
                plt.axvline(bl_dly, ls='--', color='w')
                plt.axvline(-bl_dly, ls='--', color='w')
                plt.gca().set_yticks([t for t in plt.gca().get_yticks() if float(t) >= extent_tdelay[-1] and float(t) <= extent_tdelay[-2]])
                plt.xlim(-2000, 2000)
                plt.gca().tick_params(labelsize=12)
                for flabel, field in zip(field_labels, lst_fields):
                    time_bounds = [uvpt.time_avg_array[np.argmin(np.abs(lst0 - uvpt.lst_avg_array * 12 / np.pi))] for lst0 in field]
                    plt.axhline(time_bounds[0], ls=':', color='w')
                    plt.axhline(time_bounds[1], ls=':', color='w'"")
                    plt.text(0, np.mean(time_bounds), f'Field {flabel}', color='w', ha='center')
                if i == 0:
                    plt.ylabel('LST [hours]', fontsize=14)
                    plt.gca().set_yticklabels([f'{lst_func(t):.1f}' for t in plt.gca().get_yticks()])
                    cbar = fig.colorbar(cm, orientation='vertical', cax=cbax)
                    cbax.yaxis.set_ticks_position('right')
                    cbar.ax.set_ylabel(f'P($\\tau$) [${uvpt.units}$]', rotation=90, fontsize=14)
                    cbax.tick_params(labelsize=16)
                else:
                    plt.gca().set_yticklabels(['' for t in plt.gca().get_yticks()])
                plt.xlabel('$\\tau$ [ns]', fontsize=14)
                plt.title(f'{bl_vec[0]:.1f} m, {bl_vec[1]:.1f} m \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n (pI, pI)', fontsize=16)

In [None]:
for spw in spws:
    plot_cross_power_spectra_waterfalls(uvp_avg, spw=spw, dynamic_range=1e10, nblp_per_row=max_plots_per_row, blp_skip=blp_skip)

In [None]:
# plot waterfalls of cross-power spectra, expanding redundant groups in each row.
def plot_redundant_cross_power_spectra_waterfalls(uvpt, spw=0, max_nblp_per_row=16, dynamic_range=1e9, red_skip=10, polpair=('pI', 'pI'), exclude_autos=True):
    lst_func = interp1d(uvpt.time_avg_array, uvpt.lst_avg_array * 12 / np.pi)
    vmax = 10 ** np.round(np.log10(np.max(np.abs(uvpt.data_array[spw]))))
    vmin = vmax / dynamic_range
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    extent_tdelay = [uvpt.get_dlys(spw).min() * 1e9, uvpt.get_dlys(spw).max() * 1e9, uvpt.time_avg_array.max(), uvpt.time_avg_array.min()]
    blp_groups, _, _, _ = utils.get_blvec_reds(uvpt, bl_error_tol=1.0)
    if exclude_autos:
        blp_groups = [[blp for blp in blpgrp if blp[0] != blp[1]] for blpgrp in blp_groups]
        blp_groups = [blpgrp for blpgrp in blp_groups if len(blpgrp) > 1]
    cmap = cmaps.inferno
    cmap.set_bad(color='k')
    for blpc in range(0, len(blp_groups), red_skip):
        blps = blp_groups[blpc]
        uvpt_chunk = uvpt.select(blpairs=blp_groups[blpc], inplace=False)
        nrows = min(max_nblp_per_row + 1, len(blps) + 1)
        fig, axarr = plt.subplots(1, nrows)
        fig.set_size_inches(36, 6)
        cbax = fig.add_axes([0.915, 0.1, 0.005, 0.8])
        for i in range(nrows):
            plt.sca(axarr[i])
            if i < nrows - 1:
                blp = blps[i]
                bl_vec = uvpt_chunk.get_blpair_blvecs()[i]
                bl_dly = np.linalg.norm(bl_vec) / .3
                k = (spw, blp, polpair)
                data = uvpt_chunk.get_data(k)
                title = f'{blp} \n {bl_vec[0]:.1f} m, {bl_vec[1]:.1f} m \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n (pI, pI)'
            else:
                uvpt_chunk.average_spectra(blpair_groups=[blps], error_field='P_N')
                blp = uvpt_chunk.get_blpairs()[0]
                bl_vec = uvpt_chunk.get_blpair_blvecs()[0]
                bl_dly = np.linalg.norm(bl_vec) / .3
                k = (spw, blp, polpair)
                data = uvpt_chunk.get_data(k)
                title = f'Averaging {len(blps)} blpairs \n {bl_vec[0]:.1f} m, {bl_vec[1]:.1f} m \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n (pI, pI)'
            cm = plt.imshow(data.real, norm=LogNorm(vmin, vmax), cmap=cmap, interpolation='nearest', aspect='auto', extent=extent_tdelay)
            plt.axvline(bl_dly, ls='--', color='w')
            plt.axvline(-bl_dly, ls='--', color='w')
            plt.gca().set_yticks([t for t in plt.gca().get_yticks() if float(t) >= extent_tdelay[-1] and float(t) <= extent_tdelay[-2]])
            plt.xlim(-2000, 2000)
            plt.gca().tick_params(labelsize=12)
            for flabel, field in zip(field_labels, lst_fields):
                time_bounds = [uvpt.time_avg_array[np.argmin(np.abs(lst0 - uvpt.lst_avg_array * 12 / np.pi))] for lst0 in field]
                plt.axhline(time_bounds[0], ls=':', color='w')
                plt.axhline(time_bounds[1], ls=':', color='w'"")
                plt.text(0, np.mean(time_bounds), f'Field {flabel}', color='w', ha='center')
            if i == 0:
                plt.ylabel('LST [hours]', fontsize=14)
                plt.gca().set_yticklabels([f'{lst_func(t):.1f}' for t in plt.gca().get_yticks()])
                cbar = fig.colorbar(cm, orientation='vertical', cax=cbax)
                cbax.yaxis.set_ticks_position('right')
                cbar.ax.set_ylabel(f'P($\\tau$) [${uvpt.units}$]', rotation=90, fontsize=14)
                cbax.tick_params(labelsize=16)
            else:
                plt.gca().set_yticklabels(['' for t in plt.gca().get_yticks()])
            plt.xlabel('$\\tau$ [ns]', fontsize=14)
            plt.title(title, fontsize=16)

In [None]:
for spw in spws:
    plot_redundant_cross_power_spectra_waterfalls(uvp, spw=spw, dynamic_range=1e10, red_skip=grp_skip, max_nblp_per_row=max_plots_per_row)

Time for some wedge plots. Plot each field, average of all fields and all LSTs.

In [None]:
def select_field(uvpt, fields, time_average=True):
    time_selection = (uvpt.lst_avg_array * 12 / np.pi >= fields[0][0]) & (uvpt.lst_avg_array * 12 / np.pi <= fields[0][1])
    if len(fields) > 0:
        for field in fields[1:]:
            time_selection = time_selection | ((uvpt.lst_avg_array * 12 / np.pi >= field[0]) & (uvpt.lst_avg_array * 12 / np.pi <= field[1]))

    output = uvpt.select(times=uvpt.time_avg_array[time_selection], inplace=False)
    if time_average:
        output = output.average_spectra(time_avg=time_average, inplace=False, error_weights='P_N', error_field=list(uvpt.stats_array.keys()))
    return output

In [None]:
def plot_wedges(uvpt, polpair=('pI', 'pI'), spw=0, dynamic_range=1e10, max_bl_angle=75.):
    vmax = np.round(np.log10(np.max(np.abs(uvpt.data_array[spw])))) - 1
    vmin = vmax - np.log10(dynamic_range)
    fig, axarr = plt.subplots(1, len(lst_fields) + 2)
    fig.set_size_inches(36, 9)
    cmap = cmaps.inferno
    cmap.set_bad(color='k')
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    blp_groups, _, red_bl_ang, _ = utils.get_blvec_reds(uvpt, bl_error_tol=1.0)
    blps_to_keep = []
    for blg, ang in zip(blp_groups, red_bl_ang):
        if ang <= max_bl_angle:
            blps_to_keep.extend(blg)
    for fieldnum, (fieldlabel, field) in enumerate(zip(field_labels, lst_fields)):
        title = f' Field {fieldlabel} \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n {polpair}'
        uvpt_select = select_field(uvpt, [field])
        # only select baselines with angle less then max_bl_angle.
        uvpt_select.select(blpairs=blps_to_keep, inplace=True)
        plt.sca(axarr[fieldnum])
        pspecplot.delay_wedge(uvp=uvpt_select, ax=plt.gca(), pol=('pI', 'pI'), spw=spw, component='real', vmin=vmin, vmax=vmax, error_weights='P_N', 
                              fold=True, rotate=True, cmap=cmap, log10=True, colorbar=True, title=title, horizon_lines=True)
        plt.ylim(0, 2000)
    # now select the union of all fields
    plt.sca(axarr[len(lst_fields)])
    uvpt_select = select_field(uvpt, lst_fields)
    title = f' All Fields \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n {polpair}'
    # only select baselines with angle less then max_bl_angle.
    uvpt_select.select(blpairs=blps_to_keep, inplace=True)
    pspecplot.delay_wedge(uvp=uvpt_select, ax=plt.gca(), pol=('pI', 'pI'), spw=spw, component='real', vmin=vmin, vmax=vmax, error_weights='P_N', 
                          fold=True, rotate=True, cmap=cmap, log10=True, colorbar=True, title=title, horizon_lines=True)
    plt.ylim(0, 2000)
    
    
    
    plt.sca(axarr[len(lst_fields) + 1])
    # Now do all LSTs
    title = f' All LSTs \n {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n {polpair}'
    uvpt_select = uvpt.average_spectra(time_avg=True, inplace=False, error_weights='P_N')
    uvpt_select.select(blpairs=blps_to_keep, inplace=True)
    pspecplot.delay_wedge(uvp=uvpt_select, ax=plt.gca(), pol=('pI', 'pI'), spw=spw, component='real', vmin=vmin, vmax=vmax, error_weights='P_N', 
                      fold=True, rotate=True, cmap=cmap, log10=True, colorbar=True, title=title, horizon_lines=True)
    plt.ylim(0, 2000)

In [None]:
for spw in spws:
    plot_wedges(uvp_avg, spw=spw)

In [None]:
def plot_spherical_power_spectra(uvpt, uvptd=None, polpair=('pI', 'pI'), spw=0, dynamic_range=1e10, max_bl_angle=75., k_bin_multiplier=2, delta_sq=False):
    vmax = 10 ** (np.round(np.log10(np.max(np.abs(uvpt.data_array[spw])))) - 1)
    vmin = vmax / dynamic_range
    fig, axarr = plt.subplots(1, len(lst_fields) + 2)
    fig.set_size_inches(36, 9)
    freq_range = uvpt.get_spw_ranges(spw)[0][:2]
    blp_groups, _, red_bl_ang, _ = utils.get_blvec_reds(uvpt, bl_error_tol=1.0)
    blps_to_keep = []
    # SET K BINS
    kbins = uvpt.get_kparas(1)[85//2 + 0::k_bin_multiplier]
    kbin_widths = np.asarray([np.mean(np.diff(kbins)) / 2 for k in range(len(kbins))])
    for blg, ang in zip(blp_groups, red_bl_ang):
        if ang <= max_bl_angle:
            blps_to_keep.extend(blg)
    # dsq multplier
    if delta_sq:
        dsq = kbins ** 3 / (2 * np.pi ** 2)
        vmax *= dsq[dsq > 0.0].min()
        vmin *= dsq[dsq > 0.0].min()
    else:
        dsq = 1.0
    # noise multiplier
    dsqn = dsq * 2
    def single_plot(uvpt, uvptd, title):
        # only select baselines with angle less then max_bl_angle.
        uvpt.select(blpairs=blps_to_keep, inplace=True)
        # set wedge to have infinite variance.
        uvpt.set_stats_slice('P_N', m=1/0.299, b=350, above=False, val=1e40)
        # spherical average
        uvpt = grouping.spherical_average(uvpt, kbins=kbins, bin_widths=kbin_widths, 
                                          error_weights='P_N')
        to_plot = [uvpt]
        if uvptd is not None:
            uvptd.select(blpairs=blps_to_keep, inplace=True)
            uvptd.set_stats_slice('P_N', m=1/0.299, b=350, above=False, val=1e40)
            uvptd = grouping.spherical_average(uvptd, kbins=kbins, bin_widths=kbin_widths, 
                                               error_weights='P_N')
            to_plot += [uvptd]
        for color, uvpplot in zip(['grey', 'k'], to_plot[::-1]):
            k=(spw, uvpplot.get_blpairs()[0], polpair)
            d = uvpplot.get_data(k)[0]
            ltz = d <= 0.
            gtz = d > 0.
            plt.errorbar(kbins[gtz], np.abs(d*dsq)[gtz], (dsqn*uvpplot.get_stats('P_SN', k)[0])[gtz], 
                 ls='none', marker='o', color=color, label='sum', capsize=2)
            plt.errorbar(kbins[ltz], np.abs(dsq*d)[ltz], (dsqn*uvpplot.get_stats('P_SN', k)[0])[ltz], 
                         ls='none', marker='o', markerfacecolor='none', color=color, capsize=2)
            plt.plot(kbins, (dsqn*uvpplot.get_stats('P_SN', k)[0]), color='grey', ls='--', label='P_SN sum')
        plt.yscale('log')
        plt.grid()
        plt.ylim(vmin , vmax)
        plt.title(title, fontsize=18)
        plt.xlabel('k [$h$Mpc$^{-1}$]', fontsize=16)
        if delta_sq:
            plt.ylabel('$\\Delta^2$ [mK$^2$]', fontsize=16)
        else:
            plt.ylabel(f'${uvpt.units}$')
        plt.gca().tick_params(labelsize=14)
        plt.xlim(0, 1.4)

    for fieldnum, (fieldlabel, field) in enumerate(zip(field_labels, lst_fields)):
        # select fields
        plt.sca(axarr[fieldnum])
        uvpt_select = select_field(uvpt, [field])
        if uvptd is not None:
            uvptd_select = select_field(uvptd, [field])
        else:
            uvptd_select = None
        title = (f'SPW {spw}, {polpair}, z={.5*(1420e6/freq_range[1] - 1) + .5*(1420e6/freq_range[0]-1):.2f}'
                 f'({1420e6/freq_range[1] - 1:.2f}-{1420e6/freq_range[0]-1:.2f})\n $\\nu$ = {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n Field {fieldlabel}')
        single_plot(uvpt_select, uvptd_select, title)

    # now select the union of all fields
    plt.sca(axarr[len(lst_fields)])
    uvpt_select = select_field(uvpt, lst_fields)
    if uvptd is not None:
        uvptd_select = select_field(uvptd, lst_fields)
    else:
        uvptd_select = None

    title = (f'SPW {spw}, {polpair}, z={.5*(1420e6/freq_range[1] - 1) + .5*(1420e6/freq_range[0]-1):.2f}'
             f'({1420e6/freq_range[1] - 1:.2f}-{1420e6/freq_range[0]-1:.2f})\n $\\nu$ = {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n  All Fields')
    single_plot(uvpt_select, uvptd_select, title)
    
    # All LSTs.

    plt.sca(axarr[len(lst_fields) + 1])
    # Now do all LSTs
    title = (f'SPW {spw}, {polpair}, z={.5*(1420e6/freq_range[1] - 1) + .5*(1420e6/freq_range[0]-1):.2f}'
             f'({1420e6/freq_range[1] - 1:.2f}-{1420e6/freq_range[0]-1:.2f})\n $\\nu$ = {freq_range[0]/1e6:.1f}-{freq_range[1]/1e6:.1f} MHz \n  All LSTs')    
    uvpt_select = select_field(uvpt, [(uvpt.lst_avg_array.min() * 12 / np.pi, uvpt.lst_avg_array.max() * 12 / np.pi)])
    if uvptd is not None:
        uvptd_select = select_field(uvptd, [(uvpt.lst_avg_array.min() * 12 / np.pi, uvpt.lst_avg_array.max() * 12 / np.pi)])
    else:
        uvptd_select = None
    single_plot(uvpt_select, uvptd_select, title)

In [None]:
for spw in spws:
    plot_spherical_power_spectra(uvp_avg, uvpd_avg, spw=spw)

In [None]:
for spw in spws:
    plot_spherical_power_spectra(uvp_avg, uvpd_avg, spw=spw, delta_sq=True)