# Fourth Round of Full Day RFI Flagging Using FRF-Filtered pI SNRs


**by Josh Dillon**, last updated February 19, 2026

# TODO: EDIT 
This notebook brings together the results of the [single-baseline pI FRF SNR notebook](https://github.com/HERA-Team/hera_notebook_templates/blob/master/notebooks/single_baseline_pI_FRF_SNR.ipynb) to make a set of flagging decisions after inpainting. This approach is iterative, and very similar to [Round 3 flagging](https://github.com/HERA-Team/hera_notebook_templates/blob/master/notebooks/full_day_rfi_round_3.ipynb), though it uses delay+fringe-rate filtered pseudo-Stokes pI SNRs rather than 2D-filtered SNRs, includes integration of Round 3 flags for watershed seeding, and detects persistent single-channel RFI. 

Here's a set of links to skip to particular figures and tables:
# [• Figure 1: Waterfall of pI z-Score Before Round 4 Flagging](#Figure-1:-Waterfall-of-pI-z-Score-Before-Round-4-Flagging)
# [• Figure 2: Histogram of z-Scores](#Figure-2:-Histogram-of-z-Scores)
# [• Figure 3: Waterfall of pI z-Score After Round 4 Flagging](#Figure-3:-Waterfall-of-pI-z-Score-After-Round-4-Flagging)
# [• Figure 4: Summary of Flags Before and After Round 4 Flagging](#Figure-4:-Summary-of-Flags-Before-and-After-Round-4-Flagging)

In [None]:
import time
tstart = time.time()
!hostname

In [None]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin  # REQUIRED to have the compression plugins available
import numpy as np
import yaml
import glob
import re
import matplotlib
from scipy.signal import convolve, convolve2d
from pyuvdata import UVFlag
from hera_qm import xrfi
from hera_cal import io, flag_utils
from hera_filters import dspec
import matplotlib.pyplot as plt
from astropy.coordinates import Angle
import astropy.constants as const
from hera_cal.utils import eq2top_m

from IPython.display import display, HTML
%matplotlib inline
display(HTML("<style>.container { width:100% !important; }</style>"))
_ = np.seterr(all='ignore')  # get rid of red warnings
%config InlineBackend.figure_format = 'retina'

In [None]:
RED_AVG_FILE = os.environ.get("RED_AVG_FILE", None)
# RED_AVG_FILE = '/lustre/aoc/projects/hera/h6c-analysis/IDR3/2459935/zen.2459935.25792.sum.smooth_calibrated.red_avg.uvh5'

CORNER_TURN_MAP_YAML = os.environ.get("CORNER_TURN_MAP_YAML", 
                                        os.path.join(os.path.dirname(RED_AVG_FILE), "single_baseline_files/corner_turn_map.yaml"))
SNR_SUFFIX = os.environ.get("SNR_SUFFIX", ".inpainted.pI_FRF_SNR.uvh5")
OUTFILE = os.environ.get("OUTFILE", None)
if OUTFILE is None:
    jdstr = [s for s in os.path.basename(RED_AVG_FILE).split('.') if s.isnumeric()][0]
    OUTFILE = os.path.basename(RED_AVG_FILE).split(jdstr)[0] + jdstr + '.flag_waterfall_round_4.h5'
    OUTFILE = os.path.join(os.path.dirname(CORNER_TURN_MAP_YAML), OUTFILE)

ROUND_3_FLAG_FILE = os.environ.get("ROUND_3_FLAG_FILE", None)
if ROUND_3_FLAG_FILE is None:
    ROUND_3_FLAG_FILE = OUTFILE.replace('round_4', 'round_3')

MIN_SAMP_FRAC = float(os.environ.get("MIN_SAMP_FRAC", .15))
FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz

Z_THRESH = float(os.environ.get("Z_THRESH", 4))
WS_Z_THRESH = float(os.environ.get("WS_Z_THRESH", 2))
AVG_Z_THRESH = float(os.environ.get("AVG_Z_THRESH", 1))
MAX_FREQ_FLAG_FRAC = float(os.environ.get("MAX_FREQ_FLAG_FRAC", .25))
MAX_TIME_FLAG_FRAC = float(os.environ.get("MAX_TIME_FLAG_FRAC", .25))

FREQ_CONV_SIZE  = float(os.environ.get("FREQ_CONV_SIZE", 8.0)) # in MHz

SINGLE_CHAN_FLAG_FRAC = float(os.environ.get("SINGLE_CHAN_FLAG_FRAC", .25))
_sczt = os.environ.get("SINGLE_CHAN_Z_THRESH", "")
SINGLE_CHAN_Z_THRESH = float(_sczt) if _sczt else Z_THRESH

PULSAR_RA = os.environ.get("PULSAR_RA", "06h30m49.3s")
PULSAR_DEC = os.environ.get("PULSAR_DEC", "-28d34m42.6s")
COHERENT_COMBINE = os.environ.get("COHERENT_COMBINE", "True").lower() in ('true', '1', 'yes')

for setting in ['RED_AVG_FILE', 'CORNER_TURN_MAP_YAML', 'SNR_SUFFIX', 'OUTFILE', 'ROUND_3_FLAG_FILE', 'PULSAR_RA', 'PULSAR_DEC']:
    print(f'{setting} = "{eval(setting)}"')
for setting in ['MIN_SAMP_FRAC', 'FM_LOW_FREQ', 'FM_HIGH_FREQ', 'Z_THRESH', 'WS_Z_THRESH',
                'AVG_Z_THRESH', 'MAX_FREQ_FLAG_FRAC', 'MAX_TIME_FLAG_FRAC', 'FREQ_CONV_SIZE',
                'SINGLE_CHAN_FLAG_FRAC', 'SINGLE_CHAN_Z_THRESH', 'COHERENT_COMBINE']:
    print(f'{setting} = {eval(setting)}')

In [None]:
with open(CORNER_TURN_MAP_YAML, 'r') as file:
    corner_turn_map = yaml.unsafe_load(file)

In [None]:
all_snr_files = [snr_file.replace('.uvh5', SNR_SUFFIX) 
                 for snr_files in corner_turn_map['files_to_outfiles_map'].values() 
                 for snr_file in snr_files]
extant_snr_files = [snr_file for snr_file in all_snr_files if os.path.exists(snr_file)]
print(f'Found {len(extant_snr_files)} SNR files, starting with {extant_snr_files[0]}')

In [None]:
# get autocorrelations
# TODO: generalize for not-inpainted data
all_outfiles = [outfile.replace('.uvh5', '.inpainted.uvh5') for outfiles in corner_turn_map['files_to_outfiles_map'].values() for outfile in outfiles]
for outfile in all_outfiles:
    match = re.search(r'\.(\d+)_(\d+)\.', os.path.basename(outfile))
    if match and match.group(1) == match.group(2):
        hd_autos = io.HERAData(outfile)
        _, _, auto_nsamples = hd_autos.read(polarizations=['ee', 'nn'])
        break

# For pI data, use the mean of ee and nn auto nsamples as the reference
med_auto_nsamples_pI = np.mean([np.median(auto_nsamples[bl]) for bl in auto_nsamples])

In [None]:
from tqdm.notebook import tqdm

In [None]:
# load up SNRs, counts, and nsamples
SNRs = {}
SNR_counts = {}
SNR_med_nsamples = {}

for snr_file in tqdm(extant_snr_files):
    hd = io.HERADataFastReader(snr_file)
    data, flags, nsamples = hd.read()
    for bl in data:
        SNRs[bl] = np.where(flags[bl], 0, data[bl])
        SNR_counts[bl] = np.where(flags[bl], 0, 1)
        SNR_med_nsamples[bl] = np.median(nsamples[bl][~flags[bl]])

In [None]:
# combine SNRs, excluding those with too few samples and autocorrelations
if COHERENT_COMBINE:
    ra_source = Angle(PULSAR_RA).radian
    dec_source = Angle(PULSAR_DEC).radian
    lat_rad = hd.info['latitude'] * np.pi / 180
    ha = np.array(hd.lsts) - ra_source
    s_eq = np.array([np.cos(dec_source), 0.0, np.sin(dec_source)])
    rot = eq2top_m(ha, lat_rad)
    s_enu = np.einsum('tij,j->ti', rot, s_eq)
    s_diff_over_c = (s_enu - np.array([0., 0., 1.])) / const.c.value
    complex_SNR_sum = np.zeros((len(hd.times), len(hd.freqs)), dtype=complex)

abs_SNR_sum = np.zeros((len(hd.times), len(hd.freqs)), dtype=float)
abs_SNR_count = np.zeros((len(hd.times), len(hd.freqs)), dtype=float)
bls_used = []
for bl in SNRs:
    if np.median(SNR_med_nsamples[bl]) > MIN_SAMP_FRAC * med_auto_nsamples_pI:
        bl_len = np.linalg.norm(hd.antpos[bl[0]] - hd.antpos[bl[1]])
        if bl_len > 1 and bl_len < 50:
            bls_used.append(bl)
            if COHERENT_COMBINE:
                bl_vec = hd.antpos[bl[0]] - hd.antpos[bl[1]]
                tau = np.einsum('ti,i->t', s_diff_over_c, bl_vec)
                phs = np.exp(-2j * np.pi * hd.freqs[np.newaxis, :] * tau[:, np.newaxis])
                complex_SNR_sum += SNRs[bl] * phs
            else:
                abs_SNR_sum += np.abs(SNRs[bl])
            abs_SNR_count += SNR_counts[bl]

if COHERENT_COMBINE:
    abs_SNR_sum = np.abs(complex_SNR_sum)
    print(f'Coherently combined {len(bls_used)} baselines rephased to {PULSAR_RA}, {PULSAR_DEC}')
else:
    print(f'Incoherently combined {len(bls_used)} baselines')


In [None]:
# convert SNRs to a z-score
if COHERENT_COMBINE:
    predicted_mean = 1.0 / np.sqrt(np.where(abs_SNR_count > 0, abs_SNR_count, 1))
else:
    predicted_mean = 1.0
sigma = predicted_mean * np.sqrt(2 / np.pi)
variance_expected = (4 - np.pi) / 2 * sigma**2 / abs_SNR_count
zscore = (abs_SNR_sum / abs_SNR_count - predicted_mean) / variance_expected**.5
zscore = np.where(abs_SNR_count == 0, np.nan, zscore)


In [None]:
# Load Round 3 flags for watershed seeding
uvf_r3 = UVFlag(ROUND_3_FLAG_FILE)
round3_flags = np.all(uvf_r3.flag_array, axis=-1)
assert round3_flags.shape == (len(hd.times), len(hd.freqs)), \
    f"Round 3 flag shape {round3_flags.shape} doesn't match data shape {(len(hd.times), len(hd.freqs))}"
zscore[round3_flags] = np.nan
print(f'Loaded Round 3 flags from {ROUND_3_FLAG_FILE}: {np.mean(round3_flags):.3%} flagged.')

In [None]:
if COHERENT_COMBINE:
    fig, ax = plt.subplots(figsize=(14, 10), dpi=200)
    noise_floor = 1.0 / np.sqrt(np.where(abs_SNR_count > 0, abs_SNR_count, 1))
    mean_SNR = abs_SNR_sum / np.where(abs_SNR_count > 0, abs_SNR_count, 1)
    to_plot = np.where(abs_SNR_count > 0, mean_SNR / noise_floor, np.nan)
    extent = [data.freqs[0] / 1e6, data.freqs[-1] / 1e6,
              data.times[-1] - int(data.times[0]), data.times[0] - int(data.times[0])]
    im = ax.imshow(to_plot, aspect='auto', cmap='plasma',
                   interpolation='none', vmin=0, vmax=5, extent=extent)
    plt.colorbar(im, ax=ax, location='top', label='|coherent pI SNR| / noise expectation',
                 extend='max', aspect=40, pad=.02)
    ax.set_xlabel('Frequency (MHz)')
    ax.set_ylabel(f'JD - {int(data.times[0])}')
    ax.set_title(f'Coherent combination rephased to {PULSAR_RA}, {PULSAR_DEC}')
    transit_lst_hours = np.degrees(ra_source) / 15
    lsts_hours = np.array(hd.lsts) * 12 / np.pi
    closest_t = np.argmin(np.abs((lsts_hours - transit_lst_hours + 12) % 24 - 12))
    transit_jd = hd.times[closest_t] - int(hd.times[0])
    ax.axhline(transit_jd, color='w', ls='--', lw=1, label=f'Transit LST={transit_lst_hours:.2f}h')
    ax.legend(loc='upper right')
    print(f'Pulsar transit LST = {transit_lst_hours:.4f} hours, '
          f'closest time index = {closest_t}, JD = {hd.times[closest_t]:.5f}')
    plt.tight_layout()


In [None]:
# recenter z-scores above and below FM and per-polarization
_, (low_band, high_band) = flag_utils.get_minimal_slices(~np.isfinite(zscore), freqs=data.freqs, 
                                                         freq_cuts=[FM_LOW_FREQ / 2 + FM_HIGH_FREQ / 2])
for band in [low_band, high_band]:
    zscore[:, band] -= np.nanmedian(zscore[:, band])    

## Plotting Functions

In [None]:
def plot_z_score(zscore, flags=None, vmin=-7.5, vmax=7.5):
    if flags is None:
        flags = ~np.isfinite(zscore)
    plt.figure(figsize=(14,10), dpi=300)
    extent = [data.freqs[0] / 1e6, data.freqs[-1] / 1e6, 
              data.times[-1] - int(data.times[0]), data.times[0] - int(data.times[0])]
    
    plt.imshow(np.where(flags, np.nan, zscore), aspect='auto', 
               cmap='coolwarm', interpolation='none', vmin=vmin, vmax=vmax, extent=extent)
    plt.colorbar(location='top', label='pI z-score Incoherently Averaged Across Baselines', extend='both', aspect=40, pad=.02)
    plt.xlabel('Frequency (MHz)')
    plt.ylabel(f'JD - {int(data.times[0])}')
    plt.tight_layout()

    # Add LST right axis with proper wrapping
    lst_grid = hd.lsts * 12 / np.pi  # radians to hours
    lst_grid[lst_grid > lst_grid[-1]] -= 24
    ax2 = plt.gca().twinx()
    ax2.set_ylim(lst_grid[-1], lst_grid[0])
    mod24 = lambda x, _: f"{x % 24:.1f}"
    ax2.yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(mod24))
    ax2.set_ylabel('LST (hours)')            

