# Estimate Pspec Window Function for All Baselines

**by Kai-Feng Chen and Steven Murray**, last updated July 06, 2025

In [None]:
import time
tstart = time.time()

In [None]:
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import hdf5plugin  # REQUIRED to have the compression plugins available
import numpy as np
import matplotlib.pyplot as plt
import hera_pspec as hp
from IPython.display import HTML
from pathlib import Path
from hera_notebook_templates.utils import parse_band_str
import pickle
from hera_pspec.uvwindow import FTBeam
import datetime


%matplotlib inline

## Parse settings

In [None]:
# Band settings
TAVG_PSPEC_FILE: str = "/lustre/aoc/projects/hera/h6c-analysis/IDR2/lstbin-outputs/redavg-smoothcal-inpaint-500ns-lstcal/inpaint/single_baseline_files/baselines_merged.tavg.pspec.h5"
BEAM_FILE: str = "/lustre/aoc/projects/hera/H4C/beams/NF_HERA_Vivaldi_efield_beam_healpix_pstokes.fits"
FT_BEAM_FILE: str = "/lustre/aoc/projects/hera/agorce/FTBeams_Vivaldi/FT_beam_HERA_Vivaldi_pI.hdf5"

# Output Files
RESULTS_FOLDER: str | None = None

WINDOW_FUNCTION_ALGO: str = "exact"  # "old-style" or "exact" or "with-inpainting"

BANDS_TO_USE: str = "1,2,3,5,6,9,10,13"
BANDS_TO_USE = [int(band) for band in BANDS_TO_USE.split(",")] # 1 indexed

In [None]:
SINGLE_BL_PSPEC_FILE = Path(SINGLE_BL_PSPEC_FILE)
FT_BEAM_FILE = Path(FT_BEAM_FILE)

if RESULTS_FOLDER is None:
    RESULTS_FOLDER = TAVG_PSPEC_FILE.parent
else:
    RESULTS_FOLDER = Path(RESULTS_FOLDER)

In [None]:
psc = hp.container.PSpecContainer(SINGLE_BL_PSPEC_FILE, mode='r', keep_open=False)
uvp = psc.get_pspec('stokespol', 'time_and_interleave_averaged')

## Old-style analysis

In [None]:
from scipy import special, integrate
import uvtools.dspec as dspec
from hera_pspec import uvwindow, conversions
from astropy import constants
import copy

