# Single Baseline 2D-Informed 1D DPSS Inpainting

**by Josh Dillon**, last updated April 25, 2025

This notebook performs single-baseline, full-day DPSS inpainting. Excluded are fully-flagged edge channels, FM, and times that are fully flagged either before or above FM. Inpainting is done first by iteratively forming a 2D DPSS model, then using it in the 1D DPSS fits where we have flags, but with reduced weight that increases to near unity far from unflagged channels. If desired, it also performs a subsequent 1D DPSS notch filter in time around FR=0. 

Here's a set of links to skip to particular figures and tables:
# [â€¢ Figure 1: 4-Pol Phase and Amplitude Waterfalls Before and After Inpainting](#Figure-1:-4-Pol-Phase-and-Amplitude-Waterfalls-Before-and-After-Inpainting)

In [None]:
import time
tstart = time.time()
!hostname

In [None]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin  # REQUIRED to have the compression plugins available
import numpy as np
import yaml
import glob
import copy
import re
from astropy import units
from scipy import interpolate, optimize, constants

from pyuvdata import UVFlag
from hera_cal import io, flag_utils, utils
from hera_cal.frf import sky_frates
from hera_cal.smooth_cal import solve_2D_DPSS
from hera_filters.dspec import dpss_operator, sparse_linear_fit_2D, fourier_filter
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import display
%matplotlib inline

In [None]:
RED_AVG_FILE = os.environ.get("RED_AVG_FILE", None)
# RED_AVG_FILE = '/lustre/aoc/projects/hera/jsdillon/H6C/IDR3/2459866/zen.2459866.25359.sum.smooth_calibrated.red_avg.uvh5' # 3 unit EW

CORNER_TURN_MAP_YAML = os.environ.get("CORNER_TURN_MAP_YAML", 
                                        os.path.join(os.path.dirname(RED_AVG_FILE), "single_baseline_files/corner_turn_map.yaml"))

R3_FLAG_FILE = os.environ.get("R3_FLAG_FILE", None)
if R3_FLAG_FILE is None:
    jdstr = [s for s in os.path.basename(RED_AVG_FILE).split('.') if s.isnumeric()][0]
    R3_FLAG_FILE = os.path.basename(RED_AVG_FILE).split(jdstr)[0] + jdstr + '.flag_waterfall_round_3.h5'
    R3_FLAG_FILE = os.path.join(os.path.dirname(CORNER_TURN_MAP_YAML), R3_FLAG_FILE)

for setting in ['RED_AVG_FILE', 'CORNER_TURN_MAP_YAML', 'R3_FLAG_FILE']:
    print(f'{setting} = "{eval(setting)}"')

In [None]:
FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz

AUTO_INPAINT_DELAY = float(os.environ.get("AUTO_INPAINT_DELAY", 100)) # in ns
INPAINT_DELAY = float(os.environ.get("INPAINT_DELAY", 1000)) # in ns
ITERATIVE_DELAY_DELTA = float(os.environ.get("ITERATIVE_DELAY_DELTA", 25)) # in ns
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))
CG_TOL = float(os.environ.get("CG_TOL", 1e-6))

INPAINTED_EXTENSION = os.environ.get("INPAINTED_EXTENSION", ".inpainted.uvh5")
WHERE_INPAINTED_EXTENSION = os.environ.get("WHERE_INPAINTED_EXTENSION", ".where_inpainted.h5")

INPAINT_WIDTH_FACTOR = float(os.environ.get("INPAINT_WIDTH_FACTOR", 0.5))
INPAINT_ZERO_DIST_WEIGHT = float(os.environ.get("INPAINT_ZERO_DIST_WEIGHT", 1e-2))

AUTO_FR_SPECTRUM_FILE = '/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5'
GAUSS_FIT_BUFFER_CUT = float(os.environ.get("GAUSS_FIT_BUFFER_CUT", 1e-5))

FR0_FILTER = os.environ.get("FR0_FILTER", "TRUE").upper() == "TRUE"
FR0_FILTER_EXTENSION = os.environ.get("FR0_FILTER_EXTENSION", ".inpainted.FR0_filtered.uvh5")
FR0_HALFWIDTH = float(os.environ.get("FR0_HALFWIDTH", 0.01))  # in mHz

