In [1]:
import tensorflow as tf
from tensorflow.contrib.slim.nets import resnet_v2
from datagenerator import ImageDataGenerator
from tensorflow.contrib.data import Iterator
import numpy as np
from datetime import datetime
import time
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

  from ._conv import register_converters as _register_converters


In [2]:
# Path to the textfiles for the trainings and validation set
num = 0
date = "20180808"
train_file = '/HS_code/0_Code/train%d_HR_ICROS.txt'%num
val_file1 = '/HS_code/0_Code/valid%d_HR_ICROS.txt'%num
batch_size = 128
num_classes = 2
learning_rate = 0.001
training_epochs = 50
display_step = 20
layer_num = 50

# Path for tf.summary.FileWriter and to store model checkpoints
filewriter_path = "/HS_code/2_Result_TB/tensorboard_HS_%s_ResNet_%d_hr%d_valid%d_%depoch_0_001"%(date,layer_num,num,num,training_epochs)
checkpoint_path = "/HS_code/1_Model_CP/checkpoints_HS_%s_ResNet_%d_hr%d_valid%d_%depoch_0_001"%(date,layer_num,num,num,training_epochs)

In [3]:
with tf.device('/gpu:0'):
    # data load
    tr_data = ImageDataGenerator(train_file,
                                 mode='training',
                                 batch_size=batch_size,
                                 num_classes=num_classes,
                                 shuffle=True)
    val_data1 = ImageDataGenerator(val_file1,
                                  mode='inference',
                                  batch_size=batch_size,
                                  num_classes=num_classes,
                                  shuffle=False)

     # create an reinitializable iterator given the dataset structure
    iterator = Iterator.from_structure(tr_data.data.output_types,
                                           tr_data.data.output_shapes)

    next_batch = iterator.get_next()
    

    # Ops for initializing the two different iterators
    training_init_op = iterator.make_initializer(tr_data.data)
    validation_init_op1 = iterator.make_initializer(val_data1.data)

    # TF placeholder for graph input and output
    x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3])
    y = tf.placeholder(tf.float32, [batch_size, num_classes])


    net, net_points = resnet_v2.resnet_v2_50(x, 
                                              num_classes=2,
                                              is_training=True)
        
    # Op for calculating the loss
    with tf.name_scope("cross_ent"):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=net, labels=y))
                
    # Train op
    with tf.name_scope("train"):
        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=net, labels=y))
        train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
            
    # Add the loss to summary
    tf.summary.scalar('cross_entropy', loss)

    with tf.name_scope("accuracy"):
        correct_prediction = tf.equal(tf.argmax(net[:,0,0], 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # Add the accuracy to the summary
    tf.summary.scalar('accuracy', accuracy)
    
    # Merge all summaries together
    merged_summary = tf.summary.merge_all()

    # Initialize the FileWriter
    writer = tf.summary.FileWriter(filewriter_path)

    # Initialize an saver for store model checkpoints
    saver = tf.train.Saver()

    train_batches_per_epoch = int(np.floor(tr_data.data_size / batch_size))
    val_batches_per_epoch1 = int(np.floor(val_data1.data_size / batch_size)) 
    print(train_batches_per_epoch)

Instructions for updating:
Use `tf.data.Dataset.from_tensor_slices()`.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
28


In [4]:
# Training
config=tf.ConfigProto(allow_soft_placement = True, log_device_placement=True)
config.gpu_options.allow_growth = True
with tf.Session(config=config, graph=tf.get_default_graph()) as sess:
    sess.run(tf.global_variables_initializer())
    
    # Add the model graph to TensorBoard
    writer.add_graph(sess.graph)
   
    print("{} Start training...".format(datetime.now()))
    print("{} Open tensorboard --logdir={}".format(datetime.now(),
                                                      filewriter_path))
    
    img_batch = np.zeros((batch_size,227,227,3), dtype ='uint8')
    
    for epoch in range(training_epochs):
       
        print("{} Epoch number: {}".format(datetime.now(), epoch+1))
       
        # Initialize iterator with the training dataset
        sess.run(training_init_op)

        for step in range(train_batches_per_epoch):

            # get next batch of data
            img_batch, label_batch = sess.run(next_batch)       

            # And run the training op
            sess.run(train_op, feed_dict={x: img_batch, y: label_batch})

            # Generate summary with the current batch of data and write to file
            if step % display_step == 0:
                s = sess.run(merged_summary, feed_dict={x: img_batch,
                                                        y: label_batch})
                writer.add_summary(s, epoch*train_batches_per_epoch + step)
                print("{} {} step".format(datetime.now(), step))

        # Validate the model on the entire validation set
        print("{} Start validation".format(datetime.now()))
        sess.run(validation_init_op1)
        test_acc = 0.
        test_count = 0
        
        for a in range(val_batches_per_epoch1):

            img_batch, label_batch = sess.run(next_batch)
            acc = sess.run(accuracy, feed_dict={x: img_batch,
                                                y: label_batch})
            test_acc += acc
            test_count += 1

        test_acc /= test_count
        print("{} Validation Accuracy = {:.4f}".format(datetime.now(),
                                                       test_acc))
        if epoch == training_epochs-1 :
            print("{} Saving checkpoint of model...".format(datetime.now()))
            # save checkpoint of the model
            checkpoint_name = os.path.join(checkpoint_path,
                                           'model_epoch'+str(epoch+1)+'.ckpt')
            save_path = saver.save(sess, checkpoint_name)

            print("{} Model checkpoint saved at {}".format(datetime.now(),
                                                           checkpoint_name))

2018-08-09 06:31:53.731606 Start training...
2018-08-09 06:31:53.732305 Open tensorboard --logdir=/HS_code/2_Result_TB/tensorboard_HS_20180808_ResNet_50_hr0_valid0_50epoch_0_001
2018-08-09 06:31:53.732677 Epoch number: 1
2018-08-09 06:31:58.536603 0 step
2018-08-09 06:32:11.553163 20 step
2018-08-09 06:32:16.061356 Start validation
2018-08-09 06:32:18.256695 Validation Accuracy = 0.7009
2018-08-09 06:32:18.256882 Epoch number: 2
2018-08-09 06:32:20.824613 0 step
2018-08-09 06:32:34.028343 20 step
2018-08-09 06:32:38.568022 Start validation
2018-08-09 06:32:40.656147 Validation Accuracy = 0.7734
2018-08-09 06:32:40.656413 Epoch number: 3
2018-08-09 06:32:43.200503 0 step
2018-08-09 06:32:56.394551 20 step
2018-08-09 06:33:00.947538 Start validation
2018-08-09 06:33:03.003923 Validation Accuracy = 0.7612
2018-08-09 06:33:03.004047 Epoch number: 4
2018-08-09 06:33:05.547037 0 step
2018-08-09 06:33:18.800560 20 step
2018-08-09 06:33:23.366566 Start validation
2018-08-09 06:33:25.440888 Val

2018-08-09 06:47:06.866175 Validation Accuracy = 0.9732
2018-08-09 06:47:06.866375 Epoch number: 39
2018-08-09 06:47:09.703554 0 step
2018-08-09 06:47:23.770062 20 step
2018-08-09 06:47:28.690393 Start validation
2018-08-09 06:47:31.472212 Validation Accuracy = 0.9699
2018-08-09 06:47:31.472361 Epoch number: 40
2018-08-09 06:47:34.334978 0 step
2018-08-09 06:47:48.581697 20 step
2018-08-09 06:47:53.370656 Start validation
2018-08-09 06:47:56.363277 Validation Accuracy = 0.9688
2018-08-09 06:47:56.363469 Epoch number: 41
2018-08-09 06:47:59.201882 0 step
2018-08-09 06:48:13.462881 20 step
2018-08-09 06:48:18.456973 Start validation
2018-08-09 06:48:21.106845 Validation Accuracy = 0.9732
2018-08-09 06:48:21.106967 Epoch number: 42
2018-08-09 06:48:23.936907 0 step
2018-08-09 06:48:38.362394 20 step
2018-08-09 06:48:43.121487 Start validation
2018-08-09 06:48:45.822987 Validation Accuracy = 0.9777
2018-08-09 06:48:45.823100 Epoch number: 43
2018-08-09 06:48:48.636076 0 step
2018-08-09 06: