In [1]:
import rnn_cells
from rnn import dynamic_rnn
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [2]:
num_epochs = 100
total_series_length = 50000
max_time_steps = 10
state_size = 10
num_classes = 2
echo_step = 3
batch_size = 5
num_batches = total_series_length//max_time_steps//batch_size

In [3]:
def generate_data():
    x = np.array(np.random.choice(2, total_series_length, p=[0.5, 0.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0

    x = x.reshape((batch_size, -1, 1))
    y = y.reshape((batch_size, -1, 1))

    return (x, y)

In [4]:
batchX = tf.placeholder(tf.float32, [batch_size, max_time_steps, 1], name="X")
batchY = tf.placeholder(tf.int32, [batch_size, max_time_steps, 1], name="Y")

inputs_series = tf.unstack(batchX, axis=1)
labels_series = tf.unstack(tf.reshape(batchY,[batch_size, max_time_steps]), axis=1)

In [5]:
W_class = tf.Variable(np.random.rand(state_size, num_classes),dtype=tf.float32)
b_class = tf.Variable(np.zeros((1,num_classes)), dtype=tf.float32)

In [6]:
lstm_cell = rnn_cells.LSTMCell(state_size, input_size=1)       

In [14]:
init_cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
initial_state = (init_cell_state, init_hidden_state)
output_states, final_state = dynamic_rnn(lstm_cell, batchX, initial_state=initial_state, dtype=tf.float32)

In [8]:
logits_series = [tf.matmul(state, W_class) + b_class for state in output_states]
predictions_series = [tf.nn.softmax(logits) for logits in logits_series]

losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) 
          for logits, labels in zip(logits_series, labels_series)]
total_loss = tf.reduce_mean(losses)

In [9]:
print labels_series[0]
print logits_series[0]

Tensor("unstack_1:0", shape=(5,), dtype=int32)
Tensor("add:0", shape=(5, 2), dtype=float32)


In [10]:
train_step = tf.train.AdagradOptimizer(0.3).minimize(total_loss)

In [18]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    loss_list = []
    
    _cell_state = np.zeros([batch_size, state_size], dtype=np.float32)
    _hidden_state = np.zeros([batch_size, state_size], dtype=np.float32)
    _state = (_cell_state, _hidden_state)
    
    for epoch_idx in range(num_epochs):
        x, y = generate_data()
        
        print("New data, epoch", epoch_idx)
        
        for batch_idx in range(num_batches):
            start_idx = batch_idx * max_time_steps
            end_idx = start_idx + max_time_steps
            
            _batchX = x[:,start_idx:end_idx,:]
            _batchY = y[:,start_idx:end_idx]
            
            _total_loss, _train_step, _state, _predictions_series = sess.run(
                [total_loss, train_step, initial_state, predictions_series],
                feed_dict={
                    batchX: _batchX,
                    batchY: _batchY,
                    initial_state: _state
                })

            loss_list.append(_total_loss)

            if batch_idx%100 == 0:
                print("Step",batch_idx, "Loss", _total_loss)

('New data, epoch', 0)
('Step', 0, 'Loss', 0.74724555)
('Step', 100, 'Loss', 0.6117664)
('Step', 200, 'Loss', 0.3784672)
('Step', 300, 'Loss', 0.24713343)
('Step', 400, 'Loss', 0.23487505)
('Step', 500, 'Loss', 0.21619901)
('Step', 600, 'Loss', 0.24305928)
('Step', 700, 'Loss', 0.19282819)
('Step', 800, 'Loss', 0.24289352)
('Step', 900, 'Loss', 0.21515656)
('New data, epoch', 1)
('Step', 0, 'Loss', 0.22615322)
('Step', 100, 'Loss', 0.18510377)
('Step', 200, 'Loss', 0.22237438)
('Step', 300, 'Loss', 0.21045233)
('Step', 400, 'Loss', 0.19415361)
('Step', 500, 'Loss', 0.205651)
('Step', 600, 'Loss', 0.22585781)
('Step', 700, 'Loss', 0.23508471)
('Step', 800, 'Loss', 0.27521381)
('Step', 900, 'Loss', 0.1776036)
('New data, epoch', 2)
('Step', 0, 'Loss', 0.20564742)
('Step', 100, 'Loss', 0.20453459)
('Step', 200, 'Loss', 0.22780757)
('Step', 300, 'Loss', 0.20777182)
('Step', 400, 'Loss', 0.208685)
('Step', 500, 'Loss', 0.19484146)
('Step', 600, 'Loss', 0.22363579)
('Step', 700, 'Loss', 0.21

('Step', 800, 'Loss', 0.21132553)
('Step', 900, 'Loss', 0.20658045)
('New data, epoch', 23)
('Step', 0, 'Loss', 0.20973623)
('Step', 100, 'Loss', 0.21102639)
('Step', 200, 'Loss', 0.209539)
('Step', 300, 'Loss', 0.208103)
('Step', 400, 'Loss', 0.20687893)
('Step', 500, 'Loss', 0.20269209)
('Step', 600, 'Loss', 0.20401873)
('Step', 700, 'Loss', 0.20549235)
('Step', 800, 'Loss', 0.19663039)
('Step', 900, 'Loss', 0.21122558)
('New data, epoch', 24)
('Step', 0, 'Loss', 0.21345328)
('Step', 100, 'Loss', 0.20687634)
('Step', 200, 'Loss', 0.21581723)
('Step', 300, 'Loss', 0.21012922)
('Step', 400, 'Loss', 0.21118237)
('Step', 500, 'Loss', 0.20949073)
('Step', 600, 'Loss', 0.20573786)
('Step', 700, 'Loss', 0.20747207)
('Step', 800, 'Loss', 0.216536)
('Step', 900, 'Loss', 0.20243372)
('New data, epoch', 25)
('Step', 0, 'Loss', 0.21710312)
('Step', 100, 'Loss', 0.20868771)
('Step', 200, 'Loss', 0.20870104)
('Step', 300, 'Loss', 0.21373236)
('Step', 400, 'Loss', 0.21121091)
('Step', 500, 'Loss', 

('Step', 500, 'Loss', 0.2094093)
('Step', 600, 'Loss', 0.20708349)
('Step', 700, 'Loss', 0.20448755)
('Step', 800, 'Loss', 0.20354719)
('Step', 900, 'Loss', 0.21069324)
('New data, epoch', 46)
('Step', 0, 'Loss', 0.2123227)
('Step', 100, 'Loss', 0.21109529)
('Step', 200, 'Loss', 0.20627977)
('Step', 300, 'Loss', 0.20706205)
('Step', 400, 'Loss', 0.20906042)
('Step', 500, 'Loss', 0.20821744)
('Step', 600, 'Loss', 0.21604539)
('Step', 700, 'Loss', 0.21235579)
('Step', 800, 'Loss', 0.20458385)
('Step', 900, 'Loss', 0.2087865)
('New data, epoch', 47)
('Step', 0, 'Loss', 0.20553219)
('Step', 100, 'Loss', 0.20389067)
('Step', 200, 'Loss', 0.21220441)
('Step', 300, 'Loss', 0.20960306)
('Step', 400, 'Loss', 0.20760426)
('Step', 500, 'Loss', 0.20605822)
('Step', 600, 'Loss', 0.21095629)
('Step', 700, 'Loss', 0.20589045)
('Step', 800, 'Loss', 0.2061141)
('Step', 900, 'Loss', 0.2062639)
('New data, epoch', 48)
('Step', 0, 'Loss', 0.21371324)
('Step', 100, 'Loss', 0.20889468)
('Step', 200, 'Loss',

('Step', 200, 'Loss', 0.20300381)
('Step', 300, 'Loss', 0.20580882)
('Step', 400, 'Loss', 0.20970556)
('Step', 500, 'Loss', 0.20677668)
('Step', 600, 'Loss', 0.20932619)
('Step', 700, 'Loss', 0.2091226)
('Step', 800, 'Loss', 0.21105972)
('Step', 900, 'Loss', 0.20761007)
('New data, epoch', 69)
('Step', 0, 'Loss', 0.20752212)
('Step', 100, 'Loss', 0.21060877)
('Step', 200, 'Loss', 0.20501564)
('Step', 300, 'Loss', 0.21587564)
('Step', 400, 'Loss', 0.20395203)
('Step', 500, 'Loss', 0.20876274)
('Step', 600, 'Loss', 0.21395607)
('Step', 700, 'Loss', 0.20703869)
('Step', 800, 'Loss', 0.20984948)
('Step', 900, 'Loss', 0.2113343)
('New data, epoch', 70)
('Step', 0, 'Loss', 0.2058149)
('Step', 100, 'Loss', 0.20378026)
('Step', 200, 'Loss', 0.20897061)
('Step', 300, 'Loss', 0.20889753)
('Step', 400, 'Loss', 0.21209323)
('Step', 500, 'Loss', 0.21014221)
('Step', 600, 'Loss', 0.20780222)
('Step', 700, 'Loss', 0.20544623)
('Step', 800, 'Loss', 0.21425007)
('Step', 900, 'Loss', 0.19384743)
('New d

('New data, epoch', 91)
('Step', 0, 'Loss', 0.20811327)
('Step', 100, 'Loss', 0.20569159)
('Step', 200, 'Loss', 0.21300979)
('Step', 300, 'Loss', 0.21315536)
('Step', 400, 'Loss', 0.20366915)
('Step', 500, 'Loss', 0.21210176)
('Step', 600, 'Loss', 0.20998184)
('Step', 700, 'Loss', 0.20168449)
('Step', 800, 'Loss', 0.20581466)
('Step', 900, 'Loss', 0.21177496)
('New data, epoch', 92)
('Step', 0, 'Loss', 0.21085475)
('Step', 100, 'Loss', 0.21173756)
('Step', 200, 'Loss', 0.20963798)
('Step', 300, 'Loss', 0.21247587)
('Step', 400, 'Loss', 0.21026385)
('Step', 500, 'Loss', 0.20890631)
('Step', 600, 'Loss', 0.20334201)
('Step', 700, 'Loss', 0.20743492)
('Step', 800, 'Loss', 0.2050105)
('Step', 900, 'Loss', 0.20948572)
('New data, epoch', 93)
('Step', 0, 'Loss', 0.20887695)
('Step', 100, 'Loss', 0.20587192)
('Step', 200, 'Loss', 0.2155868)
('Step', 300, 'Loss', 0.20501962)
('Step', 400, 'Loss', 0.20739086)
('Step', 500, 'Loss', 0.21924637)
('Step', 600, 'Loss', 0.20519508)
('Step', 700, 'Los