# XOR problem

## Import

In [None]:
%matplotlib notebook
import tensorflow as tf
import FFBP

# Prevent unwanted logging messages by tensorflow (log ERROR messages only)
tf.logging.set_verbosity(tf.logging.ERROR) 

## Build

In [None]:
# TRAIN CONFIGS
# NUM_EPOCHS = 1000
BATCH_SIZE = 4
INP_SIZE = 2
TARG_SIZE = 1
DATA_LEN = 4
WR = .45

LR = 0.1
M = 0.9

FFBP_GRAPH = tf.Graph()

with FFBP_GRAPH.as_default():
    
    with tf.name_scope('train_data'):
        train_data = FFBP.InputData(
            path_to_data_file = 'xor_data.txt',
            batch_size = BATCH_SIZE, 
            inp_size = INP_SIZE, 
            targ_size = TARG_SIZE,
            data_len = DATA_LEN,
            shuffle_seed = None
        )

    with tf.name_scope('test_data'):
        test_data = FFBP.InputData(
            path_to_data_file = 'xor_data.txt',
            batch_size = 1,
            inp_size = INP_SIZE, 
            targ_size = TARG_SIZE,
            data_len = DATA_LEN,
            shuffle_seed = None,
        )

    # NETWORK CONSTRUCTION
    model_name = 'xor_model'
    with tf.name_scope(model_name):

        input_  = tf.placeholder(dtype = tf.float32, shape=[None, INP_SIZE], name='model_inp')

        hidden_layer = FFBP.BasicLayer(
            layer_name = 'hidden_layer', 
            layer_input = input_, 
            size = 2, 
            wrange = [-WR, WR], 
            nonlin = tf.nn.sigmoid, 
            seed = None
        )

        output_layer = FFBP.BasicLayer(
            layer_name = 'output_layer', 
            layer_input = hidden_layer.output, 
            size = 1, 
            wrange = [-WR, WR], 
            nonlin = tf.nn.sigmoid, 
            seed = None
        )

        target = tf.placeholder(dtype = tf.float32, shape=[None, TARG_SIZE], name='targets')

        MODEL = FFBP.Model(
            name = model_name,
            layers = [hidden_layer, output_layer],
            train_data = train_data, 
            inp        = input_,
            targ       = target,
            loss       = tf.reduce_sum(tf.squared_difference(target, output_layer.output), name='loss_function'),
            optimizer  = tf.train.MomentumOptimizer(LR, M),
            test_data  = test_data
        )

## Run

In [None]:
# Set up run parameters
NUM_RUNS = 3
NUM_EPOCHS = 500
TEST_EPOCHS = [i for i in range(0, NUM_EPOCHS, 100)]
SAVE_EPOCHS = [None]
ECRIT = 0
RESTORE_DIR = 'exercise_params' # path to logdir with 'checkpoint_files' directory to restore a saved model

# Create ModelSaver to manage test data saving and model checkpointing:
saver = FFBP.ModelSaver(restore_from=None, logdir=None)

for run_ind in range(NUM_RUNS):
    print('>>> RUN {}'.format(run_ind))
    
    with tf.Session(graph=FFBP_GRAPH) as sess:

        # restore or initialize FFBP_GRAPH variables:
        start_epoch = saver.init_model(session=sess)

        # create coordinator and start queue runners
        coordinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coordinator)

        for i in FFBP.prog_bar(
            sequence=range(start_epoch, start_epoch + NUM_EPOCHS), 
            name='Run {}/{}, Epoch'.format(run_ind+1, NUM_RUNS)):
            
            # Test model occasionally
            if any([i==test_epoch for test_epoch in TEST_EPOCHS]):
                testloss, snap = MODEL.test_epoch(session=sess, verbose=True)
                saver.save_test(snap, run_ind)

            # Run one training epoch
            loss = MODEL.train_epoch(session=sess, verbose=False)
            saver.save_loss(loss, run_ind)

            # Save model occasionally
            if any([i==save_epoch for save_epoch in SAVE_EPOCHS]):
                checkpoint_path = saver.save_model(session=sess, model=MODEL, run_ind=run_ind)

            # Do final test, stop queues, and break out from training loop
            if loss < ECRIT or i == start_epoch + (NUM_EPOCHS - 1): 
                print('Final test ({})'.format(
                    'loss < ecrit' if loss < ECRIT else 'num_epochs reached'))

                testloss, snap = MODEL.test_epoch(session=sess, verbose=True)
                saver.save_test(snap, run_ind)

                coordinator.request_stop()
                coordinator.join(threads)

#                 saver.save_model(session=sess, model=MODEL, run_ind=run_ind)
                break
                
FFBP.vis_utils.view_progress(logdir=saver.logdir)