# Single Baseline 2D DPSS Filtered SNRs

**by Josh Dillon and Tyler Cox**, last updated February 6, 2025

This notebook performs single-baseline, full-day DPSS filtering on corner-turned files to calculate a 2D DPSS filtered SNR, which can later be combined to look for residual RFI or other systematics that may have evaded Round 2 RFI flagging based on 1D DPSS filtering in frequency/delay.

Here's a set of links to skip to particular figures and tables:
# [• Figure 1: Waterfalls of 2D DPSS Filtered SNRs](#Figure-1:-Waterfalls-of-2D-DPSS-Filtered-SNRs)
# [• Figure 2: Histograms of 2D DPSS Filtered SNRs](#Figure-2:-Histograms-of-2D-DPSS-Filtered-SNRs)

In [None]:
import time
tstart = time.time()
!hostname

In [None]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin  # REQUIRED to have the compression plugins available
import numpy as np
import yaml
import glob
import copy
import re
from hera_cal import io, redcal, red_groups
from hera_cal.frf import sky_frates
from hera_cal.smooth_cal import solve_2D_DPSS
from hera_filters.dspec import dpss_operator, sparse_linear_fit_2D
import matplotlib.pyplot as plt
from IPython.display import display
%matplotlib inline

In [None]:
RED_AVG_FILE = os.environ.get("RED_AVG_FILE", None)
# RED_AVG_FILE = '/lustre/aoc/projects/hera/jsdillon/H6C/corner_turn_dev/2459861/zen.2459861.25364.sum.smooth_calibrated.red_avg.uvh5'

CORNER_TURN_MAP_YAML = os.environ.get("CORNER_TURN_MAP_YAML", 
                                        os.path.join(os.path.dirname(RED_AVG_FILE), "single_baseline_files/corner_turn_map.yaml"))

SNR_SUFFIX =  os.environ.get("SNR_SUFFIX", ".2Dfilt_SNR.uvh5")

FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5)) # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0)) # in MHz

FILTER_DELAY = float(os.environ.get("FILTER_DELAY", 750)) # in ns
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))

for setting in ['RED_AVG_FILE', 'CORNER_TURN_MAP_YAML', 'SNR_SUFFIX']:
    print(f'{setting} = "{eval(setting)}"')
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'FILTER_DELAY', 'EIGENVAL_CUTOFF']:
    print(f'{setting} = {eval(setting)}')

## Preliminaries

In [None]:
with open(CORNER_TURN_MAP_YAML, 'r') as file:
    corner_turn_map = yaml.unsafe_load(file)

In [None]:
# get autocorrelations
all_outfiles = [outfile for outfiles in corner_turn_map['files_to_outfiles_map'].values() for outfile in outfiles]
for outfile in all_outfiles:
    match = re.search(r'\.(\d+)_(\d+)\.', os.path.basename(outfile))
    if match and match.group(1) == match.group(2):
        hd_autos = io.HERAData(outfile)
        autos, _, _ = hd_autos.read(polarizations=['ee', 'nn'])
        break

## Define functions for main loop

In [None]:
def get_slices(flags):
    '''Gets the minimal boxes of all unflagged data above and below FM, handling case where one might be entirely flagged.'''
    and_of_flags = np.all([flags[bl] for bl in flags], axis=0)
    low_band, high_band, tslice = None, None, None
    if not np.all(and_of_flags):    
        # get band slices
        not_always_flagged_freqs = data.freqs[~np.all(and_of_flags, axis=0)]
        if np.any(not_always_flagged_freqs < FM_LOW_FREQ * 1e6):
            low_start = np.argwhere(data.freqs == np.min(not_always_flagged_freqs))[0][0]
            low_stop = np.argwhere(data.freqs == np.max(not_always_flagged_freqs[not_always_flagged_freqs < FM_LOW_FREQ * 1e6]))[0][0]
            low_band = slice(low_start, low_stop + 1)
        if np.any(not_always_flagged_freqs > FM_LOW_FREQ * 1e6):
            high_start = np.argwhere(data.freqs == np.min(not_always_flagged_freqs[not_always_flagged_freqs > FM_LOW_FREQ * 1e6]))[0][0]
            high_stop = np.argwhere(data.freqs == np.max(not_always_flagged_freqs))[0][0]
            high_band = slice(high_start, high_stop + 1)
        
        # get time slice
        not_always_flagged_tinds = np.arange(len(data.times))[~np.all(and_of_flags, axis=1)]
        tslice = slice(np.min(not_always_flagged_tinds), np.max(not_always_flagged_tinds) + 1)
    return low_band, high_band, tslice

In [None]:
def plot_2D_filtered_SNR_waterfalls():
    fig, axes = plt.subplots(1, len(data), figsize=(14,10), dpi=100, sharex=True, sharey=True)
    extent = [data.freqs[0] / 1e6, data.freqs[-1] / 1e6, data.times[-1] - int(data.times[0]), data.times[0] - int(data.times[0])]
    vmax = 10
    
    for bl, ax in zip(data, axes):
        im = ax.imshow(np.where(flags[bl], np.nan, np.abs(filtered_SNR[bl])), aspect='auto', interpolation='none', cmap='afmhot_r', vmin=0, vmax=vmax, extent=extent)
        ax.set_title(bl)
        ax.set_xlabel('Frequency (MHz)')
    
    axes[0].set_ylabel(f'JD - {int(data.times[0])}')
    plt.tight_layout()
    largest_pixel = np.max([np.max(np.abs(filtered_SNR[bl][~flags[bl]])) 
                            for bl in filtered_SNR if not np.all(flags[bl])])
    plt.colorbar(im, ax=axes, label='|2D DPSS Filtered SNR|', pad=.02, 
                 extend=('max' if largest_pixel > vmax else None))
    plt.close(fig)
    return fig

