In [None]:
!pip install lmdb torchvision

In [2]:
!gsutil -m cp -r gs://gresearch/robotics/droid_100 .

Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00016-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00017-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00018-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00019-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00020-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00021-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00022-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00023-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00024-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2d2_faceblur-train.tfrecord-00025-of-00031...
Copying gs://gresearch/robotics/droid_100/1.0.0/r2

In [3]:
import tensorflow_datasets as tfds

tf_builder = tfds.builder_from_directory(builder_dir="./droid_100/1.0.0")
tf_dataset = tf_builder.as_dataset(split="train")

Filter out failure trajectories and flatten the dataset

In [4]:
import tensorflow as tf

def filter_success(trajectory: dict[str, any]):
  file_paths = trajectory['episode_metadata']['file_path']
  return tf.strings.regex_full_match(
   file_paths,
   ".*/success/.*"
  )
tf_dataset = tf_dataset.filter(filter_success)
tf_dataset = tf_dataset.flat_map(lambda episode: episode['steps']) # Extract steps from episodes

Filter out IS_LAST steps, since in RLDS this contains just the last observation where the action and reward are meaningless.

In [5]:
def _is_not_last(step):
  if step['is_last']:
    return False
  return True

tf_dataset = tf_dataset.filter(_is_not_last)

Save data in lmdb

In [6]:
import os
import json
from pathlib import Path
import torch
from torchvision.io import encode_jpeg
import lmdb
from pickle import dumps

lmdb_dir = Path("./droid_lmdb")
single_file_episode_num = 30 # to be changed!

if not os.path.exists(lmdb_dir):
    os.makedirs(lmdb_dir)
split_id = -1
max_steps = [] # step ID of the last sample in a split
cur_episode = -1
cur_step = -1

def to_numpy(mydict):
  for term in mydict:
    if isinstance(mydict[term], dict):
      mydict[term] = to_numpy(mydict[term])
    else:
      mydict[term] = mydict[term].numpy()
  return mydict

tf_dataset = iter(tf_dataset)
while True:
  try:
    step_data = next(tf_dataset)
    cur_step += 1
  except StopIteration:
    max_steps.append(cur_step)
    json.dump(max_steps, open(lmdb_dir/'split.json', 'w'))
    txn.commit()
    env.close()
    break
  if step_data['is_first'] == True or cur_episode == -1:
    cur_episode += 1
    print('Saving episode ', cur_episode)
    if cur_episode % single_file_episode_num == 0:
      split_id += 1
      if cur_episode != 0:
        max_steps.append(cur_step)
        json.dump(max_steps, open(lmdb_dir/'split.json', 'w'))
        txn.commit()
        env.close()
      print('Create split ', split_id, ' in episode ', cur_episode)
      env = lmdb.open(str(lmdb_dir/str(split_id)), map_size=int(3e12), readonly=False, lock=False) # maximum size of memory map is 3TB
      txn = env.begin(write=True)
    lang = []
    if step_data['language_instruction'].numpy() != b'':
      lang.append(step_data['language_instruction'].numpy())
    if step_data['language_instruction_2'].numpy() != b'':
      lang.append(step_data['language_instruction_2'].numpy())
    if step_data['language_instruction_3'].numpy() != b'':
      lang.append(step_data['language_instruction_3'].numpy())
    txn.put(f'lang_{cur_episode}'.encode(), dumps(lang))
  if len(step_data['observation']['exterior_image_1_left'].shape) != 3:
    continue # may be a problem!!!
  img = {}
  img['exterior_image_1_left'] = encode_jpeg(torch.from_numpy(step_data['observation']['exterior_image_1_left'].numpy()).permute(2,0,1))
  img['exterior_image_2_left'] = encode_jpeg(torch.from_numpy(step_data['observation']['exterior_image_2_left'].numpy()).permute(2,0,1))
  img['wrist_image_left'] = encode_jpeg(torch.from_numpy(step_data['observation']['wrist_image_left'].numpy()).permute(2,0,1))
  del step_data['observation']['exterior_image_1_left']
  del step_data['observation']['exterior_image_2_left']
  del step_data['observation']['wrist_image_left']
  del step_data['language_instruction']
  del step_data['language_instruction_2']
  del step_data['language_instruction_3']
  others = to_numpy(step_data)
  txn.put(f'cur_episode_{cur_step}'.encode(), dumps(cur_episode))
  txn.put(f'others_{cur_step}'.encode(), dumps(others))
  txn.put(f'img_{cur_step}'.encode(), dumps(img))

