In [None]:
import numpy as np
import tensorflow as tf
import json

np.set_printoptions(suppress=True, precision=4, linewidth=250)
opts = tf.GPUOptions(per_process_gpu_memory_fraction=0.02)
conf = tf.ConfigProto(gpu_options=opts)
tf.enable_eager_execution(config=conf)

In [None]:
examples = 256 
time_steps = 30
image_size = 64
test_images = np.random.randn(examples, time_steps, image_size, image_size, 3)
test_states = np.random.randn(examples, time_steps, 3)
test_actions = np.random.randn(examples, time_steps, 4)

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_states, test_actions))

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _float_vector_feature(values):
    """Returns a float_list from 1-dimensional numpy array"""
    assert values.ndim == 1
    return tf.train.Feature(float_list=tf.train.FloatList(value=values))

def serialize_example_pyfunction(images_traj, states_traj, actions_traj):
    """ Creates a tf.Example message ready to be written to a file. """
    
    feature = {}
    time_steps = images_traj.shape[0]
    for t in range(time_steps):
        image_t_key = '{}/image_aux1/encoded'.format(t)
        state_t_key = '{}/endeffector_pos'.format(t)
        action_t_key = '{}/action'.format(t)
        
        image = tf.io.serialize_tensor(images_traj[t]).numpy()
        state = states_traj[t]
        action = actions_traj[t]
    
        feature[image_t_key] = _bytes_feature(image)
        feature[state_t_key] = _float_vector_feature(state.numpy())
        feature[action_t_key] = _float_vector_feature(action.numpy())
        
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def tf_serialize_example(f0,f1,f2):
    tf_string = tf.py_function(serialize_example_pyfunction, (f0,f1,f2), tf.string) 
    return tf.reshape(tf_string, ()) 

serialized_test_dataset = test_dataset.map(tf_serialize_example)

In [None]:
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_test_dataset)

In [None]:
from google.protobuf.json_format import MessageToDict, ParseDict

In [None]:
filename = "./data/bair/test/traj_0_to_255.tfrecords"
example = next(tf.python_io.tf_record_iterator(filename))
dict_message = MessageToDict(tf.train.Example.FromString(example))
print(dict_message['features']['feature'].keys())

In [None]:
filename = "./test.tfrecord"
example = next(tf.python_io.tf_record_iterator(filename))
dict_message = MessageToDict(tf.train.Example.FromString(example))
print(dict_message['features']['feature'].keys())

In [None]:
ls data