# Single Baseline 2D DPSS Filtered SNRs

**by Josh Dillon and Tyler Cox**, last updated May 15, 2025

This notebook performs single-baseline, full-day DPSS filtering on corner-turned files to calculate a 2D DPSS filtered SNR, which can later be combined to look for residual RFI or other systematics that may have evaded Round 2 RFI flagging based on 1D DPSS filtering in frequency/delay.

Here's a set of links to skip to particular figures and tables:
# [• Figure 1: Waterfalls of 2D DPSS Filtered SNRs](#Figure-1:-Waterfalls-of-2D-DPSS-Filtered-SNRs)
# [• Figure 2: Histograms of 2D DPSS Filtered SNRs](#Figure-2:-Histograms-of-2D-DPSS-Filtered-SNRs)

In [None]:
import time
tstart = time.time()
!hostname

In [None]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin  # REQUIRED to have the compression plugins available
import numpy as np
import yaml
import glob
import copy
import re
from astropy import units
from scipy import interpolate, optimize, constants
from hera_cal import io, redcal, red_groups, flag_utils
from hera_cal.frf import sky_frates, get_FR_buffer_from_spectra
from hera_cal.smooth_cal import solve_2D_DPSS
from hera_filters.dspec import dpss_operator, sparse_linear_fit_2D, fourier_filter
import matplotlib.pyplot as plt
from IPython.display import display
%matplotlib inline

In [None]:
RED_AVG_FILE = os.environ.get("RED_AVG_FILE", None)
# RED_AVG_FILE = '/lustre/aoc/projects/hera/jsdillon/H6C/IDR3/2459861/zen.2459861.25319.sum.smooth_calibrated.red_avg.uvh5'

CORNER_TURN_MAP_YAML = os.environ.get("CORNER_TURN_MAP_YAML", 
                                        os.path.join(os.path.dirname(RED_AVG_FILE), "single_baseline_files/corner_turn_map.yaml"))

SNR_SUFFIX =  os.environ.get("SNR_SUFFIX", ".2Dfilt_SNR.uvh5")

FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz

FILTER_DELAY = float(os.environ.get("FILTER_DELAY", 750)) # in ns
POST_FILTER_DELAY_LOW_BAND = float(os.environ.get("POST_FILTER_DELAY_LOW_BAND", 200.0)) # in ns
POST_FILTER_DELAY_HIGH_BAND = float(os.environ.get("POST_FILTER_DELAY_HIGH_BAND", 50.0)) # in ns

AUTO_FR_SPECTRUM_FILE = '/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5'
GAUSS_FIT_BUFFER_CUT = float(os.environ.get("GAUSS_FIT_BUFFER_CUT", 1e-5))

EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))
CG_ITER_LIM = int(os.environ.get("CG_ITER_LIM", 500))
TV_CHAN_EDGES = os.environ.get("TV_CHAN_EDGES", "174,182,190,198,206,214,222,230,238,246,254")
TV_THRESH = float(os.environ.get("TV_THRESH", 1.0))
MIN_SAMP_FRAC = float(os.environ.get("MIN_SAMP_FRAC", .15))

for setting in ['RED_AVG_FILE', 'CORNER_TURN_MAP_YAML', 'SNR_SUFFIX', 'AUTO_FR_SPECTRUM_FILE', 'TV_CHAN_EDGES']:
    print(f'{setting} = "{eval(setting)}"')
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'FILTER_DELAY', 'POST_FILTER_DELAY_LOW_BAND', 'POST_FILTER_DELAY_HIGH_BAND',
                'GAUSS_FIT_BUFFER_CUT', 'EIGENVAL_CUTOFF', 'CG_ITER_LIM', 'TV_THRESH', 'MIN_SAMP_FRAC']:
    print(f'{setting} = {eval(setting)}')

## Preliminaries

In [None]:
with open(CORNER_TURN_MAP_YAML, 'r') as file:
    corner_turn_map = yaml.unsafe_load(file)

In [None]:
# get autocorrelations
all_outfiles = [outfile for outfiles in corner_turn_map['files_to_outfiles_map'].values() for outfile in outfiles]
for outfile in all_outfiles:
    match = re.search(r'\.(\d+)_(\d+)\.', os.path.basename(outfile))
    if match and match.group(1) == match.group(2):
        hd_autos = io.HERAData(outfile)
        autos, _, auto_nsamples = hd_autos.read(polarizations=['ee', 'nn'])
        break

