# Single Baseline pI FRF SNR

**by Josh Dillon**, last updated February 13, 2026

This notebook takes corner-turned, calibrated, redundantly-averaged visibility data, forms pseudo-Stokes pI, and computes delay-filtered, fringe-rate-filtered SNR waterfalls. The results are written out as uvh5 files to be combined across baselines to look for residual structure that fringes like the main beam.

Here's a set of links to skip to particular figures and tables:
# [• Figure 1: Delay-Filtered pI SNR in Fringe Rate Space](#Figure-1:-Delay-Filtered-pI-SNR-in-Fringe-Rate-Space)
# [• Figure 2: Delay+FR Filtered pI SNR Waterfall](#Figure-2:-Delay+FR-Filtered-pI-SNR-Waterfall)
# [• Figure 3: Delay+FR Filtered pI SNR Histogram](#Figure-3:-Delay+FR-Filtered-pI-SNR-Histogram)

In [None]:
import time
tstart = time.time()
!hostname

In [None]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import h5py
import hdf5plugin
import numpy as np
import yaml
import glob
import copy
import re
import matplotlib
from astropy import units
from scipy import interpolate
from scipy.signal.windows import blackmanharris
from hera_cal import io, utils, flag_utils
from hera_cal.frf import sky_frates_single, get_FR_buffer_from_spectra, get_m2f_mixer
from hera_filters.dspec import dpss_operator, fourier_filter
import hera_pspec as hp
import uvtools
import matplotlib.pyplot as plt
from IPython.display import display
%matplotlib inline

In [None]:
RED_AVG_FILE = os.environ.get("RED_AVG_FILE", None)
# RED_AVG_FILE = '/lustre/aoc/projects/hera/h6c-analysis/IDR3/2459935/zen.2459935.25792.sum.smooth_calibrated.red_avg.uvh5' # 7_61

CORNER_TURN_MAP_YAML = os.environ.get("CORNER_TURN_MAP_YAML", 
                                      os.path.join(os.path.dirname(RED_AVG_FILE), "single_baseline_files/corner_turn_map.yaml"))

SNR_SUFFIX = os.environ.get("SNR_SUFFIX", ".pI_FRF_SNR.uvh5")

FM_LOW_FREQ = float(os.environ.get("FM_LOW_FREQ", 87.5))  # in MHz
FM_HIGH_FREQ = float(os.environ.get("FM_HIGH_FREQ", 108.0))  # in MHz

FILTER_DELAY = float(os.environ.get("FILTER_DELAY", 750))  # in ns
EIGENVAL_CUTOFF = float(os.environ.get("EIGENVAL_CUTOFF", 1e-12))

FR_SPECTRA_FILE = os.environ.get("FR_SPECTRA_FILE", 
                                 "/lustre/aoc/projects/hera/h6c-analysis/IDR3/beam_simulation_products/spectra_cache_hera_core.h5")
AUTO_FR_SPECTRUM_FILE = os.environ.get("AUTO_FR_SPECTRUM_FILE",
                                       "/lustre/aoc/projects/hera/zmartino/hera_frf/spectra_cache/spectra_cache_hera_auto.h5")
XTALK_FR = float(os.environ.get("XTALK_FR", 0.01))  # in mHz

FR_QUANTILE_LOW = float(os.environ.get("FR_QUANTILE_LOW", 0.05))
FR_QUANTILE_HIGH = float(os.environ.get("FR_QUANTILE_HIGH", 0.95))

MIN_SAMP_FRAC = float(os.environ.get("MIN_SAMP_FRAC", .05))

for setting in ['RED_AVG_FILE', 'CORNER_TURN_MAP_YAML', 'SNR_SUFFIX', 'FR_SPECTRA_FILE', 'AUTO_FR_SPECTRUM_FILE']:
    print(f'{setting} = "{eval(setting)}"')
for setting in ['FM_LOW_FREQ', 'FM_HIGH_FREQ', 'FILTER_DELAY', 'EIGENVAL_CUTOFF', 'XTALK_FR',
                'FR_QUANTILE_LOW', 'FR_QUANTILE_HIGH', 'MIN_SAMP_FRAC']:
    print(f'{setting} = {eval(setting)}')

