In [1]:
import tensorflow as tf
import numpy as np
import os
import struct
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters


# 数据准备

In [2]:
def load_image(path):
    with open(path, 'rb') as fd:
        magic, num, rows, cols = struct.unpack('>IIII', fd.read(16))
        res = np.fromfile(fd, dtype=np.uint8).reshape(-1, 784)
    return res


def load_label(path):
    with open(path, 'rb') as fd:
        magic, n = struct.unpack('>II', fd.read(8))
        res = np.fromfile(fd, dtype=np.uint8)
    return res

from sklearn.preprocessing import StandardScaler


class MnistData:
    def __init__(self, data_path,label_path, batch_size=32, normalize=False, shuffle=False):
        '''
        paths: 文件路径
        '''
        self._data = list()
        self._target = list()
        self._n_samples = 0
        self.n_features = 0

        self._idx = 0    # mini-batch的游标
        self._batch_size = batch_size

        self._load(data_path,label_path)

        if shuffle:
            self._shuffle_data()
        if normalize:
            self._normalize_data()

        print(self._data.shape, self._target.shape)
        
    def _load(self, data_path,label_path):
        '''
        载入数据
        '''
        self._data=load_image(data_path)
        self._target=load_label(label_path)

        self._n_samples, self.n_features = self._data.shape[0], self._data.shape[1]
        
    def _shuffle_data(self):
        '''
        打乱数据
        '''
        idxs = np.random.permutation(self._n_samples)
        self._data = self._data[idxs]
        self._target = self._target[idxs]

    def _normalize_data(self):
        scaler = StandardScaler()
        self._data = scaler.fit_transform(self._data)

    def next_batch(self):
        '''
        生成mini-batch
        '''
        while self._idx < self._n_samples:
            yield self._data[self._idx: (self._idx+self._batch_size)], self._target[self._idx: (self._idx+self._batch_size)]
            self._idx += self._batch_size

        self._idx = 0
        self._shuffle_data()

In [3]:
MNIST_DIR = '../dataset/MNIST/'
train_data_path = os.path.join(MNIST_DIR, 'train-images.idx3-ubyte')
train_label_path = os.path.join(MNIST_DIR, 'train-labels.idx1-ubyte')
test_data_path = os.path.join(MNIST_DIR, 't10k-images.idx3-ubyte')
test_label_path = os.path.join(MNIST_DIR, 't10k-labels.idx1-ubyte')

batch_size = 32
train_data = MnistData(train_data_path,train_label_path, batch_size=batch_size,
                       normalize=False, shuffle=True)
test_data = MnistData(test_data_path,test_label_path, batch_size=batch_size,
                      normalize=False, shuffle=False)

(60000, 784) (60000,)
(10000, 784) (10000,)


# 网络结构设计

In [None]:
unit_I = 28
n_steps = 28    # 状态数
unit_h = 256
unit_O = 10

# 搭建网络

In [None]:
X = tf.placeholder(tf.float32, [None, n_steps, unit_I])
Y = tf.placeholder(tf.int64, [None])

# RNN网络图
with tf.name_scope('RNN'):
    X_seq = tf.transpose(X, [1, 0, 2])    # 把状态移到第一维
    X_seq = tf.reshape(X_seq, [-1, unit_I])
    X_seq = tf.split(X_seq, n_steps)

    lstm_fw_cell = tf.nn.rnn_cell.BasicLSTMCell(unit_h)    # 前向单元
    lstm_bw_cell = tf.nn.rnn_cell.BasicLSTMCell(unit_h)    # 反向单元

    # 双向循环单元
    outputs, _, _ = tf.nn.static_bidirectional_rnn(
        cell_fw=lstm_fw_cell,
        cell_bw=lstm_bw_cell,
        inputs=X_seq,
        dtype=tf.float32
    )
    
    logits=tf.layers.dense(outputs[-1], unit_O, activation=None)

# 评估图
with tf.name_scope('Eval'):
    loss = tf.losses.sparse_softmax_cross_entropy(labels=Y, logits=logits)
    predict = tf.argmax(logits, 1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, Y), tf.float32))

# 优化图
with tf.name_scope('train_op'):
    lr = 1e-3
    train_op = tf.train.AdamOptimizer(lr).minimize(loss)

init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True    # 按需使用显存

Instructions for updating:
This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.
Instructions for updating:
Please use `keras.layers.Bidirectional(keras.layers.RNN(cell, unroll=True))`, which is equivalent to this API
Instructions for updating:
Please use `keras.layers.RNN(cell, unroll=True)`, which is equivalent to this API
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use tf.cast instead.


# 训练网络

In [None]:
with tf.Session(config=config) as sess:
    sess.run(init)
    epochs = 20

    batch_cnt = 0
    for epoch in range(epochs):
        for batch_data, batch_labels in train_data.next_batch():
            batch_data = batch_data.reshape((-1, n_steps, unit_I))
            batch_cnt += 1
            loss_val, acc_val, _ = sess.run(
                [loss, accuracy, train_op],
                feed_dict={
                    X: batch_data,
                    Y: batch_labels})

            # 每1000batch输出一次信息
            if (batch_cnt+1) % 1000 == 0:
                print('epoch: {}, batch_loss: {}, batch_acc: {}'.format(
                    epoch, loss_val, acc_val))

            # 每5000batch做一次验证
            if (batch_cnt+1) % 5000 == 0:
                all_test_acc_val = list()
                for test_batch_data, test_batch_labels in test_data.next_batch():
                    test_batch_data = test_batch_data.reshape((-1, n_steps, unit_I))
                    test_acc_val = sess.run(
                        [accuracy],
                        feed_dict={
                            X: test_batch_data,
                            Y: test_batch_labels
                        })
                    all_test_acc_val.append(test_acc_val)
                test_acc = np.mean(all_test_acc_val)
                print('epoch: {}, test_acc: {}'.format(epoch, test_acc))

epoch: 0, batch_loss: 0.3887373208999634, batch_acc: 0.8125
epoch: 1, batch_loss: 0.2442668378353119, batch_acc: 0.84375
epoch: 1, batch_loss: 0.1861989051103592, batch_acc: 0.9375
epoch: 2, batch_loss: 0.26198431849479675, batch_acc: 0.90625
epoch: 2, batch_loss: 0.2022436261177063, batch_acc: 0.9375
epoch: 2, test_acc: 0.9263178706169128
epoch: 3, batch_loss: 0.32303673028945923, batch_acc: 0.9375
epoch: 3, batch_loss: 0.3292866349220276, batch_acc: 0.90625
epoch: 4, batch_loss: 0.0797535702586174, batch_acc: 1.0
epoch: 4, batch_loss: 0.2004379779100418, batch_acc: 0.90625
epoch: 5, batch_loss: 0.10955925285816193, batch_acc: 0.96875
epoch: 5, test_acc: 0.9399960041046143
epoch: 5, batch_loss: 0.06019783765077591, batch_acc: 0.96875
epoch: 6, batch_loss: 0.09795217961072922, batch_acc: 0.96875
epoch: 6, batch_loss: 0.20790743827819824, batch_acc: 0.9375
epoch: 7, batch_loss: 0.14399084448814392, batch_acc: 0.9375
epoch: 7, batch_loss: 0.10233209282159805, batch_acc: 0.96875
epoch: 7,