In [2]:
import tensorflow as tf
tf.executing_eagerly()

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def encode(features):
    package = {}
    for key, value in features.items():
        package[key] = _int64_feature(value.flatten().tolist())
    print(package)
    example_proto = tf.train.Example(features=tf.train.Features(feature=package))
    return example_proto.SerializeToString()

def _parse(example_proto):
    features = {
        'feature': tf.io.FixedLenFeature((1,), tf.int64)
    }
    parsed_features = tf.io.parse_single_example(example_proto, features)
    print(parsed_features)
    return parsed_features['feature']

## Sample and write records



In [4]:
record0 = np.arange(4)
record1 = np.arange(10, 14)
record2 = np.arange(100, 104)

records = [record0, record1, record2]

In [5]:
results_path = os.path.join(
    os.environ['HOME'],
    'world-models-experiments/tf-records-debug'
)

!rm -rf ~/world-models-experiments/tf-records-debug

os.makedirs(results_path, exist_ok=True)
paths = []
for num, rec in enumerate(records):
    path = '{}/rec{}.tfrecord'.format(results_path, num)
    paths.append(path)
    with tf.io.TFRecordWriter(path) as writer:
        for act in rec:
            encoded = encode({'feature': act})
            writer.write(encoded)

{'feature': int64_list {
  value: 0
}
}
{'feature': int64_list {
  value: 1
}
}
{'feature': int64_list {
  value: 2
}
}
{'feature': int64_list {
  value: 3
}
}
{'feature': int64_list {
  value: 10
}
}
{'feature': int64_list {
  value: 11
}
}
{'feature': int64_list {
  value: 12
}
}
{'feature': int64_list {
  value: 13
}
}
{'feature': int64_list {
  value: 100
}
}
{'feature': int64_list {
  value: 101
}
}
{'feature': int64_list {
  value: 102
}
}
{'feature': int64_list {
  value: 103
}
}


In [6]:
!ls ~/world-models-experiments/tf-records-debug

rec0.tfrecord rec1.tfrecord rec2.tfrecord


## Make a simple, unshuffled dataset

A simple dataset, which will read record0 then record1

In [7]:
files = tf.data.TFRecordDataset(paths)
dataset = files.map(_parse)

dataset = iter(dataset)
for _ in range(12):
    print(next(dataset))

{'feature': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=(1,) dtype=int64>}
tf.Tensor([0], shape=(1,), dtype=int64)
tf.Tensor([1], shape=(1,), dtype=int64)
tf.Tensor([2], shape=(1,), dtype=int64)
tf.Tensor([3], shape=(1,), dtype=int64)
tf.Tensor([10], shape=(1,), dtype=int64)
tf.Tensor([11], shape=(1,), dtype=int64)
tf.Tensor([12], shape=(1,), dtype=int64)
tf.Tensor([13], shape=(1,), dtype=int64)
tf.Tensor([100], shape=(1,), dtype=int64)
tf.Tensor([101], shape=(1,), dtype=int64)
tf.Tensor([102], shape=(1,), dtype=int64)
tf.Tensor([103], shape=(1,), dtype=int64)


## Unshuffled & batched

In [10]:
files = tf.data.TFRecordDataset(paths)
dataset = files.map(_parse)
dataset = dataset.batch(4)/

dataset = iter(dataset)
for _ in range(6):
    print(next(dataset))

{'feature': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=(1,) dtype=int64>}
tf.Tensor(
[[0]
 [1]
 [2]
 [3]], shape=(4, 1), dtype=int64)
tf.Tensor(
[[10]
 [11]
 [12]
 [13]], shape=(4, 1), dtype=int64)
tf.Tensor(
[[100]
 [101]
 [102]
 [103]], shape=(4, 1), dtype=int64)


StopIteration: 

## Shuffled dataset

In [37]:
files = tf.data.TFRecordDataset(paths)
dataset = files.map(_parse).shuffle(10)

dataset = iter(dataset)
for _ in range(12):
    print(next(dataset))