class SimplePspec:
    def __init__(self, freqs, beamfunc, beamtype="pspec_beam",
                 cosmo=None, little_h=True,
                 vis_unit='mK', taper='blackman-harris'):
        self.spw_Nfreqs = freqs.size
        self.spw_Ndlys  = freqs.size
        self.freqs = freqs  
        df = np.median(np.diff(self.freqs))
        self.delays = np.fft.fftshift(np.fft.fftfreq(self.spw_Ndlys, d=df)) #in sec
        
        self.taper = taper
        if self.taper == 'none':
            self.taper_func = np.ones(self.spw_Nfreqs)
        else:
            self.taper_func = dspec.gen_window(self.taper, self.spw_Nfreqs)
            
        if type(beamfunc) == hp.pspecbeam.PSpecBeamUV:
            _beam, beam_omega, N = \
            beamfunc.beam_normalized_response(pol='pI', freq=self.freqs)
        elif type(beamfunc) == tuple:
            assert len(beamfunc) == 3, "Invalid beam function"
            _beam, beam_omega, N = beamfunc
        else:
            raise ValueError("Invalid beam function")
        
        self.omega_p = beam_omega.real
        self.omega_pp = np.sum(_beam**2, axis=-1).real*np.pi/(3.*N*N)
        _beam = _beam/self.omega_p[:, None]
        
        self.qnorm_exact = np.pi/(3.*N*N) * np.dot(_beam, _beam.T)
        self.qnorm_exact *= np.median(np.diff(self.delays))
        
 
        if cosmo is not None:
            self.cosmo = cosmo
        else:
            self.cosmo = conversions.Cosmo_Conversions()
        df = np.median(np.diff(self.freqs))
        integration_freqs = np.linspace(self.freqs.min(),
                                        self.freqs.min() + df*self.spw_Nfreqs,
                                        5000, endpoint=True, dtype=float)
        integration_freqs_MHz = integration_freqs / 1e6

        # Get redshifts and cosmological functions
        redshifts = self.cosmo.f2z(integration_freqs).flatten()
        X2Y = np.array([self.cosmo.X2Y(z, little_h=little_h) for z in redshifts])
        self.scalar = integrate.trapezoid(X2Y, x=integration_freqs)/(np.abs(integration_freqs[-1]-integration_freqs[0]))
        
        if vis_unit == 'Jy':
            c =  constants.c.cgs.value
            k_b =  constants.k_B.cgs.value
            self.Jy2mK = 1e3 * 1e-23 * c**2 / (2 * k_b * self.freqs**2 * self.omega_p)
        
    def get_R(self):
        return np.diag(self.taper_func)
        
    def get_Q_alt(self, mode):
        if self.spw_Ndlys % 2 == 0:
            start_idx = -self.spw_Ndlys/2
        else:
            start_idx = -(self.spw_Ndlys - 1)/2
        m = (start_idx + mode) * np.arange(self.spw_Nfreqs)
        m = np.exp(-2j * np.pi * m / self.spw_Ndlys)

        Q_alt = np.einsum('i,j', m.conj(), m) # dot it with its conjugate
        return Q_alt 
        
    def get_GH(self, operator=None):
        G = np.zeros((self.spw_Ndlys, self.spw_Ndlys), dtype=complex)
        H = np.zeros((self.spw_Ndlys, self.spw_Ndlys), dtype=complex)
        R = self.get_R()
        
        sinc_matrix = np.zeros((self.spw_Nfreqs, self.spw_Nfreqs))
        for i in range(self.spw_Nfreqs):
            for j in range(self.spw_Nfreqs):
                sinc_matrix[i,j] = float(i - j)
        sinc_matrix = np.sinc(sinc_matrix / float(self.spw_Nfreqs))
        
        iR1Q1, iR2Q2, iR1Q1_win, iR2Q2_win = {}, {}, {}, {}
        for ch in range(self.spw_Ndlys):
            Q_alt = self.get_Q_alt(ch)
            if operator is not None:
                iR1Q1[ch] = np.conj(operator).T@np.conj(R).T@(Q_alt) # O R_1 Q_alt
                iR2Q2[ch] = R@operator@(Q_alt* self.qnorm_exact) # R_2 OQ_true
                iR1Q1_win[ch] = iR1Q1[ch] #np.conj(operator).T@np.conj(R).T@(Q_alt) # O R_1 Q_alt
                iR2Q2_win[ch] = R@operator@(Q_alt* self.qnorm_exact * sinc_matrix) # R_2 O Q_true
            else:
                iR1Q1[ch] = np.conj(R).T@(Q_alt) # R_1 Q_alt
                iR2Q2[ch] = R@(Q_alt * self.qnorm_exact) # R_2 Q_true                
                iR1Q1_win[ch] = iR1Q1[ch] # R_1 Q_alt
                iR2Q2_win[ch] = R@(Q_alt* self.qnorm_exact * sinc_matrix) # R_2 Q_true
            
        for i in range(self.spw_Ndlys):
            for j in range(self.spw_Ndlys):
                # tr(R_2 Q_i R_1 Q_j)
                G[i,j] = np.einsum('ab,ba', iR1Q1[i], iR2Q2[j])  
                H[i,j] = np.einsum('ab,ba', iR1Q1_win[i], iR2Q2_win[j])
        if np.count_nonzero(G) == 0:
            G = np.eye(self.spw_Ndlys)
        if np.count_nonzero(H) == 0:
            H = np.eye(self.spw_Ndlys)            
        self.G = G/2.
        self.H = H/2.
        return G/2., H/2. 

    def get_MW(self, GH=None, operator=None):
        if GH is None:
            if hasattr(self, 'G'):
                G, H = self.G, self.H
            else:
                G, H = self.get_GH(operator)
        else:
            G, H = GH
        M = np.diag(1. / np.sum(G, axis=1)) 
        W_norm = np.diag(1. / np.sum(H, axis=1))
        W = np.dot(W_norm, H)
        return M, W
    