Saving episode  0
Create split  0  in episode  0
Saving episode  1
Saving episode  2
Saving episode  3
Saving episode  4
Saving episode  5
Saving episode  6
Saving episode  7
Saving episode  8
Saving episode  9
Saving episode  10
Saving episode  11
Saving episode  12
Saving episode  13
Saving episode  14
Saving episode  15
Saving episode  16
Saving episode  17
Saving episode  18
Saving episode  19
Saving episode  20
Saving episode  21
Saving episode  22
Saving episode  23
Saving episode  24
Saving episode  25
Saving episode  26
Saving episode  27
Saving episode  28
Saving episode  29
Saving episode  30
Create split  1  in episode  30
Saving episode  31
Saving episode  32
Saving episode  33
Saving episode  34
Saving episode  35
Saving episode  36
Saving episode  37
Saving episode  38
Saving episode  39
Saving episode  40
Saving episode  41
Saving episode  42
Saving episode  43
Saving episode  44
Saving episode  45
Saving episode  46
Saving episode  47
Saving episode  48
Saving episode  

In [63]:
!rm -rf ./droid_lmdb

Define Droid lmdb reader class

In [7]:
import json
from pathlib import Path
import lmdb
from tqdm import tqdm
from torchvision.io import decode_jpeg
from pickle import loads

class DroidReader():

    def __init__(self, lmdb_dir):
        if isinstance(lmdb_dir, str):
            lmdb_dir = Path(lmdb_dir)
        self.lmdb_dir = lmdb_dir
        self.envs = []
        self.txns = []
        self.max_steps = json.load(open(lmdb_dir/'split.json', 'r'))
        split_num = len(self.max_steps)
        self.min_steps = [0] + [self.max_steps[split_id]+1 for split_id in range(split_num-1)]
        self.dataset_len = self.max_steps[-1] + 1

    def __len__(self):
        return self.dataset_len

    def open_lmdb(self, write=False):
        for split_id, split in enumerate(self.max_steps):
            split_path = self.lmdb_dir / str(split_id)
            env = lmdb.open(str(split_path), readonly=not write, create=False, lock=False, map_size=int(3e12))
            txn = env.begin(write=write)
            self.envs.append(env)
            self.txns.append(txn)

    def close_lmdb(self):
        for txn in self.txns:
            txn.commit()
        for env in self.envs:
            env.close()
        self.envs = []
        self.txns = []

    def get_split_id(self, idx, array):
        left, right = 0, len(self.max_steps) - 1
        while left < right:
            mid = (left + right) // 2
            if array[mid] > idx:
                right = mid
            else:
                left = mid + 1
        return left

    def get_episode(self, idx):
        if self.envs == []:
            self.open_lmdb()
        split_id = self.get_split_id(idx, self.max_steps)
        cur_episode = loads(self.txns[split_id].get(f'cur_episode_{idx}'.encode()))
        return cur_episode

    def get_img(self, idx):
        if self.envs == []:
            self.open_lmdb()
        split_id = self.get_split_id(idx, self.max_steps)
        img = loads(self.txns[split_id].get(f'img_{idx}'.encode()))
        img['exterior_image_1_left'] = decode_jpeg(img['exterior_image_1_left'])
        img['exterior_image_2_left'] = decode_jpeg(img['exterior_image_2_left'])
        img['wrist_image_left'] = decode_jpeg(img['wrist_image_left'])
        return img

    def get_langs(self, idx):
        if self.envs == []:
            self.open_lmdb()
        split_id = self.get_split_id(idx, self.max_steps)
        ep_id = self.get_episode(idx)
        langs = loads(self.txns[split_id].get(f'lang_{ep_id}'.encode()))
        return langs

    def get_lang_tokens(self, idx):
        if self.envs == []:
            self.open_lmdb()
        split_id = self.get_split_id(idx, self.max_steps)
        ep_id = self.get_episode(idx)
        lang_tokens = loads(self.txns[split_id].get(f'lang_token_{ep_id}'.encode()))
        return lang_tokens

    def get_others(self, idx):
        if self.envs == []:
            self.open_lmdb()
        split_id = self.get_split_id(idx, self.max_steps)
        others = loads(self.txns[split_id].get(f'others_{idx}'.encode()))
        return others

    def write_lang_token_id(self, tokenizer):
        if self.envs == []:
            self.open_lmdb(write=True)
        last_episode_id = -1
        for idx in tqdm(range(self.dataset_len - 1), desc="Write lang token id"):
            episode_id = self.get_episode(idx)
            if episode_id != last_episode_id:
                langs = self.get_langs(idx)
                lang_tokens = []
                for lang in langs:
                    lang_token = tokenizer(
                        lang.decode('utf-8'),
                        return_tensors="pt",
                        padding=False,
                        max_length=None,
                        truncation=None,
                        return_token_type_ids=False,
                    )['input_ids'][0].numpy()
                    lang_tokens.append(lang_token)
                split_id = self.get_split_id(idx, self.max_steps)
                self.txns[split_id].put(
                    f'lang_token_{episode_id}'.encode(),
                    dumps(lang_tokens),
                )
                last_episode_id = episode_id
        self.close_lmdb()