for setting in ['AUTO_FR_SPECTRUM_FILE', 'FR0_FILTER_EXTENSION']:
    print(f'{setting} = "{eval(setting)}"')
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'INPAINT_DELAY', 'ITERATIVE_DELAY_DELTA', 
                'EIGENVAL_CUTOFF', 'CG_TOL', 'GAUSS_FIT_BUFFER_CUT', 
                'FR0_FILTER', 'FR0_HALFWIDTH']:
    print(f'{setting} = {eval(setting)}')

In [None]:
add_to_history = 'Produced by single_baseline_2D_informed_inpaint notebook with the following environment:\n' + '=' * 65 + '\n' + os.popen('mamba env export').read() + '=' * 65

## Preliminaries

In [None]:
if False:
    # This branch is meant for interactive testing of the notebook (e.g. for exploring new algorithms), avoiding corner turn logic

    # EDIT THIS TO PICK A DIFFERENT JD/BASELINE FROM THE LIST BELOW (no need to edit the rest, in theory)
    RED_AVG_FILE = '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459866.baseline.0_4.sum.smooth_calibrated.red_avg.uvh5' 
    
    files = ['/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459861.baseline.0_0.sum.smooth_calibrated.red_avg.uvh5', 
             '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459861.baseline.0_1.sum.smooth_calibrated.red_avg.uvh5',
             '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459861.baseline.0_4.sum.smooth_calibrated.red_avg.uvh5',
             '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459861.baseline.1_61.sum.smooth_calibrated.red_avg.uvh5',
             '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459866.baseline.0_0.sum.smooth_calibrated.red_avg.uvh5',
             '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459866.baseline.0_1.sum.smooth_calibrated.red_avg.uvh5',
             '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459866.baseline.0_4.sum.smooth_calibrated.red_avg.uvh5',
             '/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.2459866.baseline.1_61.sum.smooth_calibrated.red_avg.uvh5']
    corner_turn_map = {'files_to_outfiles_map': {f: [f] for f in files}}
    jdstr = [s for s in os.path.basename(RED_AVG_FILE).split('.') if s.isnumeric()][0]
    R3_FLAG_FILE = f'/users/jsdillon/lustre/H6C/2D_inpainting_battletest/zen.{jdstr}.flag_waterfall_round_3.h5'
else:
    with open(CORNER_TURN_MAP_YAML, 'r') as file:
        corner_turn_map = yaml.unsafe_load(file)

In [None]:
# get round 3 flags
print(f'Loading {R3_FLAG_FILE} for additional flags to add to the data.')
uvf = UVFlag(R3_FLAG_FILE)
round_3_flags = np.all(uvf.flag_array, axis=-1)

## Functions for main loop

