In [263]:
import os
import math
import tensorflow as tf
import six

tf.random.set_seed(1)

In [291]:
LIST_FOLDER = './data/VOCdevkit/VOC2012/ImageSets/Segmentation'
OUTPUT_DIR = "./data/VOCdevkit/tfrecord"

IMAGE_FOLDER = './data/VOCdevkit/VOC2012/JPEGImages'
IMAGE_FORMAT = 'jpg'

SEGMENTATION_FOLDER = "./data/VOCdevkit/VOC2012/SegmentationClass"
SEGMENTATION_FORMAT = 'png'

NUM_SHARDS = 4

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def img_seg_to_example(filename, image_data, seg_data):
    feature = {
        'image/encoded': _bytes_list_feature(image_data),
        'image/filename': _bytes_list_feature(filename),
        'image/channels': _int64_feature(3),
        'image/segmentation/class/encoded': (_bytes_list_feature(seg_data)),
    }
    
    return tf.train.Example(features=tf.train.Features(feature=feature))

def convert_dataset(dataset_split):
    dataset = os.path.basename(dataset_split)[:-4]
    
    filenames = [x.strip('\n') for x in open(dataset_split, 'r')]
    num_images = len(filenames)
    
    print(f"Processing {dataset}: {num_images} Images")
    num_per_shard = int(math.ceil(num_images / NUM_SHARDS))
    
    for shard_id in range(NUM_SHARDS):
        output_filename = os.path.join(OUTPUT_DIR, '%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, NUM_SHARDS))
        with tf.io.TFRecordWriter(output_filename) as writer:
            start_idx = shard_id * num_per_shard
            end_idx = min((shard_id + 1) * num_per_shard, num_images)
            for i in range(start_idx, 2):
                # READ IMAGE             
                image_filename = os.path.join(IMAGE_FOLDER, filenames[i] + '.' + IMAGE_FORMAT)
                image_data = tf.io.gfile.GFile(image_filename, mode='rb').read()

                # READ SEGMENTATION
                seg_filename = os.path.join(SEGMENTATION_FOLDER, filenames[i] + '.' + SEGMENTATION_FORMAT)
                seg_data = tf.io.gfile.GFile(seg_filename, mode='rb').read()

                example = img_seg_to_example(filenames[i], image_data, seg_data)
                writer.write(example.SerializeToString())
            
def main():
    dataset_splits = tf.io.gfile.glob(os.path.join(LIST_FOLDER, "*.txt"))
    for dataset_split in dataset_splits:
        convert_dataset(dataset_split)

In [292]:
if __name__ == "__main__":
    main()

Processing train: 1464 Images
Processing val: 1449 Images
