In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pyuvdata import UVData

In [2]:
import copy

In [11]:
class uvdata_pol_calibrator():
    
    def __init__(self, model_data=None, real_data=None):
        """
        """
        self.model_data, self.real_data = model_data, real_data
        
    def data_slice(self, use_all_times=False, md_time_range=[], rd_time_range=[], use_all_frequencies=False, md_freq_range=[], rd_freq_range=[], 
                   use_all_ants=False, md_ants=[], rd_ants=[]):
        """
        Get a slice of the data.
        
        Parameters
        ----------

       


        """
        if use_all_ants:
            ants = np.unique(self.model_data.ant_1_array)
        else:
            assert isinstance(md_ants, list) and isinstance(rd_ants, list), "ants must be a list." 
            md_ants, rd_ants= list(np.sort(md_ants)), list(np.sort(rd_ants))
            assert md_ants == rd_ants
            ants = md_ants
        self.ants2cal = ants
        
        if use_all_times:
            md_time_range = [np.unique(self.model_data.time_array)[0]-1e-2, np.unique(self.model_data.time_array)[-1]+1e-2]
            rd_time_range = [np.unique(self.real_data.time_array)[0]-1e-2, np.unique(self.real_data.time_array)[-1]+1e-2]
        else:
            assert isinstance(md_time_range, list) and isinstance(rd_time_range, list), "time_range must be a list"
            assert len(md_time_range) == len(rd_time_range), "Length of time_ranges must be the same."
            assert len(md_time_range) == 2, "Length of time_range must be 2."
        
        if use_all_frequencies:
            md_freqs = list(range(self.model_data.Nfreqs))
            rd_freqs = list(range(self.real_data.Nfreqs))
        else:
            assert isinstance(md_freq_range, list) and isinstance(rd_freq_range, list), "freq_range must be a list"
            assert md_freq_range == rd_freq_range, "freq_ranges must be the same."
            if len(md_freq_range) == 1:
                md_freqs = md_freq_range
                rd_freqs = rd_freq_range
            elif len(md_freq_range) == 2:
                md_freqs = list(range(md_freq_range[0],md_freq_range[1]+1))
                rd_freqs = list(range(rd_freq_range[0],rd_freq_range[1]+1))
            else:
                raise ValueError("Length of freq_range must be 1 or 2.")
            
        model_data_slice = self.model_data.select(antenna_nums=ants, freq_chans=md_freqs, time_range=md_time_range, inplace=False) 
        real_data_slice = self.real_data.select(antenna_nums=ants, freq_chans=rd_freqs, time_range=rd_time_range, inplace=False)             
        model_data_dict = {}
        real_data_dict = {}
        
        for ant1 in ants:
            for ant2 in ants:
                if ant1 <= ant2:
                    baseline_number = 2048*(ant1+1)+(ant2+1)+2**16
                    # baseline index = 2048 * (ant1+1) + (ant2+1) + 2**16
                    model_data_dict[(ant1,ant2)] = np.copy(model_data_slice.get_data(baseline_number))
                    model_data_copy = model_data_slice.get_data(baseline_number)
                    data_shape = model_data_slice.get_data(baseline_number).shape
                    model_data_dict[(ant1,ant2)][:,:,1],  model_data_dict[(ant1,ant2)][:,:,2],  model_data_dict[(ant1,ant2)][:,:,3] =  model_data_copy[:,:,2], model_data_copy[:,:,3], model_data_copy[:,:,1]
                    model_data_dict[(ant1,ant2)] = model_data_dict[(ant1,ant2)].reshape((data_shape[0],data_shape[1],2,2))

                    real_data_dict[(ant1,ant2)] = np.copy(real_data_slice.get_data(baseline_number))
                    real_data_copy = real_data_slice.get_data(baseline_number)
                    real_data_dict[(ant1,ant2)][:,:,1],  real_data_dict[(ant1,ant2)][:,:,2], real_data_dict[(ant1,ant2)][:,:,3] =  real_data_copy[:,:,2],  real_data_copy[:,:,3],  real_data_copy[:,:,1]
                    real_data_dict[(ant1,ant2)] = real_data_dict[(ant1,ant2)].reshape((data_shape[0],data_shape[1],2,2))
                    """
                    data_array orginally stores the polarization information as a 1d array [-5, -6, -7, -8], corresponding to [XX, YY, XY, YX]. 
                    Here we modify it into a 2d array[[-5,-7],[-8,-6]], or [[XX, XY], [YX, YY]].
                    """

                if ant1 > ant2:
                    model_data_dict[(ant1,ant2)] = np.transpose(np.conj(model_data_dict[(ant2,ant1)]), (0,1,3,2))
                    real_data_dict[(ant1,ant2)] = np.transpose(np.conj(real_data_dict[(ant2,ant1)]), (0,1,3,2))
            
        self.model_data_dict,  self.real_data_dict = model_data_dict, real_data_dict
        
    def Wirtinger_lm_cal(self, diagonalize=False, Niteration=50, including_autobaseline=False):
        """
        Using Newton-Gauss method to obtain calibration gains G which minimizing \sum{D_[ij]-G_i M_{ij} G_j^H}, where D, G and M are all 2*2 matrices. 
        Update each step: G_{k+1} = [J(G_k)^H J(G_k)]^{-1} * J(G_k)^H * D, where J is the Jacobian matrix. 

        Parameters
        ----------

       

        Returns
        -------

        gain : dict
            calibration gains

        residual : 

        """
        ants, data, model = self.ants2cal, self.real_data_dict, self.model_data_dict
        gain_prev, gain_H_prev, gain_next, gain_H_next = {}, {}, {}, {}
        for ant in ants:
            gain_prev[ant] = np.array([[1,0],[0,1]]).astype(np.complex128)
            gain_H_prev[ant] = np.transpose(np.conj(gain_prev[ant]))
        residual = np.zeros(Niteration)

        for iteration in range(Niteration):
            for ant in ants:
                JH_J = np.zeros((2,2)).astype(np.complex128)
                JH_D = np.zeros((2,2)).astype(np.complex128)
                for ant_q in ants:
                    if including_autobaseline:
                        # sum over baselines, frequencies and times
                        JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_q)], np.matmul(gain_H_prev[ant_q], np.matmul(gain_prev[ant_q], model[(ant_q,ant)]))), axis=0), axis=0)
                        JH_D += np.sum(np.sum(np.matmul(data[(ant,ant_q)], np.matmul(gain_prev[q], model[(ant_q, ant)])), axis=0), axis=0)
                    else:
                        if ant_q != ant:
                        # if not including auto-baseline
                            JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_q)], np.matmul(gain_H_prev[ant_q], np.matmul(gain_prev[ant_q], model[(ant_q,ant)]))), axis=0), axis=0)
                            JH_D += np.sum(np.sum(np.matmul(data[(ant,ant_q)], np.matmul(gain_prev[ant_q], model[(ant_q, ant)])), axis=0), axis=0)

                if diagonalize==True:
                    gain_next[ant] = np.diag(np.diag(np.matmul(JH_D, np.linalg.inv(JH_J))))
                else:
                    gain_next[ant] = np.matmul(JH_D, np.linalg.inv(JH_J))
                print(ant, 'JH_J', JH_J, 'JH_D', JH_D, 'gain', gain_next[ant])
                
                JH_J = np.zeros((2,2)).astype(np.complex128)
                JH_D = np.zeros((2,2)).astype(np.complex128)
                for ant_p in ants:
                    if including_autobaseline:
                        # sum over baselines, frequencies and times
                        JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], np.matmul(gain_prev[ant_p], model[(ant_p,ant)]))), axis=0), axis=0)
                        JH_D += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], data[(ant_p, ant)])), axis=0), axis=0)
                    else:
                        if ant_p != ant:
                        # if not including auto-baseline
                            JH_J += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], np.matmul(gain_prev[ant_p], model[(ant_p,ant)]))), axis=0), axis=0)
                            JH_D += np.sum(np.sum(np.matmul(model[(ant,ant_p)], np.matmul(gain_H_prev[ant_p], data[(ant_p, ant)])), axis=0), axis=0)
                
                if diagonalize==True:
                    gain_H_next[ant] = np.diag(np.diag(np.matmul(np.linalg.inv(JH_J), JH_D)))
                else:
                    gain_H_next[ant] = np.matmul(np.linalg.inv(JH_J), JH_D)

            gain_prev = gain_next
            gain_H_prev = gain_H_next

            for ant in ants:
                for ant_r in ants:
                    if including_autobaseline:
                         residual[iteration] += np.linalg.norm(data[(ant,ant_r)] - np.matmul(gain_prev[ant], np.matmul(model[(ant,ant_r)], gain_H_prev[ant_r])))
                    else:
                        if ant_r != ant:
                            residual[iteration] += np.linalg.norm(data[(ant,ant_r)] - np.matmul(gain_prev[ant], np.matmul(model[(ant,ant_r)], gain_H_prev[ant_r])))
                    
        self.gain, self.gain_H, self.cal_res = gain_prev, gain_H_prev, residual                    


