In [1]:
import pickle
import numpy as np
import tensorflow as tf
def unpickle(file):
    with open(file, 'rb') as f:
        dict = pickle.load(f, encoding='bytes')
    return dict

In [2]:
cifar = unpickle('./cifar-10-batches-py/data_batch_1')
data = cifar[b'data']
label = cifar[b'labels']

In [3]:
from PIL import Image
import matplotlib.pyplot as plt

img = data[-1].reshape(3, 32, 32)
red = Image.fromarray(img[0])
green = Image.fromarray(img[1])
blue = Image.fromarray(img[2])
img = Image.merge("RGB", (red, green, blue))
plt.figure("image")
plt.imshow(img)
plt.show()

<Figure size 640x480 with 1 Axes>

In [4]:
## 数据预处理
class CifarData():
    def __init__(self, data, labels):
        self._data = data
        self._labels = labels
        self._indicator = 0
        self._example_num = len(data)
        self._random_shuffle()
        self._deal_img()
        
    def _random_shuffle(self):
        p = np.random.permutation(self._example_num)
        self._data = self._data[p]
        self._labels = self._labels[p]
        
    def _deal_img(self):
        data = self._data.reshape(self._example_num, 3, 32, 32)
        data = data / 255
        self._data = np.transpose(data, [0, 2, 3, 1])
        new_label = np.zeros([self._example_num, 10])
        for i in range(self._example_num):
            new_label[i, self._labels[i]] = 1 
        self._labels = new_label
        
    def next_batch(self, batch_size):
        end_indicator = self._indicator + batch_size
        if end_indicator > self._example_num:
            self._random_shuffle()
            self._indicator = 0
            end_indicator = batch_size
        batch_data = self._data[self._indicator: end_indicator]
        batch_label = self._labels[self._indicator: end_indicator]
        self._indicator = end_indicator        
        return batch_data, batch_label
    

In [5]:
cifar = unpickle('./cifar-10-batches-py/data_batch_1')
data = cifar[b'data']
labels = np.array(cifar[b'labels'])
cifar_data = CifarData(data, labels)
batch_data, batch_label = cifar_data.next_batch(10)
batch_label

array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

In [6]:
def get_default_parms():
    return tf.contrib.training.HParams(
        batch_size = 128,
        channels = [32, 64, 128],
        learning_rate = 0.002
    )

In [7]:
## conv封装
def conv2d_warpper(inputs, out_channel, name, training):
    def leaky_relu(x, alpha=0.2, name = ''):
        return tf.maximum(x, alpha * x, name=name)
    with tf.variable_scope(name):
        conv2d = tf.layers.conv2d(inputs, out_channel, [5, 5], strides=(2, 2), padding="SAME")
        bn = tf.layers.batch_normalization(conv2d, training=training)
        return leaky_relu(bn, name='output')


In [8]:
class NetWork():
    
    def __init__(self, hps):
        self._hps = hps
        
    def build(self):
        
        self._input = tf.placeholder(tf.float32, [None, 32, 32, 3])
        self._label = tf.placeholder(tf.int32, [None, 10])
        
        ## 搭建网络
        conv_input = self._input
        with tf.variable_scope('conv'):
            for i in range(len(self._hps.channels)):
                conv_input = conv2d_warpper(conv_input, self._hps.channels[i], 'conv2d_%d' % i, training=True)
        fc_inputs = conv_input
        with tf.variable_scope('fc'):
            flatten = tf.layers.flatten(fc_inputs)
            fc= tf.layers.dense(flatten, 1024, name='fc')
        out = tf.layers.dense(fc, 10, name='output')
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=out, labels=self._label))
        return self._input, self._label, tf.nn.softmax(out), loss
    
    def build_op(self, loss):
        opt = tf.train.AdamOptimizer(learning_rate=self._hps.learning_rate)
        return opt.minimize(loss)
    

In [None]:
hps = get_default_parms()
tf.reset_default_graph()
net = NetWork(hps)
input_ts, labels_ts, out, loss = net.build()
opt = net.build_op(loss)

## 读取训练数据
cifar = unpickle('./cifar-10-batches-py/data_batch_1')
all_data = cifar[b'data']
all_label = np.array(cifar[b'labels'])
for i in range(1, 5):
    cifar = unpickle('./cifar-10-batches-py/data_batch_%d' % (i+1))
    all_data = np.concatenate((all_data, cifar[b'data']), axis=0)
    all_label = np.concatenate((all_label, np.array(cifar[b'labels'])), axis=0)
    cifar_data = CifarData(all_data, all_label)

