In [1]:
import numpy as np
import tensorflow as tf
import os
import sys
sys.path.insert(0, 'module')
import time
from tools import print_time
import cells
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
# Data Related
TRAIN_SIZE = mnist.train.num_examples
VALID_SIZE = mnist.validation.num_examples
TEST_SIZE = mnist.test.num_examples
BATCH_SIZE = 500

In [3]:
# Save & Load
CKPT_DIR = 'checkpoint/fnn'

In [4]:
class fnn(object):
    def __init__(self, cell, lr, scope = None):
        with tf.variable_scope(scope or type(self).__name__):
            self.cell = cell
            self.inputs = tf.placeholder(dtype = tf.float32, shape = [None, cell.input_size], name = 'inputs')
            self.targets = tf.placeholder(dtype = tf.float32, shape = [None, cell.output_size], name = 'targets')
            self.global_step = tf.Variable(0, trainable = False, name = 'global_step')
            self.total_time = tf.Variable(0, dtype = tf.float32, trainable = True, name = 'total_time')

            # cell output
            self.output, self.state = cell(self.inputs)
            self.loss = tf.reduce_mean(-tf.reduce_sum(self.targets * tf.log(self.output), reduction_indices = [1]))
            tf.summary.scalar('loss', self.loss) # added

            # Training related
            self.params = tf.trainable_variables()
            self.gradients = tf.gradients(self.loss, self.params)
            self.opt_func = tf.train.GradientDescentOptimizer(lr)
            #clipped_gradients, norm = tf.clip_by_global_norm(gradients, 1)
            self.updates = self.opt_func.apply_gradients(zip(self.gradients, self.params), global_step = self.global_step)

            # Testing related
            self.correct_prediction = tf.equal(tf.argmax(self.output, 1), tf.argmax(self.targets, 1))
            self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))

            # Save & Load related
            self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 5)
        
    def save(self, sess, ckpt_dir):
        ckpt_path = os.path.join(ckpt_dir, 'params.ckpt')
        self.saver.save(sess, ckpt_path, self.global_step)
    
    def check_and_load(self, sess, ckpt_dir):
        # Preparation for Save & Load
        if not os.path.exists(ckpt_dir):
            print 'Create directory for saving checkpoint...',
            os.makedirs(ckpt_dir)
            print 'Completed'
        else:
            print 'Checkpoint directory exists\n'
            
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        
        # Check if checkpoint exists
        if ckpt:
            #print 'Read model parameters from %s... '%ckpt.model_checkpoint_path,
            self.saver.restore(sess, ckpt.model_checkpoint_path)
            print '\n'
        else:
            print 'Create new model...\n'
            sess.run(tf.global_variables_initializer())
        
    
    def train(self, sess, batch_gen, ckpt_dir, log_dir, max_epoch, patience, lr, lrdf, evaluate_per):
        '''
            args:
                batch_gen : batch generator with batch size, [train_gen, valid_gen, test_gen, batch_size]
                ckpt_dir : checkpoint directory
                log_dir : summary directory
                patience
                lr : learning rate
                lrdf : learning rate decay factor
        '''
        print('----------------------NN INFORMATION-----------------------')
        print('Input Size %d, Output Size %d' %(self.cell.input_size, self.cell.output_size))
        print('Stack Size %d, Hidden Size %d' %(self.cell.num_layer, self.cell.state_size))
        print('-----------------------------------------------------------\n')
        print ('---------------------BASIC PARAMETERS----------------------')
        print ('Batch size %d, Learning rate %f, Decay factor %f' %(batch_gen[-1], lr, lrdf))
        print ('Evaluate per %d iteration' %(evaluate_per))
        print ('Saving directory %s' %(ckpt_dir))
        print ('-----------------------------------------------------------')
        
        # Summary Writer
        merged = tf.summary.merge_all()
        
        train_writer = tf.summary.FileWriter(log_dir + '/train', sess.graph)
        valid_writer = tf.summary.FileWriter(log_dir + '/valid', sess.graph)
        
        train_gen, valid_gen, test_gen, batch_size = batch_gen
        
        # prevention for repetition
        valid_x, valid_y = valid_gen.next_batch(VALID_SIZE)
        test_x, test_y = test_gen.next_batch(TEST_SIZE)

        valid_input_feed = {self.inputs: valid_x, self.targets: valid_y}
        valid_output_feed = [merged, self.loss, self.accuracy]

        test_input_feed = {self.inputs: test_x, self.targets: test_y}
        test_output_feed = [self.loss, self.accuracy]
        
        self.check_and_load(sess, ckpt_dir)
        
        min_valid_loss = 9999
        iter_count = 0
        
        
        for epoch in xrange(max_epoch):
            iner_step = TRAIN_SIZE / batch_size # indicator for epoch completion
            # Time Check
            start_time = time.time()
            
            while True:
                try:
                    
                    train_x, train_y = train_gen.next_batch(batch_size)

                    input_feed = {self.inputs: train_x, self.targets: train_y}
                    output_feed = [merged, self.loss, self.updates]
                    summary, loss, _ = sess.run(output_feed, input_feed)
                    train_writer.add_summary(summary, iter_count)
                    
                    iter_count += 1
                    iner_step -= 1

                    # Evaluation step
                    if((self.global_step).eval() % evaluate_per == 0 and (self.global_step).eval() > 1):
                        
                        summary, valid_loss, accuracy = sess.run(valid_output_feed, valid_input_feed)
                        valid_writer.add_summary(summary, iter_count)
                        print 'Epoch : %3d | Evaluation : %4d | Learning Rate : %2.2f'%(epoch, self.global_step.eval()/evaluate_per, lr)
                        print '-------------------------------------------------------'
                        print '%-22s %10.5s'%('Training Loss :', loss)
                        print '%-22s %10.5s'%('Validation Loss :', valid_loss)
                        print '%-22s %10.5s'%('Best Validation loss :', min_valid_loss)
                        print '%-22s %10.5s%s'%('Validation Accuracy :', accuracy*100, '%')
                        
                        
                        # Check & Print training time
                        till_now = time.time() - start_time
                        sess.run(tf.assign(self.total_time, self.total_time + till_now)) # we must assign value to the tensor
                        print_time(self.total_time.eval())
                        
                        if valid_loss < min_valid_loss:
                            min_valid_loss = valid_loss
                            self.save(sess, ckpt_dir)
                    

                    # Check for one epoch completion
                    if iner_step == 0:
                        raise StopIteration

                except StopIteration:
                    break


