In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.contrib.keras import datasets   # 我们使用这个函数来下载数据
import os
%matplotlib inline
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters


In [2]:
trainSet, testSet = datasets.cifar10.load_data()

Downloading data from http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

In [None]:
class CifarModel(object):    
    def __init__(self, lr, batch_size, iter_num):
        self.lr = lr
        self.batch_size = batch_size
        self.iter_num = iter_num
        
        tf.reset_default_graph()
        self.X = tf.placeholder(tf.float32, [None, 32, 32, 3])
        self.y = tf.placeholder(tf.int32, [None, 10])        
        self.dropRate = tf.placeholder(tf.float32)                    
 
        conv1 = tf.layers.conv2d(self.X, 32, 5, padding='same', activation=tf.nn.relu,
                                     kernel_initializer=tf.truncated_normal_initializer(stddev=0.0001, seed=0),
                                     bias_initializer=tf.constant_initializer(0.001))        
        pool1 = tf.layers.max_pooling2d(conv1 , 3, 2, padding='same')           
        conv2 = tf.layers.conv2d(pool1, 32, 5, padding='same', activation=tf.nn.relu,
                                     kernel_initializer=tf.truncated_normal_initializer(stddev=0.01, seed=0),
                                     bias_initializer=tf.constant_initializer(0.001))
        pool2 = tf.layers.average_pooling2d(conv2, 3,2, padding='same')          
        conv3 = tf.layers.conv2d(pool2, 64, 5, padding='same', activation=tf.nn.relu,
                                     kernel_initializer=tf.truncated_normal_initializer(stddev=0.01, seed=0),
                                     bias_initializer=tf.constant_initializer(0.001))
        pool3 = tf.layers.average_pooling2d(conv3, 3,2, padding='same')  

        flatten = tf.reshape(pool3 , [self.batch_size, 4*4*64])
        dense1 = tf.layers.dense(flatten, 64,  activation=tf.nn.relu, use_bias=True,
                                 kernel_initializer=tf.truncated_normal_initializer(stddev=0.01, seed=0),
                                 bias_initializer=tf.constant_initializer(0.001))      
        dense1 = tf.nn.dropout(dense1, self.dropRate)
        dense2 = tf.layers.dense(dense1, 10, use_bias=True,
                                 kernel_initializer=tf.truncated_normal_initializer(stddev=0.01, seed=0),
                                 bias_initializer=tf.constant_initializer(0.1))  
        
        self.loss = tf.losses.softmax_cross_entropy(onehot_labels=self.y, logits=dense2)
        self.train_step = tf.train.GradientDescentOptimizer(self.lr).minimize(self.loss )        

        # 用于模型训练
        self.correct_prediction = tf.equal(tf.argmax(self.y, axis=1), tf.argmax(dense2, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
        # 用于保存训练好的模型
        self.saver = tf.train.Saver()
        
        summary_loss = tf.summary.scalar('loss', self.loss)
        summary_accuracy = tf.summary.scalar('accuracy', self.accuracy)
        self.merged_summary_op = tf.summary.merge_all()
        
    def get_next_train_batch(self):        
        m = 0
        while True:
            batch_x = trainSet[0][m:m+self.batch_size]
            batch_y = trainSet[1][m:m+self.batch_size]

            m += self.batch_size
            if m+self.batch_size > 50000:
                m=0
            yield batch_x, (np.arange(10) == batch_y[:, None]).astype(int).reshape(self.batch_size,10)
        
            
    def get_next_test_batch(self):        
        n = 0
        while True:
            batch_x = testSet[0][n:n+self.batch_size]
            batch_y = testSet[1][n:n+self.batch_size]

            n += self.batch_size
            if n+self.batch_size > 10000:
                n=0
            yield batch_x, (np.arange(10) == batch_y[:, None]).astype(int).reshape(self.batch_size,10)
            
    def train(self):
        
        with tf.Session() as sess:            #  打开一个会话。可以想象成浏览器打开一个标签页一样，直观地理解一下
            sess.run(tf.global_variables_initializer())  # 先初始化所有变量。
            generator = self.get_next_train_batch()  # 读取一批数据
            genetator_test = self.get_next_test_batch()
            
            summary_writer = tf.summary.FileWriter('log/train_base', sess.graph)
            summary_writer_test = tf.summary.FileWriter('log/test_base')
            
            for i in range(self.iter_num):
                batch_x, batch_y = generator.next()                  
                loss, _= sess.run([self.loss, self.train_step], feed_dict={self.X: batch_x, self.y: batch_y, self.dropRate:0.5})   # 每调用一次sess.run，就像拧开水管一样，所有self.loss和self.train_step涉及到的运算都会被调用一次。
                
                if i%1000 == 0:  
                    batch_x, batch_y = generator.next()             
                    train_accuracy, summary_str = sess.run([self.accuracy, self.merged_summary_op], feed_dict={self.X: batch_x, self.y: batch_y, self.dropRate:1.})  # 把训练集数据装填进去
                    summary_writer.add_summary(summary_str, i)                    
                    test_x, test_y = genetator_test.next()
                    test_accuracy, summary_str = sess.run([self.accuracy, self.merged_summary_op], feed_dict={self.X: test_x, self.y: test_y, self.dropRate:1.})   # 把测试集数据装填进去
                    summary_writer_test.add_summary(summary_str, i)
                    print ('iter\t%i\tloss\t%f\ttrain_accuracy\t%f\ttest_accuracy\t%f' % (i,loss,train_accuracy,test_accuracy))
            self.saver.save(sess, 'model/cifarModel') # 保存模型
            summary_writer.flush()
            summary_writer_test.flush()
            
    def test(self):
        with tf.Session() as sess:
            self.saver.restore(sess, 'model/cifarModel')
            genetator_test = self.get_next_test_batch()
            
            Accuracy = []
            for i in range(int(10000 / self.batch_size)):
                test_x, test_y = genetator_test.next()
                test_accuracy = sess.run(self.accuracy, feed_dict={self.X: test_x, self.y: test_y,self.dropRate:1.0})
                Accuracy.append(test_accuracy)
            print( '==' * 15)
            print( 'Test Accuracy: ', np.mean(np.array(Accuracy)))