In [1]:
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], '/Users/andrew/Documents/Lab/aas/modules'))
import notebook_loading

from FC_Data import Data_Creator

import io
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

importing Jupyter notebook from FC_Data.ipynb


In [None]:
class FCN_Train(object):
    """Currently the only difference between this class and CNN_Train
       is the shape of the input to the netwowkr (2D vs 4D)
       and this one does not have a conv_keep_prob argument (and does not have that param in the feed_dicts).
       
       I assume these could be combined into one class but i am unsure how to account for those differences.
       """
    
    def __init__(self,
                 network = None,
                 abs_min_max_delay = 0.040,
                 num_flatnesses = 100,
                 num_epochs = 100,
                 batch_size = 32,
                 log_dir = 'logs/',
                 model_save_interval = 25,
                 pretrained_model_path = None,
                 sample_keep_prob = 0.80,
                 fcl_keep_prob = 0.50,
                 clean = False):
        
        self.network = network
        self.abs_min_max_delay = abs_min_max_delay
        self.num_flatnesses = num_flatnesses
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.log_dir = log_dir
        self.model_save_interval = model_save_interval
        self.pretrained_model_path = pretrained_model_path
        self.sample_keep_prob = sample_keep_prob
        self.fcl_keep_prob = fcl_keep_prob
        self.clean = clean
        
    def add_data(self,train_info, test_info, gains):

        
        self.train_batcher = Data_Creator(self.num_flatnesses,
                                          train_info[0],
                                          train_info[1],
                                          gains,
                                          self.abs_min_max_delay,
                                          clean = self.clean)
        self.train_batcher.gen_data()


        self.test_batcher = Data_Creator(self.num_flatnesses,
                                         test_info[0],
                                         test_info[1],
                                         gains,
                                         self.abs_min_max_delay,
                                         clean = self.clean)
        
        self.test_batcher.gen_data()

        ## if passing clean, then passing info does nothing except waste time & resouces...
        
    def train(self):
        
        def gen_plot(predicted_values, actual_values):
            """Create a prediction plot and save to byte string."""


            abs_min_max_delay = self.abs_min_max_delay
            delay_tx  = lambda x: (np.array(x) + abs_min_max_delay) / (2. * abs_min_max_delay)
            delay_itx = lambda x: np.array(x) * 2. * abs_min_max_delay - abs_min_max_delay

            prediction_unscaled = delay_itx(predicted_values)
            actual_unscaled = delay_itx(actual_values)

            sorting_idx = np.argsort(actual_unscaled.T[0])

            fig, ax = plt.subplots(figsize = (5, 3), dpi = 144)

            ax.plot(prediction_unscaled.T[0][sorting_idx],
                    linestyle = 'none', marker = '.', markersize = 1,
                    color = 'darkblue')

            ax.plot(actual_unscaled.T[0][sorting_idx],
                    linestyle = 'none', marker = '.', markersize = 1, alpha = 0.50,
                    color = '#E50000')       

            ax.set_title('std: %.9f' %np.std(prediction_unscaled.T[0][sorting_idx] - actual_unscaled.T[0][sorting_idx]))

            buf = io.BytesIO()
            fig.savefig(buf, format='png', dpi = 144)
            plt.close(fig)
            buf.seek(0)

            return buf.getvalue()
        

        self.MISGs = []
        self.MSEs = []
        self.PWTs = []
        
        tf.reset_default_graph()
        
        self.network.create_graph()
        saver = tf.train.Saver()
        
        with tf.Session() as session:

            if self.pretrained_model_path == None:
                session.run(tf.global_variables_initializer())
                
            else:
                saver.restore(session, self.pretrained_model_path)


            archive_loc = self.log_dir + self.network.name
            training_writer = tf.summary.FileWriter(archive_loc + '/training', session.graph)
            testing_writer = tf.summary.FileWriter(archive_loc + '/testing', session.graph)
            self.model_save_location = archive_loc + '/trained_model.ckpt'   

            for epoch in range(self.num_epochs):

                training_inputs, training_targets = self.train_batcher.get_data(); self.train_batcher.gen_data()
                testing_inputs, testing_targets = self.test_batcher.get_data(); self.test_batcher.gen_data()  

                # if the division here has a remainde some values are just truncated
                batch_size = self.batch_size
                num_entries = self.num_flatnesses * 60.

                for j in range(int(num_entries/batch_size)):

                    training_inputs_batch = training_inputs[j*batch_size:(j + 1)*batch_size].reshape(-1,1024)
                    training_targets_batch = training_targets[j*batch_size:(j + 1)*batch_size].reshape(-1,1)


                    session.run([self.network.optimizer], feed_dict = {self.network.X: training_inputs_batch,
                                                                       self.network.targets: training_targets_batch,
                                                                       self.network.sample_keep_prob : self.sample_keep_prob,
                                                                       self.network.fcl_keep_prob : self.fcl_keep_prob}) 
                # Prediction: Scaled Train(ing results)   
                PST = session.run(self.network.predictions,
                                  feed_dict = {self.network.X: training_inputs.reshape(-1,1024),
                                               self.network.sample_keep_prob : 1.,
                                               self.network.fcl_keep_prob : 1.}) 
                
                train_feed_dict = {self.network.X: training_inputs.reshape(-1,1024),
                                  self.network.targets: training_targets.reshape(-1,1),
                                  self.network.sample_keep_prob : 1.,
                                  self.network.fcl_keep_prob : 1.,
                                  self.network.image_buf: gen_plot(PST, training_targets)}


                training_MISG, training_MSE, training_PWT, training_summary = session.run([self.network.MISG,
                                                                                           self.network.MSE,
                                                                                           self.network.PWT,
                                                                                           self.network.summary],
                                                                                          feed_dict = train_feed_dict) 

                training_writer.add_summary(training_summary, epoch)
                training_writer.flush()  

                # Prediction: Scaled test(ing results)   
                PST = session.run(self.network.predictions,
                                  feed_dict = {self.network.X: testing_inputs.reshape(-1,1024),
                                               self.network.sample_keep_prob : 1.,
                                               self.network.fcl_keep_prob : 1.}) 
                
                test_feed_dict = {self.network.X: testing_inputs.reshape(-1,1024),
                                  self.network.targets: testing_targets.reshape(-1,1),
                                  self.network.sample_keep_prob : 1.,
                                  self.network.fcl_keep_prob : 1.,
                                  self.network.image_buf: gen_plot(PST,testing_targets)} 

                testing_MISG, testing_MSE, testing_PWT, testing_summary = session.run([self.network.MISG,
                                                                                       self.network.MSE,
                                                                                       self.network.PWT,
                                                                                       self.network.summary],
                                                                                      feed_dict = test_feed_dict)
                                                                                       

                sys.stdout.write('\r' + "Epoch: " + str(epoch)
                                 + " (Training, Testing)"
                                 + " MISG: ({:0.4f}, {:0.4f})".format(training_MISG, testing_MISG)
                                 + " MSE: ({:0.4f}, {:0.4f})".format(training_MSE, testing_MSE)
                                 + " PWT: ({:2.2f}, {:2.2f})".format(training_PWT, testing_PWT))
                                                              
                testing_writer.add_summary(testing_summary, epoch)
                testing_writer.flush()  

                self.MISGs.append((training_MISG, testing_MISG))
                self.MSEs.append((training_MSE, testing_MSE))
                self.PWTs.append((training_PWT, testing_PWT))

                if (epoch + 1) % self.model_save_interval == 0:
                    saver.save(session, self.model_save_location, epoch)

            print('\rTraining Finished')

            training_writer.close()
            testing_writer.close()

        session.close()        