# XOR problem

## 1. Preliminaries
### 1.1. Imports
We begin by importing several python libraries:

In [None]:
# %%bash
# rm -rf logdirs

In [None]:
import os
import pickle
import tensorflow as tf
import numpy as np
from collections import OrderedDict, namedtuple
print('tensorflow version: {}'.format(tf.__version__))
print('numpy version: {}'.format(np.__version__))
print('current working directory: {}'.format(os.getcwd()))

tf.logging.set_verbosity(tf.logging.ERROR)

from FFBP.constructors import InputData, BasicLayer, FFBPModel, FFBPSaver

## 2. Construction

In [None]:
# TRAIN CONFIGS
NUM_EPOCHS = 350
BATCH_SIZE = 4
INP_SIZE = 2
TARG_SIZE = 1
DATA_LEN = 4

lr = 0.5
m = 0.9

XOR_GRAPH = tf.Graph()

with XOR_GRAPH.as_default():
    
    with tf.name_scope('train_data'):
        train_examples = InputData(
            path_to_data_file = 'train_data_B.txt',
            num_epochs = NUM_EPOCHS,
            batch_size = BATCH_SIZE, 
            inp_size = INP_SIZE, 
            targ_size = TARG_SIZE,
            data_len = DATA_LEN,
            shuffle = True, 
            shuffle_seed = 1
        )

    with tf.name_scope('test_data'):
        test_examples = InputData(
            path_to_data_file = 'train_data_B.txt',
            num_epochs = NUM_EPOCHS,
            batch_size = DATA_LEN,
            inp_size = INP_SIZE, 
            targ_size = TARG_SIZE,
            data_len = DATA_LEN,
            shuffle = False
        )

    # NETWORK CONSTRUCTION
    model_name = 'xor_model'
    with tf.name_scope(model_name):

        model_inp  = tf.placeholder(dtype = tf.float32, shape=[None, INP_SIZE], name='model_inp')

        hidden_layer = BasicLayer(
            layer_name = 'hidden_layer', 
            layer_input = model_inp, 
            size = 2, 
            wrange = [-1,1], 
            nonlin = tf.nn.sigmoid, 
            bias = True, 
            seed = None
        )

        output_layer = BasicLayer(
            layer_name = 'output_layer', 
            layer_input = hidden_layer.output, 
            size = 1, 
            wrange = [-1,1], 
            nonlin = tf.nn.sigmoid, 
            bias = True, 
            seed = None
        )

        target = tf.placeholder(dtype = tf.float32, shape=[None, TARG_SIZE], name='targets')

        model = FFBPModel(
            name = model_name,
            layers = [hidden_layer, output_layer],
            train_data = train_examples, 
            inp        = model_inp,
            targ       = target,
            loss       = tf.reduce_sum(tf.squared_difference(target, output_layer.output), name='loss_function'),
            optimizer  = tf.train.MomentumOptimizer(lr, m),
            test_data  = test_examples
        )

## 3. Running model

In [None]:
TEST_EPOCHS = [0,1,3,5,30,60,100,180,300]
SAVE_EPOCHS = [NUM_EPOCHS-1]
ECRIT = 0.01
# CHECKPOINT_DIR = 'logdirs/ffbp_logdir_000/'
EXERCISE_PARAMS_DIR = 'exercise_params'

with tf.Session(graph=XOR_GRAPH) as sess:

    # create saver within current session
    saver = FFBPSaver(session=sess)
    
    # restore or initialize XOR_GRAPH variables
    # for random global variables (weights) initialization use: start_epoch = saver.init_model()
    start_epoch = saver.restore_model(logdir_path=EXERCISE_PARAMS_DIR, make_new_logdir=True)
    
    # create coordinator and queue runners
    coordinator = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coordinator)
    
    for i in range(start_epoch, start_epoch + NUM_EPOCHS):
        # Test model occasionally
        if any([i==test_epoch for test_epoch in TEST_EPOCHS]):
            loss, snap = model.test_epoch(session=sess, verbose=True)
            saver.snap2pickle(snap)
            
        # Run one training epoch
        loss = model.train_epoch(session=sess, verbose=False)
        
        if any([i==save_epoch for save_epoch in SAVE_EPOCHS]):
            saver.save_model(model)
        
        # Do final test, stop queues gracefully, and break out from training loop
        if loss < ECRIT or i == start_epoch + (NUM_EPOCHS - 1): 
            print('Final test ({})'.format(
                'loss < ecrit' if loss < ECRIT else 'num_epochs reached'))
            
            loss, snap = model.test_epoch(session=sess, verbose=True)
            saver.snap2pickle(snap)
            
            coordinator.request_stop()
            coordinator.join(threads)
            
            saver.save_model(model)
            break
        