In [None]:
import numpy as np
import matplotlib.pyplot as plt
from hera_cal import frf
import glob
import os
from copy import deepcopy
from hera_cal import redcal
from IPython.display import display, HTML
from hera_cal.io import HERAData
from matplotlib.colors import LogNorm
from hera_cal import utils
%config Completer.use_jedi = False
from scipy.interpolate import interp1d

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
# Use environment variables to figure out path to data
JD = os.environ['JULIANDATE']
data_path = os.environ['DATA_PATH']
label = os.environ['LABEL']
nreds = int(os.environ['NREDS'])
max_bls_per_redgrp = int(os.environ['MAX_BLS_PER_REDGRP'])
nskip = int(os.environ['NSKIP'])
spws = os.environ['SPWS'].split(',')
ext = os.environ['EXT']
print(spws)
print([(spw.split('~')[0], spw.split('~')[1]) for spw in spws])
spws = [(int(spw.split('~')[0]), int(spw.split('~')[1])) for spw in spws]

print(f'JD = "{JD}"')
print(f'data_path = "{data_path}"')
print(f'label = "{label}"')
print(f'nreds = "{nreds}"')
print(f'max_bls_per_redgrp = "{max_bls_per_redgrp}"')
print(f'nskip = "{nskip}"')
print(f'spws = "{spws}"')


In [None]:
from astropy.time import Time
utc = Time(JD, format='jd').datetime
print(f'Date: {utc.month}-{utc.day}-{utc.year}')

In [None]:
print('Looking for sum xtalk-filtered data in', data_path, 'on JD', JD)
xtalk_filtered_sums = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.sum.{label}.{ext}.xtalk_filtered.tavg.uvh5')))
print('Found {} files.'.format(len(xtalk_filtered_sums)))
print('Looking for diff xtalk-filtered data in', data_path, 'on JD', JD)
xtalk_filtered_diffs = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.diff.{label}.{ext}.xtalk_filtered.tavg.uvh5')))
print('Found {} files.'.format(len(xtalk_filtered_diffs)))
print('Looking for sum time-inpainted data in', data_path, 'on JD', JD)
time_inpainted_sums = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.sum.{label}.{ext}.time_inpainted.tavg.uvh5')))
print('Found {} files.'.format(len(time_inpainted_sums)))
print('Looking for diff time-inpainted data in', data_path, 'on JD', JD)
time_inpainted_diffs = sorted(glob.glob(os.path.join(data_path, f'zen.{JD}.?????.diff.{label}.{ext}.time_inpainted.tavg.uvh5')))
print('Found {} files.'.format(len(time_inpainted_diffs)))

Examine Waterfalls and FR-plots for several redundant groups.

In [None]:
hd = HERAData(xtalk_filtered_sums[0])
antpairs_data = hd.get_antpairs()
reds = redcal.get_pos_reds(hd.antpos)
#reds = redcal.filter_reds(reds, antpos=hd.antpos)
reds = [[bl for bl in grp if bl in antpairs_data or bl[::-1] in antpairs_data] for grp in reds]
reds = [grp for grp in reds if len(grp)>0]
reds = sorted(reds, key=len, reverse=True)

In [None]:
frf_xtalk = frf.FRFilter(xtalk_filtered_sums)
frf_xtalk.read(axis='blt')
# generate redundantly averaged data
hd_xtalkr = utils.red_average(frf_xtalk.hd, inplace=False, reds=reds, red_bl_keys=[grp[0] for grp in reds])
frf_xtalkr = frf.FRFilter(hd_xtalkr)
for spw_num, spw in enumerate(spws):
    frf_xtalkr.fft_data(window='bh', ax='both', assign=f'dfft2_spw_{spw_num}', 
                          verbose=False, overwrite=True, edgecut_low=(0, spw[0]), edgecut_hi=(0, frf_xtalkr.Nfreqs-spw[1]))
    frf_xtalkr.fft_data(window='bh', ax='freq', assign=f'dfft_spw_{spw_num}',
                          verbose=False, overwrite=True, edgecut_low=spw[0], edgecut_hi=frf_xtalkr.Nfreqs-spw[1])


In [None]:
if len(time_inpainted_sums) > 0:
    frf_inpaint = frf.FRFilter(time_inpainted_sums)
    frf_inpaint.read(axis='blt')
    # generate redundantly averaged data
    hd_inpaintr = utils.red_average(frf_inpaint.hd, inplace=False, reds=reds, red_bl_keys=[grp[0] for grp in reds])
    frf_inpaintr = frf.FRFilter(hd_inpaintr)
    for spw_num, spw in enumerate(spws):
        frf_inpaintr.fft_data(window='bh', ax='both', assign=f'dfft2_spw_{spw_num}', 
                              verbose=False, overwrite=True, edgecut_low=(0, spw[0]), edgecut_hi=(0, frf_inpaintr.Nfreqs-spw[1]))
        frf_inpaintr.fft_data(window='bh', ax='freq', assign=f'dfft_spw_{spw_num}',
                              verbose=False, overwrite=True, edgecut_low=spw[0], edgecut_hi=frf_inpaintr.Nfreqs-spw[1])


