In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pyuvdata import UVData

In [2]:
import copy

In [19]:
class uvdata_pol_calibrator():
    
    def __init__(self, model_data=None, real_data=None):
        """
        """
        self.model_data, self.real_data = model_data, real_data
        
    def data_slice(self, use_all_times=False, md_time_range=[], rd_time_range=[], use_all_frequencies=False, md_freq_range=[], rd_freq_range=[], 
                   use_all_ants=False, md_ants=[], rd_ants=[]):
        """
        Get a slice of the data.
        
        Parameters
        ----------

       


        """
        if use_all_ants:
            ants = np.unique(self.model_data.ant_1_array)
        else:
            assert isinstance(md_ants, list) and isinstance(rd_ants, list), "ants must be a list." 
            md_ants, rd_ants= list(np.sort(md_ants)), list(np.sort(rd_ants))
            assert md_ants == rd_ants
            ants = md_ants
        self.ants2cal = ants
        
        if use_all_times:
            md_time_range = [np.unique(self.model_data.time_array)[0]-1e-2, np.unique(self.model_data.time_array)[-1]+1e-2]
            rd_time_range = [np.unique(self.real_data.time_array)[0]-1e-2, np.unique(self.real_data.time_array)[-1]+1e-2]
        else:
            assert isinstance(md_time_range, list) and isinstance(rd_time_range, list), "time_range must be a list"
            assert len(md_time_range) == len(rd_time_range), "Length of time_ranges must be the same."
            assert len(md_time_range) == 2, "Length of time_range must be 2."
        
        if use_all_frequencies:
            md_freqs = list(range(self.model_data.Nfreqs))
            rd_freqs = list(range(self.real_data.Nfreqs))
        else:
            assert isinstance(md_freq_range, list) and isinstance(rd_freq_range, list), "freq_range must be a list"
            assert md_freq_range == rd_freq_range, "freq_ranges must be the same."
            if len(md_freq_range) == 1:
                md_freqs = md_freq_range
                rd_freqs = rd_freq_range
            elif len(md_freq_range) == 2:
                md_freqs = list(range(md_freq_range[0],md_freq_range[1]+1))
                rd_freqs = list(range(rd_freq_range[0],rd_freq_range[1]+1))
            else:
                raise ValueError("Length of freq_range must be 1 or 2.")
            
        model_data_slice = self.model_data.select(antenna_nums=ants, freq_chans=md_freqs, time_range=md_time_range, inplace=False) 
        real_data_slice = self.real_data.select(antenna_nums=ants, freq_chans=rd_freqs, time_range=rd_time_range, inplace=False)             
        
        # original data_array shape: (Nblts, Nspws, Nfreqs, )
        # intermidiate data_array shape: (Nants, Nants, Nspws, Ntimes, Nfreqs, Npols)
        # final data_array shape: (Nants, Nants, :, 2, 2)
        assert model_data_slice.data_array.shape == real_data_slice.data_array.shape
        model_data_array = np.zeros((len(ants), len(ants), model_data_slice.Ntimes, model_data_slice.Nspws, model_data_slice.Nfreqs, model_data_slice.Npols)).astype(np.complex128)
        real_data_array = np.zeros((len(ants), len(ants), real_data_slice.Ntimes, real_data_slice.Nspws, real_data_slice.Nfreqs, real_data_slice.Npols)).astype(np.complex128)
        
        for (i, ant1) in enumerate(ants):
            for (j,ant2) in enumerate(ants):
                for spw in range(model_data_slice.Nspws): 
                    if ant1 <= ant2:
                        baseline_number = 2048*(ant1+1)+(ant2+1)+2**16
                        # baseline index = 2048 * (ant1+1) + (ant2+1) + 2**16
                        model_data_array[i,j,spw] = np.copy(model_data_slice.get_data(baseline_number))
                        # uvdata.get_data() returns an array with a shape (Ntimes, Nfreqs, Npols) 
                        model_data_array[i,j,spw, :,:,1],  model_data_array[i,j,spw, :,:,2],  model_data_array[i,j,spw, :,:,3] = model_data_array[i,j,spw, :,:,2],  model_data_array[i,j,spw, :,:,3],  model_data_array[i,j,spw, :,:,1]
                        """
                        data_array orginally stores the polarization information as a 1d array [-5, -6, -7, -8], corresponding to [XX, YY, XY, YX].
                        Here we first modify it into [-5,-7, -8,-6], 
                        then we will modify it into a 2d array[[-5,-7],[-8,-6]], corresponding to [[XX, XY], [YX, YY]].
                        """
                        real_data_array[i,j,spw] = np.copy(real_data_slice.get_data(baseline_number))
                        real_data_array[i,j,spw, :,:,1],  real_data_array[i,j,spw, :,:,2],  real_data_array[i,j,spw, :,:,3] = real_data_array[i,j,spw, :,:,2], real_data_array[i,j,spw, :,:,3], real_data_array[i,j,spw, :,:,1]

                    if ant1 > ant2:
                        baseline_number = 2048*(ant2+1)+(ant1+1)+2**16
                        model_data_array[i,j,spw] = np.conj(np.copy(model_data_slice.get_data(baseline_number)))
                        model_data_array[i,j,spw, :,:,1],  model_data_array[i,j,spw, :,:,2],  model_data_array[i,j,spw, :,:,3] = model_data_array[i,j,spw, :,:,3], model_data_array[i,j,spw, :,:,2], model_data_array[i,j,spw, :,:,1]
                        """
                        Since V_{ji} = V_{ij}^H, we should take conjugate values here and then reorder the pols as
                        [-5,-8, -7,-6], which becomes [[-5,-8],[-7,-6]] after converted into a 2d array.
                        """
                        real_data_array[i,j,spw] = np.conj(np.copy(real_data_slice.get_data(baseline_number)))
                        real_data_array[i,j,spw, :,:,1], real_data_array[i,j,spw, :,:,2], real_data_array[i,j,spw, :,:,3] = real_data_array[i,j,spw, :,:,3], real_data_array[i,j,spw, :,:,2], real_data_array[i,j,spw, :,:,1]


        data_shape = model_data_array.shape
        model_data_array = model_data_array.reshape((data_shape[0], data_shape[1], data_shape[2], data_shape[3], data_shape[4], 2, 2))
        real_data_array = real_data_array.reshape((data_shape[0], data_shape[1], data_shape[2], data_shape[3], data_shape[4], 2, 2))
        # reshape the pols
        model_data_array = model_data_array.reshape((data_shape[0], data_shape[1], data_shape[2]*data_shape[3]*data_shape[4], 2, 2))
        real_data_array = real_data_array.reshape((data_shape[0], data_shape[1], data_shape[2]*data_shape[3]*data_shape[4], 2, 2))
        # concatenate axis to a shape (Nants, Nants, :, 2, 2)

        self.model_data_array,  self.real_data_array = model_data_array, real_data_array
        
    def Wirtinger_lm_cal(self, diagonalize=False, Niteration=50, including_autobaseline=False):
        """
        Using Newton-Gauss method to obtain calibration gains G which minimizing \sum{D_[ij]-G_i M_{ij} G_j^H}, where D, G and M are all 2*2 matrices. 
        Update each step: G_{k+1} = [J(G_k)^H J(G_k)]^{-1} * J(G_k)^H * D, where J is the Jacobian matrix. 

        Parameters
        ----------

       

        Returns
        -------

        gain : dict
            calibration gains

        residual : 

        """
        ants, data, model = self.ants2cal, self.real_data_array, self.model_data_array
        gain_prev = np.array([[1,0],[0,1]]).astype(np.complex128)
        gain_prev = np.repeat(gain_prev[np.newaxis,:,:], len(ants), axis=0)
        gain_H_prev = np.copy(gain_prev)
        gain_next, gain_H_next = np.zeros_like(gain_prev), np.zeros_like(gain_prev)
        
        residual = np.zeros(Niteration)

        for iteration in range(Niteration):
            for (i,ant) in enumerate(ants):
                JH_J = np.zeros((2,2)).astype(np.complex128)
                JH_D = np.zeros((2,2)).astype(np.complex128)
                for (j,ant_q) in enumerate(ants):
                    # sum over baselines, frequencies and times
                    JH_J += np.sum(np.matmul(model[i,j], np.matmul(gain_H_prev[j], np.matmul(gain_prev[j], model[j,i]))), axis=0)
                    JH_D += np.sum(np.matmul(data[i,j], np.matmul(gain_prev[j], model[(j, i)])), axis=0)
                    
                if not including_autobaseline:
                    # if not including auto-baseline
                    JH_J -= np.sum(np.matmul(model[i,i], np.matmul(gain_H_prev[i], np.matmul(gain_prev[i], model[i,i]))), axis=0)
                    JH_D -= np.sum(np.matmul(data[i,i], np.matmul(gain_prev[i], model[(i, i)])), axis=0)         
                    """
                    if including_autobaseline:
                        # sum over baselines, frequencies and times
                        JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_q)], np.matmul(gain_H_prev[ant_q], np.matmul(gain_prev[ant_q], model[(ant_q,ant)]))), axis=0), axis=0)
                        JH_D += np.sum(np.sum(np.matmul(data[(ant,ant_q)], np.matmul(gain_prev[ant_q], model[(ant_q, ant)])), axis=0), axis=0)
                    else:
                        if ant_q != ant:
                        # if not including auto-baseline
                            JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_q)], np.matmul(gain_H_prev[ant_q], np.matmul(gain_prev[ant_q], model[(ant_q,ant)]))), axis=0), axis=0)
                            JH_D += np.sum(np.sum(np.matmul(data[(ant,ant_q)], np.matmul(gain_prev[ant_q], model[(ant_q, ant)])), axis=0), axis=0)
                    """
               
                if diagonalize==True:
                    gain_next[i] = np.diag(np.diag(np.matmul(JH_D, np.linalg.inv(JH_J))))
                else:
                    gain_next[i] = np.matmul(JH_D, np.linalg.inv(JH_J))
                print(i, ant, 'JH_J', JH_J, 'JH_D', JH_D, 'gain', gain_next[i])
                JH_J = np.zeros((2,2)).astype(np.complex128)
                JH_D = np.zeros((2,2)).astype(np.complex128)
                for (k,ant_p) in enumerate(ants):
                    JH_J += np.sum(np.matmul(model[i,k], np.matmul(gain_H_prev[k], np.matmul(gain_prev[k], model[k,i]))), axis=0)
                    JH_D += np.sum(np.matmul(model[i,k], np.matmul(gain_H_prev[k], data[k, i])), axis=0)
                
                if not including_autobaseline:
                    JH_J -= np.sum(np.matmul(model[i,i], np.matmul(gain_H_prev[i], np.matmul(gain_prev[i], model[i,i]))), axis=0)
                    JH_D -= np.sum(np.matmul(model[i,i], np.matmul(gain_H_prev[i], data[i, i])), axis=0)
                    """
                    if including_autobaseline:
                        # sum over baselines, frequencies and times
                        JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], np.matmul(gain_prev[ant_p], model[(ant_p,ant)]))), axis=0), axis=0)
                        JH_D += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], data[(ant_p, ant)])), axis=0), axis=0)
                    else:
                        if ant_p != ant:
                        # if not including auto-baseline
                            JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], np.matmul(gain_prev[ant_p], model[(ant_p,ant)]))), axis=0), axis=0)
                            JH_D += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], data[(ant_p, ant)])), axis=0), axis=0)
                    """
                   
                if diagonalize==True:
                    gain_H_next[i] = np.diag(np.diag(np.matmul(np.linalg.inv(JH_J), JH_D)))
                else:
                    gain_H_next[i] = np.matmul(np.linalg.inv(JH_J), JH_D)

            gain_prev = gain_next
            gain_H_prev = gain_H_next

            for ant in range(len(ants)):
                for ant_r in range(len(ants)):
                    if including_autobaseline:
                         residual[iteration] += np.linalg.norm(data[ant,ant_r] - np.matmul(gain_prev[ant], np.matmul(model[ant,ant_r], gain_H_prev[ant_r])))
                    else:
                        if ant_r != ant:
                            residual[iteration] += np.linalg.norm(data[ant,ant_r] - np.matmul(gain_prev[ant], np.matmul(model[ant,ant_r], gain_H_prev[ant_r])))
                    
        self.gain, self.gain_H, self.cal_res = gain_prev, gain_H_prev, residual                    


