In [1]:
import pylab
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
import logging
import pytest
from astropy.time import Time
from astropy import units as u

%load_ext autoreload
%autoreload 2
from craft.corruvfits import CorrUvFitsFile
from craco.uvfitsfile_sink import *
from craft.craco import ant2bl, baseline_iter
from craco import uvfitsfile_sink
import craco.card_averager
from craft import uvfits


import craco


In [2]:

class TestVisblock:
     def __init__(self, d, mjdmid, uvw, valid_ants_0based):
        self.data = d
        self.fid_start = 1234
        self.fid_mid = self.fid_start + np.uint64(NSAMP_PER_FRAME//2)
        self.mjd_mid = mjdmid
        self.uvw = uvw
        self.source_index = 0
        nant = len(valid_ants_0based)
        self.antflags = np.zeros(nant, dtype=bool)
        af = self.antflags
        self.baseline_flags = np.array([af[blinfo.ia1] | af[blinfo.ia2] for blinfo in baseline_iter(valid_ants_0based)])

fcent = 850e6
foff = 1e6
npol = 1
tstart = Time.now().mjd
source_list = [{'name':'test', 'ra':123, 'dec':-33}]
antennas = []
extra_header= {}
nbeam = 36
nant = 24
valid_ants_0based = np.arange(nant)
nc_per_card = 24
nt = 32
npol = 1
vis_fscrunch = 6
vis_tscrunch = 1
real_dtype = np.float32
cplx_dtype = np.float32
nrx = 72
nchan = nc_per_card*nrx // vis_fscrunch
vis_nt = nt // vis_tscrunch
nbl = nant*(nant-1)//2
dt = craco.card_averager.get_averaged_dtype(nbeam, nant, nc_per_card, nt, npol, vis_fscrunch, vis_tscrunch, real_dtype, cplx_dtype)
valid_ants_0based = np.arange(nant)
uvw = np.random.randn(nbl*3).reshape(nbl,3)
tstart = Time(60467.28828320785, format='mjd', scale='utc')
fits_sourceidx = 1
inttime = 13.4e-3
mjds = np.array([(tstart + inttime*u.second*i).utc.value for i in range(vis_nt)])
sampleidxs = np.arange(vis_nt)
mjdiffs = sampleidxs*inttime/86400
baseline_info = list(baseline_iter(valid_ants_0based))
blids = [bl.blid for bl in baseline_iter(valid_ants_0based)]

In [3]:
fast_uvout = CorrUvFitsFile('fast.uvfits',
                            fcent,
                            foff,
                            nchan,
                            npol,
                            tstart.value,
                            source_list,
                            antennas,
                            extra_header=extra_header,
                            instrume='CRACO')



In [4]:
prepper = DataPrepper(fast_uvout, baseline_info, vis_nt, fits_sourceidx, inttime)

In [5]:
input = np.zeros(nrx, dtype=dt)
input['vis'][:] = np.random.randn(*input['vis'].shape)
vis_block = TestVisblock(input['vis'], tstart, uvw, valid_ants_0based)

In [6]:
_= prepper.write(vis_block)
 

In [7]:
%timeit prepper.write(vis_block)

177 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [8]:
%timeit prepper._set_uvw_baselines(vis_block)

375 µs ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
prepper.dout.shape

(32, 276)

In [10]:
from astropy.time import Time
t = Time.now()

In [11]:
t.utc.iso

'2024-12-10 05:09:59.104'

In [12]:
prepper.dout.shape

(32, 276)

In [14]:
uvfitsfile_sink.prep_data_fast_numba(prepper.dout, vis_block.data, prepper.uvw_baselines, prepper.iblk, prepper.inttime_days)

In [15]:
%timeit uvfitsfile_sink.prep_data_fast_numba(prepper.dout, vis_block.data, prepper.uvw_baselines, prepper.iblk, prepper.inttime_days)

161 ms ± 240 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [84]:
@njit # damn - njit doesn't support big endian on intel.
def prep_data_fast_numba2(dout, vis_data, uvw_baselines, iblk, inttim):
    '''
    dout is the dtype = np.dtype([('UU', dt), ('VV', dt), ('WW', dt), \
            ('DATE', dt), ('BASELINE', dt), \
            ('FREQSEL', dt), ('SOURCE', dt), ('INTTIM', dt), \
            ('DATA', dt, (1, 1, 1, nchan, npol, self.ncomplex))])

    it has shape (vis_nt_out, nbl)

    ncomplex = 2 if no flags, and 3 if there are flags

    vis_data is the input and is [nrx, nbl, vis_nc, vis_nt ]
    '''
    nrx, nbl, vis_nc, vis_nt = vis_data.shape[:4]
    for ibl in range(nbl):
        for it in range(vis_nt):
            for irx in range(nrx):
                for ic in range(vis_nc):
                    cout = ic + vis_nc*irx            
                    isamp = it + iblk
                    mjddiff = isamp*inttim                
                    d = dout[it, ibl]
                    d['UU'] = uvw_baselines[ibl,0]
                    d['VV'] = uvw_baselines[ibl,1]
                    d['WW'] = uvw_baselines[ibl,2]
                    d['DATE'] = mjddiff
                    data = d['DATA']
                    vis0 = vis_data[irx,ibl,ic,it,0]
                    vis1 = vis_data[irx,ibl,ic,it,1]

                    if vis0 == 0 and vis1 == 0:
                        weight = 0
                    else:
                        weight = 1

                    
                    #data[0,0,0,cout,0,0] = vis0
                    #data[0,0,0,cout,0,1] = vis1
                    #data[0,0,0,cout,0,2] = weight
                    this_dout = data[0,0,0,cout,0,:]
                    this_dout[0] = vis0
                    this_dout[1] = vis1
                    this_dout[2] = weight
                    

In [85]:
%timeit prep_data_fast_numba2(prepper.dout, vis_block.data, prepper.uvw_baselines, prepper.iblk, prepper.inttime_days)

9.44 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [79]:
vis_block.data.size/0.0278/1e9

0.18299395683453237

In [76]:
# damn - njit doesn't support big endian on intel.

In [69]:
v = prepper.dout.view(np.float32)
%timeit v = prepper.dout.view(np.float32)


545 ns ± 1.33 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [70]:
%timeit v.byteswap(inplace=True) # FITS is big endian. Damn.


3.82 ms ± 9.36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [72]:
prepper.dout.flags

  C_CONTIGUOUS : True
  F_CONTIGUOUS : False
  OWNDATA : True
  WRITEABLE : True
  ALIGNED : True
  WRITEBACKIFCOPY : False

In [136]:
@njit(cache=True) # damn - njit doesn't support big endian on intel.
def prep_data_fast_numba_tscrunch2(dout, vis_data, uvw_baselines, iblk, inttim):
    '''
    dout is the dtype = np.dtype([('UU', dt), ('VV', dt), ('WW', dt), \
            ('DATE', dt), ('BASELINE', dt), \
            ('FREQSEL', dt), ('SOURCE', dt), ('INTTIM', dt), \
            ('DATA', dt, (1, 1, 1, nchan, npol, self.ncomplex))])

    it has shape (vis_nt_out, nbl)

    ncomplex = 2 if no flags, and 3 if there are flags

    vis_data is the input and is [nrx, nbl, vis_nc, vis_nt ]

    inttim is the integration time in days per input sample
    output data can be an integer fraction less than the input, in which case it does tscrunching
    '''
    nrx, nbl, vis_nc, vis_nt = vis_data.shape[:4]

    vis_nt_out, nbl_out = dout.shape
    assert vis_nt % vis_nt_out == 0
    assert nbl_out == nbl

    tscrunch = vis_nt // vis_nt_out
    scale = np.float32(1./tscrunch)


    for ibl in range(nbl):
        for it in range(vis_nt_out):
            for irx in range(nrx):
                for ic in range(vis_nc):
                    cout = ic + vis_nc*irx
                    isamp = (it + iblk)                    
                    mjddiff = isamp*inttim
                    d = dout[it, ibl]
                    d['UU']= uvw_baselines[ibl,0]
                    d['VV']= uvw_baselines[ibl,1]
                    d['WW'] = uvw_baselines[ibl,2]
                    d['DATE'] = mjddiff
                    data = d['DATA']
                    vs0 =  np.float32(0)
                    vs1 = np.float32(0)
                    vstart = vis_data[irx, ibl, ic, it*tscrunch:(it+1)*tscrunch, :]
                    for ix in range(tscrunch):
                        vs0 += vstart[ix, 0]
                        vs1 += vstart[ix, 1]

                    vs0 *= scale
                    vs1 *= scale
                    
                    if vs0 == 0 and vs1 == 1:
                        weight = 0
                    else:
                        weight = 1

                    this_dout = data[0,0,0,cout,0,:]
                    this_dout[0] = vs0
                    this_dout[1] = vs1
                    this_dout[2] = weight

In [135]:
%timeit prep_data_fast_numba_tscrunch2(prepper.dout, vis_block.data, prepper.uvw_baselines, prepper.iblk, prepper.inttime_days)

17 ms ± 58.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
