In [324]:
import os
import math
import tensorflow as tf
import six

tf.random.set_seed(1)

In [327]:
LIST_DIR = './data/VOCdevkit/VOC2012/ImageSets/Segmentation'
OUTPUT_DIR = "./data/VOCdevkit/tfrecord"

IMAGE_DIR = './data/VOCdevkit/VOC2012/JPEGImages'
IMAGE_FORMAT = 'jpg'

SEGMENTATION_DIR = "./data/VOCdevkit/VOC2012/SegmentationClass"
SEGMENTATION_FORMAT = 'png'

NUM_SHARDS = 6

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def img_seg_to_example(filename, image_data, seg_data):
    image_shape = tf.image.decode_jpeg(image_data).shape
    
    feature = {
        'image/encoded': _bytes_list_feature(image_data),
        'image/format': _bytes_list_feature(IMAGE_FORMAT),
        'image/filename': _bytes_list_feature(filename),
        'image/height': _int64_feature(image_shape[0]),
        'image/width': _int64_feature(image_shape[1]),
        'image/channels':  _int64_feature(image_shape[2]),
        'image/segmentation/class/encoded': (_bytes_list_feature(seg_data)),
        'image/segmentation/class/format': _bytes_list_feature(SEGMENTATION_FORMAT),
    }
    
    return tf.train.Example(features=tf.train.Features(feature=feature))

def convert_dataset(dataset_split):
    dataset = os.path.basename(dataset_split)[:-4]
    
    filenames = [x.strip('\n') for x in open(dataset_split, 'r')]
    num_images = len(filenames)
    
    print(f"Processing {dataset}: {num_images} Images")
    num_per_shard = int(math.ceil(num_images / NUM_SHARDS))
    
    for shard_id in range(NUM_SHARDS):
        output_filename = os.path.join(OUTPUT_DIR, '%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, NUM_SHARDS))
        
        with tf.io.TFRecordWriter(output_filename) as writer:
            start_idx = shard_id * num_per_shard
            end_idx = min((shard_id + 1) * num_per_shard, num_images)
            
            for i in range(start_idx, end_idx):
                # READ IMAGE  
                image_filename = os.path.join(IMAGE_DIR, filenames[i] + '.' + IMAGE_FORMAT)
                image_data = tf.io.gfile.GFile(image_filename, 'rb').read()

                # READ SEGMENTATION
                seg_filename = os.path.join(SEGMENTATION_DIR, filenames[i] + '.' + SEGMENTATION_FORMAT)
                seg_data = tf.io.gfile.GFile(seg_filename, mode='rb').read()
                
                # CREATE TFRECORD EXAMPLE
                example = img_seg_to_example(filenames[i], image_data, seg_data)

                # WRITE TO DISK
                writer.write(example.SerializeToString())
            
def main():
    dataset_splits = tf.io.gfile.glob(os.path.join(LIST_DIR, "*.txt"))
    for dataset_split in dataset_splits:
        convert_dataset(dataset_split)

In [328]:
if __name__ == "__main__":
    main()

Processing train: 1464 Images
Processing val: 1449 Images


In [430]:
IMG_DIM = 224
BATCH_SIZE = 32
BUFFER_SIZE = 500

def parse_image(content, channels):
    return tf.cond(
        tf.image.is_jpeg(content), 
        lambda: tf.image.decode_jpeg(content, channels), 
        lambda: tf.image.decode_png(content, channels)
    )
    
def parse_dataset(example_proto):
    features = {
        'image/encoded':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/filename':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/format':
            tf.io.FixedLenFeature((), tf.string, default_value='jpeg'),
        'image/height':
            tf.io.FixedLenFeature((), tf.int64, default_value=0),
        'image/width':
            tf.io.FixedLenFeature((), tf.int64, default_value=0),
        'image/segmentation/class/encoded':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/segmentation/class/format':
            tf.io.FixedLenFeature((), tf.string, default_value='png'),
    }
    
    parsed_feature = tf.io.parse_single_example(example_proto, features)
    
    image = parse_image(parsed_feature['image/encoded'], channels=3)
    image = tf.image.resize(image, (IMG_DIM, IMG_DIM))
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    label = parse_image(parsed_feature['image/segmentation/class/encoded'], channels=1)
    label = tf.image.resize(label, (IMG_DIM, IMG_DIM))
    label = tf.image.convert_image_dtype(label, tf.int64)
    
    image_name = parsed_feature['image/filename']
    
    sample = {
        "image_name": image_name,
        "image": image,
        "label": label,
        "height": IMG_DIM,
        "width": IMG_DIM,
    }
    
    return sample

list_ds = tf.data.Dataset.list_files(OUTPUT_DIR + "/train-*.tfrecord")

dataset = (tf.data
    .TFRecordDataset(list_ds)
    .map(parse_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE))
    

In [431]:
for i in dataset.take(1):
    print(i)

{'image_name': <tf.Tensor: id=239035, shape=(32,), dtype=string, numpy=
array([b'2010_004960', b'2010_002962', b'2010_003534', b'2008_000238',
       b'2008_002067', b'2010_003342', b'2008_000832', b'2010_001329',
       b'2007_008948', b'2010_000986', b'2010_003887', b'2007_008801',
       b'2010_001630', b'2008_000495', b'2010_004963', b'2010_001933',
       b'2008_003180', b'2008_001413', b'2010_002218', b'2010_001154',
       b'2010_003798', b'2008_002248', b'2007_009216', b'2008_001137',
       b'2010_002556', b'2007_009464', b'2010_004072', b'2010_003250',
       b'2007_008821', b'2008_004365', b'2008_003779', b'2007_009605'],
      dtype=object)>, 'image': <tf.Tensor: id=239034, shape=(32, 224, 224, 3), dtype=float32, numpy=
array([[[[8.96086349e+01, 1.24458740e+02, 1.29896240e+02],
         [1.17411316e+02, 1.52563095e+02, 1.73422684e+02],
         [1.14934113e+02, 1.40332504e+02, 1.56468124e+02],
         ...,
         [8.08107605e+01, 1.12998253e+02, 2.64135838e+01],
        