In [29]:
from lerobot.datasets.lerobot_dataset import LeRobotDataset 
from collections import defaultdict
import torch
from collections import Counter
import os

In [30]:
os.getcwd()

'/home/santari/Projects/pred2control/notebooks'

In [31]:
repo_id = "aleksantari/pred2control_target"
OUT_PATH = "../data/pred2control_target.pt"

dataset = LeRobotDataset(repo_id)


In [18]:
dataset

LeRobotDataset({
    Repository ID: 'aleksantari/pred2control_target',
    Number of selected episodes: '50',
    Number of selected samples: '15000',
    Features: '['action', 'observation.state', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index']',
})',

In [19]:
len(dataset)

15000

In [20]:
type(dataset)

lerobot.datasets.lerobot_dataset.LeRobotDataset

In [21]:

dataset.features

{'action': {'dtype': 'float32',
  'names': ['shoulder_pan.pos',
   'shoulder_lift.pos',
   'elbow_flex.pos',
   'wrist_flex.pos',
   'wrist_roll.pos',
   'gripper.pos'],
  'shape': (6,)},
 'observation.state': {'dtype': 'float32',
  'names': ['shoulder_pan.pos',
   'shoulder_lift.pos',
   'elbow_flex.pos',
   'wrist_flex.pos',
   'wrist_roll.pos',
   'gripper.pos'],
  'shape': (6,)},
 'timestamp': {'dtype': 'float32', 'shape': (1,), 'names': None},
 'frame_index': {'dtype': 'int64', 'shape': (1,), 'names': None},
 'episode_index': {'dtype': 'int64', 'shape': (1,), 'names': None},
 'index': {'dtype': 'int64', 'shape': (1,), 'names': None},
 'task_index': {'dtype': 'int64', 'shape': (1,), 'names': None}}

In [22]:

# main keys for a single sample
test_sample = dataset[1] 
print(test_sample.keys())

dict_keys(['action', 'observation.state', 'timestamp', 'frame_index', 'episode_index', 'index', 'task_index', 'task'])


In [23]:
test_sample

{'action': tensor([ -6.5605, -85.3180,  89.4975,  68.0070,   3.1013,   0.9934]),
 'observation.state': tensor([ -2.5660, -79.8212,  80.3114,  79.1123,   3.2967,   0.8368]),
 'timestamp': tensor(0.0333),
 'frame_index': tensor(1),
 'episode_index': tensor(0),
 'index': tensor(1),
 'task_index': tensor(0),
 'task': 'Home -> shared prefix -> branch -> reach target (no gripper)'}

In [24]:
test_sample["action"]

tensor([ -6.5605, -85.3180,  89.4975,  68.0070,   3.1013,   0.9934])


The dataset is one giant list. So we will seperate by episode and include the category 


1:direct,
2:left arc,
3:right arc,
4:combined,
5:random

In [25]:
episode_actions = defaultdict(list)
episode_category = {}

# for all sample in the dataset
for i in range(len(dataset)):
    sample = dataset[i] # extract sampe 
    ep = int(sample["episode_index"]) # what episode is this a part of

    a  = torch.as_tensor(sample["action"], dtype=torch.float32)  # (6,), extract the action from this sample
    episode_actions[ep].append(a) # add to this episode in our new dict

    if ep not in episode_category:
        episode_category[ep] = ep // 10 + 1 # use floor division to assign this sample a category


In [26]:
print(episode_actions.keys())
print(episode_category)

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
{0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 2, 11: 2, 12: 2, 13: 2, 14: 2, 15: 2, 16: 2, 17: 2, 18: 2, 19: 2, 20: 3, 21: 3, 22: 3, 23: 3, 24: 3, 25: 3, 26: 3, 27: 3, 28: 3, 29: 3, 30: 4, 31: 4, 32: 4, 33: 4, 34: 4, 35: 4, 36: 4, 37: 4, 38: 4, 39: 4, 40: 5, 41: 5, 42: 5, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5, 49: 5}


In [27]:
type(episode_actions[0])

list

we want these lists of actions to be tensors

In [32]:
episodes, meta = [], []

for ep_id in sorted(episode_actions.keys()):
    ep_tensor = torch.stack(episode_actions[ep_id], dim=0) # (T, 6)
    episodes.append(ep_tensor)

    meta.append({
        "episode_id": ep_id,
        "category": int(episode_category[ep_id])
    })

payload = {
    "episodes": episodes,
    "meta": meta,
    "fps": 30, 
    "action_dim": 6,
}

torch.save(payload, OUT_PATH)
print("Saved:", len(episodes), "episodes")
print("Example:", episodes[0].shape, meta[0])

Saved: 50 episodes
Example: torch.Size([300, 6]) {'episode_id': 0, 'category': 1}


sanity check

In [68]:
lengths = [ep.shape[0] for ep in episodes]
print(min(lengths), max(lengths))
assert all(ep.shape[1] == 6 for ep in episodes)


300 300


In [71]:

cats = [m["category"] for m in meta]
print(Counter(cats))


Counter({1: 10, 2: 10, 3: 10, 4: 10, 5: 10})