In [None]:
def plot_histogram():
    plt.figure(figsize=(14,4), dpi=100)
    bins = np.arange(-50, 100, .1)
    hist = plt.hist(np.ravel(zscore), bins=bins, density=True, label=f'z-scores', alpha=.5)
    plt.plot(bins, (2*np.pi)**-.5 * np.exp(-bins**2 / 2), 'k:', label='Gaussian approximate\nnoise-only distribution')
    plt.axvline(WS_Z_THRESH, c='r', ls='--', label='Watershed z-score')
    plt.axvline(Z_THRESH, c='r', ls='-', label='Threshold z-score')    
    plt.yscale('log')
    all_densities = hist[0][hist[0] > 0]
    plt.ylim(np.min(all_densities) / 2, np.max(all_densities) * 2)
    plt.xlim([-50, 100])
    plt.legend()
    plt.xlabel('z-score')
    plt.ylabel('Density')
    plt.tight_layout()

In [None]:
def summarize_flagging(zscore, flags):
    plt.figure(figsize=(14,10), dpi=200)
    cmap = matplotlib.colors.ListedColormap(((0, 0, 0),) + matplotlib.cm.get_cmap("Set2").colors[0:2])
    extent = [data.freqs[0] / 1e6, data.freqs[-1] / 1e6, 
              data.times[-1] - int(data.times[0]), data.times[0] - int(data.times[0])]    
    plt.imshow(np.where(~np.isfinite(zscore), 1, np.where(flags, 2, 0)), 
               aspect='auto', cmap=cmap, interpolation='none', extent=extent)
    plt.clim([-.5, 2.5])
    cbar = plt.colorbar(location='top', aspect=40, pad=.02)
    cbar.set_ticks([0, 1, 2])
    cbar.set_ticklabels(['Unflagged', 'Flagged After Round 3', 'Flagged After Round 4'])
    plt.xlabel('Frequency (MHz)')
    plt.ylabel(f'JD - {int(data.times[0])}')
    plt.tight_layout()

