In [3]:
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
%matplotlib inline

### Scaffolding

In [5]:
class ML(object):
    def __init__(self, name='ML'):
        self.name = name
        self.graph = tf.Graph()
    
    def inference(self, X):
        """Compute inference model over data X and return the result."""
        raise NotImplementedError('Method inference(X) must be implemented.')
    
    def loss(self, X, Y):
        """Compute loss over training data X and expected outputs Y."""
        pass
    
    def inputs(self):
        """Read/generate input training data X and expected outputs Y."""
        pass
    
    def train(self, total_loss):
        """Train / adjust model parameters according to computed total loss."""
        
    def evaluate(self, sess, X, Y):
        """Evaluate the resulting trained model."""
        
    def run_training(self, training_steps, print_every=10, save_every=1000):
        save_dir = os.path.join('saves', self.name)
        self.sess = tf.Session(graph=self.graph)
        tf.initialize_all_variables().run()
        
        X, Y = self.inputs()
        
        total_loss = self.loss(X, Y)
        train_op = self.train(total_loss)
        
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        
        # Create a saver.
        saver = tf.train.Saver()
        
        # Load the latest training checkpoint, if it exists
        ckpt = tf.train.get_checkpoint_state(save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # Restores from checkpoint
            saver.restore(sess, ckpt.model_checkpoint_path)
            initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])
        else:
            initial_step = 0
        
        # The actual training loop
        for step in range(initial_step, training_steps):
            sess.run([train_op])
            
            # for debugging and learning purposes, see how the loss gets
            # decremented thru training steps
            if step % print_every == 0:
                print('loss: ', sess.run([total_loss]))
            
            # Model checkpoint
            if step % save_every == 0:
                saver.save(sess, save_dir, global_step=step)
                
        saver.save(sess, save_dir, global_step=training_steps)
        
        self.evaluate(sess, X, Y)
        coord.request_stop()
        coord.join(threads)
        sess.close()
        
    def __del__(self):
        #if hasattr(self, 'sess'):
        #    self.sess.close()
        pass