## Imports

In [1]:
import numpy as np
from pyuvdata import UVData
import hera_cal as hc
import random
from threading import Thread
from glob import glob
import sys

## Functions

In [6]:
def load_relevant_data(miriad_path, calfits_path):
    """Loads redundant baselines, gains, and data.
    
    Arguments:
        
        miriad_path : string - path to a miriad file for some JD
        calfits_path : string - path to a calfits file for the same JD
        
    Returns:
        red_bls, gains, uvd
        
    """

    # read the data
    uvd = UVData()
    uvd.read_miriad(miriad_path)

    # get the redundancies for that data
    aa = hc.utils.get_aa_from_uv(uvd)
    info = hc.omni.aa_to_info(aa)
    red_bls = np.array(info.get_reds())

    # gains for same data 
    gains, _ = hc.io.load_cal(calfits_path)
    
    return red_bls, gains, uvd

In [7]:
def get_good_red_bls(red_bls, gain_keys, min_group_len = 4):
    """Select all the good antennas from red_bls
    
    Each baseline group in red_bls will have its bad separations removed.
    Groups with less than min_group_len are removed.
    
    Arguments:
    
        red_bls : list of lists - Each sublist is a group of redundant separations
                                  for a unique baseline.
        gain_keys : dict - gains.keys() from hc.io.load_cal()
        min_group_len: int - Minimum number of separations in a 'good' sublist.
                            (Default = 4, so that both training and testing can take two seps)
    
    Returns:
    
        list of lists - Each sublist is a len >=4 list of separations of good antennas
    
    """
    
    def ants_good(sep):
        """Returns True if both antennas are good.

        Because we are using data from firstcal
        (is this right? I have trouble rememebring the names
        and properties of the different data sources)
        we can check for known good or bad antennas by looking to see if the antenna
        is represented in gains.keys(). If the antenna is present, its good.

        Arguments:

            sep : tuple - antenna indices

        Returns :

            bool - True if both antennas are in gain_keys
        """

        ants = [a[0] for a in gain_keys]

        if sep[0] in ants and sep[1] in ants:
            return True

        else:
            return False

    good_redundant_baselines = []
    
    for group in red_bls:
        
        new_group = []
        
        for sep in group:
            
            # only retain seps made from good antennas
            if ants_good(sep) == True:
                new_group.append(sep)
                
        new_group_len = len(new_group)
        
        # make sure groups are large enough that both the training set
        # and the testing set can take two seps 
        if new_group_len >= min_group_len:
            
            # Make sure groups are made from even number of seps
            #
            # I honestly dont recall why I did this. ¯\_(ツ)_/¯ 
#             if new_group_len % 2 != 0:
#                 new_group.pop()
                
            good_redundant_baselines.append(sorted(new_group))
            
    return good_redundant_baselines