# Figure 1: Waterfall of pI z-Score Before Round 4 Flagging

This figure shows the pI z-score derived from delay+fringe-rate-filtered pseudo-Stokes I SNRs. Dotted lines in the high band show TV allocations, which receive special treatment. Large positive excursions are problematic and likely need flagging. Note that below and above FM are handled separately and may have different levels of post-filter residuals.

In [None]:
plot_z_score(zscore)

# Figure 2: Histogram of z-Scores

Shows a comparison of the histogram of pI z-scores to a Gaussian approximation of what one might expect from thermal noise. The underlying SNR is the absolute value of complex delay+fringe-rate-filtered pseudo-Stokes I data, which should follow a Rayleigh distribution for noise-only data. To make the z-scores more reliable, a single per-band median is subtracted from the waterfall. Any points beyond the solid red line are flagged. Any points neighboring a flag beyond the dashed red line are also flagged (watershed, seeded with Round 3 flags). Channels with persistent single-channel outliers are also flagged entirely.

In [None]:
plot_histogram()

## Flagging functions

In [None]:
def iteratively_flag_on_averaged_zscore(flags, zscore, avg_func=np.nanmean, avg_z_thresh=AVG_Z_THRESH, verbose=True):
    '''Flag whole integrations or channels based on average z-score. This is done
    iteratively to prevent bad times affecting channel averages or vice versa.'''

    _, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
    flagged_chan_count = 0
    flagged_int_count = {low_band: 0, high_band: 0}
    for band in (low_band, high_band):
        while True:
            zspec = avg_func(np.where(flags, np.nan, zscore)[:, band], axis=0)
            ztseries = avg_func(np.where(flags, np.nan, zscore)[:, band], axis=1)
    
            if (np.nanmax(zspec) < avg_z_thresh) and (np.nanmax(ztseries) < avg_z_thresh):
                break
    
            if np.nanmax(zspec) >= np.nanmax(ztseries):
                flagged_chan_count += np.sum((zspec >= np.nanmax(ztseries)) & (zspec >= avg_z_thresh))
                flags[:, band][:, (zspec >= np.nanmax(ztseries)) & (zspec >= avg_z_thresh)] = True
            else:
                flagged_int_count[band] += np.sum((ztseries >= np.nanmax(zspec)) & (ztseries >= avg_z_thresh))
                flags[(ztseries >= np.nanmax(zspec)) & (ztseries >= avg_z_thresh), band] = True

    ztseries_low = avg_func(np.where(flags, np.nan, zscore)[:, low_band], axis=1)
    flags[(ztseries_low > avg_z_thresh) & np.all(flags[:, high_band], axis=1), low_band] = True
    
    if verbose:
        if (flagged_int_count[low_band] > 0) or (flagged_int_count[high_band] > 0) or (flagged_chan_count > 0):
            print(f'\tFlagging an additional {flagged_int_count[low_band]} low-band integrations, '
                  f'{flagged_int_count[high_band]} high-band integrations, and {flagged_chan_count} channels.')

