保存/加载参数 - save/load model of TensorFlow
====
>MNIST数据集<br>


In [1]:
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange

In [2]:
# dataset
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("../MNIST_data/")

import tensorflow as tf
import os
import numpy as np

# default graph
sess = tf.InteractiveSession()

# hyper parameters
input_size = 784
num_classes = 10

# placeholder
x = tf.placeholder(tf.float32, [None, input_size], name='placeholder_x')
y = tf.placeholder(tf.int64, [None, ], name='placeholder_y')

# model
# hidden layer
fc1 = tf.layers.dense(x, units=128, activation=tf.nn.relu)
y_ = tf.layers.dense(fc1, units=num_classes)
tf.add_to_collection('logits', y_)

# metrics
loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y_, labels=y)
                         , axis=-1)
tf.add_to_collection('loss_op', loss_op)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
train_op = optimizer.minimize(loss_op)
tf.add_to_collection('train_op', train_op)
accuracy_op = tf.reduce_mean(tf.cast(tf.equal(y, tf.argmax(y_, 1)), dtype=tf.float32))
tf.add_to_collection('accuracy_op', accuracy_op)

# initialize global graph
tf.global_variables_initializer().run()

# evaluate
print(accuracy_op.eval({x: mnist.test.images, y: mnist.test.labels}))


def mkdir(filename):
    if os.path.exists(filename):
        pass
    else:
        os.mkdir(filename)

  from ._conv import register_converters as _register_converters


Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
0.082


In [3]:
# training
for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    train_op.run({x: batch_xs, y: batch_ys})

# evaluate
print(accuracy_op.eval({x: mnist.test.images, y: mnist.test.labels}))

0.9561


## 方法一
利用`TensorFlow`自带的`save`和`restore`，保存和加载模型。<br>
该方式加载模型需要写出模型代码。

In [4]:
def load(sess, filename='model1'):
    # load
    filename = './%s/model.ckpt' % (filename)
    saver = tf.train.Saver()
    saver.restore(sess, filename)


def save(sess, filename='model1'):
    # save
    mkdir(filename)
    filename = './%s/model.ckpt' % (filename)
    # 导入网络参数
    # Saver参数为空，表示导入全部；若是参数不足，则默认从第0个元素开始倒入，直至参数完成
    # saver = tf.train.Saver(tf.all_variables())
    saver = tf.train.Saver()
    saver.save(sess, filename)

save(sess)
# load(sess)

## 方法二
利用`Numpy`保存和导入参数<br>
该模型的速度非常慢，不建议使用。但是可以将模型的数值保存为numpy格式，便于其他用途

In [5]:
def load_np(filename='model2'):
    # load
    for variables in tf.trainable_variables():
        variables.load(value=np.load(os.path.join(filename, variables.name.replace('/', '_') + '.npy')))


def save_np(filename='model2'):
    # save
    mkdir(filename)
    for variables in tf.trainable_variables():
        np.save(file=os.path.join(filename, variables.name.replace('/', '_') + '.npy'),
                arr=variables.eval())

save_np()
# load_np()

## 方法三
利用`TensorFlow`自带的`save`和`restore`，保存和加载图（grarph）。<br>
该方式加载模型不需要写出代码，直接使用即可

In [6]:
def load_graph(sess, filename='model3', global_step=None):
    # load
    saver = tf.train.import_meta_graph('./%s/model-%s.meta' % (filename, global_step))
    saver.restore(sess, './%s/model-%s' % (filename, global_step))


def save_graph(sess, filename='model3', global_step=None):
    # save
    mkdir(filename)
    saver = tf.train.Saver()
    filename = './%s/model' % filename
    saver.save(sess, filename, global_step)

save_graph(sess, global_step=100)
# load_graph(sess, global_step=100)

## 方法三加载模型
在另一脚本中写入一下模型运行

In [7]:
with tf.Session() as sess:
    load_graph(sess, global_step=100)
    graph = tf.get_default_graph()

    x = graph.get_operation_by_name('placeholder_x').outputs[0]
    y = graph.get_operation_by_name('placeholder_y').outputs[0]

    logits = tf.get_collection('logits')[0]
    loss_op = tf.get_collection('loss_op')[0]
    train_op = tf.get_collection('train_op')[0]
    accuracy_op = tf.get_collection('accuracy_op')[0]

    feed_dict = {x: mnist.test.images, y: mnist.test.labels}
    print('before re-training', accuracy_op.eval(feed_dict=feed_dict))

    # continue training
    for _ in xrange(100):
        batch_images, batch_labels = mnist.test.next_batch(batch_size=100)
        feed_dict = {x: batch_images, y: batch_labels}
        sess.run(train_op, feed_dict=feed_dict)

    feed_dict = {x: mnist.test.images, y: mnist.test.labels}
    print('after re-training', accuracy_op.eval(feed_dict=feed_dict))


INFO:tensorflow:Restoring parameters from ./model3/model-100
before re-training 0.9561
after re-training 0.9688