In [None]:
# define slices for TV allocations
tv_edges = [float(edge) for edge in TV_CHAN_EDGES.split(',')]
tv_slices = []
for i in range(len(tv_edges) - 1):
    chans_in_band = np.argwhere((autos.freqs / 1e6 > tv_edges[i]) & (autos.freqs / 1e6 < tv_edges[i+1]))
    if len(chans_in_band) > 0:
        tv_slices.append(slice(np.min(chans_in_band), np.max(chans_in_band) + 1))

## Define functions for main loop

In [None]:
def plot_2D_filtered_SNR_waterfalls(vmax=5):
    fig, axes = plt.subplots(1, len(data), figsize=(14,10), dpi=200, sharex=True, sharey=True)
    extent = [data.freqs[0] / 1e6, data.freqs[-1] / 1e6, data.times[-1] - int(data.times[0]), data.times[0] - int(data.times[0])]
    
    for bl, ax in zip(data, axes):
        im = ax.imshow(np.where(flags[bl], np.nan, np.abs(filtered_SNR[bl])), aspect='auto', interpolation='none', 
                       cmap='afmhot_r', vmin=0, vmax=vmax, extent=extent)
        ax.set_title(bl)
        ax.set_xlabel('Frequency (MHz)')
        for freq in tv_edges:
            if freq < data.freqs[-1] * 1e-6:
                ax.axvline(freq, lw=.5, ls='--', color='k')

    
    axes[0].set_ylabel(f'JD - {int(data.times[0])}')
    plt.tight_layout()
    largest_pixel = np.max([np.max(np.abs(filtered_SNR[bl][~flags[bl]])) 
                            for bl in filtered_SNR if not np.all(flags[bl])])
    plt.colorbar(im, ax=axes, label='|2D DPSS Filtered SNR|', pad=.02, 
                 extend=('max' if largest_pixel > vmax else None))
    plt.close(fig)
    return fig

In [None]:
def plot_2D_filtered_SNR_histograms():
    fig = plt.figure(figsize=(12, 4))
    bins = np.arange(0, 25, .05)
    
    all_densities = []
    for bl in filtered_SNR:
        hist = plt.hist(np.where(flags[bl], np.nan, np.abs(filtered_SNR[bl])).ravel(), bins=bins, label=str(bl), density=True, alpha=.5)
        all_densities.extend(hist[0][hist[0] > 0])
    
    plt.plot(bins, 2 * bins * np.exp(-bins**2), 'k--', label='Rayleigh Distribution (Noise-Only)')
    plt.yscale('log')
    plt.ylim(np.min(all_densities) / 2, np.max(all_densities) * 2)
    plt.legend()
    plt.ylabel('Density')
    plt.xlabel('2D DPSS Filtered SNR')
    plt.tight_layout()
    plt.close(fig)
    return fig

In [None]:
FR_CENTER_AND_HW_CACHE = {}

def cache_fr_center_and_hw(hd, antpair, tslice, band):
    '''Figure out the range of FRs in Hz spanned for a given band and tslice, buffered by the size of the autocorrelation FR kernel,
    and stores the value in FR_CENTER_AND_HW_CACHE (if it hasn't already been computed.'''
    if (tslice is not None) and (band is not None) and ((antpair, tslice, band) not in FR_CENTER_AND_HW_CACHE):
        # calculate fringe rate center and half-width and then update cache
        fr_buffer = get_FR_buffer_from_spectra(AUTO_FR_SPECTRUM_FILE, hd.times[tslice], hd.freqs[band], 
                                               gauss_fit_buffer_cut=GAUSS_FIT_BUFFER_CUT)
        hd_here = hd.select(inplace=False, frequencies=hd.freqs[band])
        fr_center = list(sky_frates(hd_here)[0].values())[0] / 1e3  # converts to Hz
        fr_hw = (list(sky_frates(hd_here)[1].values())[0] + fr_buffer) / 1e3    
        FR_CENTER_AND_HW_CACHE[(antpair, tslice, band)] = fr_center, fr_hw

