# 通用代码框架

In [1]:
import tensorflow as tf

## 0. 必要参数定义

In [2]:
trainset_path = "./image/train/"
testset_path = "./image/test/"
trainset_record_path = "./output/training-images/"
validset_record_path = "./output/validing-images/"
testset_record_path = "./output/testing-images/"
log_dir = "./logdir/"
ckpt_path = "./checkpoint/"

image_width = 24
image_height = 24
image_channels = 3

#总共3类
num_breed = 3

## 1. 辅助函数定义

In [3]:
# 定义w,b初始化函数
def variable_with_weight_loss(shape, stddev, wl):
    var = tf.Variable(tf.truncated_normal(shape, stddev=stddev))
    if wl is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
        tf.add_to_collection('losses', weight_loss)
    return var

## 2. 通用框架函数定义

In [4]:
def inference(X):
    with tf.name_scope("input") as scope:
        #x = tf.placeholder(tf.float32,shape = [None, image_height, image_width, image_channels], name = "x_input")
        #y = tf.placeholder(tf.int64,shape = [X.shape[0]],name = "y_input")
        #x = image_batch
        #y = label_batch
        x = X
        tf.summary.image("image_input",X, batch_size)

    #卷积层一
    with tf.name_scope("conv1") as scope:
        weight1 = variable_with_weight_loss(shape=[5, 5, 3, 64], stddev=5e-2, wl=0.0)
        kernel1 = tf.nn.conv2d(x, weight1, [1, 1, 1, 1], padding='SAME')
        bias1 = tf.Variable(tf.constant(0.0, shape=[64]))
        conv1 = tf.nn.relu(tf.nn.bias_add(kernel1, bias1))
        pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
        norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
    
        tf.summary.histogram("weight1",weight1)
    
    #卷积层二
    with tf.name_scope("conv2") as scope:
        weight2 = variable_with_weight_loss(shape=[5, 5, 64, 64], stddev=5e-2, wl=0.0)
        kernel2 = tf.nn.conv2d(norm1, weight2, [1, 1, 1, 1], padding='SAME')
        bias2 = tf.Variable(tf.constant(0.1, shape=[64]))
        conv2 = tf.nn.relu(tf.nn.bias_add(kernel2, bias2))
        norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)
        pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
    
        tf.summary.histogram("weight2",weight2)

    #全连接层一
    with tf.name_scope("full_connect1") as scope:
        reshape = tf.reshape(pool2, [batch_size, -1])
        weight3 = variable_with_weight_loss(shape=[6*6*64, 384], stddev=0.04, wl=0.004)
        bias3 = tf.Variable(tf.constant(0.1, shape=[384]))
        local3 = tf.nn.relu(tf.matmul(reshape, weight3) + bias3)
    
        tf.summary.histogram("weight3",weight3)
    
    #全连接层二
    with tf.name_scope("full_connect2") as scope:
        weight4 = variable_with_weight_loss(shape=[384, 192], stddev=0.04, wl=0.004)
        bias4 = tf.Variable(tf.constant(0.1, shape=[192]))                                      
        local4 = tf.nn.relu(tf.matmul(local3, weight4) + bias4)
    
        tf.summary.histogram("weight4",weight4)
    
    #全连接层三
    with tf.name_scope("full_connect3") as scope:
        weight5 = variable_with_weight_loss(shape=[192, 3], stddev=1/192.0, wl=0.0)
        bias5 = tf.Variable(tf.constant(0.0, shape=[3]))
        logits = tf.add(tf.matmul(local4, weight5), bias5)
    
        tf.summary.histogram("weight5",weight5)
        # softmax处理
        #Y_ = tf.nn.softmax(logits)
        return logits

def loss(X, Y):
    logits = inference(X)
    with tf.name_scope("cross_entropy") as scope:
        #定义训练代价函数
        cross_entropy = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits = logits, labels = Y))
        tf.summary.scalar("cross_entropy",cross_entropy)
        
    return cross_entropy

