# CNN

## Imports

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [2]:
import sys, os

sys.path.insert(1, os.path.join(sys.path[0], '../modules'))
import notebook_loading

from bn_CNN_Train import bn_CNN_Train
from CNN_Data import *
from batch_norm_fully_CNN import bn_fully_CN_Network

importing Jupyter notebook from bn_CNN_Train.ipynb
importing Jupyter notebook from CNN_Data.ipynb
importing Jupyter notebook from batch_norm_fully_CNN.ipynb


## Load & Setup Data

In [3]:
# extract the redundant baselines and their gains and data from miriad and calfits files
red_bls, gains, uvd = load_relevant_data('../zen_data/zen.2458098.58037.xx.HH.uv','../zen_data/zen.2458098.58037.xx.HH.uv.abs.calfits')

# seperate trining and testing redundant baselines 
# if we have not already done this, load them from disk
training_redundant_baselines_dict, testing_redundant_baselines_dict = get_or_gen_test_train_red_bls_dicts(red_bls, gains.keys())

# seperate the visiblites
training_baselines_data = get_seps_data(training_redundant_baselines_dict, uvd)
testing_baselines_data = get_seps_data(testing_redundant_baselines_dict, uvd)

In [4]:
from pprint import pprint

In [5]:
def rand_params():
    
    rand_nCLs = lambda : np.random.randint(low = 2, high = 5) # ¯\_(ツ)_/¯ 
    rand_BS = lambda : np.random.randint(low = 32, high = 128)
    # random batch size
    rand_LR = lambda : np.random.uniform(low = 1e-5, high = 1e-4)
    # random learning rate

    input_width = 1024
    
    num_layers = rand_nCLs()
    minimum_layer_width = 8  # ¯\_(ツ)_/¯ 
    
    sample_widths = []
    worfs = []
    wifes = []

    min_wfw = 2
    
    max_first_wfw = input_width / 2
    max_first_wrf = input_width / 32
    first_wrf = np.random.randint(low = 2, high = input_width / minimum_layer_width)
    while (input_width/first_wrf - input_width/float(first_wrf) == 0) == False:
        first_wrf = np.random.randint(low = 2, high = input_width / minimum_layer_width)
    worfs.append(first_wrf)
    sample_widths.append(input_width/first_wrf)
    
    first_wfw = np.random.randint(low = min_wfw, high = (sample_widths[-1] / 2) / 2)
    wifes.append(first_wfw)

    for i in range(num_layers):
        if i > 0:
            previous_wrf = worfs[-1]

            new_wrf = np.random.randint(low = 1, high = sample_widths[-1] / minimum_layer_width)
            while (sample_widths[-1]/new_wrf - sample_widths[-1]/float(new_wrf) == 0) == False:
                new_wrf = np.random.randint(low = 1, high = sample_widths[-1] / minimum_layer_width)
            worfs.append(new_wrf)
            sample_widths.append(sample_widths[-1]/new_wrf)
            new_wfw = np.random.randint(low = min_wfw, high = (sample_widths[-1] / 2) / 2)
            wifes.append(new_wfw)


    rand_params = {}
    rand_params['WFWs'] = wifes
    rand_params['WRFs'] = worfs
    rand_params['LR'] = round(rand_LR(), 10)
    rand_params['BS'] = rand_BS()
    rand_params['num_1x1_conv_filters'] = np.random.randint(low = 4, high = 16)  # ¯\_(ツ)_/¯ 
    
    return rand_params

## Trainer

In [50]:
trainer = bn_CNN_Train()

In [56]:
num_trials = 100
# 'L'oaded 'R'andom : score
previous_seed = {}
continue_count = 0
for trial in range(num_trials):
    
    # if it is the first trial
    # or if the previous seed did not continue to perform better
    if len(previous_seed.keys()) == 0:
        params = rand_params()
        network = bn_fully_CN_Network(name = 'bn_rand_{}'.format(trial),
                                      wide_filter_widths = params['WFWs'],
                                      width_reduction_factors = params['WRFs'],
                                      learning_rate = params['LR'],
                                      num_1x1_conv_filters = params['num_1x1_conv_filters'],
                                      g_shift = 1e-5)
        trainer = bn_CNN_Train()
        trainer.pred_keep_prob = 1.0
        trainer.add_data((training_baselines_data, training_redundant_baselines_dict),
                 (testing_baselines_data, testing_redundant_baselines_dict),
                 gains)
        trainer.network = network
         
        trainer.num_epochs = 1000
        trainer.model_save_interval = 250
        trainer.batch_size = params['BS']
        trainer.PWTs = []
        trainer.MSEs = []
        trainer.MISGs = []
        trainer.train()
        score = np.mean(np.array(trainer.PWTs).T[1][:-20])
        previous_seed = {trial: score}
        print('New seed {}'.format(previous_seed.keys()[0]))
        continue_count = 0
        
    else:
        print('Continuing seed {} in {}'.format(previous_seed.keys()[0], trial,))
        continue_count += 1
        network = bn_fully_CN_Network('bn_rand_{}'.format(trial),[], [])
        
        previous_seed_key = previous_seed.keys()[0]
        
        previous_seed_name = 'bn_rand_{}'.format(previous_seed_key)
        network.load_params('logs/' + previous_seed_name + '/params/bn_fully_CN_Network')
        network.learning_rate *= 0.95
        trainer.network = network
        trainer.load_params('logs/' + previous_seed_name + '/params/bn_CNN_Train')
        trainer.pretrained_model_path = 'logs/' + previous_seed_name + '/trained_model.ckpt-{}'.format(trainer.num_epochs - 1)
        trainer.PWTs = []
        trainer.MSEs = []
        trainer.MISGs = []
        trainer.train()
        score = np.mean(np.array(trainer.PWTs).T[1][:-20])

        
        # only use this network as a seed if it continued to do better
        # better means at least some % better (to avoid plateaus)
        if score > previous_seed[previous_seed.keys()[0]] * 1.0025:
            previous_seed = {trial: score}
        else:
            previous_seed = {}
            print('Previous seed dead')
      

Training Finished
New seed 0
Continuing seed 0 in 1
Network ReadyINFO:tensorflow:Restoring parameters from logs/bn_rand_0/trained_model.ckpt-99
Training Finished
Continuing seed 1 in 2
Network ReadyINFO:tensorflow:Restoring parameters from logs/bn_rand_1/trained_model.ckpt-99
Epoch: 7 (Training, Testing) MISG: (0.9433, 0.9363) MSE: (0.0741, 0.0728) PWT: (0.97, 1.05)

KeyboardInterrupt: 

1.3325