# TRAIN

In [1]:
# coding: UTF-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os 

import numpy as np

from reader import Cifar10Reader

In [2]:
import tensorflow as tf
import os 
import time
import model

EPOCH = 30
data_dir = "./data/cifar-10-batches-bin/"
checkpoint_dir = "./checkpoint_dir/"

In [3]:
filenames = [
    os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)
]

In [4]:
LEARNING_RATE = 0.001

def _loss(logits, label):
    labels = tf.cast(label, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits, labels, name='cross_entropy_per_example')
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
    return cross_entropy_mean


def _train(total_loss, global_step):
    opt = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
    grads = opt.compute_gradients(total_loss)
    train_op = opt.apply_gradients(grads, global_step=global_step)
    return train_op

In [5]:
global_step = tf.Variable(0, trainable=False)
with tf.device('/gpu:0'):
    train_placeholder = tf.placeholder(tf.float32,
                                      shape=[32, 32, 3], name="input_image")
   
    label_placeholder = tf.placeholder(tf.int32, shape=[1], name="label")
    image_node = tf.expand_dims(train_placeholder, 0)

    logits = model.inference(image_node)
    total_loss = _loss(logits, label_placeholder)
    train_op = _train(total_loss, global_step)

In [6]:
with tf.Session(config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(tf.initialize_all_variables())
    
    total_duration = 0
    
    for epoch in range(1, EPOCH):
        start_time = time.time()
        
        for file_index in range(5):
            print("Epoch %d: %s" % (epoch, filenames[file_index]))
            reader = Cifar10Reader(filenames[file_index])
            
            for index in range(10000):
                image = reader.read(index)
                
                _, loss_value, logits_value = sess.run(
                    [train_op, total_loss, logits],
                    feed_dict={
                        train_placeholder: image.byte_array,
                        label_placeholder: image.label
                        })
                assert not np.isnan(loss_value), \
                    "Model diverged with loss = NaN"
                
                if index % 1000 == 0:
                    print("[%d]: %r" %(image.label, logits_value))
                    
                reader.close()
                
            duration = time.time() - start_time
            total_duration += duration
            
            print('epoch %d duration = %d sec' % (epoch, duration))
            
            tf.train.SummaryWriter(checkpoint_dir, sess.graph)
            
        print('Total duration = %d sec' % total_duration)

Instructions for updating:
Use `tf.global_variables_initializer` instead.
Epoch 1: ./data/cifar-10-batches-bin/data_batch_1.bin
[6]: array([[ 0.01269201,  0.00350406,  0.00831947,  0.01605432,  0.00153285,
         0.00880676, -0.00178913,  0.02524121, -0.01442572, -0.00948294]], dtype=float32)
[9]: array([[ 0.34206891,  0.59747398, -0.23578905,  0.63309163, -1.18035305,
         0.92095989, -1.10292315, -0.03301514,  0.60466611, -0.50838846]], dtype=float32)
[7]: array([[ 0.24253657, -0.9067204 ,  0.66818041,  0.11894067,  0.23858941,
         0.5085429 , -0.14684136, -0.13119869, -0.05172829, -0.43493044]], dtype=float32)
[3]: array([[-1.50766766, -0.83135581,  1.10887206,  0.70167381,  1.04970694,
        -0.19736518,  1.98130405,  0.18806779, -1.56505501, -0.49789575]], dtype=float32)
[5]: array([[ 0.97052425,  0.0746243 ,  0.0536098 , -0.66864258, -0.16027038,
        -0.05318333, -1.10253489,  0.167454  ,  0.76967359,  0.01122805]], dtype=float32)
[6]: array([[-0.83200151, -0.712

KeyboardInterrupt: 