In [1]:
import os
os.sys.path.append(os.path.dirname(os.path.abspath('..')))

# 数据准备

In [2]:
from dataset.dataset import load_cifar10

train_data,test_data=load_cifar10(batch_size=64)



(50000, 3072) (50000,)




(10000, 3072) (10000,)


# 网络结构设计

In [3]:
import tensorflow as tf

unit_I = train_data.n_features    # 输入单元数，等于特征数

filters = 32    # 卷积核的数量
conv_size = (3, 3)    # 卷积核尺寸

pool_size = (2, 2)    # 池化核尺寸
strides = (2, 2)    # 核移动的步长

fc_size = 128

unit_O = 10    # 输出单元数，类别数

  from ._conv import register_converters as _register_converters


# 搭建网络
TensorFlow自身提供了增加层数的方法。

In [4]:
# 输入必须是可由用户指定的，所以设为placeholder
X = tf.placeholder(tf.float32, [None, unit_I])  # 数据的样本数不指定，只指定特征数
Y = tf.placeholder(tf.int64, [None])    # 目标值为列向量，int64为了兼容
X_img = tf.transpose(tf.reshape(X, [-1, 3, 32, 32]),
                     perm=[0, 2, 3, 1])    # 转为图片格式送入模型，(n_samples,width,height,depth)

# 网络结构图
with tf.name_scope('CNN'):
    conv1 = tf.layers.conv2d(X_img, filters=filters,
                             kernel_size=conv_size, padding='same',
                             activation=tf.nn.relu, name='conv1')
    pooling1 = tf.layers.max_pooling2d(conv1, pool_size=pool_size,
                                       strides=strides, name='pooling1')
    conv2 = tf.layers.conv2d(pooling1, filters=filters,
                             kernel_size=conv_size, padding='same',
                             activation=tf.nn.relu, name='conv2')
    pooling2 = tf.layers.max_pooling2d(conv2, pool_size=pool_size,
                                       strides=strides, name='pooling2')
    conv3 = tf.layers.conv2d(pooling2, filters=filters,
                             kernel_size=conv_size, padding='same',
                             activation=tf.nn.relu, name='conv3')
    pooling3 = tf.layers.max_pooling2d(conv3, pool_size=pool_size,
                                       strides=strides, name='pooling3')
    fc = tf.layers.dense(tf.layers.flatten(pooling3),
                         fc_size, activation=tf.nn.relu)
    logits = tf.layers.dense(fc, unit_O, activation=None)

# 评估图
with tf.name_scope('Eval'):
    # 计算一维向量与onehot向量之间的损失
    loss = tf.losses.sparse_softmax_cross_entropy(labels=Y, logits=logits)
    predict = tf.argmax(logits, 1)
    accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, Y), tf.float32))

# 优化图
with tf.name_scope('train_op'):
    lr = 1e-3
    train_op = tf.train.AdamOptimizer(lr).minimize(loss)

init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True    # 按需使用显存

Instructions for updating:
Use keras.layers.conv2d instead.
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use keras.layers.max_pooling2d instead.
Instructions for updating:
Use keras.layers.flatten instead.
Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Use tf.cast instead.


# 训练网络

In [5]:
import numpy as np

with tf.Session(config=config) as sess:
    sess.run(init)
    epochs = 20

    batch_cnt = 0
    for epoch in range(epochs):
        for batch_data, batch_labels in train_data.next_batch():
            batch_cnt += 1
            loss_val, acc_val, _ = sess.run(
                [loss, accuracy, train_op], feed_dict={
                    X: batch_data,
                    Y: batch_labels})

            # 每1000batch输出一次信息
            if (batch_cnt+1) % 1000 == 0:
                print('epoch: {}, batch_loss: {}, batch_acc: {}'.format(
                    epoch, loss_val, acc_val))

            # 每5000batch做一次验证
            if (batch_cnt+1) % 5000 == 0:
                all_test_acc_val = list()
                for test_batch_data, test_batch_labels in test_data.next_batch():
                    test_acc_val = sess.run(
                        [accuracy],
                        feed_dict={
                            X: test_batch_data,
                            Y: test_batch_labels
                        })
                    all_test_acc_val.append(test_acc_val)
                test_acc = np.mean(all_test_acc_val)
                print('epoch: {}, test_acc: {}'.format(epoch, test_acc))

epoch: 1, batch_loss: 1.1121983528137207, batch_acc: 0.65625
epoch: 2, batch_loss: 1.0409471988677979, batch_acc: 0.546875
epoch: 3, batch_loss: 0.9590150117874146, batch_acc: 0.625
epoch: 5, batch_loss: 0.5987226963043213, batch_acc: 0.828125
epoch: 6, batch_loss: 0.739209771156311, batch_acc: 0.734375
epoch: 6, test_acc: 0.6999198794364929
epoch: 7, batch_loss: 0.82710862159729, batch_acc: 0.703125
epoch: 8, batch_loss: 0.7348452806472778, batch_acc: 0.75
epoch: 10, batch_loss: 0.43963366746902466, batch_acc: 0.828125
epoch: 11, batch_loss: 0.7753157615661621, batch_acc: 0.734375
epoch: 12, batch_loss: 0.6161017417907715, batch_acc: 0.796875
epoch: 12, test_acc: 0.7143429517745972
epoch: 14, batch_loss: 0.4765985310077667, batch_acc: 0.75
epoch: 15, batch_loss: 0.4126424789428711, batch_acc: 0.8125
epoch: 16, batch_loss: 0.7285220623016357, batch_acc: 0.78125
epoch: 17, batch_loss: 0.5084350109100342, batch_acc: 0.796875
epoch: 19, batch_loss: 0.39072269201278687, batch_acc: 0.828125