In [4]:
uvd_model, uvd_data = UVData(), UVData()

In [5]:
uvd_data.read("/lustre/aoc/projects/hera/H1C_IDR2/2458106/zen.2458106.22245.HH.uvh5")

In [6]:
uvd_model.read("/lustre/aoc/projects/hera/zmartino/hera_calib_model/IDR2/abscal_files/zen.2458106.21913.uvh5")

Telescope mock-HERA is not in known_telescopes.


In [20]:
uv_polcal = uvdata_pol_calibrator(model_data=uvd_model, real_data=uvd_data)

In [21]:
uv_polcal.data_slice(md_time_range=[2458106.21925450, 2458106.21925464], rd_time_range=[2458106.22263509, 2458106.22263519], 
                     md_freq_range=[0], rd_freq_range=[0], use_all_ants=True)
uv_polcal.Wirtinger_lm_cal(diagonalize=False, Niteration=2)


0 0 JH_J [[145431.19647545   +0.j         -15274.56548279-3783.95399348j]
 [-15086.30485436+4955.9677596j   30719.83125085+1172.01376612j]] JH_D [[-2.58509924e-06-1.38883309e-06j -2.41325935e-06-2.11664948e-06j]
 [-2.58509924e-06-1.38883309e-06j -2.41325935e-06-2.11664948e-06j]] gain [[-3.01938047e-11-1.42134286e-11j -9.47216721e-11-7.60743205e-11j]
 [-3.01938047e-11-1.42134286e-11j -9.47216721e-11-7.60743205e-11j]]
