In [2]:
import tensorflow as tf
import numpy as np

from experiments.utils import generate_synthetic_arithmetic_dataset
from layers.nalu_layer import NaluLayer

if __name__ == "__main__":
    
    EPOCHS = 10000
    PATIENCE = 15
    LEARNING_RATE = .005
    BATCH_SIZE = 16384
    FEATURES_NUM = 100
    
    X_train, y_train, boundaries = generate_synthetic_arithmetic_dataset("add", -1000, 1000, FEATURES_NUM, 100000)
    X_test, y_test, _ = generate_synthetic_arithmetic_dataset("add", -10000, 10000, FEATURES_NUM, 100000, boundaries)
    input = tf.placeholder(tf.float32, shape=[None, FEATURES_NUM])
    y_true = tf.placeholder(tf.float32, shape=[None, ])
    nalu_layer = NaluLayer(FEATURES_NUM, 1, 2, 2, core_cell_type="nalu")
    y_pred = tf.squeeze(nalu_layer(input))
    

    loss = tf.losses.mean_squared_error(y_true, y_pred)  # NALU uses mse
    optimizer = tf.train.RMSPropOptimizer(LEARNING_RATE)
    train_op = optimizer.minimize(loss)

    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    
    p = PATIENCE
    old_loss = 0
    for epoch in range(EPOCHS):
        i = 0
        while i < len(X_train):
            X_batch, y_batch = X_train[i:i + BATCH_SIZE], y_train[i:i + BATCH_SIZE]

            _, output_batch, l = sess.run([train_op, y_pred, loss],
                                     feed_dict={input: X_batch, y_true: y_batch})
            i += BATCH_SIZE

        acc = np.sum(np.isclose(output_batch, y_batch, atol=.1, rtol=0)) / len(y_batch)
        print('epoch {2}, loss: {0}, accuracy: {1}'.format(l, acc, epoch))
        
        if old_loss - l < .00001:
            p -= 1
            if p < 0:
                print("Early Stopping after {} epochs of no improvements".format(PATIENCE))
                break
        else:
            p = PATIENCE
        old_loss = l 
            
    output_test, l = sess.run([y_pred, loss],
                                     feed_dict={input: X_test, y_true: y_test})
    acc = np.sum(np.isclose(output_test, y_test, atol=.1, rtol=0)) / len(y_test)
    print('test loss: {0}, accuracy: {1}'.format(l, acc))