In [3]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

In [4]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


Instructions for updating:
Please write your own downloading logic.


Instructions for updating:
Please use urllib or similar directly.


Successfully downloaded

 

train-images-idx3-ubyte.gz

 

9912422

 

bytes.




Instructions for updating:
Please use tf.data to implement this functionality.


Extracting

 

MNIST_data/train-images-idx3-ubyte.gz




Successfully downloaded

 

train-labels-idx1-ubyte.gz

 

28881

 

bytes.




Instructions for updating:
Please use tf.data to implement this functionality.


Extracting

 

MNIST_data/train-labels-idx1-ubyte.gz




Instructions for updating:
Please use tf.one_hot on tensors.


Successfully downloaded

 

t10k-images-idx3-ubyte.gz

 

1648877

 

bytes.




Extracting

 

MNIST_data/t10k-images-idx3-ubyte.gz




Successfully downloaded

 

t10k-labels-idx1-ubyte.gz

 

4542

 

bytes.




Extracting

 

MNIST_data/t10k-labels-idx1-ubyte.gz




Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.




In [8]:
class NeuralNetWork():
    def __init__(self, initial_weights, activation_fn, use_batch_norm):
        """
        初始化网络对象
        :param initial_weights: 权重初始化值，是一个list，list中每一个元素是一个权重矩阵
        :param activation_fn: 隐层激活函数
        :param user_batch_norm: 是否使用batch normalization
        """
        self.use_batch_norm = use_batch_norm
        self.name = "With Batch Norm" if use_batch_norm else "Without Batch Norm"
        
        self.is_training = tf.placeholder(tf.bool, name='is_training')
        
        # 存储训练准确率
        self.training_accuracies = []
        
        self.build_network(initial_weights, activation_fn)
        
    def build_network(self, initial_weights, activation_fn):
        """
        构建网络图
        :param initial_weights: 权重初始化，是一个list
        :param activation_fn: 隐层激活函数
        """
        self.input_layer = tf.placeholder(tf.float32, [None, initial_weights[0].shape[0]])
        layer_in = self.input_layer
        
        # 前向计算（不计算最后输出层）
        for layer_weights in initial_weights[:-1]:
            layer_in = self.fully_connected(layer_in, layer_weights, activation_fn)
            
        # 输出层
        self.output_layer = self.fully_connected(layer_in, initial_weights[-1])
    
    def fully_connected(self, layer_in, layer_weights, activation_fn=None):
        """
        抽象出的全连接层计算
        """
        # 如果使用BN与激活函数
        if self.use_batch_norm and activation_fn:
            weights = tf.Variable(layer_weights)
            linear_output = tf.matmul(layer_in, weights)
            
            # 调用BN接口
            batch_normalized_output = tf.layers.batch_normalization(linear_output, training=self.is_training)

            return activation_fn(batch_normalized_output)
        # 如果不使用BN或激活函数（即普通隐层） 
        else:
            weights = tf.Variable(layer_weights)
            bias = tf.Variable(tf.zeros([layer_weights.shape[-1]]))
            linear_output = tf.add(tf.matmul(layer_in, weights), bias)

            return activation_fn(linear_output) if activation_fn else linear_output
    
    def train(self, sess, learning_rate, training_batches, batches_per_validate_data, save_model=None):
        """
        训练模型
        :param sess: TensorFlow Session
        :param learning_rate: 学习率
        :param training_batches: 用于训练的batch数
        :param batches_per_validate_data: 训练多少个batch对validation数据进行一次验证
        :param save_model: 存储模型
        """
        
        # 定义输出label
        labels = tf.placeholder(tf.float32, [None, 10])
        
        # 定义损失函数
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, 
                                                                                  logits=self.output_layer))
        
        # 准确率
        correct_prediction = tf.equal(tf.argmax(self.output_layer, 1), tf.argmax(labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
        #
        if self.use_batch_norm:
            with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
            
        else:
            train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cross_entropy)
        
        # 显示进度条
        for i in tqdm.tqdm(range(training_batches)):
            batch_x, batch_y = mnist.train.next_batch(60)
            sess.run(train_step, feed_dict={self.input_layer: batch_x,
                                            labels: batch_y,
                                            self.is_training: True})
            if i % batches_per_validate_data == 0:
                val_accuracy = sess.run(accuracy, feed_dict={self.input_layer: mnist.validation.images,
                                                              labels: mnist.validation.labels,
                                                              self.is_training: False})
                self.training_accuracies.append(val_accuracy)
        print("{}: The final accuracy on validation data is {}".format(self.name, val_accuracy))
        
        # 存储模型
        if save_model:
            tf.train.Saver().save(sess, save_model)
    
    def test(self, sess, test_training_accuracy=False, restore=None):
        # 定义label
        labels = tf.placeholder(tf.float32, [None, 10])
        
        # 准确率
        correct_prediction = tf.equal(tf.argmax(self.output_layer, 1), tf.argmax(labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        
        # 是否加载模型
        if restore:
            tf.train.Saver().restore(sess, restore)
        
        test_accuracy = sess.run(accuracy, feed_dict={self.input_layer: mnist.test.images,
                                                      labels: mnist.test.labels,
                                                      self.is_training: False})
        
        print("{}: The final accuracy on test data is {}".format(self.name, test_accuracy))