1 1 JH_J [[132638.0271353    +0.j         -12126.0375439 -4316.28827029j]
 [-12128.83969733+5738.41245692j  25716.04348658+1422.12418663j]] JH_D [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]] gain [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]]
2 2 JH_J [[111247.18504384   +0.j          -7134.21591215-4859.77580739j]
 [ -7141.15717475+6317.95622833j  17500.45724261+1458.18042094j]] JH_D [[-2.48610294e-06-3.15224769e-06j -9.94892587e-07-3.30810830e-06j]
 [-2.48610294e-06-3.15224769e-06j -9.94892587e-07-3.30810830e-06j]] gain [[-3.93157992e-11-3.71770311e-11j -7.99211565e-11-2.08443912e-10j]
 [-3.931

45 137 JH_J [[56449.64474154  +0.j         -2680.69396438+289.96328655j]
 [-2566.61966733 +39.67445774j  7502.53729894+329.6377443j ]] JH_D [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]] gain [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]]
46 138 JH_J [[46173.16826622  +0.j           642.10011618-191.55101269j]
 [  798.96577055 +70.18484945j  5419.73314296-121.36616324j]] JH_D [[2.32781233e-06+4.31482510e-06j 4.74803442e-06+3.18708455e-06j]
 [3.31844664e-05+1.63959181e-05j 8.01359771e-06+7.55642736e-06j]] gain [[3.65228261e-11+8.17868562e-11j 8.55437241e-10+5.98809247e-10j]
 [6.97431569e-10+3.28637438e-10j 1.35277311e-09+1.41025095e-09j]]