def impose_max_chan_flag_frac(flags, max_flag_frac=MAX_FREQ_FLAG_FRAC, verbose=True):
    '''Flag channels already flagged more than max_flag_frac (excluding completely flagged times).'''
    _, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
    for band in [low_band, high_band]:
        unflagged_times = ~np.all(flags[:, band], axis=1)
        frequently_flagged_chans =  np.mean(flags[unflagged_times, band], axis=0) >= max_flag_frac
        if verbose:
            flag_diff_count = np.sum(frequently_flagged_chans) - np.sum(np.all(flags[:, band], axis=0))
            if flag_diff_count > 0:
                print(f'\tFlagging {flag_diff_count} channels previously flagged {max_flag_frac:.2%} or more.')        
        flags[:, band][:, frequently_flagged_chans] = True
        
def impose_max_time_flag_frac(flags, max_flag_frac=MAX_TIME_FLAG_FRAC, verbose=True):
    '''Flag times already flagged more than max_flag_frac (excluding completely flagged channels).'''
    _, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
    for name, band in zip(['low', 'high'], [low_band, high_band]):
        unflagged_chans = ~np.all(flags[:, band], axis=0)
        frequently_flagged_times =  np.mean(flags[:, band][:, unflagged_chans], axis=1) >= max_flag_frac
        if verbose:
            flag_diff_count = np.sum(frequently_flagged_times) - np.sum(np.all(flags[:, band], axis=1))
            if flag_diff_count > 0:
                print(f'\tFlagging {flag_diff_count} {name}-band times previously flagged {max_flag_frac:.2%} or more.')
        flags[frequently_flagged_times, band] = True

