In [1]:
# RNN LSTM 主要面对 语音 文本 这样的序列化的问题 但是同样可以用在图片分类问题

In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# from tensorflow.contrib import rnn 
old_v = tf.logging.get_verbosity()
tf.logging.set_verbosity(tf.logging.ERROR)

In [2]:
# 载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 每批次训练数据大小 和 批次大小
batch_size = 50
n_batch = mnist.train.num_examples // batch_size

# 输入图片是28×28
depth = 28                 # 词向量维度 输入层有28个神经元 刚好对应数字图片一行有28个像素
max_time = 28              # 每一个时序 28是最大时间步长 一共28行 相当于输入28次
lstm_size = 100            # 隐藏层有100个“神经元” 它不是简单的神经元 它是一个block
n_classes = 10             # 共10个分类

# none表示第一个维度可以是任意的长度
# 给输入数据预留位置
x = tf.placeholder(tf.float32, shape = [None,784])
y = tf.placeholder(tf.float32, shape = [None,10])

# 初始化权值 权值矩阵
weights = tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1)) # [隐藏层数量，分类个数]
# 初始化偏置 偏置值矩阵
biases = tf.Variable(tf.constant(0.1,shape=[n_classes])) # 10个分类就10个偏置


# ？？LSTM的输入层到隐藏层没有权值和偏置？？？
# 如果不给 tf会自动分配


# 定义RNN网络
def RNN(X,weights,biases):
    # inputs的格式需要是：[batch_size,max_time,n_inputs]
    inputs = tf.reshape(X,[-1,max_time,depth])
    # 定义LSTM基本CELL
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(lstm_size)
#     lstm_cell = rnn.BasicLSTMCell(lstm_size)

    outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
    # outputs 记录着（输入28次）每一次输出的结果 就是ht 也是yt
    # final_state 记录着（输入28次）最后一次的输出结果

    # final_state[0]是cell_state
    # final_state[1]是hidden_state
    results = tf.nn.softmax(tf.matmul(final_state[1],weights) + biases)
    return results




# x一进来就先被reshape了
# 这里的batch_size是50 那就是说一次性能进来50张图 那None就是50 x的shape就是[50,784]
# -1的位置也就是50了 变换完就成了[batch_size,max_time,n_inputs]格式[50,28,28]
# 为什么要变成这样的格式呢 看在哪儿调用了inputs
# tf.nn.dynamic_rnn()这个函数 的inputs 必须是这样的格式[批次大小，序列长度，序列中每次传入的数据]
# LSTM的CELL构建 BasicLSTMCell可改 这里是定义了100个LSTM单元

# 执行tf.nn.dynamic_rnn （隐藏cell，输入数据，定义类型）
# 得到两个返回值 outputs和final_state 
# final_state：这个列表包括三个维度[state,batch_size,cell.state_size]
# state包含两个元素 state[0]是cell_state
#                   state[1]是hidden_state
# batch_size就是之前定义的批次中有多少个样本
# cell.state_size是隐藏单元的个数 就是创建cell时tf.contrib.rnn.BasicLSTMCell(lstm_size)的lstm_size

#  outputs: The RNN output `Tensor`.
#     If time_major == False (default), this will be a `Tensor` shaped:
#       `[batch_size, max_time, cell.output_size]`.
#     If time_major == True, this will be a `Tensor` shaped:
#       `[max_time, batch_size, cell.output_size]`.

#     Note, if `cell.output_size` is a (possibly nested) tuple of integers
#     or `TensorShape` objects, then `outputs` will be a tuple having the
#     same structure as `cell.output_size`, containing Tensors having shapes
#     corresponding to the shape data in `cell.output_size`.

#   state: The final state.  If `cell.state_size` is an int, this
#     will be shaped `[batch_size, cell.state_size]`.  If it is a
#     `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
#     If it is a (possibly nested) tuple of ints or `TensorShape`, this will
#     be a tuple having the corresponding shapes. If cells are `LSTMCells`
#     `state` will be a tuple containing a `LSTMStateTuple` for each cell.


# tf.nn.dynamic_rnn(
#     cell,
# 自己定义的LSTM的细胞单元,如是convLSTM,自己写也可以。
#     inputs,
# 一个三维的变量,[batchsize,time_step,input_size],搭配time_major=False。
# 其中batch_size表示batch的大小。time_steps序列长度，input_size输入数据单个序列单个时间维度上固有的长度。
# 这里还补充一点,就是叫dynamic的原因,就是输入数据的time_step不一定要相同
# 如果长短不一,会自动跟短的补0,但是处理时候,不会处理0,在0前面就截止了.
# 这就是dynamic对比static的好处.
# time_major
# If true,   these Tensors must be shaped [max_time, batch_size, depth].
# If false, these Tensors must be shaped `[batch_size, max_time, depth] 
#     sequence_length=None,
#     initial_state=None,
#     dtype=None,
#     parallel_iterations=None,
#     swap_memory=False,
#     time_major=False,
#     scope=None
# )