In [None]:
def delay_plots(frft, frft_red, spw_num):
    spw = spws[spw_num]
    frft.fft_data(window='bh', ax='both', assign=f'dfft2_{spw_num}', keys=[reds[0][0] + ('nn',)], overwrite=True,
                 edgecut_low=(0, spw[0]), edgecut_hi=(0, frf_xtalkr.Nfreqs-spw[1]))
    df = np.mean(np.diff(frft.freqs))
    dt = np.mean(np.diff(frft.times * 3600 * 24))
    cmax_frate = 10 ** np.round(np.log10(np.abs(getattr(frft_red, f'dfft2_spw_{spw_num}')[reds[0][0] + ('nn',)] * dt * df).max()))
    cmin_frate = cmax_frate / 1e5
    cmax_delay = 10 ** np.round(np.log10(np.abs(getattr(frft_red, f'dfft_spw_{spw_num}')[reds[0][0] + ('nn',)] * df).max()))
    cmin_delay = cmax_delay / 1e5
    for gn, grp in enumerate(reds[::nskip][:nreds]):
        ext_frate = [frft.delays.min(), frft.delays.max(), frft.frates.max(), frft.frates.min()]
        ext_tdelay = [frft.delays.min(), frft.delays.max(), 
                frft.times.max(), frft.times.min()]
        lst_func = interp1d(frft.times, frft.lsts * 12 / np.pi)
        fig, axarr = plt.subplots(2, 2 * min(len(grp) + 1, max_bls_per_redgrp + 1))
        nbls = (len(axarr[0]) - 1) // 2
        fig.set_size_inches(32, 8)
        cbax1 = fig.add_axes([0.105, 0.35, 0.005, 0.3])
        cbax2 = fig.add_axes([0.915, 0.35, 0.005, 0.3])
        if grp[0] in frft.bllens:
            hrzn_dly = frft.bllens[grp[0]] * 1e9
            blvec = frft.blvecs[grp[0]]
        else:
            hrzn_dly = frft.bllens[grp[0][::-1]] * 1e9
            blvec = -frft.blvecs[grp[0][::-1]]

        # get vmin and vmax from grp[0][0] min / max rounded up / down
        # generate fringe-rate plots.
        for pn, pol in enumerate(['ee', 'nn']):
            for blnum in range(nbls + 1):
                plt.sca(axarr[pn][blnum])
                if blnum < nbls:
                    bl = grp[blnum]
                    blk = bl + (pol,)
                    frft.fft_data(window='bh', ax='both', assign=f'dfft2_spw_{spw_num}', keys=[blk], overwrite=True,
                                  edgecut_low=[0, spw[0]], edgecut_hi=[0, frf_xtalkr.Nfreqs-spw[1]])
                    cm = plt.imshow(np.abs(getattr(frft, f'dfft2_spw_{spw_num}')[blk] * df * dt), norm=LogNorm(cmin_frate, cmax_frate), extent=ext_frate, aspect='auto', interpolation='nearest', cmap='inferno')
                    plt.title(f'{blk} \n{frft.freqs[spw[0]] / 1e6:.1f} - {frft.freqs[spw[1] - 1] / 1e6:.1f} ')
                else:
                    blk = grp[0] + (pol,)
                    d = getattr(frft_red, f'dfft2_spw_{spw_num}')[blk] * df * dt
                    conj = blk not in list(frft_red.data.keys())
                    if conj:
                        d = np.conj(d[::-1, ::-1])
                    cm = plt.imshow(np.abs(d), norm=LogNorm(cmin_frate, cmax_frate), extent=ext_frate, aspect='auto', interpolation='nearest', cmap='inferno')
                    plt.title(f'{blvec[0]:.1f} m, {blvec[1]:.1f} m, {pol}\n{frft.freqs[spw[0]] / 1e6:.1f} - {frft.freqs[spw[1]-1] / 1e6:.1f} ')
                plt.xlim(-1000, 1000)
                plt.ylim(-1.5, 1.5)
                plt.axvline(hrzn_dly, ls='--', color='w', lw=1)
                plt.axvline(-hrzn_dly, ls='--', color='w', lw=1)
                if pn == 0:
                    cbar = fig.colorbar(cm, orientation='vertical', cax=cbax1)
                    cbax1.yaxis.set_ticks_position('left')
                    plt.gca().set_xticklabels(['' for tick in plt.gca().get_xticklabels()])
                    cbar.ax.set_ylabel('Abs($\\widetilde{V}_{\\tau, f_r}$) [Jy]', rotation=90)
                else:
                    plt.gca().set_xlabel('$\\tau$ [ns]')
                if blnum > 0:
                    plt.gca().set_yticklabels(['' for tick in plt.gca().get_yticklabels()])
                else:
                    plt.gca().set_ylabel('$f_r$ [mHz]')

        # generate delay-waterfall plots.
        for pn, pol in enumerate(['ee', 'nn']):
            for blnum in range(nbls + 1):
                plt.sca(axarr[pn][blnum + nbls + 1])
                if blnum < nbls:
                    bl = grp[blnum]
                    blk = bl + (pol,)
                    frft.fft_data(window='bh', ax='freq', assign=f'dfft_spw_{spw_num}', keys=[blk], overwrite=True,
                                  edgecut_low=spw[0], edgecut_hi=frf_xtalkr.Nfreqs-spw[1])
                    cm = plt.imshow(np.abs(getattr(frft, f'dfft_spw_{spw_num}')[blk] * df), norm=LogNorm(cmin_delay, cmax_delay), extent=ext_tdelay, aspect='auto', interpolation='nearest', cmap='inferno')
                    plt.title(f'{blk}')
                else:
                    blk = grp[0] + (pol,)
                    d = getattr(frft_red, f'dfft_spw_{spw_num}')[blk] * df
                    conj = blk not in list(frft_red.data.keys())
                    if conj:
                        d = np.conj(d[:, ::-1])
                    cm = plt.imshow(np.abs(d), norm=LogNorm(cmin_delay, cmax_delay), extent=ext_tdelay, aspect='auto', interpolation='nearest', cmap='inferno')
                    plt.title(f'{blvec[0]:.1f} m, {blvec[1]:.1f} m, {pol}')
                plt.xlim(-1000, 1000)
                plt.axvline(hrzn_dly, ls='--', color='w', lw=1)
                plt.axvline(-hrzn_dly, ls='--', color='w', lw=1)
                plt.gca().set_yticks([t for t in plt.gca().get_yticks() if t >= ext_tdelay[-1] and t <= ext_tdelay[-2]])
                if pn == 0:
                    plt.gca().set_xticklabels(['' for tick in plt.gca().get_xticklabels()])
                else:
                    plt.gca().set_xlabel('$\\tau$ [ns]')
                if blnum < nbls:
                    plt.gca().set_yticklabels(['' for tick in plt.gca().get_yticklabels()])
                else:
                    plt.gca().set_ylabel('LST [Hrs]')
                    plt.gca().set_yticklabels([f'{lst_func(t):.1f}' for t in plt.gca().get_yticks()])
                    cbar = fig.colorbar(cm, orientation='vertical', cax=cbax2)
                    cbar.ax.set_ylabel('Abs($\\widetilde{V}$) [Jy Hz]', rotation=90)

                plt.gca().yaxis.tick_right()
                plt.gca().yaxis.set_label_position("right")



        plt.show()