## Preliminaries

In [None]:
with open(CORNER_TURN_MAP_YAML, 'r') as file:
    corner_turn_map = yaml.unsafe_load(file)

In [None]:
# get autocorrelations
# TODO: generalize for not-previously inpainted data
all_outfiles = [outfile.replace('.uvh5', '.inpainted.uvh5') for outfiles in corner_turn_map['files_to_outfiles_map'].values() for outfile in outfiles]
for outfile in all_outfiles:
    match = re.search(r'\.(\d+)_(\d+)\.', os.path.basename(outfile))
    if match and match.group(1) == match.group(2):
        hd_autos = io.HERAData(outfile)
        autos, _, auto_nsamples = hd_autos.read(polarizations=['ee', 'nn'])
        break

## Define Plotting and Helper Functions

In [None]:
def compute_mb_fr_ranges(hd, antpair):
    '''Compute main beam fringe rate ranges from beam simulation spectra.
    Returns per-frequency arrays: freqs in MHz, FR bounds in mHz.'''
    with h5py.File(FR_SPECTRA_FILE, "r") as h5f:
        metadata = h5f["metadata"]
        bl_to_index_map = {tuple(ap): int(index) for index, antpairs
                           in metadata["baseline_groups"].items() for ap in antpairs}
        spectrum_freqs = metadata["frequencies_MHz"][()] * 1e6
        m_modes = metadata["erh_mode_integer_index"][()]
        if antpair in bl_to_index_map:
            mmode_spectrum = h5f["erh_mode_power_spectrum"][:, :, bl_to_index_map[antpair]]
        else:
            mmode_spectrum = h5f["erh_mode_power_spectrum"][:, :, bl_to_index_map[antpair[::-1]]]
            m_modes = m_modes * -1

    # Build mixing matrix from m-modes to fringe rates
    full_times = hd.times
    times_ks = (full_times - full_times[0] + np.median(np.diff(full_times))) * units.day.to(units.ks)
    filt_frates = np.fft.fftshift(np.fft.fftfreq(times_ks.size, d=np.median(np.diff(times_ks))))
    _m2f_mixer = get_m2f_mixer(times_ks, m_modes)

    # Vectorized: FR spectrum for every spectrum_freq channel at once
    # _m2f_mixer: (n_frates, n_mmodes), mmode_spectrum: (n_mmodes, n_spectrum_freqs)
    fr_spectra = np.abs(np.einsum("fm,mc,mf->fc", _m2f_mixer, mmode_spectrum, _m2f_mixer.T.conj()))
    # Normalize each channel
    fr_spectra /= fr_spectra.sum(axis=0, keepdims=True)

    # Compute quantile bounds per spectrum_freq channel
    cumsum = np.cumsum(fr_spectra, axis=0)
    spec_tops = np.array([np.interp(FR_QUANTILE_HIGH, cumsum[:, c], filt_frates)
                           for c in range(len(spectrum_freqs))])
    spec_bottoms = np.array([np.interp(FR_QUANTILE_LOW, cumsum[:, c], filt_frates)
                              for c in range(len(spectrum_freqs))])

    # Interpolate from spectrum_freqs to data freqs (with extrapolation)
    mb_frate_tops = interpolate.interp1d(spectrum_freqs, spec_tops, fill_value='extrapolate')(hd.freqs)
    mb_frate_bottoms = interpolate.interp1d(spectrum_freqs, spec_bottoms, fill_value='extrapolate')(hd.freqs)

    return (hd.freqs / 1e6, mb_frate_tops, mb_frate_bottoms)


def compute_sky_fr_ranges(hd, antpair):
    '''Compute sky fringe rate ranges using analytic sky_frates_single + empirical buffer.
    Returns per-frequency arrays in MHz and mHz.'''
    blvec = hd.antpos[antpair[0]] - hd.antpos[antpair[1]]
    latitude = hd.telescope.location.lat.rad
    sky_centers, sky_hws = sky_frates_single(hd.freqs, blvec, latitude)  # mHz
    fr_buffer = get_FR_buffer_from_spectra(AUTO_FR_SPECTRUM_FILE, hd.times, hd.freqs, gauss_fit_buffer_cut=1e-5)
    return (hd.freqs / 1e6, sky_centers + sky_hws + fr_buffer, sky_centers - sky_hws - fr_buffer)