In [5]:
tf.reset_default_graph()

In [6]:
cell = cells.FNNCell(784, 100, 1, 10)
batch_gen = [mnist.train, mnist.validation, mnist.test, BATCH_SIZE]

In [7]:
model = fnn(cell, 0.5)

In [8]:
sess = tf.InteractiveSession()

In [None]:
model.train(sess, batch_gen, CKPT_DIR, './', max_epoch = 10000, patience = 0, lr = 0.5, lrdf = 1, evaluate_per = 100)

----------------------NN INFORMATION-----------------------
Input Size 784, Output Size 10
Stack Size 1, Hidden Size 100
-----------------------------------------------------------

---------------------BASIC PARAMETERS----------------------
Batch size 500, Learning rate 0.500000, Decay factor 1.000000
Evaluate per 100 iteration
Saving directory checkpoint/fnn
-----------------------------------------------------------
Checkpoint directory exists

Create new model...

Epoch :   0 | Evaluation :    1 | Learning Rate : 0.50
-------------------------------------------------------
Training Loss :             0.537
Validation Loss :           0.493
Best Validation loss :       9999
Validation Accuracy :       88.11%
Total time cost : 0.53 seconds

Epoch :   1 | Evaluation :    2 | Learning Rate : 0.50
-------------------------------------------------------
Training Loss :             0.364
Validation Loss :           0.365
Best Validation loss :      0.493
Validation Accuracy :       90.03%