In [8]:
def _train_test_split_red_bls(red_bls, training_percent = 0.80):
    """Slit a list of redundant baselines into a training set and a testing set. 
    
    Each of the two sets has at least one pair of baselines from every group from red_bls.
    However, separations from one set will not appear the other.
    
    Arguments:
    
        red_bls : list of lists - Each sublist is a group of redundant separations
                                  for a unique baseline.
                                  Each sublist must have at least 4 separations.
        training_percent : float - *Approximate* portion of the separations that will
                                   appear in the training set.
        
    Returns:
    
        tuple of dicts - train_red_bls_dict, test_red_bls_dict
    
    """
    
    
    training_redundant_baselines_dict = {}
    testing_redundant_baselines_dict = {}

    # make sure that each set has at least 2 seps from each group
    #
    thinned_groups_dict = {}
    for group in red_bls:

        # group key is the sep with the lowest antenna indicies
        key = sorted(group)[0]
        
        # pop off seps from group and append them to train or test groups
        random.shuffle(group)
        training_group = []
        training_group.append(group.pop())
        training_group.append(group.pop())

        testing_group = []
        testing_group.append(group.pop())
        testing_group.append(group.pop())
        
        # add the new train & test groups into the dicts
        training_redundant_baselines_dict[key] = training_group
        testing_redundant_baselines_dict[key] = testing_group
        
        # if there are still more seps in the group, save them into a dict for later assignment
        if len(group) != 0:
            thinned_groups_dict[key] = group

    # Shuffle and split the group keys into two sets using 
    #
    thinned_dict_keys = thinned_groups_dict.keys()
    random.shuffle(thinned_dict_keys)
    
    """Because we are ensuring that each set has some seps from every group,
       the ratio of train / test gets reduced a few percent.
       This (sort of) accounts for that with an arbitrary shift found by trial and error.
       
       Without this the a setting of training_percent = 0.80 results in a 65/35 split, not 80/20.
       
       I assume there is a better way..."""
    t_pct = np.min([0.95, training_percent + 0.15])

    
    # why did i call this extra?
    # these are the keys that each set will extract seps from thinned_groups_dict with
    training_red_bls_extra, testing_red_bls_extra = np.split(thinned_dict_keys,
                                                             [int(len(thinned_dict_keys)*t_pct)])

    # extract seps from thinned_groups_dict and apply to same key in training set
    for key in training_red_bls_extra:
        key = tuple(key)
        group = thinned_groups_dict[key]
        training_group = training_redundant_baselines_dict[key]
        training_group.extend(group)
        training_redundant_baselines_dict[key] = training_group

    # extract seps from thinned_groups_dict and apply to same key in testing set
    for key in testing_red_bls_extra:
        key = tuple(key)
        group = thinned_groups_dict[key]
        testing_group = testing_redundant_baselines_dict[key]
        testing_group.extend(group)
        testing_redundant_baselines_dict[key] = testing_group
        
    return training_redundant_baselines_dict, testing_redundant_baselines_dict

In [9]:
def _loadnpz(filename):
    """Loads up npzs. For dicts do loadnpz(fn)[()]"""
    
    a = np.load(filename)
    d = dict(zip(("data1{}".format(k) for k in a), (a[k] for k in a)))
    
    return d['data1arr_0']

In [10]:
def get_or_gen_test_train_red_bls_dicts(red_bls = None,
                                        gain_keys = None,
                                        training_percent = 0.80):

    if len(glob("*.npz")) == 2:
        
        training_red_bls_dict = _loadnpz('training_redundant_baselines_dict.npz')[()]
        testing_red_bls_dict = _loadnpz('testing_redundant_baselines_dict.npz')[()]
    else:
        
        assert type(red_bls) != None, "Provide a list of redundant baselines"
        assert type(gain_keys) != None, "Provide a list of gain keys"
        
        good_red_bls = get_good_red_bls(red_bls, gain_keys)
        training_red_bls_dict, testing_red_bls_dict = _train_test_split_red_bls(good_red_bls,
                                                                                training_percent = training_percent)

        np.savez('training_redundant_baselines_dict', training_red_bls_dict)
        np.savez('testing_redundant_baselines_dict', testing_red_bls_dict)

    return training_red_bls_dict, testing_red_bls_dict

In [11]:
def get_seps_data(red_bls_dict, uvd):
    """Get the data for all the seps in a redundant baselines dictionary."""
    
    data = {}
    for key in red_bls_dict.keys():
        for sep in red_bls_dict[key]:
            data[sep] = uvd.get_data(sep)
    
    return data

## Classes