def watershed_flag(flags, zscore, ws_z_thresh=WS_Z_THRESH, round3_flags=None):
    '''Wrapper around xrfi._ws_flag_waterfall to be performed separately above and below FM.
    If round3_flags is provided, uses them as additional seeds for the watershed, fully
    re-flagging those locations and their neighbors.'''
    while True:        
        nflags = np.sum(flags)
        _, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
        for band in [low_band, high_band]:
            for pol in zscore:
                if round3_flags is not None:
                    combined_seeds = flags[:, band] | round3_flags[:, band]
                    flags[:, band] |= xrfi._ws_flag_waterfall(zscore[pol][:, band], combined_seeds, ws_z_thresh)
                else:
                    flags[:, band] |= xrfi._ws_flag_waterfall(zscore[pol][:, band], flags[:, band], ws_z_thresh)
        if np.sum(flags) == nflags:
            break

# def flag_single_channel_repeat_outliers(flags, zscore, z_thresh=None, flag_frac=SINGLE_CHAN_FLAG_FRAC, verbose=True):
#     '''Iteratively flag entire channels where a disproportionate fraction of unflagged times
#     are individually above z_thresh. This catches persistent narrow-band RFI that appears
#     at a single channel across many times.'''
#     if z_thresh is None:
#         z_thresh = SINGLE_CHAN_Z_THRESH
#     _, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
#     total_chans_flagged = 0
#     for band in [low_band, high_band]:
#         while True:
#             above_thresh = np.any([zscore[pol][:, band] > z_thresh for pol in zscore], axis=0)
#             unflagged = ~flags[:, band]
#             n_unflagged = np.sum(unflagged, axis=0).astype(float)
#             n_outliers = np.sum(above_thresh & unflagged, axis=0).astype(float)
#             with np.errstate(divide='ignore', invalid='ignore'):
#                 outlier_frac = np.where(n_unflagged > 0, n_outliers / n_unflagged, 0)
#             chans_to_flag = (outlier_frac >= flag_frac) & (n_unflagged > 0)
#             if not np.any(chans_to_flag):
#                 break
#             total_chans_flagged += np.sum(chans_to_flag)
#             flags[:, band][:, chans_to_flag] = True
#     if verbose and total_chans_flagged > 0:
#         print(f'\tFlagging {total_chans_flagged} channels with > {flag_frac:.1%} of times individually above z = {z_thresh}.')

