In [1]:
import tensorflow as tf
from object_detection.utils import dataset_util
import os
import imagesize
import pickle

In [2]:
# the food related classes
id_to_class = {'/m/01599': (1, 'Beer'), '/m/015wgc': (2, 'Croissant'), '/m/01b9xk': (3, 'Hot dog'), '/m/01_bhs': (4, 'Fast food'), '/m/01dwsz': (5, 'Waffle'), '/m/01dwwc': (6, 'Pancake'), '/m/01f91_': (7, 'Pretzel'), '/m/01fb_0': (8, 'Bagel'), '/m/01hrv5': (9, 'Popcorn'), '/m/01j3zr': (10, 'Burrito'), '/m/01nkt': (11, 'Cheese'), '/m/01tcjp': (12, 'Muffin'), '/m/01ww8y': (13, 'Snack'), '/m/01z1kdw': (14, 'Juice'), '/m/021mn': (15, 'Cookie'), '/m/024g6': (16, 'Cocktail'), '/m/0270h': (17, 'Dessert'), '/m/0271t': (18, 'Drink'), '/m/0284d': (19, 'Dairy'), '/m/02g30s': (20, 'Guacamole'), '/m/02jnhm': (21, 'Tin can'), '/m/02vqfm': (22, 'Coffee'), '/m/02wbm': (23, 'Food'), '/m/02xwb': (24, 'Fruit'), '/m/02y6n': (25, 'French fries'), '/m/033cnk': (26, 'Egg'), '/m/04zpv': (27, 'Milk'), '/m/052lwg6': (28, 'Baked goods'), '/m/05z55': (29, 'Pasta'), '/m/0663v': (30, 'Pizza'), '/m/06nwz': (31, 'Seafood'), '/m/06pcq': (32, 'Submarine sandwich'), '/m/07030': (33, 'Sushi'), '/m/07clx': (34, 'Tea'), '/m/07crc': (35, 'Taco'), '/m/081qc': (36, 'Wine'), '/m/09728': (37, 'Bread'), '/m/09tvcd': (38, 'Wine glass'), '/m/0cdn1': (39, 'Hamburger'), '/m/0cxn2': (40, 'Ice cream'), '/m/0f4s2w': (41, 'Vegetable'), '/m/0fszt': (42, 'Cake'), '/m/0grw1': (43, 'Salad'), '/m/0jy4k': (44, 'Doughnut'), '/m/0l515': (45, 'Sandwich')}

In [3]:
""" 
This script creates an tfrecord file from the Open Images Dataset V4. 
We used this in order to pretrain our model, hoping that it would increase its performance 
"""
def create_tf_example(bb):
    """ this function creates an tf example """
    width, height = imagesize.get('images/' + bb[0][0] + '.jpg')
    filename = str.encode(bb[0][0] + '.jpg')
    encoded_image_data = tf.gfile.FastGFile('images/' + bb[0][0] + '.jpg', 'rb').read()
    image_format = str.encode('jpeg')
    xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = [] # List of normalized right x coordinates in bounding box
             # (1 per box)
    ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = [] # List of normalized bottom y coordinates in bounding box
             # (1 per box)
    classes_text = [] # List of string class name of bounding box (1 per box)
    classes = [] # List of integer class id of bounding box (1 per box)
    for b in bb:
        xmins.append(float(b[4]))
        xmaxs.append(float(b[5]))
        ymins.append(float(b[6]))
        ymaxs.append(float(b[7]))
        class_label = id_to_class[b[2]]
        classes.append(class_label[0])
        classes_text.append(str.encode(class_label[1]))
    tf_label_and_data = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(filename),
      'image/source_id': dataset_util.bytes_feature(filename),
      'image/encoded': dataset_util.bytes_feature(encoded_image_data),
      'image/format': dataset_util.bytes_feature(image_format),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_label_and_data

In [4]:
with open('bb', 'rb') as f:
    bbd = pickle.load(f)

In [5]:
import contextlib2
from object_detection.dataset_tools import tf_record_creation_util
num_shards=100
output_filebase='sharded_data/oid_data'

""" create sharded tf record file """
with contextlib2.ExitStack() as tf_record_close_stack:
  output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
      tf_record_close_stack, output_filebase, num_shards)
  for index, bb in enumerate(bbd.values()):
    tf_data = create_tf_example(bb)
    output_shard_index = index % num_shards
    output_tfrecords[output_shard_index].write(tf_data.SerializeToString())

Instructions for updating:
Use tf.gfile.GFile.