In [80]:
class block_Data_Creator(object):
    """Creates data in an alternate thread.
    
    ## usage:
    ## data_maker = data_creator(num_flatnesses=250, mode = 'train')
    ## data_maker.gen_data() (before loop)
    ## inputs, targets = data_maker.get_data() (start of loop)
    ## data_maker.gen_data() (immediately after get_data())

    """

    def __init__(self,
                 num_flatnesses,
                 bl_data = None,
                 bl_dict = None,
                 gains = None,
                 abs_min_max_delay = 0.040):
        
        """
        Arguments
            num_flatnesses : int - number of flatnesses used to generate data.
                                   Number of data samples = 60 * num_flatnesses
            bl_data : data source. Output of get_seps_data()
            bl_dict : dict - Dictionary of seps with bls as keys. An output of get_or_gen_test_train_red_bls_dicts()
            gains : dict - Gains for this data. An output of load_relevant_data()
            
                                   
        """
        
        
        self._num = num_flatnesses
                    
        self._bl_data = bl_data
        self._bl_data_c = None
        
        self._bl_dict = bl_dict
        
        self._gains = gains
        self._gains_c = None
        
        self._epoch_batch = []
        self._nu = np.arange(1024)
        self._tau = abs_min_max_delay
        
    def _gen_data(self):
        
        # scaling tools
        # the NN likes data in the range (0,1)
        angle_tx  = lambda x: (np.asarray(x) + np.pi) / (2. * np.pi)
        angle_itx = lambda x: np.asarray(x) * 2. * np.pi - np.pi

        delay_tx  = lambda x: (np.array(x) + self._tau) / (2. * self._tau)
        delay_itx = lambda x: np.array(x) * 2. * self._tau - self._tau
        
        targets = np.random.uniform(low = -self._tau, high = self._tau, size = (self._num * 60, 1))
        applied_delay = np.exp(-2j * np.pi * (targets * self._nu + np.random.uniform()))



        assert type(self._bl_data) != None, "Provide visibility data"
        assert type(self._bl_dict) != None, "Provide dict of baselines"
        assert type(self._gains)   != None, "Provide antenna gains"

        if self._bl_data_c == None:
            self._bl_data_c = {key : self._bl_data[key].conjugate() for key in self._bl_data.keys()}

        if self._gains_c == None:
            self._gains_c = {key : self._gains[key].conjugate() for key in self._gains.keys()}


        def _flatness(seps):
            """Create a flatness from a given pair of seperations, their data & their gains."""

            a, b = seps[0][0], seps[0][1]
            c, d = seps[1][0], seps[1][1]


            return self._bl_data[seps[0]]   * self._gains_c[(a,'x')] * self._gains[(b,'x')] * \
                   self._bl_data_c[seps[1]] * self._gains[(c,'x')]   * self._gains_c[(d,'x')]

        inputs = []
        for _ in range(self._num):

            unique_baseline = random.sample(self._bl_dict.keys(), 1)[0]
            two_seps = [random.sample(self._bl_dict[unique_baseline], 2)][0]

            inputs.append(_flatness(two_seps))

        inputs = np.angle(np.array(inputs).reshape(-1,1024) * applied_delay)
        
        permutation_index = np.random.permutation(np.arange(self._num * 60))
        
        #0.00025 precision
        rounded_targets = np.asarray([np.round(abs(np.round(d * 40,2)/40), 5) for d in targets]).reshape(-1)
        classes = np.arange(0,0.04025, 0.00025)
        # 0.0005 precision
        # x = [np.round(abs(np.round(d * 20,2)/20), 5) for d in dels]

        # 0.001 precision
        # x = [np.round(abs(np.round(d * 10,2)/10), 5) for d in dels]
        
        # 0.005 precision - 9 blocks
#         rounded_targets = np.asarray([np.round(abs(np.round(d * 2,2)/2), 5) for d in targets]).reshape(-1)
        # classes = np.arange(0,0.045, 0.005)

        eye = np.eye(len(classes), dtype = int)
        classes_labels = {}
        for i, key in enumerate(classes):
            classes_labels[np.round(key,5)] = eye[i].tolist()
            
#         print(classes_labels)
            
        labels = [classes_labels[x] for x in rounded_targets]

        self._epoch_batch.append((angle_tx(inputs[permutation_index]), labels))

    def gen_data(self):
        """Starts a new thread and generates data there."""
        
        self._thread = Thread(target = self._gen_data, args=())
        self._thread.start()

    def get_data(self, timeout = 10):
        """Retrieves the data from the thread.
        
        Returns:
            
            list of shape (num_flatnesses, 60, 1024)
             - needs to be reshaped for training
        """
        
        if len(self._epoch_batch) == 0:
            self._thread.join(timeout)
            
        return self._epoch_batch.pop(0)

In [7]:
import numpy as np
classes = np.arange(0,0.04025, 0.00025)

In [8]:
eye = np.eye(len(classes), dtype = int)
classes_labels = {}
for i, key in enumerate(classes):
    classes_labels[np.round(key,5)] = eye[i].tolist()


In [11]:
len(classes_labels.keys())

161