# XOR problem

## 1. Imports

In [None]:
# Run this cell to delete the logdirs folder
%%bash
rm -rf logdirs

In [None]:
import tensorflow as tf
import FFBP

## 2. Construction

In [None]:
# TRAIN CONFIGS
NUM_EPOCHS = 350
BATCH_SIZE = 4
INP_SIZE = 2
TARG_SIZE = 1
DATA_LEN = 4

lr = 0.5
m = 0.9

FFBP_GRAPH = tf.Graph()

with FFBP_GRAPH.as_default():
    
    with tf.name_scope('train_data'):
        train_examples = FFBP.InputData(
            path_to_data_file = 'xor_data.txt',
            num_epochs = NUM_EPOCHS,
            batch_size = BATCH_SIZE, 
            inp_size = INP_SIZE, 
            targ_size = TARG_SIZE,
            data_len = DATA_LEN,
            shuffle = True, 
            shuffle_seed = 1
        )

    with tf.name_scope('test_data'):
        test_examples = FFBP.InputData(
            path_to_data_file = 'xor_data.txt',
            num_epochs = NUM_EPOCHS,
            batch_size = DATA_LEN,
            inp_size = INP_SIZE, 
            targ_size = TARG_SIZE,
            data_len = DATA_LEN,
            shuffle = False
        )

    # NETWORK CONSTRUCTION
    model_name = 'xor_model'
    with tf.name_scope(model_name):

        model_inp  = tf.placeholder(dtype = tf.float32, shape=[None, INP_SIZE], name='model_inp')

        hidden_layer = FFBP.BasicLayer(
            layer_name = 'hidden_layer', 
            layer_input = model_inp, 
            size = 2, 
            wrange = [-1,1], 
            nonlin = tf.nn.sigmoid, 
            bias = True, 
            seed = None
        )

        output_layer = FFBP.BasicLayer(
            layer_name = 'output_layer', 
            layer_input = hidden_layer.output, 
            size = 1, 
            wrange = [-1,1], 
            nonlin = tf.nn.sigmoid, 
            bias = True, 
            seed = None
        )

        target = tf.placeholder(dtype = tf.float32, shape=[None, TARG_SIZE], name='targets')

        MODEL = FFBP.Model(
            name = model_name,
            layers = [hidden_layer, output_layer],
            train_data = train_examples, 
            inp        = model_inp,
            targ       = target,
            loss       = tf.reduce_sum(tf.squared_difference(target, output_layer.output), name='loss_function'),
            optimizer  = tf.train.MomentumOptimizer(lr, m),
            test_data  = test_examples
        )

## 3. Running model

In [None]:
# Prevent unwanted logging messages by tensorflow (log ERROR messages only)
tf.logging.set_verbosity(tf.logging.ERROR) 

# Set up run parameters
NUM_RUNS = 1
TEST_EPOCHS = [0,1,3,5,30,60,120,180,270,300]
SAVE_EPOCHS = [NUM_EPOCHS-1]
ECRIT = 0.01
RESTORE_DIR = 'exercise_params' # provide path to logdir with 'checkpoint_files' directory to restore a saved model

# Create ModelSaver to manage test data saving and model checkpointing:
# if restore_path is None (or not provided), graph variables will be initialized from scratch
saver = FFBP.ModelSaver(restore_from=RESTORE_DIR, make_new_logdir=True)

for run_ind in range(NUM_RUNS):
    print('>>> RUN {}'.format(run_ind))
    
    with tf.Session(graph=FFBP_GRAPH) as sess:

        # restore or initialize FFBP_GRAPH variables:
        start_epoch = saver.init_model(session=sess)

        # create coordinator and start queue runners
        coordinator = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coordinator)

        for i in range(start_epoch, start_epoch + NUM_EPOCHS):
            # Test model occasionally
            if any([i==test_epoch for test_epoch in TEST_EPOCHS]):
                loss, snap = MODEL.test_epoch(session=sess, verbose=True)
                saver.snap2pickle(snap, run_ind)

            # Run one training epoch
            loss = MODEL.train_epoch(session=sess, verbose=False)

            # Save model occasionally
            if any([i==save_epoch for save_epoch in SAVE_EPOCHS]):
                saver.save_model(session=sess, model=MODEL)

            # Do final test, stop queues, and break out from training loop
            if loss < ECRIT or i == start_epoch + (NUM_EPOCHS - 1): 
                print('Final test ({})'.format(
                    'loss < ecrit' if loss < ECRIT else 'num_epochs reached'))

                loss, snap = MODEL.test_epoch(session=sess, verbose=True)
                saver.snap2pickle(snap, run_ind)

                coordinator.request_stop()
                coordinator.join(threads)

                saver.save_model(session=sess, model=MODEL)
                break