In [None]:
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], '/Users/andrew/Documents/Lab/aas/modules'))
import notebook_loading

from FC_Data import Data_Creator

import io
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [None]:
def fc_train(network,
             num_flatnesses,
             num_epochs,
             batch_size,
             log_dir,
             model_save_interval,
             train_info,
             test_info,
             gains,
             clean = False,
             pretrained_model_path = None,
             sample_keep_prob = 0.80,
             fcl_keep_prob = 0.50):
    
    """TODO: Docstring, comments, cleanup.."""
    
    
    def gen_plot(predicted_values, actual_values):
        """Create a prediction plot and save to byte string."""
        

        abs_min_max_delay = 0.040
        delay_tx  = lambda x: (np.array(x) + abs_min_max_delay) / (2. * abs_min_max_delay)
        delay_itx = lambda x: np.array(x) * 2. * abs_min_max_delay - abs_min_max_delay

        prediction_unscaled = delay_itx(predicted_values)
        actual_unscaled = delay_itx(actual_values)

        sorting_idx = np.argsort(actual_unscaled.T[0])

        fig, ax = plt.subplots(figsize = (5, 3), dpi = 144)

        ax.plot(prediction_unscaled.T[0][sorting_idx],
                linestyle = 'none', marker = '.', markersize = 1,
                color = 'darkblue')

        ax.plot(actual_unscaled.T[0][sorting_idx],
                linestyle = 'none', marker = '.', markersize = 1, alpha = 0.50,
                color = '#E50000')       

        ax.set_title('std: %.9f' %np.std(prediction_unscaled.T[0][sorting_idx] - actual_unscaled.T[0][sorting_idx]))

        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi = 144)
        plt.close(fig)
        buf.seek(0)

        return buf.getvalue()

    num = num_flatnesses
    num_entries = num * 60

    MISG = []
    MSE = []
    
    
    train_data, train_dict = train_info
    train_batcher = Data_Creator(num, bl_data = train_data, bl_dict = train_dict, gains = gains, clean = clean)
    train_batcher.gen_data()

    test_data, test_dict = test_info
    test_batcher = Data_Creator(num, bl_data = test_data, bl_dict = test_dict, gains = gains, clean = clean)

    test_batcher.gen_data()
    saver = tf.train.Saver()

    with tf.Session() as session:

        if pretrained_model_path == None:
            session.run(tf.global_variables_initializer())
        else:
            saver.restore(session, pretrained_model_path)

        training_writer = tf.summary.FileWriter(log_dir + '/training', session.graph)
        testing_writer = tf.summary.FileWriter(log_dir + '/testing', session.graph)
        model_save_location = log_dir + '/trained_model.ckpt'   

        for epoch in range(num_epochs):

            training_inputs, training_targets = train_batcher.get_data(); train_batcher.gen_data()
            testing_inputs, testing_targets = test_batcher.get_data(); test_batcher.gen_data()  

            for j in range(int(num_entries/batch_size)):

                training_inputs_batch = training_inputs[j*batch_size:(j + 1)*batch_size].reshape(-1,1024)
                training_targets_batch = training_targets[j*batch_size:(j + 1)*batch_size].reshape(-1,1)
                

                session.run([network.optimizer], feed_dict = {network.X: training_inputs_batch,
                                                              network.targets: training_targets_batch,
                                                              network.sample_keep_prob : sample_keep_prob,
                                                              network.fcl_keep_prob : fcl_keep_prob}) 
            # Prediction: Scaled Train(ing results)   
            PST = session.run(network.predictions,
                              feed_dict = {network.X: training_inputs.reshape(-1,1024),
                                           network.sample_keep_prob : 1.,
                                           network.fcl_keep_prob : 1.}) 


            training_MISG, training_MSE, training_summary = session.run([network.MISG, network.MSE, network.summary],
                                                                        feed_dict = {network.X: training_inputs.reshape(-1,1024),
                                                                        network.targets: training_targets.reshape(-1,1),
                                                                        network.sample_keep_prob : 1.,
                                                                        network.fcl_keep_prob : 1.,
                                                                        network.image_buf: gen_plot(PST,training_targets)}) 

            sys.stdout.write('\r' + "Epoch: " + str(epoch) + ". Training: MISG = {:.6f}, MSE = {:.6f}".format(training_MISG, training_MSE))

            training_writer.add_summary(training_summary, epoch)
            training_writer.flush()  

            # Prediction: Scaled test(ing results)   
            PSt = session.run(network.predictions,
                              feed_dict = {network.X: testing_inputs.reshape(-1,1024),
                                           network.sample_keep_prob : 1.,
                                           network.fcl_keep_prob : 1.}) 

            testing_MISG, testing_MSE, testing_summary = session.run([network.MISG, network.MSE, network.summary],
                                                                      feed_dict = {network.X: testing_inputs.reshape(-1,1024),
                                                                                   network.targets: testing_targets.reshape(-1,1),
                                                                                   network.sample_keep_prob : 1.,
                                                                                   network.fcl_keep_prob : 1.,
                                                                                   network.image_buf: gen_plot(PSt,testing_targets)}) 

            sys.stdout.write('\r' + "Epoch: " + str(epoch) + ". Testing: MISG = {:.6f}, MSE = {:.6f}".format(testing_MISG, testing_MSE))

            testing_writer.add_summary(testing_summary, epoch)
            testing_writer.flush()  

            MISG.append((training_MISG, testing_MISG))
            MSE.append((training_MSE, testing_MSE))

            if (epoch + 1) % model_save_interval == 0:
                saver.save(session, model_save_location, epoch)



        print('\rDone')

        training_writer.close()
        testing_writer.close()

    session.close()