In [None]:
def freq_plots(frft, frft_red, spw_num):
    cmax_freq = 10 ** np.round(np.log10(np.abs(frft_red.data[reds[0][0] + ('nn',)]).max()))
    cmin_freq = cmax_freq / 1e5
    spw_inds = np.arange(spws[spw_num][0], spws[spw_num][1]).astype(int)
    for gn, grp in enumerate(reds[::nskip][:nreds]):
        ext_freq = [frft.freqs[spw_inds].min() / 1e6, frft.freqs[spw_inds].max() / 1e6, 
                    frft.times.max(), frft.times.min()]
        lst_func = interp1d(frft.times, frft.lsts * 12 / np.pi)
        fig, axarr = plt.subplots(2, 2 * min(len(grp) + 1, max_bls_per_redgrp + 1))
        cbax1 = fig.add_axes([0.105, 0.35, 0.005, 0.3])
        cbax2 = fig.add_axes([0.915, 0.35, 0.005, 0.3])
        nbls = (len(axarr[0]) - 1) // 2
        fig.set_size_inches(32, 8)
        if grp[0] in frft.bllens:
            hrzn_dly = frft.bllens[grp[0]] * 1e9
            blvec = frft.blvecs[grp[0]]
        else:
            hrzn_dly = frft.bllens[grp[0][::-1]] * 1e9
            blvec = -frft.blvecs[grp[0][::-1]]

        # generate fringe-rate plots.
        for pn, pol in enumerate(['ee', 'nn']):
            for blnum in range(nbls + 1):
                plt.sca(axarr[pn][blnum])
                if blnum < nbls:
                    bl = grp[blnum]
                    blk = bl + (pol,)
                    cm = plt.imshow(np.abs(frft.data[blk][:, spw_inds]) / ~frft.flags[blk][:, spw_inds], norm=LogNorm(cmin_freq, cmax_freq), extent=ext_freq, aspect='auto', interpolation='nearest', cmap='inferno')
                    plt.title(f'{blk}')
                else:
                    blk = grp[0] + (pol,)
                    d = frft_red.data[blk][:, spw_inds]
                    conj = blk not in list(frft_red.data.keys())
                    if conj:
                        d = np.conj(d)
                    cm = plt.imshow(np.abs(d), norm=LogNorm(cmin_freq, cmax_freq), extent=ext_freq, aspect='auto', interpolation='nearest', cmap='inferno')
                    plt.title(f'{blvec[0]:.1f} m, {blvec[1]:.1f} m, {pol}')
                plt.gca().set_yticks([t for t in plt.gca().get_yticks() if t >= ext_freq[-1] and t <= ext_freq[-2]])
                if pn == 0:
                    plt.gca().set_xticklabels(['' for tick in plt.gca().get_xticklabels()])
                    cbar = fig.colorbar(cm, orientation='vertical', cax=cbax1)
                    cbax1.yaxis.set_ticks_position('left')
                    cbar.ax.set_ylabel('Abs(V) [Jy]', rotation=90)
                else:
                    plt.gca().set_xlabel('$\\nu$ [MHz]')
                if blnum > 0:
                    plt.gca().set_yticklabels(['' for tick in plt.gca().get_yticklabels()])
                else:
                    plt.gca().set_ylabel('LST [Hrs]')
                    plt.gca().set_yticklabels([f'{lst_func(t):.1f}' for t in plt.gca().get_yticks()])

        # generate delay-waterfall plots.
        for pn, pol in enumerate(['ee', 'nn']):
            for blnum in range(nbls + 1):
                plt.sca(axarr[pn][blnum + nbls + 1])
                if blnum < nbls:
                    bl = grp[blnum]
                    blk = bl + (pol,)
                    cm = plt.imshow(np.angle(frft.data[blk][:, spw_inds]) /  ~frft.flags[blk][:, spw_inds], vmin=-np.pi, vmax=np.pi, extent=ext_freq, aspect='auto', interpolation='nearest', cmap='twilight')
                    plt.title(f'{blk}')
                else:
                    blk = grp[0] + (pol,)
                    d = frft_red.data[blk][:, spw_inds]
                    conj = blk not in list(frft_red.data.keys())
                    if conj:
                        d = np.conj(d)
                    cm = plt.imshow(np.angle(d) /  ~frft.flags[blk][:, spw_inds], vmin=-np.pi, vmax=np.pi, extent=ext_freq, aspect='auto', interpolation='nearest', cmap='twilight')
                    plt.title(f'{blvec[0]:.1f} m, {blvec[1]:.1f} m, {pol}')
                plt.gca().set_yticks([t for t in plt.gca().get_yticks() if t >= ext_freq[-1] and t <= ext_freq[-2]])
                if pn == 0:
                    plt.gca().set_xticklabels(['' for tick in plt.gca().get_xticklabels()])
                else:
                    plt.gca().set_xlabel('$\\nu$ [MHz]')
                if blnum < nbls:
                    plt.gca().set_yticklabels(['' for tick in plt.gca().get_yticklabels()])
                else:
                    plt.gca().set_ylabel('LST [Hrs]')
                    plt.gca().set_yticklabels([f'{lst_func(t):.1f}' for t in plt.gca().get_yticks()])
                    cbar = fig.colorbar(cm, orientation='vertical', cax=cbax2)
                    cbar.ax.set_ylabel('Arg(V) [rad]', rotation=270)

                plt.gca().yaxis.tick_right()
                plt.gca().yaxis.set_label_position("right")



        plt.show()

        

In [None]:
if len(time_inpainted_sums) > 0:
    for spw_num in range(len(spws)):
        freq_plots(frf_inpaint, frf_inpaintr, spw_num)

In [None]:
if len(time_inpainted_sums) > 0:
    for spw_num in range(len(spws)):
        delay_plots(frf_inpaint, frf_inpaintr, spw_num)


In [None]:
for spw_num in range(len(spws)):
    freq_plots(frf_xtalk, frf_xtalkr, spw_num)

In [None]:
for spw_num in range(len(spws)):
    delay_plots(frf_xtalk, frf_xtalkr, spw_num)