In [1]:
import numpy as np
import tensorflow as tf
tf.reset_default_graph()
sess = tf.InteractiveSession()

In [2]:
# Input Format
n_fact = 2
seq_length = 5

In [3]:
# Input Preprocessing
def cast(string):
    if string == '':
        return 0
    else:
        return float(string)

X = open('lights.txt', 'r').read().splitlines()
X = [x.split(',') for x in X]
X = [[cast(a) for a in x] for x in X]
factors = np.array([x[:n_fact] for x in X])
lights = np.array([x[n_fact:] for x in X])

In [4]:
# Network Variables
hidden_size = 10
l = tf.placeholder(tf.float32, [None, seq_length])
f = tf.placeholder(tf.float32, [None, n_fact])
learning_rate = tf.placeholder(tf.float32,[])
params = {
    'Wxh': tf.Variable(tf.random_normal([1, hidden_size],stddev=0.01)),
    'Wfh': tf.Variable(tf.random_normal([n_fact, hidden_size],stddev=0.01)),
    'Whh': tf.Variable(tf.random_normal([hidden_size, hidden_size],stddev = 0.01)),
    'Why': tf.Variable(tf.random_normal([hidden_size,1],stddev = 0.01)),
    'bh': tf.Variable(tf.zeros([hidden_size])),
    'by': tf.Variable(tf.zeros([1]))
}

In [5]:
# RNN architecture
def RNN(x_t,f,params):
    batch_size = tf.shape(x_t)[0]
    h = tf.zeros([batch_size,hidden_size])
    y = x_t
    for t in range(1,seq_length):
        h = tf.tanh(tf.add(tf.add(tf.add(tf.matmul(x_t,params['Wxh']),tf.matmul(f,params['Wfh'])),tf.matmul(h,params['Whh'])),params['bh']))
        x_t = tf.add(tf.matmul(h,params['Why']),params['by'])
        y = tf.concat(1,[y,x_t])
    return y

In [6]:
# Optimization Technique
y = RNN(l[:,0:1],f,params)
loss = tf.reduce_sum(tf.squared_difference(y,l))
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
capped_grads_and_vars = [(tf.clip_by_value(gv[0],-1.,1.), gv[1]) for gv in grads_and_vars]
capped_optimizer = optimizer.apply_gradients(capped_grads_and_vars)
sess.run(tf.initialize_all_variables())

In [7]:
# Training
n_iter = 0
while n_iter < 10000:
    _,loss_,y_ = sess.run([capped_optimizer,loss,y],feed_dict={l:lights,f:factors,learning_rate:0.001})
    if n_iter % 100 == 0:
        y_proc = [[round(a) for a in b] for b in y_]
        print y_proc
        print('iter'+str(n_iter)+' loss:'+str(loss_))
    n_iter += 1

[[5.0, 0.0, -0.0, 0.0, 0.0], [5.0, 0.0, 0.0, 0.0, 0.0], [5.0, 0.0, -0.0, -0.0, -0.0], [5.0, 0.0, 0.0, 0.0, 0.0]]
iter0 loss:87.9799
[[5.0, 1.0, 0.0, 0.0, 0.0], [5.0, 1.0, 0.0, 0.0, 0.0], [5.0, 1.0, 0.0, 0.0, 0.0], [5.0, 1.0, 0.0, 0.0, 0.0]]
iter100 loss:68.5406
[[5.0, 2.0, 1.0, 1.0, 1.0], [5.0, 2.0, 1.0, 1.0, 1.0], [5.0, 2.0, 1.0, 1.0, 1.0], [5.0, 2.0, 1.0, 1.0, 1.0]]
iter200 loss:49.1722
[[5.0, 3.0, 2.0, 1.0, 1.0], [5.0, 3.0, 1.0, 1.0, 1.0], [5.0, 3.0, 1.0, 1.0, 0.0], [5.0, 3.0, 1.0, 0.0, 0.0]]
iter300 loss:40.5794
[[5.0, 3.0, 2.0, 1.0, 1.0], [5.0, 3.0, 2.0, 1.0, 1.0], [5.0, 3.0, 1.0, 0.0, -0.0], [5.0, 3.0, 1.0, -0.0, -0.0]]
iter400 loss:36.539
[[5.0, 3.0, 1.0, 2.0, 1.0], [5.0, 3.0, 1.0, 2.0, 1.0], [5.0, 3.0, 0.0, 1.0, -0.0], [5.0, 2.0, -0.0, 0.0, -0.0]]
iter500 loss:28.1409
[[5.0, 4.0, 1.0, 4.0, 0.0], [5.0, 3.0, 1.0, 3.0, 0.0], [5.0, 3.0, 0.0, 1.0, -1.0], [5.0, 2.0, 0.0, 0.0, -1.0]]
iter600 loss:20.9809
[[5.0, 4.0, 1.0, 4.0, 0.0], [5.0, 3.0, 1.0, 3.0, 0.0], [5.0, 3.0, 0.0, 1.0, -1.0]

KeyboardInterrupt: 

In [117]:
# Testing
def cast(string):
    if string == '':
        return 0
    else:
        return float(string)

X = open('lights_test.txt', 'r').read().splitlines()
X = [x.split(',') for x in X]
X = [[cast(a) for a in x] for x in X]
batch_size = len(X)
factors_test = np.array([x[:n_fact] for x in X])
lights_test = np.array([x[n_fact:] for x in X])

In [118]:
y_ = sess.run(y,feed_dict={l:lights_test,f:factors_test})
print y_

[[  5.00000000e+00   1.04601383e-02   1.07107162e-02  -7.81846046e-03
   -7.71862268e-03]
 [  5.00000000e+00   6.00600719e+00   1.95697546e-02   3.21156979e-02
   -1.24495625e-02]
 [  5.00000000e+00   3.00422955e+00   1.01781833e+00   4.02139473e+00
    1.47323012e-02]
 [  5.00000000e+00   3.00449300e+00   1.01458693e+00   4.01371527e+00
    4.50116396e-03]]


In [104]:
sess.close()