Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,13 +1162,15 @@ def __getitem__(self, idx) -> dict:

episode_index = item["episode_index"].item()
# don't convert to timestamp to `float`, because torch.float64 is not supported on MPS
timestamp = item["timestamp"]
frame_index = item["frame_index"].item()

# change data naming to standard data format
item = self._to_standard_data_format(item)

if self.meta.advantages is not None:
advantage = self.meta.advantages.get((episode_index, timestamp), 0)
# if the advantage file is present, it should contain all advantage values
advantage = self.meta.advantages.get((episode_index, frame_index), 0)
logging.warning(f"Unable to query advantage value for {episode_index=} and {frame_index=}.")
item["advantage"] = torch.tensor(advantage, dtype=torch.bfloat16)
else:
item["advantage"] = torch.tensor(0.0, dtype=torch.bfloat16)
Expand All @@ -1195,7 +1197,7 @@ def __getitem__(self, idx) -> dict:
item["current_idx"] = idx
item["last_step"] = idx + self.cfg.policy.reward_config.N_steps_look_ahead >= ep_end
item["episode_index"] = episode_index
item["timestamp"] = timestamp
item["frame_index"] = frame_index
else:
item["return_bin_idx"] = torch.tensor(0, dtype=torch.long)
item["return_continuous"] = torch.tensor(0, dtype=torch.float32)
Expand Down
3 changes: 2 additions & 1 deletion lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def load_advantages(local_dir: Path) -> dict:
if not (local_dir / ADVANTAGES_PATH).exists():
return None
advantages = load_json(local_dir / ADVANTAGES_PATH)
return {(int(k.split(",")[0]), float(k.split(",")[1])): v for k, v in advantages.items()}
# keys are of the form "episode_index,frame_index", where both episode_index and frame_index are integers
return {tuple(map(int, k.split(","))): v for k, v in advantages.items()}


def write_task(task_index: int, task: dict, local_dir: Path):
Expand Down
13 changes: 7 additions & 6 deletions lerobot/scripts/get_advantage_and_percentiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def ensure_primitive(maybe_tensor):
return ensure_primitive(torch.from_numpy(maybe_tensor))
if isinstance(maybe_tensor, torch.Tensor):
assert maybe_tensor.numel() == 1, f"Tensor must be a single value, got shape={maybe_tensor.numel()}"
return maybe_tensor.item()
return maybe_tensor


Expand Down Expand Up @@ -171,7 +172,7 @@ def main(cfg: TrainPipelineConfig):
success=success,
n_steps_look_ahead=cfg.policy.reward_config.N_steps_look_ahead,
episode_end_idx=episode_end_idx,
max_episode_length=cfg.policy.reward_config.reward_normalizer,
reward_normalizer=cfg.policy.reward_config.reward_normalizer,
current_idx=current_idx,
c_neg=cfg.policy.reward_config.C_neg,
)
Expand All @@ -180,14 +181,14 @@ def main(cfg: TrainPipelineConfig):

# Second pass to compute the advantages
for batch in dataloader:
for episode_index, current_idx, timestamp in zip(
for episode_index, current_idx, frame_index in zip(
batch["episode_index"],
batch["current_idx"],
batch["timestamp"],
batch["frame_index"],
strict=True,
):
episode_index, current_idx, timestamp = map(
ensure_primitive, (episode_index, current_idx, timestamp)
episode_index, current_idx, frame_index = map(
ensure_primitive, (episode_index, current_idx, frame_index)
)
# check if the value for the next n_steps_look_ahead steps is available, else set it to 0
look_ahead_idx = current_idx + cfg.policy.reward_config.N_steps_look_ahead
Expand All @@ -196,7 +197,7 @@ def main(cfg: TrainPipelineConfig):
v0 = values.get((episode_index, current_idx), _default0)["v0"]
advantage = ensure_primitive(reward + vn - v0)
advantages.append(advantage)
ds_advantage[(episode_index, timestamp)] = advantage
ds_advantage[(episode_index, frame_index)] = advantage

# Convert tuple keys to strings for JSON serialization
advantage_data_json = {f"{ep_idx},{ts}": val for (ep_idx, ts), val in ds_advantage.items()}
Expand Down
Loading