In [None]:
def get_FR_buffers_from_spectra(antpair, jds, freqs, bands, gauss_fit_buffer_cut=GAUSS_FIT_BUFFER_CUT):
    '''This function computes an appropriate buffer to fringe-rate half-widths for a given antpair and times.
    These are used to pad the widths given by hera_cal.frf.sky_frates() for the basic calculation, with a
    Gaussian fit the the FR spectra of the autocorrelation at the highest frequency (and thus widest FR range) in each band.
    
    Arguments:
        antpair: 2-tuple of antennas
        jds: times in units of days
        freqs: frequencies in Hz
        bands: list of slices into freqs
        gauss_fit_buffer_cut: where to cutoff the Gaussian fit to figure out the halfwidth buffer

    Returns:
        fr_buffers: dict from band slice to FR buffers in mHz
    '''
    with h5py.File(AUTO_FR_SPECTRUM_FILE, "r") as h5f:
        metadata = h5f["metadata"]
        bl_to_index_map = {tuple(ap): int(index) for index, antpairs in metadata["baseline_groups"].items() for ap in antpairs}
        spectrum_freqs = metadata["frequencies_MHz"][()] * 1e6
        m_modes = metadata["erh_mode_integer_index"][()]
        mmode_spectrum = h5f["erh_mode_power_spectrum"][:, :, bl_to_index_map[0, 0]]  # use autocorrelation
    
    # convert to fringe rate, accouting for the fact that we have less than 24 hours of LST
    def m2f(m_modes):
        # Convert m-modes to fringe-rates in mHz.
        return m_modes / units.sday.to(units.ks)
    times_ks = (jds - jds[0] + np.median(np.diff(jds))) * units.day.to(units.ks)
    m2f_phasors = np.exp(2j * np.pi * m2f(m_modes)[None, :] * times_ks[:, None])
    m2f_mixer = np.fft.fftshift(np.fft.fft(np.fft.ifftshift(m2f_phasors, axes=0), axis=0), axes=0)
    # f is fringe rate, m is m-mode, n is nu (i.e. freqeuency)
    fr_spectrum = np.abs(np.einsum('fm,mn,mf->fn', m2f_mixer, mmode_spectrum, m2f_mixer.T.conj()))

    # create interpolator as a funciton of frequency 
    fr_spec_interpolator = interpolate.interp1d(spectrum_freqs, fr_spectrum, kind='cubic', fill_value='extrapolate')
    frates = np.fft.fftshift(np.fft.fftfreq(len(times_ks), d=np.median(np.diff(times_ks))))

    # loop over bands to produce results
    fr_buffers = {}
    for band in bands:
        if band is None:
            continue
        band_top_fr_spectrum = fr_spec_interpolator(freqs[band][-1])  # take top frequency (widest FR) per band
        band_top_fr_spectrum /= np.max(band_top_fr_spectrum)

        # fit gaussian to get a decent estimate of the width without being too sensitive to the FT of the limited time range
        gaussian = lambda x, a, sigma: a * np.exp(-(x**2) / (2 * sigma**2))
        initial_guess = [1.0, np.std(frates[band_top_fr_spectrum > 1e-2])]
        popt, _ = optimize.curve_fit(gaussian, frates, band_top_fr_spectrum, p0=initial_guess)
        fr_buffers[band] = np.abs(2 * popt[1]**2 * np.log(gauss_fit_buffer_cut))**.5  # how far out in the gaussian fit should we go
        
    return fr_buffers

In [None]:
def get_fr_centers_and_hws(antpair, hd, times, freqs, bands):
    '''Figure out the range of FRs in Hz spanned by each band, buffered by the size of the autocorrelation FR kernel.'''
    fr_buffers = get_FR_buffers_from_spectra(antpair, times, freqs, bands)
    fr_centers = {}
    fr_hws = {}
    for band in bands:
        if band is None:
            continue
        hd_here = hd.select(inplace=False, frequencies=hd.freqs[band])
        fr_centers[band] = list(sky_frates(hd_here)[0].values())[0] / 1e3  # converts to Hz
        fr_hws[band] = (list(sky_frates(hd_here)[1].values())[0] + fr_buffers[band]) / 1e3
    return fr_centers, fr_hws

In [None]:
def get_ip_nsamples(nsamples, tslice, bands):
    '''Put in reasonable values for nsamples in totally flagged integrations (used only for 2D DPSS fitting)'''
    ip_nsamples = copy.deepcopy(nsamples)
    for bl in nsamples:
        for band in [low_band, high_band]:
            if band is None:
                continue
            med = np.median(nsamples[bl][tslice, band])
            all_flagged = np.all(flags[bl][tslice, band], axis=1)
            ip_nsamples[bl][tslice, band][all_flagged] = med

    # check that all ip_nsamples are constant across frequency within a band
    for band in bands:
        assert np.all(ip_nsamples[bl][tslice, band] == ip_nsamples[bl][tslice, band][:, 0:1])
    
    return ip_nsamples