# def iterative_freq_conv_flagging(flags, zscore, conv_size=FREQ_CONV_SIZE, one_chan_thresh=Z_THRESH, full_kernel_thresh=AVG_Z_THRESH):
#     '''Looks for stretches of increasing size that fit a decreasing threshold. At conv_size (in MHz), it flags 
#     stretches with average z-score above full_kernel_thresh. At one pixel, it uses one_chan_thresh.
#     In between, it interpolates logarithmically.'''
#     _, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
#     df_MHz = np.median(np.diff(data.freqs)) / 1e6
#     widths = np.array([int(w) + 1 for w in 2**np.arange(1, np.ceil(np.log2(conv_size / df_MHz) + np.finfo(float).eps))])
    
#     # prevent any widths from being so big that they mix high and low bands
#     max_width = (high_band.start - low_band.stop) * 2
#     widths[widths > max_width] = max_width
#     widths = np.unique(widths)

#     # Create cuts that get more strict as the kernel gets bigger
#     cuts = one_chan_thresh * (full_kernel_thresh / one_chan_thresh)**((widths - 1) / (conv_size / df_MHz - 1))

#     for width, cut in zip(widths, cuts):
#         kernel = np.ones((1, int(width)), dtype=float)
#         mask = ~(np.isnan(zscore) | flags)
#         filled_data = np.where(mask, zscore, 0.0)
#         conv_data = convolve2d(filled_data, kernel, mode='same')
#         conv_mask = convolve2d(mask.astype(float), kernel, mode='same')
#         with np.errstate(divide='ignore', invalid='ignore'):
#             result = conv_data / conv_mask

#         for band in [low_band, high_band]:
#             above_cut = (result[:, band] > cut)
#             flags[:, band] |= (convolve2d(above_cut.astype(float), kernel, mode='same') > 0)
        
#         print(f'{np.mean(flags):.3%} of waterfall flagged after {width}-channel convolution-based flagging with z-scores above {cut:.3f}.')

## Main Flagging Routine

# TODO: UPDATE

In [None]:
flags = ~np.isfinite(zscore)
_, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[100e6])
print(f'{np.mean(flags):.3%} of waterfall flagged to start.')