47 139 JH_J [[57994.34583423  +0.j           612.5610381 +868.74886499j]
 [  759.47809675-720.10514218j  7826.90087153+148.64372281j]] JH_D [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]] gain [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]]
48 140 JH_J [[104081.00558347   +0.j          -8898.92347782+2917.67915192j]
 [ -8572.9333706 -1806.30350317j  14893.40930534+1111.37564875j]] JH_D [[0.00000000e+0

38 98 JH_J [[  80.15766578+133.86282697j  -67.37441933+137.86377235j]
 [ -88.81480558+115.02822476j -100.63300341+121.80826819j]] JH_D [[ 5.78098459e-13-2.02665850e-14j  2.51609065e-14-1.50694707e-13j]
 [ 5.60942871e-13+6.32362276e-13j -1.85597393e-14-5.65161467e-13j]] gain [[ 4.08350251e-15+7.12532580e-16j -4.85439175e-15+7.38825669e-16j]
 [ 4.61691294e-15+8.17819324e-15j -8.89884948e-15-4.30562488e-15j]]
39 120 JH_J [[622.11616131+12.90241096j  -1.0534758  +3.2853078j ]
 [-24.69469945 -3.22821984j -10.90791547+15.14793104j]] JH_D [[ 5.45037112e-14+7.93582763e-13j -2.35644135e-13-1.96561866e-13j]
 [ 1.41616264e-13+6.19930834e-13j -2.10405555e-13-1.81242710e-13j]] gain [[-6.84160079e-18+1.90671310e-15j -1.27596861e-15+1.60619533e-14j]
 [ 1.28812872e-16+1.56421033e-15j -1.40439242e-15+1.45531383e-14j]]
40 121 JH_J [[3078.38359979 +0.53908878j -244.10148836+50.26552068j]
 [-107.64148676-38.09079038j  -47.44371356+34.00188761j]] JH_D [[1.37127448e-12-1.46554007e-12j 1.95558662e-12-9.50031