## 读取测试数据
cifar = unpickle('./cifar-10-batches-py/test_batch')
test_data = cifar[b'data']
test_label = np.array(cifar[b'labels'])
cifar_test = CifarData(test_data, test_label)

## 保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
    
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, './checkpoints/myModel')
    step = 1001
    for i in range(1, step):
        batch_data, batch_label = cifar_data.next_batch(256)
        batch_test, batch_test_label = cifar_test.next_batch(128)
        
        sess.run(opt, feed_dict={input_ts:batch_data, labels_ts:batch_label})
        loss_value = sess.run(loss, feed_dict={input_ts:batch_data, labels_ts:batch_label})
        

        correct_predict = tf.equal(tf.argmax(out, 1), tf.argmax(labels_ts, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))
        train_acc = sess.run(accuracy, feed_dict={input_ts:batch_data, labels_ts:batch_label})
        
        test_acc = sess.run(accuracy, feed_dict={input_ts:batch_test, labels_ts:batch_test_label})
        
        print("step: %d" % i + " loss=" + str(loss_value) + " train_acc:" + str(train_acc)+ \
               "test_acc: " + str(test_acc))
        
        ### 200step 保存一下模型
        if i % 200 == 0:
            saver.save(sess, "./checkpoints/myModel")
        

INFO:tensorflow:Restoring parameters from ./checkpoints/myModel
step: 1 loss=0.35854995 train_acc:0.8984375test_acc: 0.671875
step: 2 loss=0.42117316 train_acc:0.86328125test_acc: 0.6796875
step: 3 loss=0.48204365 train_acc:0.8359375test_acc: 0.7265625
step: 4 loss=0.42130262 train_acc:0.875test_acc: 0.75
step: 5 loss=0.44047892 train_acc:0.8515625test_acc: 0.7265625
step: 6 loss=0.41938305 train_acc:0.8515625test_acc: 0.640625
step: 7 loss=0.5212034 train_acc:0.828125test_acc: 0.671875
step: 8 loss=0.4213887 train_acc:0.87890625test_acc: 0.71875
step: 9 loss=0.42119008 train_acc:0.86328125test_acc: 0.71875
step: 10 loss=0.45440602 train_acc:0.8515625test_acc: 0.6953125
step: 11 loss=0.3733456 train_acc:0.87109375test_acc: 0.6796875
step: 12 loss=0.43982542 train_acc:0.82421875test_acc: 0.625
step: 13 loss=0.3566019 train_acc:0.90234375test_acc: 0.6328125
step: 14 loss=0.3473129 train_acc:0.89453125test_acc: 0.6953125
step: 15 loss=0.49865213 train_acc:0.82421875test_acc: 0.7578125
ste

step: 131 loss=0.54348195 train_acc:0.78125test_acc: 0.75
step: 132 loss=0.5547216 train_acc:0.80078125test_acc: 0.6875
step: 133 loss=0.4944804 train_acc:0.8203125test_acc: 0.7734375
step: 134 loss=0.51389325 train_acc:0.83203125test_acc: 0.765625
step: 135 loss=0.5015711 train_acc:0.82421875test_acc: 0.671875
step: 136 loss=0.59536374 train_acc:0.78515625test_acc: 0.6640625
step: 137 loss=0.52120936 train_acc:0.8203125test_acc: 0.609375
step: 138 loss=0.5648858 train_acc:0.765625test_acc: 0.7421875
step: 139 loss=0.50144005 train_acc:0.828125test_acc: 0.6640625
step: 140 loss=0.45124343 train_acc:0.828125test_acc: 0.6640625
step: 141 loss=0.51765466 train_acc:0.8046875test_acc: 0.7109375
step: 142 loss=0.6092316 train_acc:0.7890625test_acc: 0.71875
step: 143 loss=0.5404019 train_acc:0.78125test_acc: 0.6484375
step: 144 loss=0.48840755 train_acc:0.8359375test_acc: 0.765625
step: 145 loss=0.5907675 train_acc:0.78125test_acc: 0.671875
step: 146 loss=0.54513645 train_acc:0.80078125test_a