def overlaps_FR0(bands_bl, mb_frate_tops, mb_frate_bottoms):
    '''This checks if the main beam FRs are overlap FR=0 ± XTALK_FR.'''
    for band in bands_bl:    
        # check if all FRs above FR=0 band
        if not ((np.all(mb_frate_tops[band] > XTALK_FR)) and (np.all(mb_frate_bottoms[band] > XTALK_FR))):
            # check if all FRs below FR=0 band
            if not ((np.all(mb_frate_tops[band] < -XTALK_FR)) and (np.all(mb_frate_bottoms[band] < -XTALK_FR))):
                return True
    return False

In [None]:
def plot_fr_waterfall(snr_wf, flags_wf, taper_2d, freqs, times, title,
                      mb_frate_freqs_MHz=None, mb_frate_tops=None, mb_frate_bottoms=None,
                      sky_frate_freqs_MHz=None, sky_frate_tops=None, sky_frate_bottoms=None,
                      vmax=5):
    '''Plot freq vs fringe rate waterfall of |SNR| after FFT along time axis.
    Accepts pre-assembled full waterfalls with a 2D per-band taper.'''
    ntimes = len(times)
    times_in_seconds = (times - times[0]) * 24 * 3600
    frates = uvtools.utils.fourier_freqs(times_in_seconds) * 1000  # mHz

    # Per-column normalization accounting for taper and flags
    unflagged = (~flags_wf).astype(float)
    norm = (ntimes * np.mean((taper_2d * unflagged)**2, axis=0))**.5

    # FFT with per-band taper 
    to_plot = np.fft.fftshift(np.fft.fft(taper_2d * np.where(taper_2d > 0, snr_wf, 0), axis=0), axes=0)
    to_plot = np.abs(to_plot) / norm[np.newaxis, :]

    fig = plt.figure(figsize=(14, 8), dpi=200)
    extent = [freqs[0] / 1e6, freqs[-1] / 1e6, frates[-1], frates[0]]
    im = plt.imshow(to_plot, aspect='auto', interpolation='none',
                    extent=extent, vmin=0, vmax=vmax, cmap='plasma')
    plt.colorbar(im, extend='max', label='|pI SNR|')
    plt.xlabel('Frequency (MHz)')
    plt.ylabel('Fringe Rate (mHz)')
    plt.title(title)

    if sky_frate_freqs_MHz is not None:
        plt.plot(sky_frate_freqs_MHz, sky_frate_tops, 'w:', lw=1, label='Sky FRs')
        plt.plot(sky_frate_freqs_MHz, sky_frate_bottoms, 'w:', lw=1)
        plt.ylim([-np.max([np.abs(sky_frate_tops), np.abs(sky_frate_bottoms)]) * 1.25,
                  np.max([np.abs(sky_frate_tops), np.abs(sky_frate_bottoms)]) * 1.25])
    else:
        plt.ylim([-5, 5])
    if mb_frate_freqs_MHz is not None:
        plt.plot(mb_frate_freqs_MHz, mb_frate_tops, 'w--', lw=1, label='Main Beam FRs')
        plt.plot(mb_frate_freqs_MHz, mb_frate_bottoms, 'w--', lw=1)
    if sky_frate_freqs_MHz is not None or mb_frate_freqs_MHz is not None:
        plt.legend()

    plt.tight_layout()
    plt.close(fig)
    return fig


