In [1]:
import gpflow as gp
import tensorflow as tf
import tensorflow_probability as tfp
from bayes_tec.datapack import DataPack
import numpy as np
from bayes_tec.utils.data_utils import calculate_weights, make_coord_array

import warnings
warnings.filterwarnings("ignore")


def gain_solve(datapack, ant_sel=None, time_sel=None,dir_sel=None, pol_sel=slice(0,1,1), freq_sel=slice(0,48,1),
                  flag_dirs=[3,9,10,11,12,16,17,27,31]):
    with DataPack(datapack,readonly=True) as datapack:
        datapack.switch_solset('sol000')
        datapack.select(time=time_sel, ant=ant_sel,pol=pol_sel, dir=dir_sel,freq=freq_sel)
        phase, axes = datapack.phase
        amp, axes = datapack.amplitude
        patch_names, directions = datapack.get_sources(axes['dir'])
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        timestamps, times = datapack.get_times(axes['time'])
        _,freqs = datapack.get_freqs(axes['freq'])
        select = np.where(~np.isin(np.arange(len(patch_names)), np.array(flag_dirs)))[0]

    gains = amp*np.exp(1j*phase)
    gains = gains[0,select,:,:]
    var = calculate_weights(gains.real,indep_axis=-1, N=4,phase_wrap=False) + calculate_weights(gains.imag,indep_axis=-1, N=4,phase_wrap=False)
    var = var.mean(-2)
    var[...,:2] = var[...,3:4]
    var[...,-2:] = var[..., -3:-2]
    var = np.median(np.median(var, axis=1))

        
    X_t = (times.mjd*86000. - times[0].mjd*86400.)
    X_t = X_t[:,None]
    X_d = np.array([directions.ra.deg - directions.ra.deg.mean(), 
                  directions.dec.deg - directions.dec.deg.mean()]).T.astype(np.float64)
    X = make_coord_array(X_t, X_d[select,:],flat=True)
    Xstar = make_coord_array(X_t, X_d,flat=True)

    Npol,Nd,Na,Nf,Nt = phase.shape
    Nd_ = len(select)
    #Nt*Nd, Na*Nf
    Y = gains.transpose((3,0,1,2)).reshape((Nt*Nd_,-1))
    Y = np.concatenate([Y.real,Y.imag],axis=1)
    y_mean = Y.mean(0,keepdims=True)
    Y -= y_mean
    y_std = np.mean(Y.std(0,keepdims=True))+1e-8
    Y /= y_std

    var /= y_std**2
    
    with tf.Session(graph=tf.Graph()) as sess:
        with gp.defer_build():
            kernt = gp.kernels.RBF(1,active_dims=slice(0,1,1))
            kernd = gp.kernels.RBF(2,active_dims=slice(1,3,1))
            kernt.lengthscales = 80.
            kernt.lengthscales.transform = gp.transforms.positiveRescale(80.)
            kernd.lengthscales = 1.
            kernt.variance = 1.
            kernt.variance.transform = gp.transforms.positiveRescale(0.75)
            kernd.variance.trainable = False
            kern = kernt*kernd
            m = gp.models.GPR(X.astype(np.float64),Y.astype(np.float64),kern)
            m.likelihood.variance = var
            m.likelihood.variance.trainable = False
            m.compile()
        ystar,varstar = m.predict_y(Xstar.astype(np.float64))
    ystar = (ystar * y_std + y_mean)
    real = ystar[:,:Na*Nf].reshape((Nt,Nd,Na,Nf)).transpose((1,2,3,0))
    imag = ystar[:,Na*Nf:].reshape((Nt,Nd,Na,Nf)).transpose((1,2,3,0))
    gstar = real + 1j*imag
    varstar = varstar * y_std**2
    real = varstar[:, :Na*Nf].reshape((Nt,Nd,Na,Nf)).transpose((1,2,3,0))
    imag = varstar[:, Na*Nf:].reshape((Nt,Nd,Na,Nf)).transpose((1,2,3,0))
    gvarstar = real + 1j*imag
    
    return gstar, gvarstar

