In [2]:
import tensorflow as tf
from object_detection.utils import dataset_util
import matplotlib.pyplot as plt

In [3]:
category_names = ['Bier', 'Bier Maß', 'Weißbier', 'Cola', 'Wasser', 'Curry-Wurst', 'Weißwein',
                   'A-Schorle', 'Jägermeister', 'Pommes', 'Burger', 'Williamsbirne', 'Alm-Breze', 'Brotzeitkorb',
                   'Käsespätzle']


def create_tf_example(label_line, path_to_image, plot=False):
    """
    This function converts an image with its labels into a tensorflow example.
    label_line: the label creates with our annotation tool. 
                OpenCV style i.e. "<file_name> <number_of_labels> (<category> <x> <y> <width> <height> )*" 
    path_to_image: the directory where the image is located 
    plot: weather the image should be plotted for debug purposes 
    returns: a tf example 
    """
    # example for label_line: 1526752420389_70.jpg 3 12 845 590 520 470 12 250 410 495 540 
    data_split = label_line.split(' ')

    if int(data_split[1]) == 0:
        # no labels
        return None

    height = 1080  # Image height
    width = 1920  # Image width
    if plot:
        img = plt.imread(path_to_image+data_split[0])
        plt.imshow(img)
        plt.show()
    
    # create tf.Example
    filename = str.encode(data_split[0]) # Filename of the image. Empty if image is not from file
    encoded_image_data = tf.gfile.FastGFile(path_to_image+data_split[0], 'rb').read() # Encoded image bytes
    image_format = str.encode('jpeg') # b'jpeg' or b'png'
    xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = [] # List of normalized right x coordinates in bounding box
             # (1 per box)
    ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = [] # List of normalized bottom y coordinates in bounding box
             # (1 per box)
    classes_text = [] # List of string class name of bounding box (1 per box)
    classes = [] # List of integer class id of bounding box (1 per box)
    amt_bb_label = int(data_split[1])
    amt_bb = int(len(data_split)/5)
    if not amt_bb == amt_bb_label:
        print("Incorrectly number of items in label")
    for i in range(amt_bb):
        x_s = float(data_split[i*5+3])
        y_s = float(data_split[i*5+4])
        xmins.append(x_s/float(width))
        xmaxs.append((x_s+float(data_split[i*5+5]))/float(width))
        ymins.append(y_s/float(height))
        ymaxs.append((y_s+float(data_split[i*5+6]))/float(height))
        classes.append(int(data_split[i*5+2]))
        classes_text.append(str.encode(category_names[classes[i]]))

    tf_label_and_data = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(filename),
      'image/source_id': dataset_util.bytes_feature(filename),
      'image/encoded': dataset_util.bytes_feature(encoded_image_data),
      'image/format': dataset_util.bytes_feature(image_format),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_label_and_data

In [13]:
def load_and_save_all_labels(path_to_labels, categories, out_path):
    """
    This function creates the tfrecord file 
    path_to_labels: path to the dataset created with the ExtractThumbnails script 
                    each folder must be annotated with the AnnotationTool script
    categories: the categories that are annotated 
    out_path: where the tfrecord should be saved  
    """
    writer_train = tf.python_io.TFRecordWriter(out_path)
    count = 0
    for c in categories:
        label_file = open(path_to_labels+c+'/files.txt', 'r')
        data_list = []
        for line in label_file:
            tf_data = create_tf_example(line, path_to_labels+c+'/')
            if tf_data is not None:
                data_list.append(tf_data)
                count += 1
            print("%d images completed"%count,end='\r')
        label_file.close()
        for i in range(len(data_list)):
            writer_train.write(data_list[i].SerializeToString())
    writer_train.close()
    print("Found %d images for training"%count)

In [14]:
load_and_save_all_labels('../../data_split/train/thumbnails/', category_names, 'train_data')

Found 1107 images for training