In [None]:
def plot_2D_filtered_SNR_histograms():
    fig = plt.figure(figsize=(12, 4))
    bins = np.arange(0, 25, .05)
    
    all_densities = []
    for bl in filtered_SNR:
        hist = plt.hist(np.where(flags[bl], np.nan, np.abs(filtered_SNR[bl])).ravel(), bins=bins, label=str(bl), density=True, alpha=.5)
        all_densities.extend(hist[0][hist[0] > 0])
    
    plt.plot(bins, 2 * bins * np.exp(-bins**2), 'k--', label='Rayleigh Distribution (Noise-Only)')
    plt.yscale('log')
    plt.ylim(np.min(all_densities) / 2, np.max(all_densities) * 2)
    plt.legend()
    plt.ylabel('Density')
    plt.xlabel('2D DPSS Filtered SNR')
    plt.tight_layout()
    plt.close(fig)
    return fig

In [None]:
def estimate_variance_correction(flags, time_filters, freq_filters):
    """
    Estimate the variance correction from a 2D DPSS fit assuming the flags are separable and correcting 
    for the portion of the flags which are not separable
    """
    # Get the separable portion of the flags
    ntimes, nfreqs = flags.shape
    freq_mask = (~np.all(flags, axis=0)).astype(float)
    time_mask = (~np.all(flags, axis=1)).astype(float)
    
    # Compute the leverage for the frequency-axis
    leverage_f = np.sum(
        freq_filters.T * np.linalg.pinv(
            (freq_filters.T.conj() * freq_mask).dot(freq_filters)
        ).dot(freq_filters.T.conj() * freq_mask),
        axis=0
    )
    # Compute the leverage for the frequency-axis
    leverage_t = np.sum(
        time_filters.T *
        np.linalg.pinv((time_filters.T.conj() * time_mask).dot(time_filters)).dot(time_filters.T.conj() * time_mask),
        axis=0
    )
    
    # Compute the outer product of the leverage along each axis
    leverage = np.abs(np.outer(leverage_t, leverage_f))
    
    # Rescale the leverage to handle point which are not separable in time and frequencys
    n_separable_flags = (
        np.sum(1 - freq_mask) * ntimes + 
        np.sum(1 - time_mask) * nfreqs - 
        np.sum(1 - time_mask) * np.sum(1 - freq_mask)
    )
    flagging_frac = (flags.sum() - n_separable_flags) / flags.size
    return (1 - flagging_frac ** 2) * (1 - leverage)

## Perform 2D DPSS filtering, looping over baselines

In [None]:
waterfall_figs = []
histogram_figs = []

for single_bl_file in corner_turn_map['files_to_outfiles_map'][RED_AVG_FILE]:
    # Load data
    print(f'Now loading {single_bl_file}')
    hd = io.HERAData(single_bl_file)
    data, flags, nsamples = hd.read(polarizations=['ee', 'nn'])
    dt = np.median(np.diff(hd.times)) * 24 * 3600
    df = np.median(np.diff(hd.freqs)) 
    low_band, high_band, tslice = get_slices(flags)

    # Perform filtering
    filtered_SNR = copy.deepcopy(data)
    for bl in filtered_SNR.keys():

        # get sky-like FR ranges
        fr_center = sky_frates(hd)[0][bl]
        fr_hw = sky_frates(hd)[1][bl]

        # calculate noise
        auto_bl = [k for k in autos if k[2] == bl[2]][0]
        noise = np.abs(autos[auto_bl]) / (nsamples[bl] * dt * df)**.5
        
        for band in [low_band, high_band]:
            if (band is None) or np.all(flags[bl][band, tslice]):
                continue
            time_filters, _ = dpss_operator((data.times[tslice] - data.times[tslice][0]) * 3600 * 24, 
                                            [fr_center / 1e3], [fr_hw / 1e3], eigenval_cutoff=[EIGENVAL_CUTOFF])
            freq_filters, _ = dpss_operator(data.freqs[band], [0.0], [FILTER_DELAY / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])
            
            SNR = data[bl][tslice, band] / noise[tslice, band]
            fit, meta = sparse_linear_fit_2D(
                data=SNR,
                weights=(~flags[bl][tslice, band]).astype(float),
                axis_1_basis=time_filters,
                axis_2_basis=freq_filters,
            )
            d_mdl = time_filters.dot(fit).dot(freq_filters.T)
            filtered_SNR[bl][tslice, band] = SNR - d_mdl
    
            # estimate the leverage as the outer-product of the leverages along each axis assuming separable flags
            # and a small correction which accounts for the non-separable flags
            variance_correction = estimate_variance_correction(flags[bl][tslice, band], time_filters, freq_filters)
            filtered_SNR[bl][tslice, band] /= variance_correction

    # save figures to display later
    if not np.all(list(flags.values())):
        waterfall_figs.append(plot_2D_filtered_SNR_waterfalls())
        histogram_figs.append(plot_2D_filtered_SNR_histograms())
    else:
        print(f'{list(flags.keys())} are all entirely flagged.')

    # save results
    hd.update(data=filtered_SNR)
    print(f"Writing results to {single_bl_file.replace('.uvh5', SNR_SUFFIX)}")
    hd.write_uvh5(single_bl_file.replace('.uvh5', SNR_SUFFIX), clobber=True)

# *Figure 1: Waterfalls of 2D DPSS Filtered SNRs*

In [None]:
for wf_fig in waterfall_figs:
    display(wf_fig)

# *Figure 2: Histograms of 2D DPSS Filtered SNRs*

In [None]:
for h_fig in histogram_figs:
    display(h_fig)

## Metadata

In [None]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'pyuvdata', 'numpy']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')

In [None]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')