In [89]:
# stretching the cells
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# cuda settings
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)

In [90]:
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## Reading then tfrecords

In [91]:
def read_training_data(data_path='train.tfrecords'):
    data_path = "train.tfrecords"
    feature = {'train/image':tf.FixedLenFeature([],tf.string), 'train/label':tf.FixedLenFeature([],tf.int64)}
    filename_queue = tf.train.string_input_producer([data_path],num_epochs=10)

    reader = tf.TFRecordReader()
    _,serialzed_example = reader.read(filename_queue)

    features = tf.parse_single_example(serialzed_example,features=feature)

    image = tf.decode_raw(features["train/image"],tf.float32)
    label = tf.cast(features["train/label"],tf.int32)

    image = tf.reshape(image,[224,224,3])
    label = tf.one_hot(label,depth=8)
    images,labels = tf.train.shuffle_batch([image,label],batch_size=100,capacity=100,num_threads=1,min_after_dequeue=10)

    return (images,labels)
    
# init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
# with tf.Session() as sess:
#     sess.run(init_op)

#     coord = tf.train.Coordinator()
#     threads = tf.train.start_queue_runners(coord=coord)

#     for i in range(1):
#         img,lbl = sess.run([images,labels])
#         print(img.shape,lbl.shape)        
#     coord.request_stop()

#     coord.join(threads)
#     sess.close()

        
# with tf.Session() as sess:
#     images,labels = read_training_data()
#     init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())  
#     sess.run(init_op)
#     coord = tf.train.Coordinator()
#     threads = tf.train.start_queue_runners(coord=coord)
#     for i in range(1):
#         i,l = sess.run([images,labels])
#         i = i.astype(np.uint8)
#         print(i.shape,l.shape)
#         plt.imshow(i[99])
#         plt.title(l[99])
#         plt.show()
        

#     coord.request_stop()

#     coord.join(threads)
#     sess.close()


In [92]:
def read_validation_data(data_path='valid.tfrecords'):
    data_path = "valid.tfrecords"
    feature = {'valid/image':tf.FixedLenFeature([],tf.string), 'valid/label':tf.FixedLenFeature([],tf.int64)}
    filename_queue = tf.train.string_input_producer([data_path],num_epochs=1)

    reader = tf.TFRecordReader()
    _,serialzed_example = reader.read(filename_queue)

    features = tf.parse_single_example(serialzed_example,features=feature)

    image = tf.decode_raw(features["valid/image"],tf.float32)
    label = tf.cast(features["valid/label"],tf.int32)

    image = tf.reshape(image,[224,224,3])
    label = tf.one_hot(label,depth=8)
    images,labels = tf.train.shuffle_batch([image,label],batch_size=100,capacity=100,num_threads=1,min_after_dequeue=10)

    return (images,labels)
    

## Creating the default graph

In [93]:
tf.reset_default_graph()

print("Loading default graph.....")
graph = tf.Graph()
with graph.as_default():
    
    model_path = "./vgg_16.ckpt"
    assert(os.path.isfile(model_path))

    vgg = slim.nets.vgg
    image_size = vgg.vgg_16.default_image_size
    num_classes = 8
    
    x = tf.placeholder(tf.float32,shape=(None,image_size,image_size,3))
    y = tf.placeholder(tf.int32,shape=(None,num_classes))
    with slim.arg_scope(vgg.vgg_arg_scope(weight_decay=0.0001)):
        logits,end_points = vgg.vgg_16(x,num_classes=num_classes,is_training=True)
        


    variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
    init_fn = slim.assign_from_checkpoint_fn(model_path,variables_to_restore)

    fc8_variables = slim.get_variables('vgg_16/fc8')
    fc8_init = tf.variables_initializer(fc8_variables)

    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,labels=y))

    fc8_optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    fc8_train_op = fc8_optimizer.minimize(loss,var_list=fc8_variables)

    full_optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.0001)
    full_train_op = full_optimizer.minimize(loss)

    correct_prediction = tf.equal(tf.argmax(logits,1),tf.argmax(y,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    # actual = tf.argmax(y,1)
    prediction = tf.argmax(logits,1)
        
    # tf.get_default_graph().finalize()  
    
    print("\nFinalized the graph ... !")


Loading default graph.....

Finalized the graph ... !


In [94]:
assert logits.graph == graph

# Training

In [95]:
with tf.Session(graph=graph,config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    images,labels = read_training_data()
    valid_images,valid_labels = read_validation_data()
    
    init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())  
    sess.run(init_op)
    
    init_fn(sess)
    sess.run(fc8_init)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    print("Started Training ... !")

    epoch=0
    try:
        while not coord.should_stop():
            img,lbl = sess.run([images,labels])
            img = img.astype(np.uint8)
            # print("Next batch : {}".format(img.shape,lbl.shape))
            sess.run(fc8_train_op,feed_dict={x:img,y:lbl})
            print("Loss : {:.2f}".format(sess.run(loss,feed_dict={x:img,y:lbl})))
            epoch+=1
    except tf.errors.OutOfRangeError as e:
        coord.request_stop(e)
        
    val_img,val_lbl = sess.run([valid_images,valid_labels])
    val_img = val_img.astype(np.uint8)
    print("Validation Accuracy after retraining for {} epoch : {:.2f}".format(epoch,sess.run(accuracy,feed_dict={x:val_img,y:val_lbl})))
    print(len(sess.run(prediction,feed_dict={x:val_img})))


#     try:
#         while not coord.should_stop():
#             img,lbl = sess.run([images,labels])
#             print("Next batch : {}".format(img.shape,lbl.shape))
#             sess.run(full_train_op)
#     except tf.errors.OutOfRangeError as e:
#         coord.request_stop(e)
#     val_img,val_lbl = sess.run([valid_images,valid_labels])
#     print("Validation accuracy after fune tuning : {}".format(sess.run(accuracy)))  

    coord.request_stop()
    coord.join(threads)
    sess.close()


INFO:tensorflow:Restoring parameters from ./vgg_16.ckpt
Started Training ... !
Loss : 1472.16
Loss : 1259.70
Loss : 2760.25
Loss : 2103.44
Loss : 2394.38
Loss : 3645.92
Loss : 3833.62
Loss : 1626.60
Loss : 1766.15
Loss : 974.02
Loss : 1511.09
Loss : 1259.14
Loss : 519.53
Loss : 882.81
Loss : 314.82
Loss : 763.79
Loss : 878.60
Loss : 761.53
Loss : 654.65
Loss : 445.76
Loss : 675.90
Loss : 180.58
Loss : 530.42
Loss : 668.54
Loss : 990.98
Loss : 563.93
Loss : 968.18
Loss : 185.87
Loss : 433.65
Loss : 358.65
Loss : 57.59
Loss : 566.99
Loss : 434.27
Loss : 244.52
Loss : 922.28
Loss : 46.98
Loss : 405.30
Loss : 325.35
Loss : 82.46
Loss : 710.17
Loss : 383.95
Loss : 165.17
Loss : 501.37
Loss : 228.52
Loss : 215.53
Loss : 344.00
Loss : 125.06
Loss : 537.87
Loss : 256.54
Loss : 204.62
Loss : 211.54
Loss : 2.39
Loss : 277.20
Loss : 244.27
Loss : 0.23
Loss : 323.13
Loss : 506.49
Loss : 159.35
Loss : 582.47
Loss : 44.92
Loss : 130.67
Loss : 190.47
Loss : 129.25
Loss : 241.10
Loss : 336.43
Loss : 1

# Experiment