# Train CNN to predict bounding boxes

## Load packages

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import BB_CNN

## Load data

In [2]:
dataset_train = tf.data.TFRecordDataset('train.record')
dataset_train = dataset_train.shuffle(buffer_size=2000)
iterator_train = dataset_train.make_initializable_iterator()
next_element_train = iterator_train.get_next()

dataset_val = tf.data.TFRecordDataset('val.record')
dataset_val = dataset_val.shuffle(buffer_size=2000)
iterator_val = dataset_val.make_initializable_iterator()
next_element_val = iterator_val.get_next()

## Train CNN

In [3]:
num_iter = 1000

feature = {'image/height': tf.FixedLenFeature([], tf.int64),
           'image/width': tf.FixedLenFeature([], tf.int64),
           'image/filename': tf.FixedLenFeature([], tf.string),
           'image/source_id': tf.FixedLenFeature([], tf.string),
           'image/key/sha256': tf.FixedLenFeature([], tf.string),
           'image/encoded': tf.FixedLenFeature([], tf.string),
           'image/format': tf.FixedLenFeature([], tf.string),
           'image/object/bbox/xmin': tf.FixedLenFeature([], tf.float32),
           'image/object/bbox/xmax': tf.FixedLenFeature([], tf.float32),
           'image/object/bbox/ymin': tf.FixedLenFeature([], tf.float32),
           'image/object/bbox/ymax': tf.FixedLenFeature([], tf.float32),
           'image/object/class/text': tf.FixedLenFeature([], tf.string),
           'image/object/class/label': tf.FixedLenFeature([], tf.int64),
           'image/object/difficult': tf.FixedLenFeature([], tf.int64),
           'image/object/truncated': tf.FixedLenFeature([], tf.int64),
           'image/object/view': tf.FixedLenFeature([], tf.string)}

config = tf.ConfigProto(device_count = {'GPU': 0})
with tf.Session(config=config)  as sess:
    bb_net = BB_CNN.BB_CNN(kernel_size = 7 * [3], kernel_stride = 7 * [1], num_filters =  7 * [4],
                           pool_size = 7 * [2], pool_stride = 7 * [2], hidden_dim = [100],
                           dropout = 0.5, weight_scale = 1)
    images = tf.placeholder(tf.float32, [1, 1280, 720, 3])
    train_mode = tf.placeholder(tf.bool)
    target_prob = tf.placeholder(tf.float32, [1])
    target_bb = tf.placeholder(tf.float32, [1, 4])
    bb_net.build(images, train_mode)
    bb_net.predict()
    bb_net.loss(target_prob, target_bb)
    train_step = tf.train.AdamOptimizer(1e-4).minimize(bb_net.loss)
    sess.run(tf.global_variables_initializer())
    sess.run(iterator_train.initializer)
    sess.run(iterator_val.initializer)
    for i in range(num_iter):
        try:
            next_example_train = sess.run(next_element_train)
            example_parsed = tf.parse_single_example(next_example_train, features=feature)
            example_parsed_out = sess.run(example_parsed)
            image = sess.run(tf.image.decode_jpeg(example_parsed_out['image/encoded']))
            xmin = max(example_parsed_out['image/object/bbox/xmin'], 0)
            xmax = min(example_parsed_out['image/object/bbox/xmax'], 1)
            ymin = max(example_parsed_out['image/object/bbox/ymin'], 0)
            ymax = min(example_parsed_out['image/object/bbox/ymax'], 1)
            x = image.reshape((1, 1280, 720, 3)) / 255.
            sess.run(train_step, feed_dict={images: x, train_mode: False, target_prob: [1], target_bb: [[xmin, ymin, xmax - xmin, ymax - ymin]]})
        except:
            print('The image ' + str(i + 1) + ' of the training set could not be loaded.')
        if (i + 1) % 100 == 0:
            net_loss = np.zeros(10)
            for j in range(10):
                try:
                    next_example_val = sess.run(next_element_val)
                    example_parsed = tf.parse_single_example(next_example_val, features=feature)
                    example_parsed_out = sess.run(example_parsed)
                    image = sess.run(tf.image.decode_jpeg(example_parsed_out['image/encoded']))
                    xmin = max(example_parsed_out['image/object/bbox/xmin'], 0)
                    xmax = min(example_parsed_out['image/object/bbox/xmax'], 1)
                    ymin = max(example_parsed_out['image/object/bbox/ymin'], 0)
                    ymax = min(example_parsed_out['image/object/bbox/ymax'], 1)
                    x = image.reshape((1, 1280, 720, 3)) / 255.
                    net_loss[j] = sess.run(bb_net.loss, feed_dict={images: x, train_mode: False, target_prob: [1], target_bb: [[xmin, ymin, np.log(xmax - xmin), np.log(ymax - ymin)]]})
                except:
                    print('The image ' + str(i // 100 + j + 1) + ' of the validation set could not be loaded.')
            print('loss after ' + str(i + 1) + ' iterations = ' + str(np.mean(net_loss)))

loss after 100 iterations = 4.18179235458
loss after 200 iterations = 4.06661677361
loss after 300 iterations = 1.62983140349
loss after 400 iterations = 1.99229063988
loss after 500 iterations = 2.37045666575
loss after 600 iterations = 1.59642532766
loss after 700 iterations = 1.26043538451
The image 718 of the training set could not be loaded.
loss after 800 iterations = 2.48356351256
The image 842 of the training set could not be loaded.
loss after 900 iterations = 1.12850987315
loss after 1000 iterations = 1.59097614288


In [6]:
j

9