In [4]:
uvd_model, uvd_data = UVData(), UVData()

In [5]:
uvd_data.read("/lustre/aoc/projects/hera/H1C_IDR2/2458106/zen.2458106.22245.HH.uvh5")

In [6]:
uvd_model.read("/lustre/aoc/projects/hera/zmartino/hera_calib_model/IDR2/abscal_files/zen.2458106.21913.uvh5")

Telescope mock-HERA is not in known_telescopes.


In [7]:
print(np.unique(uvd_data.time_array))

[2458106.22251081 2458106.22263509 2458106.22275937 2458106.22288364
 2458106.22300792 2458106.22313219 2458106.22325647 2458106.22338074
 2458106.22350502 2458106.2236293  2458106.22375357 2458106.22387785
 2458106.22400212 2458106.2241264  2458106.22425067 2458106.22437495
 2458106.22449923 2458106.2246235  2458106.22474778 2458106.22487205
 2458106.22499633 2458106.2251206  2458106.22524488 2458106.22536916
 2458106.22549343 2458106.22561771 2458106.22574198 2458106.22586626
 2458106.22599053 2458106.22611481 2458106.22623909 2458106.22636336
 2458106.22648764 2458106.22661191 2458106.22673619 2458106.22686046
 2458106.22698474 2458106.22710901 2458106.22723329 2458106.22735757
 2458106.22748184 2458106.22760612 2458106.22773039 2458106.22785467
 2458106.22797894 2458106.22810322 2458106.2282275  2458106.22835177
 2458106.22847605 2458106.22860032 2458106.2287246  2458106.22884887
 2458106.22897315 2458106.22909743 2458106.2292217  2458106.22934598
 2458106.22947025 2458106.22959453

In [8]:
print(np.unique(uvd_model.time_array))

[2458106.21913026 2458106.21925454 2458106.21937881 2458106.21950309
 2458106.21962736 2458106.21975164 2458106.21987591 2458106.22000019
 2458106.22012446 2458106.22024874 2458106.22037301 2458106.22049729
 2458106.22062156 2458106.22074584 2458106.22087012 2458106.22099439
 2458106.22111867 2458106.22124294 2458106.22136722 2458106.22149149
 2458106.22161577 2458106.22174004 2458106.22186432 2458106.22198859
 2458106.22211287 2458106.22223714 2458106.22236142 2458106.22248569
 2458106.22260997 2458106.22273424 2458106.22285852 2458106.22298279
 2458106.22310707 2458106.22323134 2458106.22335562 2458106.22347989
 2458106.22360417 2458106.22372845 2458106.22385272 2458106.223977
 2458106.22410127 2458106.22422555 2458106.22434982 2458106.2244741
 2458106.22459837 2458106.22472265 2458106.22484692 2458106.2249712
 2458106.22509547 2458106.22521975 2458106.22534402 2458106.2254683
 2458106.22559257 2458106.22571685 2458106.22584112 2458106.2259654
 2458106.22608967 2458106.22621395 24581

In [12]:
uv_polcal = uvdata_pol_calibrator(model_data=uvd_model, real_data=uvd_data)

In [13]:
uv_polcal.data_slice(md_time_range=[2458106.21925450, 2458106.21925464], rd_time_range=[2458106.22263509, 2458106.22263519],
                     md_freq_range=[0], rd_freq_range=[0], use_all_ants=True)

In [14]:
uv_polcal.model_data_dict[(26,0)], uv_polcal.model_data_dict[(0,26)]

(array([[[[-49.53772006-36.05405439j,  33.71876991+19.82678821j],
          [ 31.38067383+19.84352004j, -65.23585302-35.83477334j]]]]),
 array([[[[-49.53772006+36.05405439j,  31.38067383-19.84352004j],
          [ 33.71876991-19.82678821j, -65.23585302+35.83477334j]]]]))

In [15]:
uv_polcal.Wirtinger_lm_cal(diagonalize=False, Niteration=2)

0 JH_J [[145431.19647545   +0.j         -57550.48565315-3221.13177944j]
 [-57550.48565315+3221.13177944j 121269.53407376   +0.j        ]] JH_D [[-2.58509924e-06-1.38883309e-06j  2.47513985e-05+2.20625336e-05j]
 [-9.14298838e-07+8.21878263e-07j  2.20499519e-05-1.79001045e-05j]] gain [[8.25787733e-11+7.13681146e-11j 2.41395839e-10+2.17992099e-10j]
 [7.68796223e-11-6.99679306e-11j 2.20168962e-10-1.78768342e-10j]]
1 JH_J [[132638.02713531   +0.j         -42615.78057247-2730.09406404j]
 [-42615.78057247+2730.09406404j  94978.9490475    +0.j        ]] JH_D [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]] gain [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]]