# 预测值 计算RNN的返回结果
prediction = RNN(x, weights, biases)  # x:一个批次的训练数据  权值  偏置值  （不传也可以 因为这里是全局变量）
# 交叉熵代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y,logits = prediction))
# 优化器
train = tf.train.AdamOptimizer(0.0001).minimize(loss)
# 定义求准确率的方法
correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
# 初始化变量
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(101):
        for batch in range(n_batch):
            batch_data,batch_tag = mnist.train.next_batch(batch_size)
            sess.run(train,feed_dict = {x:batch_data,y:batch_tag})
        if epoch%10==0:
            train_accuracy_rate = sess.run(accuracy,feed_dict = {x:mnist.train.images,y:mnist.train.labels})# 训练集正确率
            test_accuracy_rate = sess.run(accuracy,feed_dict = {x:mnist.test.images,y:mnist.test.labels})   # 测试集正确率
            print('第' + str(epoch + 1) + '次，train准确率 ' + str(train_accuracy_rate) + ' test准确率 ' + str(test_accuracy_rate))
tf.logging.set_verbosity(old_v)


Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz


InternalError: Blas GEMM launch failed : a.shape=(50, 128), b.shape=(128, 400), m=50, n=400, k=128
	 [[node rnn/while/basic_lstm_cell/MatMul (defined at <ipython-input-2-b7fe08cdbbc1>:37) ]]

Caused by op 'rnn/while/basic_lstm_cell/MatMul', defined at:
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\ipykernel\kernelapp.py", line 486, in start
    self.io_loop.start()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tornado\platform\asyncio.py", line 127, in start
    self.asyncio_loop.run_forever()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\asyncio\base_events.py", line 422, in run_forever
    self._run_once()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\asyncio\base_events.py", line 1432, in _run_once
    handle._run()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\asyncio\events.py", line 145, in _run
    self._callback(*self._args)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tornado\platform\asyncio.py", line 117, in _handle_events
    handler_func(fileobj, events)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\zmq\eventloop\zmqstream.py", line 450, in _handle_events
    self._handle_recv()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\zmq\eventloop\zmqstream.py", line 480, in _handle_recv
    self._run_callback(callback, msg)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\zmq\eventloop\zmqstream.py", line 432, in _run_callback
    callback(*args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tornado\stack_context.py", line 276, in null_wrapper
    return fn(*args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\ipykernel\kernelbase.py", line 233, in dispatch_shell
    handler(stream, idents, msg)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\ipykernel\ipkernel.py", line 208, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\ipykernel\zmqshell.py", line 537, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\IPython\core\interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\IPython\core\interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\IPython\core\interactiveshell.py", line 2903, in run_ast_nodes
    if self.run_code(code, result):
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\IPython\core\interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-b7fe08cdbbc1>", line 108, in <module>
    prediction = RNN(x, weights, biases)  # x:一个批次的训练数据  权值  偏置值  （不传也可以 因为这里是全局变量）
  File "<ipython-input-2-b7fe08cdbbc1>", line 37, in RNN
    outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\util\deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\rnn.py", line 671, in dynamic_rnn
    dtype=dtype)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\rnn.py", line 879, in _dynamic_rnn_loop
    swap_memory=swap_memory)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3556, in while_loop
    return_same_structure)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3087, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3022, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\control_flow_ops.py", line 3525, in <lambda>
    body = lambda i, lv: (i + 1, orig_body(*lv))
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\rnn.py", line 847, in _time_step
    (output, new_state) = call_cell()
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\rnn.py", line 833, in <lambda>
    call_cell = lambda: cell(input_t, state)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 371, in __call__
    *args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\layers\base.py", line 530, in __call__
    outputs = super(Layer, self).__call__(inputs, *args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 554, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py", line 748, in call
    array_ops.concat([inputs, h], 1), self._kernel)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\math_ops.py", line 2455, in matmul
    a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 5630, in mat_mul
    name=name)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\util\deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\framework\ops.py", line 3300, in create_op
    op_def=op_def)
  File "G:\Anaconda3\Anaconda3-5.2.0\lib\site-packages\tensorflow\python\framework\ops.py", line 1801, in __init__
    self._traceback = tf_stack.extract_stack()

InternalError (see above for traceback): Blas GEMM launch failed : a.shape=(50, 128), b.shape=(128, 400), m=50, n=400, k=128
	 [[node rnn/while/basic_lstm_cell/MatMul (defined at <ipython-input-2-b7fe08cdbbc1>:37) ]]
