# Train CNN to predict bounding boxes

## Load packages

In [2]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import BB_CNN

## Load data

In [3]:
dataset_train = tf.data.TFRecordDataset('train_with_negative.record')
dataset_train = dataset_train.shuffle(buffer_size=5000)
iterator_train = dataset_train.make_initializable_iterator()
next_element_train = iterator_train.get_next()

dataset_val = tf.data.TFRecordDataset('val_with_negative.record')
dataset_val = dataset_val.shuffle(buffer_size=5000)
iterator_val = dataset_val.make_initializable_iterator()
next_element_val = iterator_val.get_next()

## Train CNN

In [4]:
image_height = 224
image_width = 224

num_iter = 1000
val_every = 100
num_val = 10

feature_with_bb = {'image/encoded': tf.FixedLenFeature([], tf.string),
                   'image/object/bbox/xmin': tf.FixedLenFeature([], tf.float32),
                   'image/object/bbox/xmax': tf.FixedLenFeature([], tf.float32),
                   'image/object/bbox/ymin': tf.FixedLenFeature([], tf.float32),
                   'image/object/bbox/ymax': tf.FixedLenFeature([], tf.float32)}

feature_without_bb = {'image/encoded': tf.FixedLenFeature([], tf.string)}

config = tf.ConfigProto(device_count = {'GPU': 0})
with tf.Session(config=config)  as sess:
    bb_net = BB_CNN.BB_CNN(kernel_size = 13 * [3], kernel_stride = 13 * [1], num_filters =  2 * [64] + 2 * [128] + 3 * [256] + 6 * [512],
                           pool_size = 2 * [1, 2] + 3 * [1, 1, 2], pool_stride = 2 * [1, 2] + 3 * [1, 1, 2], hidden_dim = 2 * [4096],
                           dropout = 0.5, weight_scale = 1, file_name = 'vgg16.npy')
    #bb_net = BB_CNN.BB_CNN(kernel_size = 7 * [3], kernel_stride = 7 * [1], num_filters =  7 * [4],
    #                       pool_size = 7 * [2], pool_stride = 7 * [2], hidden_dim = [100],
    #                       dropout = 0.5, weight_scale = 1)
    images = tf.placeholder(tf.float32, [1, image_width, image_height, 3])
    train_mode = tf.placeholder(tf.bool)
    target_prob = tf.placeholder(tf.float32, [1])
    target_bb = tf.placeholder(tf.float32, [1, 4])
    bb_net.build(images, train_mode)
    bb_net.predict()
    bb_net.loss(target_prob, target_bb)
    train_step = tf.train.AdamOptimizer(1e-4).minimize(bb_net.loss)
    sess.run(tf.global_variables_initializer())
    sess.run(iterator_train.initializer)
    sess.run(iterator_val.initializer)
    for i in range(num_iter):
        next_example_train = sess.run(next_element_train)
        try:
            parsed_out = sess.run(tf.parse_single_example(next_example_train, features=feature_with_bb))
            
            xmin = max(parsed_out['image/object/bbox/xmin'], 0)
            xmax = min(parsed_out['image/object/bbox/xmax'], 1)
            ymin = max(parsed_out['image/object/bbox/ymin'], 0)
            ymax = min(parsed_out['image/object/bbox/ymax'], 1)
            
            prob = 1
        except:
            parsed_out = sess.run(tf.parse_single_example(next_example_train, features=feature_without_bb))
            
            xmin = 0
            xmax = 1
            ymin = 0
            ymax = 1
            
            prob = 0
            
        tf_image = tf.cast(tf.image.decode_jpeg(parsed_out['image/encoded']), tf.float32) / 255.
        
        # VGG16 only
        tf_image = tf.image.resize_images(tf_image, [126, 224])
        tf_image = tf.image.resize_image_with_crop_or_pad(tf_image, 224, 224)

        image = sess.run(tf_image)
        x = image.reshape((1, image_width, image_height, 3))
        sess.run(train_step, feed_dict={images: x, train_mode: False, target_prob: [prob], target_bb: [[xmin, ymin, np.log(xmax - xmin), np.log(ymax - ymin)]]})
        
        if (i + 1) % val_every == 0:
            net_loss = np.zeros(num_val)
            for j in range(num_val):
                next_example_train = sess.run(next_element_train)
                try:
                    parsed_out = sess.run(tf.parse_single_example(next_example_train, features=feature_with_bb))

                    xmin = max(parsed_out['image/object/bbox/xmin'], 0)
                    xmax = min(parsed_out['image/object/bbox/xmax'], 1)
                    ymin = max(parsed_out['image/object/bbox/ymin'], 0)
                    ymax = min(parsed_out['image/object/bbox/ymax'], 1)

                    prob = 1
                except:
                    parsed_out = sess.run(tf.parse_single_example(next_example_train, features=feature_without_bb))

                    xmin = 0
                    xmax = 1
                    ymin = 0
                    ymax = 1

                    prob = 0
                    
                tf_image = tf.cast(tf.image.decode_jpeg(parsed_out['image/encoded']), tf.float32) / 255.
                
                # VGG16 only
                tf_image = tf.image.resize_images(tf_image, tf.constant([126, 224]))
                tf_image = tf.image.resize_image_with_crop_or_pad(tf_image, 224, 224)
                  
                image = sess.run(tf_image)
                x = image.reshape((1, image_width, image_height, 3))
                net_loss[j] = sess.run(bb_net.loss, feed_dict={images: x, train_mode: False, target_prob: [prob], target_bb: [[xmin, ymin, np.log(xmax - xmin), np.log(ymax - ymin)]]})
            print('loss after ' + str(i + 1) + ' iterations = ' + str(np.mean(net_loss)))

KeyboardInterrupt: 