def get_freq_weights(gains, freqs):
    tec_conv = -8.448e9/freqs
    
    with tf.Session(graph=tf.Graph()) as sess:
        phi_pl = tf.placeholder(tf.float64)
        tec_conv_pl = tf.placeholder(tf.float64)
        log_w = tf.Variable(np.zeros(Nf),dtype=tf.float64)
        w = tf.exp(log_w)
        w /= tf.reduce_sum(w)

        dtec = phi_pl / tec_conv_pl[:,None]
        dtec_mu = tf.reduce_sum(dtec*w[:,None],axis=-2)

        dtec_var = tf.reduce_sum(dtec**2*w[:,None],axis=-2) - dtec_mu**2

        loss = tf.reduce_mean(dtec_var)# + tf.reduce_sum(tf.abs(w_))

        opt = tf.train.AdamOptimizer(1e-3).minimize(loss,var_list=[log_w])
        init = tf.global_variables_initializer()
        sess.run(init)
        for i in range(1000):
            _, loss_, w_ = sess.run([opt, loss, w], feed_dict={phi_pl:np.angle(gains), tec_conv_pl:tec_conv})
        return w_
            
            

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*

In [2]:
import seaborn as sns
import pylab as plt
import pandas as pd
from bayes_tec.utils.data_utils import define_equal_subsets
import os

max_block_size, min_overlap = 30, 0

with DataPack('../../scripts/data/killms_datapack_4.hdf5',readonly=False) as datapack:
    datapack.select(time=None, ant=None)
    axes = datapack.axes_phase
    patch_names, directions = datapack.get_sources(axes['dir'])
    antenna_labels, antennas = datapack.get_antennas(axes['ant'])
    timestamps, times = datapack.get_times(axes['time'])
    _,freqs = datapack.get_freqs(axes['freq'])
    pol_labels,pols = datapack.get_pols(axes['pol'])
    
datapack.switch_solset('posterior_sol', 
            array_file=DataPack.lofar_array, 
            directions = np.stack([directions.ra.rad,directions.dec.rad],axis=1), patch_names=patch_names)
datapack.add_freq_dep_tab('phase', times.mjd*86400., pols = pol_labels,freqs=freqs)
datapack.add_freq_dep_tab('amplitude', times.mjd*86400., pols = pol_labels,freqs=freqs)
    
Nt = len(times)
solve_slices, _, _ = define_equal_subsets(Nt,max_block_size, min_overlap)

for i in range(1,62,1):
    ant = antenna_labels[i]
    for j, solve_slice in enumerate(solve_slices):
        time_slice = slice(*solve_slice, 1)
        gstar, gvarstar = gain_solve('../../scripts/data/killms_datapack_4.hdf5', 
                                   ant_sel=slice(i,i+1,1), time_sel=time_slice, dir_sel=None, 
                                   pol_sel=slice(0,1,1), freq_sel=slice(0,48,1),
                               flag_dirs=[3,9,10,11,12,16,17,27,31])
        with datapack:
            datapack.switch_solset('posterior_sol')
            datapack.select(ant=slice(i,i+1,1),time=time_slice,dir=None,pol=slice(0,1,1), freq=slice(0,48,1))
            datapack.phase = np.angle(gstar[None,...])
            datapack.amplitude = np.abs(gstar[None,...])
        

ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solution tables. Ignore referencing.
ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solution tables. Ignore referencing.
ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solution tables. Ignore referencing.
ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solution tables. Ignore referencing.
ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solution tables. Ignore referencing.
ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solution tables. Ignore referencing.
ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solution tables. Ignore referencing.
ERROR:root:Reference possible only for phase, scalarphase, clock, tec, tec3rd, and rotation solut

KeyboardInterrupt: 