In [None]:
import tensorflow as tf
import yaml
import os
import sys
from tqdm import trange

# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from object_detection.utils import dataset_util

In [None]:
LABELS =  {
    "Green" : 1,
    "GreenStraight" : 2,
    "GreenStraightLeft" : 3,
    "GreenStraightRight" : 4,
    "GreenLeft" : 5,
    "GreenRight" : 6,
    "Yellow" : 7,
    "Red" : 8,
    "RedStraight" : 9,
    "RedStraightLeft" : 10,
    "RedStraightRight" : 11,
    "RedLeft" : 12,
    "RedRight" : 13,
    "off" : 14
}

In [None]:
def create_tf_example(example):
    height = 720
    width = 1280
    filename = example['file_name'].encode()
    encoded_image = tf.gfile.GFile(example['abs_path'], 'rb').read()
    image_format = b'png'

    xmins = []        # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = []        # List of normalized right x coordinates in bounding box (1 per box)
    ymins = []        # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = []        # List of normalized bottom y coordinates in bounding box (1 per box)
    classes_text = [] # List of string class name of bounding box (1 per box)
    classes = []      # List of integer class id of bounding box (1 per box)

    for box in example['boxes']:
        xmins.append(box['x_min'] / width)
        xmaxs.append(box['x_max'] / width)
        ymins.append(box['y_min'] / height)
        ymaxs.append(box['y_max'] / height)
        classes_text.append(box['label'].encode())
        classes.append(LABELS[box['label']])

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_image),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))

    return tf_example

In [None]:
DIRECTORY = "/home/apc/projects/bosh_small_traffic_lights_dataset"

TRAIN_YAML = "train.yaml"
TRAIN_OUTPUT = "train.tfrecord"

TEST_YAML = "test.yaml"
TEST_DIRECTORY = 'rgb/test'
TEST_OUTPUT = "test.tfrecord"

In [None]:
train_yaml_path = os.path.join(DIRECTORY, TRAIN_YAML)
train_data = yaml.load(open(train_yaml_path, 'rb').read())

In [None]:
for i in trange(len(train_data)):
    relative_path = train_data[i]['path']
    _, file_name = os.path.split(relative_path)
    train_data[i]['abs_path'] = os.path.join(DIRECTORY, relative_path)
    train_data[i]['file_name'] = file_name

In [None]:
output_path = os.path.join(DIRECTORY, TRAIN_OUTPUT)
train_writer = tf.python_io.TFRecordWriter(output_path)

for i in trange(len(train_data)):
    tf_example = create_tf_example(train_data[i])
    train_writer.write(tf_example.SerializeToString())

train_writer.close()

In [None]:
test_yaml_path = os.path.join(DIRECTORY, TEST_YAML)
test_data = yaml.load(open(test_yaml_path, 'rb').read())

In [None]:
for i in trange(len(test_data)):
    relative_path = test_data[i]['path']
    _, file_name = os.path.split(relative_path)
    test_data[i]['abs_path'] = os.path.join(DIRECTORY, TEST_DIRECTORY, file_name)
    test_data[i]['file_name'] = file_name

In [None]:
test_output_path = os.path.join(DIRECTORY, TEST_OUTPUT)
test_writer = tf.python_io.TFRecordWriter(test_output_path)

for i in trange(len(test_data)):
    tf_example = create_tf_example(test_data[i])
    test_writer.write(tf_example.SerializeToString())

test_writer.close()