def inputs(data_dir, batch_size, shuffle = True):
    filenames =tf.train.match_filenames_once(data_dir+"*.tfrecords")
    filenames_queue = tf.train.string_input_producer(filenames, shuffle = True)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filenames_queue)
    features = tf.parse_single_example(
        serialized_example,
        features = {
            "label":tf.FixedLenFeature([],tf.int64),
            "image_raw":tf.FixedLenFeature([],tf.string)
        }
    )

    # 解码图像数据
    im = tf.decode_raw(features["image_raw"],tf.uint8) 
    reshape = tf.reshape(im,(image_height,image_width,image_channels))
    image = tf.cast(reshape,tf.float32)
    label = tf.cast(features["label"],tf.int64)


    # 组合训练数据
    min_after_dequeue = 100*batch_size
    capacity = min_after_dequeue + 3*batch_size

    image_batch, label_batch = tf.train.shuffle_batch(
    (image,label),batch_size = batch_size,
    capacity = capacity,min_after_dequeue = min_after_dequeue
    )
    return image_batch, label_batch


def train(total_loss):
    # 定义梯度优化算法
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = tf.train.AdamOptimizer(1e-4).minimize(total_loss, global_step = global_step)
    return train_op

def evaluate(sess, X, Y):
    Y_ = inference(X)
    correct_prediction = tf.equal(tf.argmax(Y_,1), tf.argmax(Y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype = tf.float32))
    
    return sess.run([accuracy])

## 3. 定义数据流图

In [5]:
batch_size = 10
X, Y = inputs(trainset_record_path, batch_size)
total_loss = loss(X, Y)
train_op = train(total_loss)

# 须在数据流图定义完后在merge_all
merged = tf.summary.merge_all()

#在迭代控制中,记得添加tf.initialize_local_variables(),官网教程没有说明,但是如过不加，会出错
init = [tf.global_variables_initializer(),tf.local_variables_initializer()]

## 4. 训练

In [6]:
num_iter = 4000
init_step = 0
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
        
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess = sess, coord = coord) 
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)    
    
     ### 从检查点恢复训练
    init_step = 0
    # 验证之前是否已经保存了检查点文件
    ckpt = tf.train.get_checkpoint_state(ckpt_path)
    if ckpt and ckpt.model_checkpoint_path:
        # 从检查点中恢复模型参数
        saver.restore(sess,ckpt.model_checkpoint_path)
        init_step = int(ckpt.model_checkpoint_path.rsplit("-",1)[1])
    
    #训练    
    print("start training......\n")
    for step in range(init_step, num_iter):
        sess.run(train_op)
        #print(step)
        # 每隔一定迭代步数参看loss情况,并保存检查点
        if step % 100 == 0:
            #loss_value  = sess.run([total_loss])
            summary, loss_value  = sess.run([merged, total_loss])
            summary_writer.add_summary(summary, step)
            print(step,loss_value)
        
        # 每隔3000 步保存一次
        if step % 3000 == 0:
            saver.save(sess,ckpt_path+ "my-model",global_step = step)
    
    saver.save(sess, ckpt_path+ "my-model" ,global_step = num_iter)
    summary_writer.close()
    
    coord.request_stop()
    coord.join(threads)
    
    print("finished.\n")
    

start training......

0 1.40794
100 0.835881
200 0.641825
300 0.333636
400 0.244037
500 0.064397
600 0.00418409
700 0.00477062
800 0.00688567
900 0.00301323
1000 0.00152629
1100 0.00108007
1200 0.0013108
1300 0.000406786
1400 0.000935769
1500 0.000275467
1600 0.000725677
1700 0.000449241
1800 0.000287554
1900 0.00019782
2000 0.000185532
2100 0.000185026
2200 0.000173659
2300 0.000104061
2400 0.000114449
2500 8.12437e-05
2600 5.23779e-05
2700 7.05063e-05
2800 0.0001757
2900 5.18061e-05
3000 7.84466e-05
3100 1.97883e-05
3200 7.54767e-05
3300 3.89909e-05
3400 2.77033e-05
3500 2.4008e-05
3600 9.91807e-06
3700 2.49972e-05
3800 1.92634e-05
3900 8.72602e-06
finished.

