# Train and Eval

In [1]:
# coding: UTF-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time

import numpy as np
import tensorflow as tf

import model
from reader import Cifar10Reader

from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile


In [2]:
EPOCH = 30
LEARNIONG_RATE = 0.001
data_dir = "./data/cifar-10-batches-bin/"
check_point_dir = "./check_point_dir/"
test_data = "./data/cifar-10-batches-bin/test_batch.bin"
graph_dir = "./graph_dir/"

In [3]:
def _loss(logits, label):
    labels = tf.cast(label, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits, labels, name="cross_entropy_per_example")
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name="cross_entropy")
    return cross_entropy_mean

def _train(total_loss, global_step):
    opt = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
    grads = opt.compute_gradients(total_loss)
    train_op = opt.apply_gradients(grads, global_step=global_step)
    return train_op

In [4]:
filenames = [
  os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)
  ]

In [5]:
def _eval(sess, top_k_op, input_image, label_placeholder):
    if not test_data:
        return np.nan
    
    image_reader = Cifar10Reader(test_data)
    true_count = 0
    for index in range(10000):
        image = image_reader.read(index)
        
        predictions = sess.run([top_k_op],
                              feed_dict={
                                  input_image: image.byte_array,
                                  label_placeholder: image.label
                              })
        true_count += np.sum(predictions)
    image_reader.close()
    
    return (true_count / 10000.0)

In [None]:
def _restore(saver, sess):
    checkpoint = tf.train.get_checkpoint_state(check_point_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        save.restore(sess, checkpoint.model_checkpoint_path)

def _export_graph(sess, epoch):
    constant_graph_def = graph_util.convert_variables_to_constants(
        sess, sess.graph_def, ["output/logits"])
    
    file_path = os.path.join(graph_dir, 'graph_%02d_epoch.pb' % epoch)
    with gfile.FastGfile(file_path, "wb") as f:
        f.write(constant_graph_def.SerializeToString())