def plot_time_freq_waterfall(snr_wf, flags_wf, freqs, times, lsts, title, vmax=5):
    '''Plot freq vs time waterfall of |SNR| in real space with LST right axis.
    Accepts pre-assembled full waterfalls.'''
    to_plot = np.where(flags_wf, np.nan, np.abs(snr_wf))

    fig, ax = plt.subplots(figsize=(14, 8), dpi=200)
    extent = [freqs[0] / 1e6, freqs[-1] / 1e6,
              times[-1] - int(times[0]), times[0] - int(times[0])]
    im = ax.imshow(to_plot, aspect='auto', interpolation='none',
                   extent=extent, vmin=0, vmax=vmax, cmap='plasma')
    plt.colorbar(im, extend='max', label='|pI SNR|', ax=ax)
    ax.set_xlabel('Frequency (MHz)')
    ax.set_ylabel(f'JD - {int(times[0])}')
    ax.set_title(title)

    # Add LST right axis with proper wrapping
    lst_grid = lsts * 12 / np.pi  # radians to hours
    lst_grid[lst_grid > lst_grid[-1]] -= 24
    ax2 = ax.twinx()
    ax2.set_ylim(lst_grid[-1], lst_grid[0])
    mod24 = lambda x, _: f"{x % 24:.1f}"
    ax2.yaxis.set_major_formatter(matplotlib.ticker.FuncFormatter(mod24))
    ax2.set_ylabel('LST (hours)')

    plt.tight_layout()
    plt.close(fig)
    return fig


def plot_snr_histograms(snr_wf, flags_wf, title):
    '''Plot histogram of |SNR| compared to the Rayleigh distribution expected for noise-only.
    Accepts pre-assembled full waterfall and flags.'''
    fig = plt.figure(figsize=(12, 5))
    bins = np.arange(0, 10, .01)
    to_hist = np.abs(snr_wf[~flags_wf])
    to_hist = to_hist[np.isfinite(to_hist) & (to_hist > 0)]
    hist = plt.hist(to_hist, bins=bins, density=True, label='Real-space |pI SNR|')
    plt.plot(bins, 2 * bins * np.exp(-bins**2), 'k--', label='Rayleigh Distribution (Noise-Only)')
    plt.yscale('log')
    all_densities = hist[0][hist[0] > 0]
    if len(all_densities) > 0:
        plt.ylim(np.min(all_densities) / 2, np.max(all_densities) * 2)
    plt.legend()
    plt.ylabel('Density')
    plt.xlabel('|pI SNR|')
    plt.xlim([-.5, 10])
    plt.title(title)
    plt.tight_layout()
    plt.close(fig)
    return fig

## Compute pI SNR, Looping Over Baselines

In [None]:
delay_fr_figs = []
waterfall_figs = []
histogram_figs = []

