In [1]:
# View more python learning tutorial on my Youtube and Youku channel!!!

# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg
# Youku video tutorial: http://i.youku.com/pythontutorial

"""
Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly.
Run this script on tensorflow r0.10. Errors appear when using lower versions.
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


BATCH_START = 0
TIME_STEPS = 20
BATCH_SIZE = 50
INPUT_SIZE = 1
OUTPUT_SIZE = 1
CELL_SIZE = 10
LR = 0.006


def get_batch():
    global BATCH_START, TIME_STEPS
    # xs shape (50batch, 20steps)
    xs = np.arange(BATCH_START, BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10*np.pi)
    seq = np.sin(xs)
    res = np.cos(xs)
    BATCH_START += TIME_STEPS
    # plt.plot(xs[0, :], res[0, :], 'r', xs[0, :], seq[0, :], 'b--')
    # plt.show()
    # returned seq, res and xs: shape (batch, step, input)
    return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs]


class LSTMRNN(object):
    def __init__(self, n_steps, input_size, output_size, cell_size, batch_size):
        self.n_steps = n_steps
        self.input_size = input_size
        self.output_size = output_size
        self.cell_size = cell_size
        self.batch_size = batch_size
        with tf.name_scope('inputs'):
            self.xs = tf.placeholder(tf.float32, [None, n_steps, input_size], name='xs')
            self.ys = tf.placeholder(tf.float32, [None, n_steps, output_size], name='ys')
        with tf.variable_scope('in_hidden'):
            self.add_input_layer()
        with tf.variable_scope('LSTM_cell'):
            self.add_cell()
        with tf.variable_scope('out_hidden'):
            self.add_output_layer()
        with tf.name_scope('cost'):
            self.compute_cost()
        with tf.name_scope('train'):
            self.train_op = tf.train.AdamOptimizer(LR).minimize(self.cost)

    def add_input_layer(self,):
        l_in_x = tf.reshape(self.xs, [-1, self.input_size], name='2_2D')  # (batch*n_step, in_size)
        # Ws (in_size, cell_size)
        Ws_in = self._weight_variable([self.input_size, self.cell_size])
        # bs (cell_size, )
        bs_in = self._bias_variable([self.cell_size,])
        # l_in_y = (batch * n_steps, cell_size)
        with tf.name_scope('Wx_plus_b'):
            l_in_y = tf.matmul(l_in_x, Ws_in) + bs_in
        # reshape l_in_y ==> (batch, n_steps, cell_size)
        self.l_in_y = tf.reshape(l_in_y, [-1, self.n_steps, self.cell_size], name='2_3D')

    def add_cell(self):
        lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.cell_size, forget_bias=1.0, state_is_tuple=True)
        with tf.name_scope('initial_state'):
            self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
        self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn(
            lstm_cell, self.l_in_y, initial_state=self.cell_init_state, time_major=False)

    def add_output_layer(self):
        # shape = (batch * steps, cell_size)
        l_out_x = tf.reshape(self.cell_outputs, [-1, self.cell_size], name='2_2D')
        Ws_out = self._weight_variable([self.cell_size, self.output_size])
        bs_out = self._bias_variable([self.output_size, ])
        # shape = (batch * steps, output_size)
        with tf.name_scope('Wx_plus_b'):
            self.pred = tf.matmul(l_out_x, Ws_out) + bs_out

    def compute_cost(self):
        losses = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
            [tf.reshape(self.pred, [-1], name='reshape_pred')],
            [tf.reshape(self.ys, [-1], name='reshape_target')],
            [tf.ones([self.batch_size * self.n_steps], dtype=tf.float32)],
            average_across_timesteps=True,
            softmax_loss_function=self.ms_error,
            name='losses'
        )
        with tf.name_scope('average_cost'):
            self.cost = tf.div(
                tf.reduce_sum(losses, name='losses_sum'),
                self.batch_size,
                name='average_cost')
            tf.summary.scalar('cost', self.cost)

    @staticmethod
    def ms_error(labels, logits):
        return tf.square(tf.subtract(labels, logits))

    def _weight_variable(self, shape, name='weights'):
        initializer = tf.random_normal_initializer(mean=0., stddev=1.,)
        return tf.get_variable(shape=shape, initializer=initializer, name=name)

    def _bias_variable(self, shape, name='biases'):
        initializer = tf.constant_initializer(0.1)
        return tf.get_variable(name=name, shape=shape, initializer=initializer)


if __name__ == '__main__':
    model = LSTMRNN(TIME_STEPS, INPUT_SIZE, OUTPUT_SIZE, CELL_SIZE, BATCH_SIZE)
    sess = tf.Session()
    merged = tf.summary.merge_all()
    writer = tf.summary.FileWriter("logs", sess.graph)
    # tf.initialize_all_variables() no long valid from
    # 2017-03-02 if using tensorflow >= 0.12
    if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
        init = tf.initialize_all_variables()
    else:
        init = tf.global_variables_initializer()
    sess.run(init)
    # relocate to the local dir and run this line to view it on Chrome (http://0.0.0.0:6006/):
    # $ tensorboard --logdir='logs'

    plt.ion()
    plt.show()
    for i in range(200):
        seq, res, xs = get_batch()
        if i == 0:
            feed_dict = {
                    model.xs: seq,
                    model.ys: res,
                    # create initial state
            }
        else:
            feed_dict = {
                model.xs: seq,
                model.ys: res,
                model.cell_init_state: state    # use last state as the initial state for this run
            }

        _, cost, state, pred = sess.run(
            [model.train_op, model.cost, model.cell_final_state, model.pred],
            feed_dict=feed_dict)

        # plotting
        plt.plot(xs[0, :], res[0].flatten(), 'r', xs[0, :], pred.flatten()[:TIME_STEPS], 'b--')
        plt.ylim((-1.2, 1.2))
        plt.draw()
        plt.pause(0.3)

        if i % 20 == 0:
            print('cost: ', round(cost, 4))
            result = sess.run(merged, feed_dict)
            writer.add_summary(result, i)

<matplotlib.figure.Figure at 0x1a21d9dc18>

cost:  30.1298


<matplotlib.figure.Figure at 0x1a21dc5978>

<matplotlib.figure.Figure at 0x1a21e14ef0>

<matplotlib.figure.Figure at 0x1a25291eb8>

<matplotlib.figure.Figure at 0x1a2540add8>

<matplotlib.figure.Figure at 0x1a21e3e9b0>

<matplotlib.figure.Figure at 0x1a21e48940>

<matplotlib.figure.Figure at 0x1a256fc160>

<matplotlib.figure.Figure at 0x1a259a8668>

<matplotlib.figure.Figure at 0x1a2556aeb8>

<matplotlib.figure.Figure at 0x1a259df080>

<matplotlib.figure.Figure at 0x1a25c51400>

<matplotlib.figure.Figure at 0x1a21e11e48>

<matplotlib.figure.Figure at 0x1a21e0c438>

<matplotlib.figure.Figure at 0x1a25868320>

<matplotlib.figure.Figure at 0x1a252b4e10>

<matplotlib.figure.Figure at 0x1a255711d0>

<matplotlib.figure.Figure at 0x1a25419cc0>

<matplotlib.figure.Figure at 0x1a25116e80>

<matplotlib.figure.Figure at 0x1a256ce080>

<matplotlib.figure.Figure at 0x1a25de4d30>

cost:  7.7164


<matplotlib.figure.Figure at 0x1a25f42ef0>

<matplotlib.figure.Figure at 0x1a25f420b8>

<matplotlib.figure.Figure at 0x1a26218da0>

<matplotlib.figure.Figure at 0x1a26379160>

<matplotlib.figure.Figure at 0x1a264e5c50>

<matplotlib.figure.Figure at 0x1a2664de10>

<matplotlib.figure.Figure at 0x1a267ab1d0>

<matplotlib.figure.Figure at 0x1a2691dcc0>

<matplotlib.figure.Figure at 0x1a26a85e80>

<matplotlib.figure.Figure at 0x1a26be9080>

<matplotlib.figure.Figure at 0x1a26e81d30>

<matplotlib.figure.Figure at 0x1a26bfcac8>

<matplotlib.figure.Figure at 0x1a269313c8>

<matplotlib.figure.Figure at 0x1a269464a8>

<matplotlib.figure.Figure at 0x1a26c165f8>

<matplotlib.figure.Figure at 0x1a26bffb00>

<matplotlib.figure.Figure at 0x1a26946898>

<matplotlib.figure.Figure at 0x1a2692d470>

<matplotlib.figure.Figure at 0x1a21e20518>

<matplotlib.figure.Figure at 0x1a21e781d0>

cost:  2.8231


<matplotlib.figure.Figure at 0x1a26a93ac8>

<matplotlib.figure.Figure at 0x1a252b0c50>

<matplotlib.figure.Figure at 0x1a264e44a8>

<matplotlib.figure.Figure at 0x1a267d7748>

<matplotlib.figure.Figure at 0x1a26635240>

<matplotlib.figure.Figure at 0x1a26213550>

<matplotlib.figure.Figure at 0x1a264dce80>

<matplotlib.figure.Figure at 0x1a2636f390>

<matplotlib.figure.Figure at 0x1a256f04a8>

<matplotlib.figure.Figure at 0x1a261f6ef0>

<matplotlib.figure.Figure at 0x1a263757f0>

<matplotlib.figure.Figure at 0x1a26363fd0>

<matplotlib.figure.Figure at 0x1a25dca240>

<matplotlib.figure.Figure at 0x1a25dc73c8>

<matplotlib.figure.Figure at 0x1a25dccf28>

<matplotlib.figure.Figure at 0x1a21df9c50>

<matplotlib.figure.Figure at 0x1a25c3ee10>

<matplotlib.figure.Figure at 0x1a25dc2320>

<matplotlib.figure.Figure at 0x1a256f0b38>

<matplotlib.figure.Figure at 0x1a259df470>

cost:  0.8094


<matplotlib.figure.Figure at 0x1a25db89b0>

<matplotlib.figure.Figure at 0x1a25daaf60>

<matplotlib.figure.Figure at 0x1a25dbdc18>

<matplotlib.figure.Figure at 0x1a25c71f60>

<matplotlib.figure.Figure at 0x1a2542e2b0>

<matplotlib.figure.Figure at 0x1a25581d30>

<matplotlib.figure.Figure at 0x1a2557f860>

<matplotlib.figure.Figure at 0x1a25277d68>

<matplotlib.figure.Figure at 0x1a25291fd0>

<matplotlib.figure.Figure at 0x1a253fc898>

<matplotlib.figure.Figure at 0x1a25f29470>

<matplotlib.figure.Figure at 0x1a26210f28>

<matplotlib.figure.Figure at 0x1a264d7eb8>

<matplotlib.figure.Figure at 0x1a264e6588>

<matplotlib.figure.Figure at 0x1a267d2f60>

<matplotlib.figure.Figure at 0x1a269172b0>

<matplotlib.figure.Figure at 0x1a26a8e6a0>

<matplotlib.figure.Figure at 0x1a26c151d0>

<matplotlib.figure.Figure at 0x1a21e155f8>

<matplotlib.figure.Figure at 0x1a2693d240>

cost:  1.4052


<matplotlib.figure.Figure at 0x1a26917898>

<matplotlib.figure.Figure at 0x1a26a8e320>

<matplotlib.figure.Figure at 0x1a25f4bcc0>

<matplotlib.figure.Figure at 0x1a267a7278>

<matplotlib.figure.Figure at 0x1a26c17358>

<matplotlib.figure.Figure at 0x1a252a1d30>

<matplotlib.figure.Figure at 0x1a252a70b8>

<matplotlib.figure.Figure at 0x1a26a736a0>

<matplotlib.figure.Figure at 0x1a25599da0>

<matplotlib.figure.Figure at 0x1a260ae160>

<matplotlib.figure.Figure at 0x1a259a6c50>

<matplotlib.figure.Figure at 0x1a26658128>

<matplotlib.figure.Figure at 0x1a253f6f28>

<matplotlib.figure.Figure at 0x1a2635ecc0>

<matplotlib.figure.Figure at 0x1a25851278>

<matplotlib.figure.Figure at 0x1a25c44f60>

<matplotlib.figure.Figure at 0x1a25db1d30>

<matplotlib.figure.Figure at 0x1a25c42ef0>

<matplotlib.figure.Figure at 0x1a25851390>

<matplotlib.figure.Figure at 0x1a25862828>

cost:  0.862


<matplotlib.figure.Figure at 0x1a25c664a8>

<matplotlib.figure.Figure at 0x1a25c665f8>

<matplotlib.figure.Figure at 0x1a25c429b0>

<matplotlib.figure.Figure at 0x1a25c78ac8>

<matplotlib.figure.Figure at 0x1a25c66cc0>

<matplotlib.figure.Figure at 0x1a21e72f98>

<matplotlib.figure.Figure at 0x1a21e3ec50>

<matplotlib.figure.Figure at 0x1a262266d8>

<matplotlib.figure.Figure at 0x1a266512b0>

<matplotlib.figure.Figure at 0x1a25123cf8>

<matplotlib.figure.Figure at 0x1a25125828>

<matplotlib.figure.Figure at 0x1a26088160>

<matplotlib.figure.Figure at 0x1a259d1d68>

<matplotlib.figure.Figure at 0x1a259c5eb8>

<matplotlib.figure.Figure at 0x1a259ddb70>

<matplotlib.figure.Figure at 0x1a2608a208>

<matplotlib.figure.Figure at 0x1a267c0d30>

<matplotlib.figure.Figure at 0x1a26082fd0>

<matplotlib.figure.Figure at 0x1a25f51240>

<matplotlib.figure.Figure at 0x1a25f442e8>

cost:  0.4282


<matplotlib.figure.Figure at 0x1a26be3cf8>

<matplotlib.figure.Figure at 0x1a26aa64a8>

<matplotlib.figure.Figure at 0x1a25f26e10>

<matplotlib.figure.Figure at 0x1a25f34438>

<matplotlib.figure.Figure at 0x1a2691b160>

<matplotlib.figure.Figure at 0x1a26aa66d8>

<matplotlib.figure.Figure at 0x1a21e127f0>

<matplotlib.figure.Figure at 0x1a21e77f60>

<matplotlib.figure.Figure at 0x1a25f19c18>

<matplotlib.figure.Figure at 0x1a26e806d8>

<matplotlib.figure.Figure at 0x1a2557f2b0>

<matplotlib.figure.Figure at 0x1a2510dcf8>

<matplotlib.figure.Figure at 0x1a256f8240>

<matplotlib.figure.Figure at 0x1a2514ab70>

<matplotlib.figure.Figure at 0x1a2557cfd0>

<matplotlib.figure.Figure at 0x1a25de3160>

<matplotlib.figure.Figure at 0x1a253f8b70>

<matplotlib.figure.Figure at 0x1a256e85c0>

<matplotlib.figure.Figure at 0x1a259d8eb8>

<matplotlib.figure.Figure at 0x1a25c74780>

cost:  0.1837


<matplotlib.figure.Figure at 0x1a2637ae48>

<matplotlib.figure.Figure at 0x1a25c48978>

<matplotlib.figure.Figure at 0x1a25c76cf8>

<matplotlib.figure.Figure at 0x1a26209438>

<matplotlib.figure.Figure at 0x1a25c6c0f0>

<matplotlib.figure.Figure at 0x1a25c74f28>

<matplotlib.figure.Figure at 0x1a2584fb70>

<matplotlib.figure.Figure at 0x1a26650710>

<matplotlib.figure.Figure at 0x1a25c577f0>

<matplotlib.figure.Figure at 0x1a21e46f60>

<matplotlib.figure.Figure at 0x1a25de2780>

<matplotlib.figure.Figure at 0x1a25837dd8>

<matplotlib.figure.Figure at 0x1a26656b70>

<matplotlib.figure.Figure at 0x1a259a9cf8>

<matplotlib.figure.Figure at 0x1a259b8828>

<matplotlib.figure.Figure at 0x1a25575048>

<matplotlib.figure.Figure at 0x1a256e7c50>

<matplotlib.figure.Figure at 0x1a2679f128>

<matplotlib.figure.Figure at 0x1a25574cf8>

<matplotlib.figure.Figure at 0x1a26c1acc0>

cost:  0.43


<matplotlib.figure.Figure at 0x1a264e4278>

<matplotlib.figure.Figure at 0x1a254150b8>

<matplotlib.figure.Figure at 0x1a26eac048>

<matplotlib.figure.Figure at 0x1a26a847b8>

<matplotlib.figure.Figure at 0x1a25f195f8>

<matplotlib.figure.Figure at 0x1a26eb45f8>

<matplotlib.figure.Figure at 0x1a264c8e48>

<matplotlib.figure.Figure at 0x1a264e0c50>

<matplotlib.figure.Figure at 0x1a26e9a128>

<matplotlib.figure.Figure at 0x1a26eb4828>

<matplotlib.figure.Figure at 0x1a252a2cc0>

<matplotlib.figure.Figure at 0x1a267c6048>

<matplotlib.figure.Figure at 0x1a264d17f0>

<matplotlib.figure.Figure at 0x1a26937048>

<matplotlib.figure.Figure at 0x1a25f403c8>

<matplotlib.figure.Figure at 0x1a26c01198>

<matplotlib.figure.Figure at 0x1a267c51d0>

<matplotlib.figure.Figure at 0x1a2559e048>

<matplotlib.figure.Figure at 0x1a260b2c50>

<matplotlib.figure.Figure at 0x1a26669128>

cost:  0.1454


<matplotlib.figure.Figure at 0x1a256cd630>

<matplotlib.figure.Figure at 0x1a2559ccf8>

<matplotlib.figure.Figure at 0x1a25c42278>

<matplotlib.figure.Figure at 0x1a25dd1c88>

<matplotlib.figure.Figure at 0x1a2583aef0>

<matplotlib.figure.Figure at 0x1a25dd00b8>

<matplotlib.figure.Figure at 0x1a2583a5c0>

<matplotlib.figure.Figure at 0x1a26208160>

<matplotlib.figure.Figure at 0x1a26669470>

<matplotlib.figure.Figure at 0x1a21e7ac50>

<matplotlib.figure.Figure at 0x1a26357128>

<matplotlib.figure.Figure at 0x1a2635f9e8>

<matplotlib.figure.Figure at 0x1a2584f400>

<matplotlib.figure.Figure at 0x1a26663860>

<matplotlib.figure.Figure at 0x1a2512cf60>

<matplotlib.figure.Figure at 0x1a25117048>

<matplotlib.figure.Figure at 0x1a259dc1d0>

<matplotlib.figure.Figure at 0x1a2663c668>

<matplotlib.figure.Figure at 0x1a21e7d828>