In [None]:
def get_weights_for_2D_inpainting(data, flags, ip_autos, auto_flags):
    '''Get inverse noise variance weights for inpainting. These come in two flavors:
        * weights_before_ip: has 0s wherever the data or autos are flagged
        * weights_after_ip: uses inpainted autos for "noise," so only has 0s wherever the 
            autos weren't inpainted (in practice, no where in the bands/tslice of interest).
    '''
    weights_before_ip = {}
    weights_after_ip = {}
    for bl in data:
        ant1, ant2 = utils.split_bl(bl)
        
        auto_bl_1 = [k for k in ip_autos if k[2] == utils.join_pol(ant1[1], ant1[1])][0]
        auto_bl_2 = [k for k in ip_autos if k[2] == utils.join_pol(ant2[1], ant2[1])][0]
        noise = (np.abs(ip_autos[auto_bl_1] * ip_autos[auto_bl_2]) / (ip_nsamples[bl] * dt * df))**.5
        flags_here = (~np.isfinite(ip_autos[auto_bl_1])) | (~np.isfinite(ip_autos[auto_bl_2])) | np.all(flags[bl])
        weights_after_ip[bl] = np.where(flags_here, 0, noise**-2)
        
        flags_here |= flags[bl] | auto_flags[auto_bl_1] | auto_flags[auto_bl_2]
        weights_before_ip[bl] = np.where(flags_here, 0, noise**-2)
        
        for wgts in [weights_after_ip[bl], weights_before_ip[bl]]:
            if np.any(wgts > 0):
                wgts /= np.mean(wgts[wgts > 0])
    return weights_before_ip, weights_after_ip

In [None]:
def fit_2D_DPSS(data, weights, filter_delay, fr_centers, fr_hws, **kwargs):
    '''Fit a 2D DPSS model to all the baselines in data. The time-dimension is based
    on sky FRs and the FR spectrum of the autos.
    Arguments:
        data: datacontainer mapping baselines to complex visibility waterfalls
        weights: datacontainer mapping baselines to real weight waterfalls. 
        filter_delay: maximum delay in ns for the 2D filter
        fr_centers: dictionary mapping band to FR centers in Hz
        fr_hws: dictionary mapping band to FR half-widths in Hz
        kwargs: kwargs to pass into sparse_linear_fit_2D()
    
    Returns:
        dpss_fit: datacontainer mapping baselines to 2D DPSS models
    '''
    dpss_fit = copy.deepcopy(data)
    for bl in data.keys():
        # set to all nans by default
        dpss_fit[bl] *= np.nan

        if np.all(weights[bl] == 0):
            continue
        
        # calculate the unflagged region to filter and thus the FR half-width buffers
        tslice, (low_band, high_band) = flag_utils.get_minimal_slices(weights[bl] == 0, freqs=data.freqs, 
                                                                      freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
        fr_buffers = get_FR_buffers_from_spectra(bl[0:2], data.times[tslice], data.freqs, [low_band, high_band])


        for band in [low_band, high_band]:
            if (band is None) or np.all(weights[bl][tslice, band] == 0):
                continue

            # perform 2D DPSS filter    
            time_filters, _ = dpss_operator((data.times[tslice] - data.times[tslice][0]) * 3600 * 24, 
                                            [fr_centers[band]], [fr_hws[band]], eigenval_cutoff=[EIGENVAL_CUTOFF])
            freq_filters, _ = dpss_operator(data.freqs[band], [0.0], [filter_delay / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])
            
            fit, meta = sparse_linear_fit_2D(
                data=data[bl][tslice, band],
                weights=weights[bl][tslice, band],
                axis_1_basis=time_filters,
                axis_2_basis=freq_filters,
                precondition_solver=True,
                **kwargs,
            )
            dpss_fit[bl][tslice, band] = time_filters.dot(fit).dot(freq_filters.T)
            
    return dpss_fit

In [None]:
def distance_to_nearest_nonzero_vectorized(arr):
    """
    For each index in `arr`, return the distance (number of indices) 
    to the nearest nonzero entry, using a fully vectorized 1D distance transform.
    """
    indices = np.arange(len(arr))

    # 1) Find nearest nonzero to the left of each position.
    left_pos = np.where(arr != 0, indices, -np.inf)
    left_pos = np.maximum.accumulate(left_pos)  # in-place left-to-right
    dist_left = np.where(~np.isfinite(left_pos), np.inf, indices - left_pos)

    # 2) Find nearest nonzero to the right of each position.
    right_pos = np.where(arr != 0, indices, np.inf)
    right_pos = np.minimum.accumulate(right_pos[::-1])[::-1] # in-place right-to-left
    dist_right = np.where(~np.isfinite(right_pos), np.inf, right_pos - indices)

    # 3) Final distance is the min of left- and right-distances.
    return np.minimum(dist_left, dist_right)

In [None]:
def four_pol_inpainting_figure(ip_data, flags, ip_flags, close=False):
    '''Plots all phase and amplitude waterfalls before and after inpainting for all 4 polarizations in the data.'''
    fig, axes = plt.subplots(4, 4, figsize=(16, 30), sharex=True, sharey=True, dpi=200, gridspec_kw={'wspace': 0.02, 'hspace': 0.01})
    
    vmin = np.nanmin([np.where(~flags[bl], np.abs(ip_data[bl]), np.nan) for bl in ip_data])
    vmax = np.nanmax([np.where(~flags[bl], np.abs(ip_data[bl]), np.nan) for bl in ip_data])
    lst_grid = ip_data.lsts * 12 / np.pi
    lst_grid[lst_grid > lst_grid[-1]] -= 24
    extent = [ip_data.freqs[0] / 1e6, ip_data.freqs[-1] / 1e6, lst_grid[-1], lst_grid[0]]
    
    for row, bl in zip(axes, data):
    
        row[0].imshow(np.where(flags[bl], np.nan, np.angle(ip_data[bl])), aspect='auto', interpolation='none', cmap='twilight', extent=extent)
        row[1].imshow(np.where(ip_flags[bl], np.nan, np.angle(ip_data[bl])), aspect='auto', interpolation='none', cmap='twilight', extent=extent)
        row[2].imshow(np.where(flags[bl], np.nan, np.abs(ip_data[bl])), aspect='auto', interpolation='none', norm=matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax), extent=extent)
        im = row[3].imshow(np.where(ip_flags[bl], np.nan, np.abs(ip_data[bl])), aspect='auto', interpolation='none', norm=matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax), extent=extent)
        mod24 = lambda x, _: f"{int(x % 24)}"
        row[0].yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(mod24))
        
        row[0].set_ylabel('LST (hours)')
        for ax in row:
            ax.tick_params(axis='x', direction='in')
    
        row[0].text(0.02, 0.99, bl, transform=row[0].transAxes, ha='left', va='top', fontsize=12, color='white',
                bbox=dict(facecolor='black', alpha=0.5, pad=2))
    
    for ax in axes[-1]:
        ax.set_xlabel('Frequency (MHz)')
    
    plt.tight_layout()   
    plt.colorbar(im, ax=axes[0], location='top', label='|V| (Jy)', aspect=50)
    
    if close:
        plt.close(fig)
    return fig

