In [1]:
import pickle
import numpy as np
import tensorflow as tf
def unpickle(file):
    with open(file, 'rb') as f:
        dict = pickle.load(f, encoding='bytes')
    return dict

In [2]:
cifar = unpickle('./cifar-10-batches-py/data_batch_1')
data = cifar[b'data']
label = cifar[b'labels']

In [3]:
from PIL import Image
import matplotlib.pyplot as plt

img = data[-1].reshape(3, 32, 32)
red = Image.fromarray(img[0])
green = Image.fromarray(img[1])
blue = Image.fromarray(img[2])
img = Image.merge("RGB", (red, green, blue))
plt.figure("image")
plt.imshow(img)
plt.show()

<Figure size 640x480 with 1 Axes>

In [4]:
## 数据预处理
class CifarData():
    def __init__(self, data, labels):
        self._data = data
        self._labels = labels
        self._indicator = 0
        self._example_num = len(data)
        self._random_shuffle()
        self._deal_img()
        
    def _random_shuffle(self):
        p = np.random.permutation(self._example_num)
        self._data = self._data[p]
        self._labels = self._labels[p]
        
    def _deal_img(self):
        data = self._data.reshape(self._example_num, 3, 32, 32)
        data = data / 255
        self._data = np.transpose(data, [0, 2, 3, 1])
        new_label = np.zeros([self._example_num, 10])
        for i in range(self._example_num):
            new_label[i, self._labels[i]] = 1 
        self._labels = new_label
        
    def next_batch(self, batch_size):
        end_indicator = self._indicator + batch_size
        if end_indicator > self._example_num:
            self._random_shuffle()
            self._indicator = 0
            end_indicator = batch_size
        batch_data = self._data[self._indicator: end_indicator]
        batch_label = self._labels[self._indicator: end_indicator]
        self._indicator = end_indicator        
        return batch_data, batch_label
    

In [5]:
cifar = unpickle('./cifar-10-batches-py/data_batch_1')
data = cifar[b'data']
labels = np.array(cifar[b'labels'])
cifar_data = CifarData(data, labels)
batch_data, batch_label = cifar_data.next_batch(10)
batch_label

array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [6]:
def get_default_parms():
    return tf.contrib.training.HParams(
        batch_size = 128,
        channels = [32, 64, 128],
        learning_rate = 0.002
    )

In [7]:
## conv封装
def conv2d_warpper(inputs, out_channel, name, training):
    def leaky_relu(x, alpha=0.2, name = ''):
        return tf.maximum(x, alpha * x, name=name)
    with tf.variable_scope(name):
        conv2d = tf.layers.conv2d(inputs, out_channel, [5, 5], strides=(2, 2), padding="SAME")
        bn = tf.layers.batch_normalization(conv2d, training=training)
        return leaky_relu(bn, name='output')


In [8]:
class NetWork():
    
    def __init__(self, hps):
        self._hps = hps
        
    def build(self):
        
        self._input = tf.placeholder(tf.float32, [None, 32, 32, 3])
        self._label = tf.placeholder(tf.int32, [None, 10])
        
        ## 搭建网络
        conv_input = self._input
        with tf.variable_scope('conv'):
            for i in range(len(self._hps.channels)):
                conv_input = conv2d_warpper(conv_input, self._hps.channels[i], 'conv2d_%d' % i, training=True)
        fc_inputs = conv_input
        with tf.variable_scope('fc'):
            flatten = tf.layers.flatten(fc_inputs)
            fc= tf.layers.dense(flatten, 1024, name='fc')
        out = tf.layers.dense(fc, 10, name='output')
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=out, labels=self._label))
        return self._input, self._label, tf.nn.softmax(out), loss
    
    def build_op(self, loss):
        opt = tf.train.AdamOptimizer(learning_rate=self._hps.learning_rate)
        return opt.minimize(loss)
    

In [None]:
hps = get_default_parms()
tf.reset_default_graph()
net = NetWork(hps)
input_ts, labels_ts, out, loss = net.build()
opt = net.build_op(loss)

## 读取训练数据
cifar = unpickle('./cifar-10-batches-py/data_batch_1')
all_data = cifar[b'data']
all_label = np.array(cifar[b'labels'])
for i in range(1, 5):
    cifar = unpickle('./cifar-10-batches-py/data_batch_%d' % (i+1))
    all_data = np.concatenate((all_data, cifar[b'data']), axis=0)
    all_label = np.concatenate((all_label, np.array(cifar[b'labels'])), axis=0)
    print(all_label.shape)
    cifar_data = CifarData(all_data, all_label)

## 读取测试数据
cifar = unpickle('./cifar-10-batches-py/test_batch')
test_data = cifar[b'data']
test_label = np.array(cifar[b'labels'])
cifar_test = CifarData(test_data, test_label)

## 保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    step = 1001
    for i in range(1, step):
        batch_data, batch_label = cifar_data.next_batch(256)
        batch_test, batch_test_label = cifar_test.next_batch(128)
        
        sess.run(opt, feed_dict={input_ts:batch_data, labels_ts:batch_label})
        loss_value = sess.run(loss, feed_dict={input_ts:batch_data, labels_ts:batch_label})
        

        correct_predict = tf.equal(tf.argmax(out, 1), tf.argmax(labels_ts, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))
        train_acc = sess.run(accuracy, feed_dict={input_ts:batch_data, labels_ts:batch_label})
        
        test_acc = sess.run(accuracy, feed_dict={input_ts:batch_test, labels_ts:batch_test_label})
        
        print("step: %d" % i + " loss=" + str(loss_value) + "train acc:" + str(train_acc)+ \
               "  test acc: " + str(test_acc))
        
        ### 200step 保存一下模型
        if i % 200 == 0:
            saver.save(sess, "./checkpoints/myModel")
        

(20000,)
(30000,)
(40000,)
(50000,)
step: 1 loss=13.3192425train acc:0.19140625test acc: 0.1484375
step: 2 loss=14.528498train acc:0.23046875test acc: 0.1953125
step: 3 loss=11.755217train acc:0.21875test acc: 0.203125
step: 4 loss=10.617779train acc:0.19140625test acc: 0.140625
step: 5 loss=9.04225train acc:0.1953125test acc: 0.21875
step: 6 loss=8.166451train acc:0.18359375test acc: 0.1484375
step: 7 loss=7.2402315train acc:0.171875test acc: 0.171875
step: 8 loss=7.5644464train acc:0.1796875test acc: 0.171875
step: 9 loss=5.4138765train acc:0.25390625test acc: 0.234375
step: 10 loss=5.9129663train acc:0.234375test acc: 0.234375
step: 11 loss=5.4094753train acc:0.23828125test acc: 0.2890625
step: 12 loss=4.7698627train acc:0.27734375test acc: 0.2734375
step: 13 loss=4.333523train acc:0.24609375test acc: 0.25
step: 14 loss=4.9808016train acc:0.265625test acc: 0.2734375
step: 15 loss=4.2850633train acc:0.2578125test acc: 0.234375
step: 16 loss=3.8146276train acc:0.33203125test acc: 0.27