# flag whole integrations or channels using outliers in median
while True:
    nflags = np.sum(flags)  
    iteratively_flag_on_averaged_zscore(flags, zscore, avg_func=np.nanmedian, avg_z_thresh=AVG_Z_THRESH, verbose=True)
    impose_max_chan_flag_frac(flags, max_flag_frac=MAX_FREQ_FLAG_FRAC, verbose=True)
    impose_max_time_flag_frac(flags, max_flag_frac=MAX_TIME_FLAG_FRAC, verbose=True)
    if np.sum(flags) == nflags:
        break  
print(f'{np.mean(flags):.3%} of waterfall flagged after flagging whole times and channels with median z > {AVG_Z_THRESH}.')

# # flag largest outliers
# _, (low_band, high_band) = flag_utils.get_minimal_slices(flags, freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
# for band in [low_band, high_band]:
#     flags[:, band] |= (zscore[:, band] > Z_THRESH) 
# print(f'{np.mean(flags):.3%} of waterfall flagged after flagging z > {Z_THRESH} outliers.')

# # # flag channels with persistent single-channel outliers
# # flag_single_channel_repeat_outliers(flags, zscore, z_thresh=SINGLE_CHAN_Z_THRESH, flag_frac=SINGLE_CHAN_FLAG_FRAC)
# # print(f'{np.mean(flags):.3%} of waterfall flagged after single-channel repeat outlier detection.')

# # watershed flagging (with Round 3 flag seeding)
# watershed_flag(flags, zscore, ws_z_thresh=WS_Z_THRESH, round3_flags=round3_flags)
# print(f'{np.mean(flags):.3%} of waterfall flagged after watershed flagging on z > {WS_Z_THRESH} neighbors of prior flags (seeded with Round 3 flags).')

# # # iterative frequency-convolved flagging
# # iterative_freq_conv_flagging(flags, zscore, conv_size=FREQ_CONV_SIZE, one_chan_thresh=Z_THRESH, full_kernel_thresh=AVG_Z_THRESH)
# # print(f'{np.mean(flags):.3%} of waterfall flagged after channel convolution flagging.')

# # watershed flagging again (with Round 3 flag seeding)
# watershed_flag(flags, zscore, ws_z_thresh=WS_Z_THRESH, round3_flags=round3_flags)
# print(f'{np.mean(flags):.3%} of waterfall flagged after watershed flagging again on z > {WS_Z_THRESH} neighbors of prior flags.')

# # flag whole integrations or channels using outliers in mean
# while True:
#     nflags = np.sum(flags)
#     iteratively_flag_on_averaged_zscore(flags, zscore, avg_func=np.nanmean, avg_z_thresh=AVG_Z_THRESH, verbose=True)
#     impose_max_chan_flag_frac(flags, max_flag_frac=MAX_FREQ_FLAG_FRAC, verbose=True)
#     impose_max_time_flag_frac(flags, max_flag_frac=MAX_TIME_FLAG_FRAC, verbose=True)
#     if np.sum(flags) == nflags:
#         break  
# print(f'{np.mean(flags):.3%} of waterfall flagged after flagging whole times and channels with average z > {AVG_Z_THRESH}.')

# # watershed flagging one last time (with Round 3 flag seeding)
# watershed_flag(flags, zscore, ws_z_thresh=WS_Z_THRESH, round3_flags=round3_flags)
# print(f'{np.mean(flags):.3%} of waterfall flagged after watershed flagging one last time on z > {WS_Z_THRESH} neighbors of prior flags.')

# Figure 3: Waterfall of pI z-Score After Round 4 Flagging

Same as [Figure 1](#Figure-1:-Waterfall-of-pI-z-Score-Before-Round-4-Flagging) above, but now with additional flagging from this round.

In [None]:
plot_z_score(zscore, flags=flags)

# Figure 4: Summary of Flags Before and After Round 4 Flagging

This plot shows which times and frequencies were flagged before and after this notebook. It is directly comparable to Figure 4 of the [full_day_rfi_round_3](https://github.com/HERA-Team/hera_notebook_templates/blob/master/notebooks/full_day_rfi_round_3.ipynb) notebook.


In [None]:
summarize_flagging(zscore, flags)

## Save results

In [None]:
uvf = UVFlag(hd_autos, mode='flag', waterfall=True)
for polind in range(uvf.flag_array.shape[2]):
    uvf.flag_array[:, :, polind] = flags

uvf.write(OUTFILE, clobber=True)

## Metadata

In [None]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')

In [None]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')