## Generate smooth model of autos for noise modeling

In [None]:
# get autocorrelations
all_outfiles = [outfile for outfiles in corner_turn_map['files_to_outfiles_map'].values() for outfile in outfiles]
for outfile in all_outfiles:
    match = re.search(r'\.(\d+)_(\d+)\.', os.path.basename(outfile))
    if match and match.group(1) == match.group(2):
        print(f'Loading {outfile} for autocorrelations to use for noise modeling.')
        hd_autos = io.HERAData(outfile)
        autos, auto_flags, auto_nsamples = hd_autos.read(polarizations=['ee', 'nn'])
        dt = np.median(np.diff(hd_autos.times)) * 24 * 3600
        df = np.median(np.diff(hd_autos.freqs))        
        for bl in auto_flags:
            auto_flags[bl] |= round_3_flags
        break

In [None]:
weights = {}
for bl in autos:
    noise = 2 *np.abs(autos[bl]) / (auto_nsamples[bl] * dt * df)**.5
    weights[bl] = np.where(auto_flags[bl], 0, noise**-2)
    weights[bl] /= np.mean(weights[bl][weights[bl] > 0])

In [None]:
tslice, (low_band, high_band) = flag_utils.get_minimal_slices(np.all([weights[bl] == 0 for bl in weights], axis=0), 
                                                              freqs=autos.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
fr_centers, fr_hws = get_fr_centers_and_hws(autos.antpairs().pop(), hd_autos, autos.times[tslice], 
                                            autos.freqs, [low_band, high_band])

In [None]:
auto_fit = fit_2D_DPSS(autos, weights, AUTO_INPAINT_DELAY, fr_centers, fr_hws, atol=CG_TOL, btol=CG_TOL)

In [None]:
# remove unused objects to save memory
del hd_autos, autos, auto_nsamples

## Main loop for inpainting crosses (and possibly also doing a FR=0 notch filter)

In [None]:
waterfall_figs = []

for single_bl_file in corner_turn_map['files_to_outfiles_map'][RED_AVG_FILE]:

    # load data
    print(f'Now loading {single_bl_file}')
    hd = io.HERAData(single_bl_file)
    data, flags, nsamples = hd.read()
    dt = np.median(np.diff(hd.times)) * 24 * 3600
    df = np.median(np.diff(hd.freqs))
    antpair = data.antpairs().pop()
    is_auto = (antpair[0] == antpair[1])
    for bl in flags:
        # update with round 3 flags
        flags[bl] |= round_3_flags

    if np.all([flags[bl] for bl in flags]):
        print('\tThis baseline is entirely flagged. Skipping...')
        continue
    
    # get tslice and bands
    tslice, (low_band, high_band) = flag_utils.get_minimal_slices(np.all([flags[bl] for bl in flags], axis=0), freqs=data.freqs, 
                                                                  freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
    # compute FR ranges plausibly attributable to the sky
    fr_centers, fr_hws = get_fr_centers_and_hws(data.antpairs().pop(), hd, data.times[tslice], 
                                                data.freqs, [low_band, high_band])

    # fill in nsamples in totally flagged integrations with reasonable values, used only for weighting, not updated in data
    ip_nsamples = get_ip_nsamples(nsamples, tslice, (low_band, high_band))

    # get weights for 2D DPSS fitting
    weights_before_ip, weights_after_ip = get_weights_for_2D_inpainting(data, flags, auto_fit, auto_flags)

    # set up ip_weights and where_inpainted
    ip_flags = copy.deepcopy(flags)
    where_inpainted = copy.deepcopy(flags)
    for bl in ip_flags:
        ip_flags[bl][:, :] = True
        where_inpainted[bl][:, :] = False
        for band in [low_band, high_band]:
            if band is None:
                continue
            ip_flags[bl][tslice, band] = np.all(weights_before_ip[bl][tslice, band] == 0, axis=1, keepdims=True)
            where_inpainted[bl][tslice, band] = flags[bl][tslice, band] & (~ip_flags[bl][tslice, band])

    # perform iterative 2D DPSS fitting and inpainting
    current_filter_delay = ITERATIVE_DELAY_DELTA
    dpss_fit = None
    ip_data = None
    while True:
        print(f'\tNow inpainting out to {current_filter_delay} ns.')

        if dpss_fit is None:
            # first fit with gaps in data
            weights = weights_before_ip
            data_here = data
        else:
            # subsequent fits
            weights = weights_after_ip
            data_here = ip_data
    
        dpss_fit = fit_2D_DPSS(data_here, weights, current_filter_delay, fr_centers, fr_hws, 
                               atol=CG_TOL, btol=CG_TOL)
    
        ip_data = copy.deepcopy(data)
        for bl in ip_data:
            ip_data[bl] = np.where(flags[bl], dpss_fit[bl], data[bl])

        # increment current delay until we finally do INPAINT_DELAY
        if current_filter_delay == INPAINT_DELAY:
            break
        current_filter_delay += ITERATIVE_DELAY_DELTA
        if current_filter_delay > INPAINT_DELAY:
            current_filter_delay = INPAINT_DELAY

    # Perform 2D-informed (feathered) 
    ip_data = copy.deepcopy(data)
    print(f'\tNow performing 2D-informed 1D DPSS inpainting out to {INPAINT_DELAY} ns.')
    for bl in ip_data:
        # figure out feathering
        distances = np.array([distance_to_nearest_nonzero_vectorized(~flags[bl][tind, :]) for tind in range(flags[bl].shape[0])])
        width = (1e-9 * INPAINT_DELAY)**-1 / df * INPAINT_WIDTH_FACTOR
        rel_weights = (1 + np.exp(-np.log(INPAINT_ZERO_DIST_WEIGHT**-1 - 1) / width * (distances - width)))**-1
    
        d_mdl = np.full_like(data[bl], np.nan)
        for band in [low_band, high_band]: 
            if band is None:
                continue

            # weights from inpainted autos, except totally-flagged integrations, then mutliplied by rel_weights where originally flagged
            wgts = np.where(ip_flags[bl][:, band], 0, weights_after_ip[bl][:, band])
            wgts = np.where(flags[bl][:, band], wgts * rel_weights[:, band], wgts)
            if np.any(wgts > 0):
                wgts /= np.mean(wgts[wgts > 0])

            # 1D DPSS fitting
            d_mdl[:, band], _, _ = fourier_filter(data.freqs[band],
                                                  np.where(flags[bl], dpss_fit[bl], data[bl])[:, band],
                                                  wgts=wgts,
                                                  filter_centers=[0],
                                                  filter_half_widths=[INPAINT_DELAY * 1e-9], 
                                                  mode='dpss_solve',
                                                  eigenval_cutoff=[EIGENVAL_CUTOFF], 
                                                  suppression_factors=[EIGENVAL_CUTOFF], 
                                                  max_contiguous_edge_flags=len(data.freqs),
                                                  filter_dims=1)
        # fill in model where we inpaint, 2D dpss_fit where we don't but were stilled flagged, and data otherwise
        ip_data[bl] = np.where(where_inpainted[bl], d_mdl, np.where(flags[bl], dpss_fit[bl], data[bl]))

    # perform FR=0 filter, if desired
    if FR0_FILTER and not is_auto:
        fr0_filt_ip_data = copy.deepcopy(ip_data)
        for bl in fr0_filt_ip_data:
            wgts_here = np.where(ip_flags[bl], 0, weights_after_ip[bl])[tslice, :]
            d_mdl, _, info = fourier_filter(data.times[tslice] * 24 * 60 * 60, 
                                            np.where(wgts_here == 0, 0, fr0_filt_ip_data[bl][tslice]), 
                                            wgts=wgts_here,
                                            filter_centers=[0], 
                                            filter_half_widths=[FR0_HALFWIDTH / 1000], 
                                            mode='dpss_solve', 
                                            eigenval_cutoff=[EIGENVAL_CUTOFF], 
                                            suppression_factors=[EIGENVAL_CUTOFF], 
                                            max_contiguous_edge_flags=len(data.times), 
                                            filter_dims=0)
            fr0_filt_ip_data[bl][tslice] -= d_mdl
    
    # save figures to display later
    if not np.all(list(flags.values())):
        waterfall_figs.append(four_pol_inpainting_figure((fr0_filt_ip_data if (FR0_FILTER and not is_auto) else ip_data), 
                                                         flags, ip_flags, close=True))

    # Save inpainting results
    hd.update(data=ip_data, flags=ip_flags)
    hd.history += add_to_history
    hd.write_uvh5(single_bl_file.replace('.uvh5', INPAINTED_EXTENSION), clobber=True)

    # Save inpainting results
    if FR0_FILTER and not is_auto:
        hd.update(data=fr0_filt_ip_data)
        hd.write_uvh5(single_bl_file.replace('.uvh5', FR0_FILTER_EXTENSION), clobber=True)
    
    # Save where_inpainted metadata
    hd.update(flags=where_inpainted)
    uvf = UVFlag(hd, mode='flag', copy_flags=True)
    uvf.history += add_to_history
    uvf.write(single_bl_file.replace('.uvh5', WHERE_INPAINTED_EXTENSION), clobber=True)

# *Figure 1: 4-Pol Phase and Amplitude Waterfalls Before and After Inpainting*

Note that this includes FR=0 filtering if `FR0_FILTER` is `True`.

In [None]:
for wf_fig in waterfall_figs:
    display(wf_fig)

## Metadata

In [None]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata', 'numpy']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')

In [None]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')