for single_bl_file in corner_turn_map['files_to_outfiles_map'][RED_AVG_FILE]:
    # Load data
    single_bl_file = single_bl_file.replace('.uvh5', '.inpainted.uvh5')
    print(f'Now loading {single_bl_file}')
    hd = io.HERAData(single_bl_file)
    data, flags, nsamples = hd.read(polarizations=['ee', 'nn'])
    dt = np.median(np.diff(hd.times)) * 24 * 3600
    df = np.median(np.diff(hd.freqs))
    antpair = data.antpairs().pop()

    if antpair[0] == antpair[1]:
        print('\tThis baseline is an autocorrelation. Skipping...')
        continue

    med_auto_nsamples = {bl[2]: np.median(n) for bl, n in auto_nsamples.items()}
    if not any([np.median(nsamples[bl]) > MIN_SAMP_FRAC * med_auto_nsamples[bl[2]] for bl in nsamples]):
        print('\tNo polarization has enough nsamples to be worth filtering. Skipping...')
        continue

    # Get tslice and bands split around FM
    tslices_bl, bands_bl = flag_utils.get_minimal_slices(
        flags[antpair + ('ee',)] | flags[antpair + ('nn',)],
        freqs=data.freqs, freq_cuts=[(FM_LOW_FREQ + FM_HIGH_FREQ) * .5e6])

    # Compute FR ranges for this baseline
    print('\tComputing fringe rate ranges...')
    (mb_frate_freqs_MHz, mb_frate_tops, mb_frate_bottoms) = compute_mb_fr_ranges(hd, antpair)
    (sky_frate_freqs_MHz, sky_frate_tops, sky_frate_bottoms) = compute_sky_fr_ranges(hd, antpair)

    # Exclude baselines that overlap FR=0, since there's sometimes extra high delay structure there that's not relevant to this step
    if overlaps_FR0(bands_bl, mb_frate_tops, mb_frate_bottoms):
        print(f'\tThis baseline overlaps with the FR = (0 ± {XTALK_FR}) mHz band. Skipping...')
        continue        

    # Process each band
    filt_flags_full = np.ones((len(hd.times), len(hd.freqs)), dtype=bool)
    dly_filt_SNR_full = np.full((len(hd.times), len(hd.freqs)), np.nan, dtype=complex)
    frf_SNR_full = np.full((len(hd.times), len(hd.freqs)), np.nan, dtype=complex)
    taper_2d = np.zeros((len(hd.times), len(hd.freqs)))

    for tslice, band in zip(tslices_bl, bands_bl):
        if (band is None) or np.all(flags[antpair + ('ee',)][tslice, band] | flags[antpair + ('nn',)][tslice, band]):
            continue

        # Extract per-pol data for this band
        d_ee = data[antpair + ('ee',)][tslice, band]
        d_nn = data[antpair + ('nn',)][tslice, band]
        f_ee = flags[antpair + ('ee',)][tslice, band]
        f_nn = flags[antpair + ('nn',)][tslice, band]
        n_ee = nsamples[antpair + ('ee',)][tslice, band]
        n_nn = nsamples[antpair + ('nn',)][tslice, band]
        a_ee = np.abs(autos[autos.antpairs().pop() + ('ee',)][tslice, band])
        a_nn = np.abs(autos[autos.antpairs().pop() + ('nn',)][tslice, band])

        # Compute variance from autos
        var_pI = a_ee**2 / (dt * df) / n_ee + a_nn**2 / (dt * df) / n_nn

        # Form pseudo-Stokes pI
        d_pI, f_pI, n_pI = hp.pstokes._combine_pol_arrays(
            'ee', 'nn', 'pI', pol_convention=hd.pol_convention,
            data_list=[d_ee, d_nn], flags_list=[f_ee, f_nn],
            nsamples_list=[n_ee, n_nn],
            x_orientation=hd.telescope.get_x_orientation_from_feeds())
        d_pI[f_pI] = 0

        # Compute SNR
        SNR = d_pI / var_pI**.5

        # Delay filter
        print(f'\tDelay filtering band {band}...')
        result, _, _ = fourier_filter(hd.freqs[band], SNR,
            wgts=np.where(f_pI, 0, 1), filter_centers=[0],
            filter_half_widths=[FILTER_DELAY * 1e-9], mode='dpss_solve',
            eigenval_cutoff=[EIGENVAL_CUTOFF], suppression_factors=[EIGENVAL_CUTOFF],
            max_contiguous_edge_flags=len(hd.freqs))
        dly_filt_SNR = SNR - result

        # Leverage correction (uniform weights for SNR data, so this can be computed once)
        X = dpss_operator(hd.freqs[band], [0],
            filter_half_widths=[FILTER_DELAY / 1e9], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
        XtX = X.T @ X
        if np.all(np.isclose(XtX.imag, 0)):
            XtX = np.real(XtX)
        leverage = np.diag(X @ np.linalg.pinv(XtX) @ X.T)
        correction = np.where(leverage > 0, (1 - leverage)**.5, np.nan)
        dly_filt_SNR /= correction[np.newaxis, :]

        # Per-channel fringe rate filter (using full-resolution main beam bounds)
        print(f'\tFR filtering band {band}...')
        frf_SNR = copy.deepcopy(dly_filt_SNR)  # real-space normalized
        times_in_seconds = (hd.times[tslice] - hd.times[tslice][0]) * 24 * 3600
        for chan, (fr_low, fr_high) in enumerate(zip(mb_frate_bottoms[band], mb_frate_tops[band])):
            fr_center = (fr_low + fr_high) / 2 / 1000  # mHz -> Hz
            fr_halfwidth = (fr_high - fr_low) / 2 / 1000  # mHz -> Hz

            # fringe stop and filter
            phase_ramp = np.exp(-2j * np.pi * fr_center * times_in_seconds)[:, np.newaxis]
            stopped_data = dly_filt_SNR[:, chan:chan+1] * phase_ramp
            result, _, _ = fourier_filter(times_in_seconds, stopped_data,
                wgts=np.where(f_pI[:, chan:chan+1], 0, 1), filter_centers=[0],
                filter_half_widths=[fr_halfwidth], mode='dpss_solve',
                eigenval_cutoff=[EIGENVAL_CUTOFF], suppression_factors=[EIGENVAL_CUTOFF],
                max_contiguous_edge_flags=len(data.times), filter_dims=0)

            # Temporal leverage correction
            Xt = dpss_operator(hd.times[tslice] * 24 * 3600, filter_centers=[0],
                filter_half_widths=[fr_halfwidth], eigenval_cutoff=[EIGENVAL_CUTOFF])[0]
            W = np.where(f_pI[:, chan], 0, 1)
            XtWXt = np.dot(Xt.T * W, Xt)
            if np.all(np.isclose(XtWXt.imag, 0)):
                XtWXt = np.real(XtWXt)
            lev_t = np.diag(Xt @ np.linalg.pinv(XtWXt) @ (Xt.T * W))

            frf_SNR[:, chan:chan+1] = np.where(f_pI[:, chan:chan+1], 0, result / np.where(lev_t > 0, lev_t**.5, np.nan)[:, None])
        
        filt_flags_full[tslice, band] = f_pI
        dly_filt_SNR_full[tslice, band] = dly_filt_SNR
        frf_SNR_full[tslice, band] = frf_SNR

        # build 2D taper for plotting
        band_ntimes = tslice.stop - tslice.start
        taper_2d[tslice, band] = blackmanharris(band_ntimes)[:, np.newaxis]

    
    if len(filt_flags_full) == 0:
        print(f'\t{antpair} is entirely flagged.')
        continue

    # Now produce figures to display later
    delay_fr_figs = []
    waterfall_figs = []
    histogram_figs = []
    
    # Figure 1: Delay-filtered SNR in FR space
    delay_fr_figs.append(plot_fr_waterfall(
        dly_filt_SNR_full, filt_flags_full, taper_2d, hd.freqs, hd.times,
        f'{antpair} Delay-Filtered pI',
        mb_frate_freqs_MHz=mb_frate_freqs_MHz, mb_frate_tops=mb_frate_tops, mb_frate_bottoms=mb_frate_bottoms,
        sky_frate_freqs_MHz=sky_frate_freqs_MHz, sky_frate_tops=sky_frate_tops, sky_frate_bottoms=sky_frate_bottoms))
    
    # Figure 2: FR-filtered SNR in time-freq space
    waterfall_figs.append(plot_time_freq_waterfall(
        frf_SNR_full, filt_flags_full, hd.freqs, hd.times, hd.lsts,
        f'{antpair} Delay+FR-Filtered pI'))
    
    # Figure 3: Histogram of real-space SNR from both bands
    histogram_figs.append(plot_snr_histograms(
        frf_SNR_full, filt_flags_full, f'{antpair} Delay+FR-Filtered pI'))
    
    # Store pI FRF SNR in the ee slot of the data container for output, then recast pols as pI
    data[antpair + ('ee',)] = np.where(np.isfinite(frf_SNR_full), frf_SNR_full, 0)
    flags[antpair + ('ee',)] = filt_flags_full
    hd.update(data=data, flags=flags)
    hd.select(polarizations=['ee'])
    hd.polarization_array[0] = utils.polstr2num('pI')
    outfile = single_bl_file.replace('.uvh5', SNR_SUFFIX)
    print(f'\tWriting results to {outfile}')
    hd.write_uvh5(outfile, clobber=True)

# *Figure 1: Delay-Filtered pI SNR in Fringe Rate Space*

In [None]:
for fig in delay_fr_figs:
    display(fig)

# *Figure 2: Delay+FR Filtered pI SNR Waterfall*

In [None]:
for fig in waterfall_figs:
    display(fig)

# *Figure 3: Delay+FR Filtered pI SNR Histogram*

In [None]:
for fig in histogram_figs:
    display(fig)

## Metadata

In [None]:
for repo in ['hera_cal', 'hera_qm', 'hera_filters', 'hera_notebook_templates', 'hera_pspec', 'pyuvdata', 'numpy']:
    exec(f'from {repo} import __version__')
    print(f'{repo}: {__version__}')

In [None]:
print(f'Finished execution in {(time.time() - tstart) / 60:.2f} minutes.')