In [1]:
import tensorflow as tf
import numpy as np
import os

  from ._conv import register_converters as _register_converters


# 数据准备
首先看一下CIFAR-10数据是怎么读取的，下面的函数是根据官网示例改编的函数，其直接返回ndarray形式的X与Y。

In [2]:
import matplotlib.pyplot as plt


def unpickle(file):
    '''
    CIFAR-10数据读取函数
    '''
    import pickle
    with open(file, 'rb') as fd:
        data = pickle.load(fd, encoding='bytes')
    return data[b'data'], np.array(data[b'labels'])


data, target = unpickle('../dataset/cifar-10-batches-py/data_batch_1')

plt.clf()
# 原数据的维度顺序为(n_samples,n_channels,width,height)，绘图时需要把channel转置到最后
plt.imshow(data[0].reshape((3, 32, 32)).transpose((1, 2, 0)))
plt.show()

<Figure size 640x480 with 1 Axes>

对于深度学习中的大型数据，mini-batch式学习是很有必要的，并且还会频繁对数据做一些其他的操作。所以定义一个专门的数据类用于管理数据：

In [3]:
from sklearn.preprocessing import StandardScaler


class CifarData:
    def __init__(self, paths, batch_size=32, normalize=False, shuffle=False):
        '''
        paths: 文件路径
        '''
        self._data = list()
        self._target = list()
        self._n_samples = 0
        self.n_features = 0

        self._idx = 0    # mini-batch的游标
        self._batch_size = batch_size

        self._load(paths)

        if shuffle:
            self._shuffle_data()
        if normalize:
            self._normalize_data()

        print(self._data.shape, self._target.shape)

    def _load(self, paths):
        '''
        载入数据
        '''
        for path in paths:
            data, labels = unpickle(path)
            self._data.append(data)
            self._target.append(labels)

        # 将所有批次的数据拼接起来
        self._data, self._target = np.vstack(
            self._data), np.hstack(self._target)

        self._n_samples, self.n_features = self._data.shape[0], self._data.shape[1]

    def _shuffle_data(self):
        '''
        打乱数据
        '''
        idxs = np.random.permutation(self._n_samples)
        self._data = self._data[idxs]
        self._target = self._target[idxs]

    def _normalize_data(self):
        scaler = StandardScaler()
        self._data = scaler.fit_transform(self._data)

    def next_batch(self):
        '''
        生成mini-batch
        '''
        while self._idx < self._n_samples:
            yield self._data[self._idx: (self._idx+self._batch_size)], self._target[self._idx: (self._idx+self._batch_size)]
            self._idx += self._batch_size

        self._idx = 0
        self._shuffle_data()

In [4]:
CIFAR_DIR = "../dataset/cifar-10-batches-py/"
train_filenames = [os.path.join(
    CIFAR_DIR, 'data_batch_{}'.format(i)) for i in range(1, 6)]
test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]

batch_size = 32
train_data = CifarData(
    train_filenames, batch_size=batch_size, normalize=True, shuffle=True)
test_data = CifarData(test_filenames, batch_size=batch_size,
                      normalize=True, shuffle=False)



(50000, 3072) (50000,)




(10000, 3072) (10000,)




# 网络结构设计

In [5]:
# 输入层单元数为(数据特征)，设计隐藏层单元数为100、50，输出单元数为10(多分类任务)

unit_I = train_data.n_features    # 输入层的单元数，与特征数相等
unit_h1 = 100    # 第一层隐藏层的单元数
unit_h2 = 50
unit_O = 10    # 输出层单元数

# 搭建网络
TensorFlow自身提供了增加层数的方法。

In [6]:
# 输入必须是可由用户指定的，所以设为placeholder
X = tf.placeholder(tf.float32, [None, unit_I])  # 数据的样本数不指定，只指定特征数
Y = tf.placeholder(tf.int64, [None])    # 目标值为列向量，int64为了兼容

# 网络结构图
with tf.name_scope('DNN'):
    hidden1 = tf.layers.dense(X, unit_h1, activation=tf.nn.relu)
    hidden2 = tf.layers.dense(hidden1, unit_h2, activation=tf.nn.relu)
    Y_pred = tf.layers.dense(hidden2, unit_O)

# 评估子图
with tf.name_scope('Eval'):
    # 计算一维向量与onehot向量之间的损失
    loss = tf.losses.sparse_softmax_cross_entropy(labels=Y, logits=Y_pred)
    predict = tf.argmax(Y_pred, 1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, Y), tf.float32))

# 优化图
with tf.name_scope('train_op'):
    lr = 1e-3
    train_op = tf.train.AdamOptimizer(lr).minimize(loss)

init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True    # 按需使用显存

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.


# 训练网络

In [7]:
with tf.Session(config=config) as sess:
    sess.run(init)
    epochs = 20

    batch_cnt = 0
    for epoch in range(epochs):
        for batch_data, batch_labels in train_data.next_batch():
            batch_cnt += 1
            loss_val, acc_val, _ = sess.run([loss, accuracy, train_op],
                                            feed_dict={X: batch_data,
                                                       Y: batch_labels})

            # 每1000batch输出一次信息
            if (batch_cnt+1) % 1000 == 0:
                print('epoch: {}, batch_loss: {}, batch_acc: {}'
                      .format(epoch, loss_val, acc_val))

            # 每5000batch做一次验证
            if (batch_cnt+1) % 5000 == 0:
                all_test_acc_val = list()
                for test_batch_data, test_batch_labels in test_data.next_batch():
                    test_acc_val = sess.run([accuracy],
                                            feed_dict={X: test_batch_data,
                                                       Y: test_batch_labels})
                    all_test_acc_val.append(test_acc_val)
                test_acc = np.mean(all_test_acc_val)
                print('epoch: {}, test_acc: {}'.format(epoch, test_acc))

epoch: 0, batch_loss: 1.8539013862609863, batch_acc: 0.375
epoch: 1, batch_loss: 2.1398262977600098, batch_acc: 0.34375
epoch: 1, batch_loss: 1.2638787031173706, batch_acc: 0.5625
epoch: 2, batch_loss: 1.2767705917358398, batch_acc: 0.46875
epoch: 3, batch_loss: 1.144681692123413, batch_acc: 0.59375
epoch: 3, test_acc: 0.46575480699539185
epoch: 3, batch_loss: 1.4559743404388428, batch_acc: 0.53125
epoch: 4, batch_loss: 1.0261310338974, batch_acc: 0.65625
epoch: 5, batch_loss: 1.0441352128982544, batch_acc: 0.65625
epoch: 5, batch_loss: 1.1458637714385986, batch_acc: 0.625
epoch: 6, batch_loss: 1.600606083869934, batch_acc: 0.5
epoch: 6, test_acc: 0.5044928193092346
epoch: 7, batch_loss: 0.6468024253845215, batch_acc: 0.875
epoch: 7, batch_loss: 1.311150312423706, batch_acc: 0.5
epoch: 8, batch_loss: 1.0780742168426514, batch_acc: 0.6875
epoch: 8, batch_loss: 1.5215224027633667, batch_acc: 0.53125
epoch: 9, batch_loss: 1.3531839847564697, batch_acc: 0.53125
epoch: 9, test_acc: 0.506489