## Optimizing tf.Dataset

A key challenge :)

In [1]:
import tensorflow as tf

from worldmodels.data.tf_records import encode_floats
from worldmodels.data.tf_records import shuffle_samples

In [2]:
def shuffle_full_episodes(parse_func, records, episode_length, batch_size, num_cpu=4):
    """ used in memory training """
    files = tf.data.Dataset.from_tensor_slices(records)

    dataset = files.interleave(
        lambda x: tf.data.TFRecordDataset(x),
        block_length=episode_length,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
        cycle_length=tf.data.experimental.AUTOTUNE
        # cycle_length=num_cpu,
    )
    dataset = dataset.map(parse_func, num_parallel_calls=num_cpu)
    dataset = dataset.batch(episode_length)
    dataset = dataset.shuffle(1000)
    dataset = dataset.repeat(None)
    dataset = dataset.batch(batch_size)
    dataset = iter(dataset)
    return dataset

In [3]:
home = os.environ['HOME']
multiple = os.path.join(home, 'world-models-experiments', 'tf-record-opt-multiple')
os.makedirs(multiple, exist_ok=True)
single = os.path.join(home, 'world-models-experiments', 'tf-record-opt-single')
os.makedirs(single, exist_ok=True)

def parse_func_multiple(example_proto):
    features = {
        'sample': tf.io.FixedLenFeature((32,), tf.float32)
    }
    parsed_features = tf.io.parse_single_example(example_proto, features)
    return parsed_features['sample']

def parse_func_single(example_proto):
    features = {
        'sample': tf.io.FixedLenFeature((episode_len, 32), tf.float32)
    }
    parsed_features = tf.io.parse_single_example(example_proto, features)
    return parsed_features['sample']

def write_test_data_multiple_records(name, episode_len):
    episode = np.random.rand(episode_len, 32).astype(np.float32)

    with tf.io.TFRecordWriter(name) as writer:
        for sample in episode:
            encoded = encode_floats({'sample': sample})
            writer.write(encoded)

    return episode

def write_test_data_single_records(name, episode_len):
    episode = np.random.rand(episode_len, 32).astype(np.float32)

    with tf.io.TFRecordWriter(name) as writer:
        encoded = encode_floats({'sample': episode})
        writer.write(encoded)

    return episode

In [4]:
def make_records(num, direct, rec_func):
    records = ['ep{}'.format(n) for n in range(num)]
    records = [
        os.path.join(direct, '{}.tfrecord'.format(name))
        for name in records
    ]

    episodes = [rec_func(rec, episode_len) for rec in records]
    
    return records
    
episode_len = 1000
m_rec = make_records(100, multiple, write_test_data_multiple_records)
s_rec = make_records(100, single, write_test_data_single_records)

In [5]:
m_ds = shuffle_full_episodes(parse_func_multiple, m_rec, episode_len, batch_size=100)

In [6]:
%%timeit
next(m_ds).shape

4.69 s ± 428 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
s_ds = shuffle_samples(parse_func_single, s_rec, batch_size=10)

In [8]:
%%timeit
next(s_ds).shape

3.41 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