In [None]:
def get_window_functions_old_style():
    # load power spectra
    uvp.get_window_function
    window_function_array = {}
    for spw in uvp.spw_array:
        freqs = uvp.freq_array[uvp.spw_freq_array == spw]
        ps_obj = SimplePspec(
            freqs=freqs,
            beamfunc=pre_calc_beam[spw], beamtype="pspec_beam",
            cosmo=None, little_h=True,
            vis_unit='Jy', taper='blackman-harris'
        ) 
        assert np.all(np.isclose(uvp.get_dlys(spw), ps_obj.delays))
        _M, _W = ps_obj.get_MW()
        window_function_array[spw] = _W.real[None, :, :, None]
    return window_function_array


First, construct a beam object:

In [None]:
if WINDOW_FUNCTION_ALGO == 'old-style':
    beam_hera = hp.PSpecBeamUV(BEAM_FILE)
    pre_calc_beam = []
    for spw in uvp.spw_array:
        beam_tuple = beam_hera.beam_normalized_response(pol='pI', freq=uvp.freq_array[uvp.spw_freq_array == spw])
        pre_calc_beam.append(beam_tuple)

In [None]:
%%time
if WINDOW_FUNCTION_ALGO == "old-style":
    window_function_array = get_window_functions_old_style()

    with open(RESULTS_FOLDER / SINGLE_BL_PSPEC_FILE.with_suffix(".window.pkl").name, 'wb') as fl:
        pickle.dump(window_function_array, fl)

In [None]:
if WINDOW_FUNCTION_ALGO == "old-style":
    for spw in uvp.spw_array:
        _win = window_function_array[spw]
        _dly = uvp.get_dlys(spw)
        for i in range(_dly.size):
            if i%3 == 0:
                plt.plot(_dly*1e9, _win[0, i, :, 0], ls='-', zorder=20)
        plt.xlim(0, 3500)
        plt.ylim(0, 0.5)
        plt.xlabel("Delays [ns]", fontsize=24)
        plt.ylabel("Window Functions", fontsize=24)
        plt.show()

## Exact Window Functions

First, read in the power spectrum data. Here we need to read all baselines:

In [None]:
if WINDOW_FUNCTION_ALGO == 'exact':
    ftbeam = FTBeam.from_file(FT_BEAM_PATH / "FT_beam_HERA_Vivaldi_pI.hdf5")

In [None]:
%%time

if WINDOW_FUNCTION_ALGO == 'exact':
    for band in BANDS_TO_USE:
        # Down-select one SPW
        spw2select = band - 1
        if (RESULTS_FOLDER / f"exact_window-spw{spw2select:02d}-wf.npy").exists():
            print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}: Band-{band} exists, skipping....")
            continue
        else:
            thisuvp = uvp.select(polpairs=["pI"], spws=[spw2select], inplace=False)
    
            print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}: Band-{band} starts...")
            
            kperpbins_spws, kparabins_spws, wf_full_spws = thisuvp.get_exact_window_functions(ftbeam=ftbeam, verbose=True, inplace=False)
            
            np.save(RESULTS_FOLDER / f"exact_window-spw{sp2select:02d}-wf.npy", wf_full_spws[0])
    
            np.save(RESULTS_FOLDER / f"exact_window-spw{sp2select:02d}-kperp.npy", kperpbins_spws[0])
            np.save(RESULTS_FOLDER / f"exact_window-spw{sp2select:02d}-kpara.npy", kparabins_spws[0])
            print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}: Band-{band} done...")