In [None]:
def estimate_SNR_correction(wgts, time_filters, freq_filters):
    """
    Estimate the SNR correction from a 2D DPSS fit with a given set of weights. Assumes weights are separable to 
    make the calculation tracktable, but then accounts for the portion of the flags which are not separable.
    """
    # Get the separable portion of the weights/flags
    ntimes, nfreqs = wgts.shape
    freq_mask = (~np.all(wgts == 0, axis=0)).astype(float)
    avg_freq_wgts = np.where(freq_mask, np.nanmean(np.where(wgts == 0, np.nan, wgts), axis=0), 0)
    time_mask = (~np.all(wgts == 0, axis=1)).astype(float)
    avg_time_wgts = np.where(time_mask, np.nanmean(np.where(wgts == 0, np.nan, wgts / avg_freq_wgts), axis=1), 0)
    
    # Compute the leverage for the frequency-axis
    leverage_f = np.sum(
        freq_filters.T * np.linalg.pinv(
            (freq_filters.T.conj() * avg_freq_wgts).dot(freq_filters)
        ).dot(freq_filters.T.conj() * avg_freq_wgts),
        axis=0
    )
    # Compute the leverage for the frequency-axis
    leverage_t = np.sum(
        time_filters.T *
        np.linalg.pinv((time_filters.T.conj() * avg_time_wgts).dot(time_filters)).dot(time_filters.T.conj() * avg_time_wgts),
        axis=0
    )
    
    # Compute the outer product of the leverage along each axis
    leverage = np.abs(np.outer(leverage_t, leverage_f))
    
    # Rescale the leverage to handle flags which are not separable in time and frequency
    n_separable_flags = (
        np.sum(1 - freq_mask) * ntimes + 
        np.sum(1 - time_mask) * nfreqs - 
        np.sum(1 - time_mask) * np.sum(1 - freq_mask)
    )
    flagging_frac = ((wgts == 0).sum() - n_separable_flags) / wgts.size
    return (1 - flagging_frac ** 2) * (1 - leverage)**.5

## Perform 2D DPSS filtering, looping over baselines

In [None]:
waterfall_figs = []
histogram_figs = []