2 JH_J [[111247.18504384   +0.j         -24538.96968403-2460.57470418j]
 [-24538.96968403+2460.57470418j  80644.20432677   +0.j        ]] JH_D [[-2.48610294e-06-3.15224769e-06j -2.28937263e-05+1.24872002e-06j]
 [ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j]] gain [[-9.07790543e-11-1.99966491e-11j -3.10898318e-10+6.62979087e-12j]
 [ 0.00000000e+00+

 [0.+0.j 0.+0.j]] gain [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]]
88 JH_J [[145247.2310775   +0.j         -65580.25712528-787.62133371j]
 [-65580.25712528+787.62133371j 169116.9764498   +0.j        ]] JH_D [[ 1.70681609e-05-3.30908102e-06j  3.24791261e-05+7.43697944e-06j]
 [-3.04049458e-05+3.10263954e-05j  2.23103652e-05-2.26389547e-05j]] gain [[ 2.47866352e-10-4.81108317e-12j  2.88191378e-10+4.32640906e-11j]
 [-1.82441725e-10+1.84817835e-10j  6.03146043e-11-6.30465894e-11j]]
98 JH_J [[ 62193.30995553  +0.j         -19561.32767705-280.18615672j]
 [-19561.32767705+280.18615672j  76386.20093645  +0.j        ]] JH_D [[-1.32299989e-04-4.39197419e-05j  3.32713449e-06-3.55608047e-06j]
 [ 1.28301380e-04+5.57775113e-05j -9.19579497e-05-8.12810267e-05j]] gain [[-2.29895547e-09-7.84195649e-10j -5.42293836e-10-2.55806988e-10j]
 [ 1.82666804e-09+6.17317076e-10j -7.38338347e-10-8.99294337e-10j]]
