# Train and Eval

In [1]:
# coding: UTF-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time

import numpy as np
import tensorflow as tf

import model
from reader import Cifar10Reader

from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile

In [2]:
EPOCH = 30
LEARNING_RATE = 0.001
data_dir = "./data/cifar-10-batches-bin/"
check_point_dir = "./check_point_dir/"
test_data = "./data/cifar-10-batches-bin/test_batch.bin"
graph_dir = "./graph_dir/"

In [3]:
def _loss(logits, label):
    labels = tf.cast(label, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits, labels, name="cross_entropy_per_example")
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name="cross_entropy")
    return cross_entropy_mean

def _train(total_loss, global_step):
    opt = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
    grads = opt.compute_gradients(total_loss)
    train_op = opt.apply_gradients(grads, global_step=global_step)
    return train_op

In [4]:
filenames = [
  os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)
  ]

In [5]:
def _eval(sess, top_k_op, input_image, label_placeholder):
    if not test_data:
        return np.nan
    
    image_reader = Cifar10Reader(test_data)
    true_count = 0
    for index in range(10000):
        image = image_reader.read(index)
        
        predictions = sess.run([top_k_op],
                              feed_dict={
                                  input_image: image.byte_array,
                                  label_placeholder: image.label
                              })
        true_count += np.sum(predictions)
    image_reader.close()
    
    return (true_count / 10000.0)

In [6]:
def _restore(saver, sess):
    checkpoint = tf.train.get_checkpoint_state(check_point_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        save.restore(sess, checkpoint.model_checkpoint_path)

def _export_graph(sess, epoch):
    constant_graph_def = graph_util.convert_variables_to_constants(
        sess, sess.graph_def, ["output/logits"])
    
    file_path = os.path.join(graph_dir, 'graph_%02d_epoch.pb' % epoch)
    with gfile.FastGFile(file_path, "wb") as f:
        f.write(constant_graph_def.SerializeToString())

# Start Train

In [7]:
global_step = tf.Variable(0, trainable=False)
with tf.device('/gpu:0'):
    train_placeholder = tf.placeholder(tf.float32, shape=[32, 32, 3],
                                   name="input_image")
    label_placeholder = tf.placeholder(tf.int32, shape=[1], name="label")

    image_node = tf.expand_dims(train_placeholder, 0)

    logits = model.inference(image_node)
    total_loss = _loss(logits, label_placeholder)
    train_op = _train(total_loss, global_step)

top_k_op = tf.nn.in_top_k(logits, label_placeholder, 1)

saver = tf.train.Saver(tf.all_variables())

Instructions for updating:
Please use tf.global_variables instead.


In [None]:
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(tf.initialize_all_variables())
    
    total_duration = 0
    
    _export_graph(sess, 0)
    
    for epoch in range(1, EPOCH + 1):
        start_time = time.time()
        
        for file_index in range(5):
            print("EPOCH %d: %s" % (epoch, filenames[file_index]))
            reader = Cifar10Reader(filenames[file_index])
            
            for index in range(10000):
                image = reader.read(index)
                
                _, loss_value = sess.run([train_op, total_loss],
                                        feed_dict={
                                            train_placeholder: image.byte_array,
                                            label_placeholder: image.label
                                        })
                assert not np.isnan(loss_value)
                
            reader.close()
        
    duration = time.time() - start_time
    total_duration += duration
    
    prediction = _eval(sess, top_k_op,
                       train_placeholder, label_placeholder)
    print("epoch %d duration = %d sec, prediction = %.3f"
          % (epoch, duration, prediction))
    
    tf.train.SummaryWriter(check_point_dir, sess.graph)
    sarver.save(sess, check_point_dir, global_step=epoch)
    _export_graph(sess, epoch)
    
print("Total duration = %d sec" % total_duration)

Instructions for updating:
Use `tf.global_variables_initializer` instead.
INFO:tensorflow:Froze 22 variables.
Converted 10 variables to const ops.
EPOCH 1: ./data/cifar-10-batches-bin/data_batch_1.bin
EPOCH 1: ./data/cifar-10-batches-bin/data_batch_2.bin
EPOCH 1: ./data/cifar-10-batches-bin/data_batch_3.bin
EPOCH 1: ./data/cifar-10-batches-bin/data_batch_4.bin
EPOCH 1: ./data/cifar-10-batches-bin/data_batch_5.bin
EPOCH 2: ./data/cifar-10-batches-bin/data_batch_1.bin
EPOCH 2: ./data/cifar-10-batches-bin/data_batch_2.bin
EPOCH 2: ./data/cifar-10-batches-bin/data_batch_3.bin
EPOCH 2: ./data/cifar-10-batches-bin/data_batch_4.bin
EPOCH 2: ./data/cifar-10-batches-bin/data_batch_5.bin
EPOCH 3: ./data/cifar-10-batches-bin/data_batch_1.bin
EPOCH 3: ./data/cifar-10-batches-bin/data_batch_2.bin
EPOCH 3: ./data/cifar-10-batches-bin/data_batch_3.bin
EPOCH 3: ./data/cifar-10-batches-bin/data_batch_4.bin
EPOCH 3: ./data/cifar-10-batches-bin/data_batch_5.bin
EPOCH 4: ./data/cifar-10-batches-bin/data_b