for single_bl_file in corner_turn_map['files_to_outfiles_map'][RED_AVG_FILE]:
    # Load data
    print(f'Now loading {single_bl_file}')
    hd = io.HERAData(single_bl_file)
    data, flags, nsamples = hd.read(polarizations=['ee', 'nn'])
    dt = np.median(np.diff(hd.times)) * 24 * 3600
    df = np.median(np.diff(hd.freqs))
    antpair = data.antpairs().pop()

    med_auto_nsamples = {bl[2]: np.median(n) for bl, n in auto_nsamples.items()}
    if not any([np.median(nsamples[bl]) > MIN_SAMP_FRAC * med_auto_nsamples[bl[2]] for bl in nsamples]):
        print('\tNo polarization has enough nsamples to be worth filtering. Skipping...')
        continue

    # get tslice and bands and figure out the corresponding fr_centers and fr_hws
    tslices = {}
    bands = {}
    for bl in data:
        tslices[bl], bands[bl] = flag_utils.get_minimal_slices(flags[bl], freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])
        for tslice, band in zip(tslices[bl], bands[bl]):
            cache_fr_center_and_hw(hd, bl[0:2], tslice, band)
    
    # Perform filtering
    filtered_SNR = copy.deepcopy(data)
    for bl in filtered_SNR.keys():

        # calculate noise
        auto_bl = [k for k in autos if k[2] == bl[2]][0]
        noise = np.abs(autos[auto_bl]) / (nsamples[bl] * dt * df)**.5
        wgts = np.where(flags[bl], 0, noise**-2)
        if np.any(wgts > 0):
            wgts /= np.mean(wgts[wgts > 0])
        
        for tslice, band in zip(tslices[bl], bands[bl]):
            if (band is None) or np.all(flags[bl][tslice, band]):
                continue

            # perform 2D DPSS filter
            fr_center, fr_hw = FR_CENTER_AND_HW_CACHE[(antpair, tslice, band)]
            time_filters, _ = dpss_operator((data.times[tslice] - data.times[tslice][0]) * 3600 * 24, 
                                            [fr_center], [fr_hw], eigenval_cutoff=[EIGENVAL_CUTOFF])
            freq_filters, _ = dpss_operator(data.freqs[band], [0.0], [FILTER_DELAY / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])
            fit, meta = sparse_linear_fit_2D(
                data=data[bl][tslice, band],
                weights=wgts[tslice, band],
                axis_1_basis=time_filters,
                axis_2_basis=freq_filters,
                precondition_solver=True,
                iter_lim=CG_ITER_LIM,
            )
            d_mdl = time_filters.dot(fit).dot(freq_filters.T)
            filtered_SNR[bl][tslice, band] = np.where(flags[bl][tslice, band], 0, 
                                                      (data[bl][tslice, band] - d_mdl) / noise[tslice, band])
            # estimate the leverage as the outer-product of the leverages along each axis assuming separable weights
            # and a small correction which accounts for the non-separable flags
            SNR_correction = estimate_SNR_correction(wgts[tslice, band], time_filters, freq_filters)
            filtered_SNR[bl][tslice, band] /= SNR_correction

            # identify TV channels with high SNR and give them near-0 weight when 1D DPSS filtering
            wgts_1D = (~flags[bl]).astype(float)
            if band == bands[bl][1]:  # high band
                predicted_mean = np.sqrt(np.pi) / 2
                predicted_std = np.sqrt((4 - np.pi) / 4)
                zscore = np.where(flags[bl], np.nan, (np.abs(filtered_SNR[bl]) - predicted_mean) / predicted_std)
                for tvs in tv_slices:
                    for tind in range(zscore.shape[0]):
                        if np.nanmean(zscore[tind, tvs]) > TV_THRESH:
                            wgts_1D[tind, tvs] *= np.finfo(float).eps  # make weight very small
            
            # filter out very low delay modes in 1D 
            post_filter_delay = (POST_FILTER_DELAY_LOW_BAND if band == bands[bl][0] else POST_FILTER_DELAY_HIGH_BAND)
            d_mdl_1D, _, _ = fourier_filter(data.freqs[band], 
                                            filtered_SNR[bl][tslice, band], 
                                            wgts=wgts_1D[tslice, band], 
                                            filter_centers=[0], 
                                            filter_half_widths=[post_filter_delay / 1e9],
                                            mode='dpss_solve', 
                                            eigenval_cutoff=[EIGENVAL_CUTOFF],
                                            suppression_factors=[EIGENVAL_CUTOFF], 
                                            max_contiguous_edge_flags=len(data.freqs))
            filtered_SNR[bl][tslice, band] = np.where(flags[bl][tslice, band], 0, filtered_SNR[bl][tslice, band] - d_mdl_1D)
    
            # calculate and apply another correction factor based on the leverage to flatten out the SNR
            correction_factors = np.full_like(wgts_1D[tslice, band], np.nan)     
            X = dpss_operator(data.freqs[band], [0], filter_half_widths=[post_filter_delay / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
            for tind in range(wgts_1D[tslice, band].shape[0]):
                W = wgts_1D[tslice.start + tind, band]
                if not np.all(W == 0):
                    leverage = np.diag(X @ np.linalg.pinv(np.dot(X.T * W, X)) @ (X.T * W))
                    correction_factors[tind, :] = np.where(leverage > 0, (1 - leverage)**.5, np.nan)
            filtered_SNR[bl][tslice, band] /= correction_factors

        # get rid of nans/infs in flagged channels
        filtered_SNR[bl] = np.where(flags[bl], 0, filtered_SNR[bl])
    
    # save figures to display later
    if not np.all(list(flags.values())):
        waterfall_figs.append(plot_2D_filtered_SNR_waterfalls())
        histogram_figs.append(plot_2D_filtered_SNR_histograms())
    else:
        print(f'{list(flags.keys())} are all entirely flagged.')

    # save results
    hd.update(data=filtered_SNR)
    print(f"Writing results to {single_bl_file.replace('.uvh5', SNR_SUFFIX)}")
    hd.write_uvh5(single_bl_file.replace('.uvh5', SNR_SUFFIX), clobber=True)

# *Figure 1: Waterfalls of 2D DPSS Filtered SNRs*

In [None]:
for wf_fig in waterfall_figs:
    display(wf_fig)

# *Figure 2: Histograms of 2D DPSS Filtered SNRs*

In [None]:
for h_fig in histogram_figs:
    display(h_fig)

## Metadata

In [None]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata', 'numpy']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')

In [None]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')