120 JH_J [[127810.22856799   +0.j         -41893.63256018-5198.84088298j]
 [-41893.63256018+5198.84088298j 12864

54 JH_J [[1627.41072091+4.54747351e-13j -802.6519146 -6.74395006e+02j]
 [-802.6519146 +6.74395006e+02j  966.79943048+1.13033321e-26j]] JH_D [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]] gain [[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]]
55 JH_J [[ 2142.24600756+7.38964445e-13j -1975.80568591-7.71859433e+02j]
 [-1975.80568591+7.71859433e+02j  2793.31133273+5.13823087e-27j]] JH_D [[ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j]
 [-4.75272192e-15-3.12796137e-14j -1.19193917e-13-1.50068908e-14j]] gain [[ 0.00000000e+00+0.00000000e+00j  0.00000000e+00+0.00000000e+00j]
 [-1.75401306e-16-1.68577905e-17j -1.62080383e-16-6.57641560e-17j]]
65 JH_J [[ 33.65569155+3.55271368e-15j -55.56165615-3.07056405e+00j]
 [-55.56165615+3.07056405e+00j  92.37657478+1.42108547e-14j]] JH_D [[-1.57882914e-12+1.05438834e-12j -1.77042006e-12+1.64903136e-12j]
 [-1.70237194e-12+1.30681488e-12j  7.77623169e-14+2.36685108e-12j]] gain [[-1.91757723e-11+1.55923284e-11j -1.20710833e-11+8.75875932e-12j]
 [-1.16803109e-11+2.020