In [None]:
import math
import numpy as np
import tensorflow as tf
import datetime
print("Tensorflow version " + tf.__version__)
AUTOTUNE = tf.data.AUTOTUNE

In [None]:
SHARDS = 5
CLASSES = [b'angel', b'cat', b'crown', b'the_eiffel_tower', b'the_mona_lisa']

In [None]:
GCS_PROJECT_ID = "<TO DEFINE>"

In [None]:
GCS_TRAINING_PATTERN = f'gs://{GCS_PROJECT_ID}/raw_images/training_data/*/*.png'
GCS_TRAINING_TFRECORDS = f'gs://{GCS_PROJECT_ID}/tfrecord_data/training_data/'

In [None]:
nb_images = len(tf.io.gfile.glob(GCS_TRAINING_PATTERN))
shard_size = math.ceil(1.0 * nb_images / SHARDS)

In [None]:
# images are arranged in folders with corresponding labels
def decode_image_and_label(filename):
    bits = tf.io.read_file(filename)
    image = tf.io.decode_png(bits)
    label = tf.strings.split(tf.expand_dims(filename, axis=-1), sep='/')
    label = label.values[-2]
    return image, label

In [None]:
def recompress_image(image, label):
    image = tf.cast(image, tf.uint8)
    image = tf.image.encode_jpeg(image, optimize_size=True, chroma_downsampling=False)
    return image, label

In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

In [None]:
def to_tfrecord(img_bytes, label):  
    class_num = np.argmax(np.array(CLASSES)==label) 
    one_hot_class = np.eye(len(CLASSES))[class_num]
    
    feature = {
      "image": _bytes_feature([img_bytes]), 
      "class_num": _int64_feature([class_num]),
      "label": _bytes_feature([label]),         
      "one_hot_class": _float_feature(one_hot_class.tolist())
    }
    
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:
filenames = tf.data.Dataset.list_files(GCS_TRAINING_PATTERN, seed=35155) # This also shuffles the images
quickdraw_dataset = filenames.map(decode_image_and_label, num_parallel_calls=AUTOTUNE)
quickdraw_dataset = quickdraw_dataset.map(recompress_image, num_parallel_calls=AUTOTUNE)
# sharding: there will be one "batch" of images per file 
quickdraw_dataset = quickdraw_dataset.batch(shard_size)

In [None]:
def write_dataset(dataset, filepath):
    print(f'Starting writing {datetime.datetime.now().strftime("%H:%M:%S")}')
    dataset = dataset.enumerate()
    for shard, (image, label) in dataset:
        shard_size = image.numpy().shape[0]
        filename = filepath + f"quickdraw_dataset{str(shard.numpy()).rjust(2, '0')}_{shard_size}.tfrec"
        print(f'Starting file writing {datetime.datetime.now().strftime("%H:%M:%S")}')

        with tf.io.TFRecordWriter(filename) as tf_writer:
            for i in range(shard_size):
                example = to_tfrecord(image.numpy()[i], label.numpy()[i])
                tf_writer.write(example.SerializeToString())
            print(f'Wrote file {filename} containing {shard_size} records')
            print(f'Wrote file at {datetime.datetime.now().strftime("%H:%M:%S")}')

In [None]:
write_dataset(quickdraw_dataset, GCS_TRAINING_TFRECORDS)

In [None]:
GCS_VALIDATION_PATTERN = f'gs://{GCS_PROJECT_ID}/raw_images/validation_data/*/*.png'
GCS_VALIDATION_TFRECORDS = f'gs://{GCS_PROJECT_ID}/tfrecord_data/validation_data/'

In [None]:
nb_images_valid = len(tf.io.gfile.glob(GCS_VALIDATION_PATTERN))
shard_size_valid = math.ceil(1.0 * nb_images_valid / SHARDS)

In [None]:
filenames_valid = tf.data.Dataset.list_files(GCS_VALIDATION_PATTERN, seed=35155) # This also shuffles the images
quickdraw_valid_dataset = filenames_valid.map(decode_image_and_label, num_parallel_calls=AUTOTUNE)
quickdraw_valid_dataset = quickdraw_valid_dataset.map(recompress_image, num_parallel_calls=AUTOTUNE)
quickdraw_valid_dataset = quickdraw_valid_dataset.batch(shard_size_valid)

In [None]:
write_dataset(quickdraw_valid_dataset, GCS_VALIDATION_TFRECORDS)