In [8]:
reader = DroidReader(lmdb_dir)
from transformers import AutoProcessor
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
tokenizer = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True).tokenizer
reader.write_lang_token_id(tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

processing_florence2.py:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Florence-2-base:
- processing_florence2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer_config.json:   0%|          | 0.00/34.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.10M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.43k [00:00<?, ?B/s]

configuration_florence2.py:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Florence-2-base:
- configuration_florence2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Write lang token id: 100%|██████████| 28542/28542 [00:00<00:00, 259763.49it/s]


In [9]:
length = len(reader)
episode = reader.get_episode(28541)
img = reader.get_img(28541)
others = reader.get_others(28541)
langs = reader.get_langs(10)
lang_tokens = reader.get_lang_tokens(10)

In [10]:
langs

[b'Put the marker in the pot',
 b'Get the marker from the table and put it inside the silver pot',
 b'Put the marker inside the silver pot']

In [11]:
lang_tokens

[array([    0, 35123,     5, 17540,    11,     5,  4728,     2]),
 array([    0, 14181,     5, 17540,    31,     5,  2103,     8,   342,
           24,  1025,     5,  4334,  4728,     2]),
 array([    0, 35123,     5, 17540,  1025,     5,  4334,  4728,     2])]

In [13]:
others

{'action': array([ 0.46315864, -0.06332067,  0.48752004, -2.5370121 , -0.30514622,
        -1.15552211,  0.        ]),
 'action_dict': {'cartesian_position': array([ 0.46315864, -0.06332067,  0.48752004, -2.5370121 , -0.30514622,
         -1.15552211]),
  'cartesian_velocity': array([0., 0., 0., 0., 0., 0.]),
  'gripper_position': array([0.]),
  'gripper_velocity': array([0.]),
  'joint_position': array([ 0.21305765, -0.378001  , -0.3413465 , -2.50903964, -0.08747865,
          2.81095266,  0.98531276]),
  'joint_velocity': array([-0.73450541,  0.14465635,  0.75720179, -0.03155373,  0.82139796,
          0.07027046, -0.5974372 ])},
 'discount': 1.0,
 'is_first': False,
 'is_last': False,
 'is_terminal': False,
 'observation': {'cartesian_position': array([ 0.46371341, -0.06080873,  0.48677155, -2.541821  , -0.30842647,
         -1.1704489 ]),
  'gripper_position': array([0.]),
  'joint_position': array([ 0.33121359, -0.39866462, -0.4616833 , -2.50273585, -0.25965858,
          2.797436

In [14]:
!zip -r droid_lmdb.zip droid_lmdb

  adding: droid_lmdb/ (stored 0%)
  adding: droid_lmdb/1/ (stored 0%)
  adding: droid_lmdb/1/data.mdb (deflated 18%)
  adding: droid_lmdb/0/ (stored 0%)
  adding: droid_lmdb/0/data.mdb (deflated 17%)
  adding: droid_lmdb/2/ (stored 0%)
  adding: droid_lmdb/2/data.mdb (deflated 18%)
  adding: droid_lmdb/split.json (stored 0%)
