In [54]:
import tensorflow as tf
import numpy as np
from PIL import Image
import os
import glob

In [45]:
npz_path = '../data1'

In [46]:
def decode_shoulder_img(image_data, image_hw=256):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.reshape(image, [image_hw, image_hw, 3])  # explicit size needed for TPU
    return image

In [47]:
def decode_gripper_img(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.reshape(image, [64, 64, 3])  # explicit size needed for TPU
    return image

In [48]:
dimensions = {'Pybullet': {'obs': 18,
                           'obs_extra_info': 18,
                           'acts': 7,
                           'achieved_goals': 11,
                           'achieved_goals_extra_info': 11,
                           'shoulder_img_hw': 200,
                           'hz': 25}}

def read_tfrecord(include_imgs=False, include_imgs2=False, include_gripper_imgs=False, sim='Pybullet'):
    def read_tfrecord_helper(example):
        LABELED_TFREC_FORMAT = {
            'obs': tf.io.FixedLenFeature([], tf.string),  # tf.string means bytestring,
            'acts': tf.io.FixedLenFeature([], tf.string),  # tf.string means bytestring,
            'achieved_goals': tf.io.FixedLenFeature([], tf.string),  # tf.string means bytestring,
            'sequence_index': tf.io.FixedLenFeature([], tf.int64),
            'sequence_id': tf.io.FixedLenFeature([], tf.int64)
        }
        if include_imgs:
            LABELED_TFREC_FORMAT['img'] = tf.io.FixedLenFeature([], tf.string)  # tf.string means bytestring
        if include_imgs2:
            LABELED_TFREC_FORMAT['img2'] = tf.io.FixedLenFeature([], tf.string)  # tf.string means bytestring
        if include_gripper_imgs:
            LABELED_TFREC_FORMAT['gripper_img'] = tf.io.FixedLenFeature([], tf.string)  # tf.string means bytestring

        data = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)

        output = {}
        if include_imgs:
            output['obs'] = tf.ensure_shape(tf.io.parse_tensor(data['obs'], tf.float32),
                                            (dimensions[sim]['obs_extra_info'],))
            output['achieved_goals'] = tf.ensure_shape(tf.io.parse_tensor(data['achieved_goals'], tf.float32),
                                                       (dimensions[sim]['achieved_goals_extra_info'],))
        else:
            output['obs'] = tf.ensure_shape(tf.io.parse_tensor(data['obs'], tf.float32), (dimensions[sim]['obs'],))
            output['achieved_goals'] = tf.ensure_shape(tf.io.parse_tensor(data['achieved_goals'], tf.float32),
                                                       (dimensions[sim]['achieved_goals'],))

        output['acts'] = tf.ensure_shape(tf.io.parse_tensor(data['acts'], tf.float32), (dimensions[sim]['acts'],))
        output['sequence_index'] = tf.cast(data['sequence_index'], tf.int32)
        output['sequence_id'] = tf.cast(data['sequence_id'], tf.int32)  # this is meant to be 32 even though you serialize as 64
        if include_imgs:
            output['img'] = decode_shoulder_img(data['img'], dimensions[sim]['shoulder_img_hw'])
        if include_imgs2:
            output['img2'] = decode_shoulder_img(data['img2'], dimensions[sim]['shoulder_img_hw'])
        if include_gripper_imgs:
            output['gripper_img'] = decode_gripper_img(data['gripper_img'])

        return output

    return read_tfrecord_helper

In [49]:
def extract_tfrecords(paths, include_imgs=False, include_imgs2=False, include_gripper_imgs=False, sim='Pybullet',
                      ordered=True, num_workers=1):
    # In our case, order does matter
    tf_options = tf.data.Options()
    tf_options.experimental_deterministic = ordered  # must be 1 to maintain order while streaming from GCS

    dataset = tf.data.TFRecordDataset(paths, num_parallel_reads=1)
    dataset = dataset.with_options(tf_options)
    dataset = dataset.map(read_tfrecord(include_imgs, include_imgs2, include_gripper_imgs, sim),
                          num_parallel_calls=num_workers)
    return dataset


In [73]:
path = '../data/UR5'
files = glob.glob(os.path.join(path, 'tf_records','*.tfrecords'))


In [None]:
path2 = '../data1/UR5'
for i, file in enumerate(files):
    dataset = extract_tfrecords(file, include_imgs=True)
    path_traj = os.path.join(path2, str(i))
    if not os.path.exists(path_traj):
        os.mkdir(path_traj)
    path_traj_img = os.path.join(path_traj, 'imgs')
    if not os.path.exists(path_traj_img):
        os.mkdir(path_traj_img)
    for j, data in enumerate(dataset):
        img_np = data['img'].numpy()
        Image.fromarray(img_np).save(os.path.join(path_traj_img,'image' + str(j) + '.png'))
        i+=1

    obs = np.array([data['obs'].numpy() for data in dataset])
    acts = np.array([data['acts'].numpy() for data in dataset])
    achieved_goals = np.array([data['achieved_goals'].numpy() for data in dataset])
    sequence_index = np.array([data['sequence_index'].numpy() for data in dataset])
    sequence_id = np.array([data['sequence_id'].numpy() for data in dataset])
    d = {'obs': obs, 'acts': acts, 'achieved_goals': achieved_goals, 'sequence_index': sequence_index, 'sequence_id': sequence_id}
    np.save(os.path.join(path_traj, 'data.npy'), d)