{'feature': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=(1,) dtype=int64>}
tf.Tensor([12], shape=(1,), dtype=int64)
tf.Tensor([13], shape=(1,), dtype=int64)
tf.Tensor([0], shape=(1,), dtype=int64)
tf.Tensor([100], shape=(1,), dtype=int64)
tf.Tensor([103], shape=(1,), dtype=int64)
tf.Tensor([3], shape=(1,), dtype=int64)
tf.Tensor([11], shape=(1,), dtype=int64)
tf.Tensor([1], shape=(1,), dtype=int64)
tf.Tensor([2], shape=(1,), dtype=int64)
tf.Tensor([102], shape=(1,), dtype=int64)
tf.Tensor([10], shape=(1,), dtype=int64)
tf.Tensor([101], shape=(1,), dtype=int64)


## Batch, no shuffle

In [54]:
files = tf.data.TFRecordDataset(paths)
dataset = files.map(_parse).batch(2)

dataset = iter(dataset)
for _ in range(4):
    print(next(dataset).numpy())

{'feature': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=(1,) dtype=int64>}
[[0]
 [1]]
[[2]
 [3]]
[[10]
 [11]]
[[12]
 [13]]


## Batch and Shuffle

Because of a small shuffle buffer, we get no shuffling

In [65]:
files = tf.data.TFRecordDataset(paths)
dataset = files.map(_parse).shuffle(1).batch(2)

dataset = iter(dataset)
for _ in range(4):
    print(next(dataset).numpy())

{'feature': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=(1,) dtype=int64>}
[[0]
 [1]]
[[2]
 [3]]
[[10]
 [11]]
[[12]
 [13]]


In [73]:
files = tf.data.TFRecordDataset(paths)
dataset = files.map(_parse).shuffle(10).batch(2)

dataset = iter(dataset)
for _ in range(4):
    print(next(dataset).numpy())

{'feature': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=(1,) dtype=int64>}
[[11]
 [13]]
[[3]
 [0]]
[[2]
 [1]]
[[100]
 [ 10]]


In [87]:
files = tf.data.Dataset.from_tensor_slices(paths)

batch_size = 2

dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4, block_length=16)
dataset = dataset.shuffle(20)
dataset = dataset.apply(tf.data.experimental.map_and_batch(map_func=_parse, batch_size=batch_size,         drop_remainder=True,
        num_parallel_batches=4))
dataset = dataset.repeat().prefetch(10)

dataset = iter(dataset)
for _ in range(10000):
    print(next(dataset).numpy())

{'feature': <tf.Tensor 'ParseSingleExample/ParseSingleExample:0' shape=(1,) dtype=int64>}
[[3]
 [0]]
[[102]
 [  1]]
[[10]
 [11]]
[[101]
 [ 12]]
[[  2]
 [103]]
[[ 13]
 [100]]
[[103]
 [102]]
[[10]
 [ 2]]
[[11]
 [ 0]]
[[ 13]
 [100]]
[[101]
 [  1]]
[[12]
 [ 3]]
[[0]
 [1]]
[[13]
 [12]]
[[100]
 [ 11]]
[[10]
 [ 2]]
[[101]
 [103]]
[[  3]
 [102]]
[[100]
 [102]]
[[11]
 [13]]
[[ 1]
 [10]]
[[103]
 [ 12]]
[[3]
 [2]]
[[  0]
 [101]]
[[ 0]
 [12]]
[[103]
 [  2]]
[[ 10]
 [100]]
[[1]
 [3]]
[[102]
 [ 11]]
[[101]
 [ 13]]
[[2]
 [1]]
[[102]
 [  0]]
[[12]
 [11]]
[[  3]
 [103]]
[[10]
 [13]]
[[101]
 [100]]
[[ 2]
 [13]]
[[12]
 [11]]
[[ 3]
 [10]]
[[100]
 [102]]
[[103]
 [  1]]
[[  0]
 [101]]
[[101]
 [102]]
[[3]
 [2]]
[[100]
 [ 10]]
[[13]
 [11]]
[[  0]
 [103]]
[[12]
 [ 1]]
[[ 0]
 [10]]
[[11]
 [12]]
[[100]
 [101]]
[[13]
 [ 3]]
[[103]
 [102]]
[[2]
 [1]]
[[103]
 [ 12]]
[[ 11]
 [101]]
[[ 1]
 [13]]
[[3]
 [2]]
[[ 0]
 [10]]
[[102]
 [100]]
[[  1]
 [102]]
[[101]
 [ 12]]
[[ 3]
 [10]]
[[2]
 [0]]
[[100]
 [ 11]]
[[103]
 [ 13]]
