In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.bet.configuration_bet import BeTConfig
from lerobot.policies.bet.modeling_bet import BeTPolicy

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/test_bet_30k_off1000")
output_directory.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda")

training_steps = 30000
losses = []
classification_losses = []
offset_losses = []
log_freq = 100

# When starting from scratch (i.e. not from a pretrained policy), we need to specify 2 things before
# creating the policy:
#   - input/output shapes: to properly size the policy
#   - dataset stats: for normalization and denormalization of input/outputs
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}

# Policies are initialized with a configuration class. For this example,
# we'll just use the defaults and so no arguments other than input/output features need to be passed.
cfg = BeTConfig(input_features=input_features, output_features=output_features, offset_loss_multiplier=1000)

# We can now instantiate our policy with this config and the dataset stats.
policy = BeTPolicy(cfg, dataset_stats=dataset_metadata.stats)
# load from pretrained
# pretrained_policy_path = Path("outputs/train/test_bet")
# policy = BeTPolicy.from_pretrained(pretrained_policy_path)

policy.train()
policy.to(device)

# Another policy-dataset interaction is with the delta_timestamps. Each policy expects a given number frames
# which can differ for inputs, outputs and rewards (if there are some).
delta_timestamps = {
    "observation.image": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "observation.state": [i / dataset_metadata.fps for i in cfg.observation_delta_indices],
    "action": [i / dataset_metadata.fps for i in cfg.action_delta_indices],
}

# standard configuration for new BeT, it is equivalent to this:
# delta_timestamps = {
#     # Load the previous image and state at -0.1 seconds before current frame,
#     # then load current image and state corresponding to 0.0 second.
#     "observation.image": [-0.4, -0.3, -0.2, -0.1, 0.0],
#     "observation.state": [-0.4, -0.3, -0.2, -0.1, 0.0],
#     # Load the previous action (-0.1), the next action to be executed (0.0), and one more into future (0.1) All these actions will be
#     # used to supervise the policy; for BeT I got n_action_pred_token: int = 3
#     "action": [-0.1, 0.0, 0.1],
# }

# We can then instantiate the dataset with these delta_timestamps configuration.
dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)

# Then we create our optimizer and dataloader for offline training.
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=4,
    batch_size=64,
    shuffle=True,
    pin_memory=device.type != "cpu",
    drop_last=True,
)

# Run training loop.
step = 0
done = False
while not done:
    for batch in dataloader:
        batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
        loss, loss_dict = policy.forward(batch)
        cl_loss, offset_loss = loss_dict.pop('classification_loss'), loss_dict.pop('offset_loss')
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step % log_freq == 0:
            print(f"step: {step} loss: {loss.item():.3f}")
        
        losses.append(loss.item())
        classification_losses.append(cl_loss)
        offset_losses.append(offset_loss)
        step += 1
        if step >= training_steps:
            done = True
            break

policy.save_pretrained(output_directory)

  @autocast(enabled=False)


number of parameters: 26.00M
self._offset_loss_multiplier: 1000


Resolving data files:   0%|          | 0/206 [00:00<?, ?it/s]

step: 0 loss: -0.000
step: 100 loss: -0.000
step: 200 loss: -0.000
step: 300 loss: -0.000
step: 400 loss: -0.000
step: 500 loss: -0.000
step: 600 loss: -0.000
step: 700 loss: -0.000
step: 800 loss: -0.000
step: 900 loss: -0.000


K-means clustering: 100%|███████████████████████████████████████| 100/100 [00:12<00:00,  8.24it/s]


step: 1000 loss: 205.325
step: 1100 loss: 8.672
step: 1200 loss: 8.614
step: 1300 loss: 10.484
step: 1400 loss: 9.117
step: 1500 loss: 8.791
step: 1600 loss: 8.839
step: 1700 loss: 9.316
step: 1800 loss: 8.467
step: 1900 loss: 9.395
step: 2000 loss: 8.994
step: 2100 loss: 7.850
step: 2200 loss: 7.256
step: 2300 loss: 7.077
step: 2400 loss: 6.213
step: 2500 loss: 5.487
step: 2600 loss: 5.856
step: 2700 loss: 5.712
step: 2800 loss: 5.334
step: 2900 loss: 6.595
step: 3000 loss: 5.199
step: 3100 loss: 5.516
step: 3200 loss: 5.704
step: 3300 loss: 5.428
step: 3400 loss: 5.347
step: 3500 loss: 4.849
step: 3600 loss: 4.391
step: 3700 loss: 4.637
step: 3800 loss: 4.126
step: 3900 loss: 4.970
step: 4000 loss: 4.725
step: 4100 loss: 4.274
step: 4200 loss: 4.028
step: 4300 loss: 3.412
step: 4400 loss: 3.477
step: 4500 loss: 3.899
step: 4600 loss: 4.305
step: 4700 loss: 4.379
step: 4800 loss: 4.677
step: 4900 loss: 3.780
step: 5000 loss: 4.410
step: 5100 loss: 3.554
step: 5200 loss: 3.716
step: 53

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import gym_pusht  # noqa: F401
import gymnasium as gym
import imageio
import numpy

device = "cuda"

# pretrained_policy_path = Path("outputs/train/test_bet")

# trained_policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)

# Initialize evaluation environment to render two observation types:
# an image of the scene and state/position of the agent. The environment
# also automatically stops running after 300 interactions/steps.
env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="pixels_agent_pos",
    max_episode_steps=300,
)

# We can verify that the shapes of the features expected by the policy match the ones from the observations
# produced by the environment
print(policy.config.input_features)
print(env.observation_space)

# Similarly, we can check that the actions produced by the policy will match the actions expected by the
# environment
print(policy.config.output_features)
print(env.action_space)

# Reset the policy and environments to prepare for rollout
policy.reset()
numpy_observation, info = env.reset(seed=42)

# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
frames = []

# Render frame of the initial state
frames.append(env.render())

step = 0
done = False
while not done:
    # Prepare observation for the policy running in Pytorch
    state = torch.from_numpy(numpy_observation["agent_pos"])
    image = torch.from_numpy(numpy_observation["pixels"])

    # Convert to float32 with image from channel first in [0,255]
    # to channel last in [0,1]
    state = state.to(torch.float32)
    image = image.to(torch.float32) / 255
    image = image.permute(2, 0, 1)

    # Send data tensors from CPU to GPU
    state = state.to(device, non_blocking=True)
    image = image.to(device, non_blocking=True)

    # Add extra (empty) batch dimension, required to forward the policy
    state = state.unsqueeze(0)
    image = image.unsqueeze(0)

    # Create the policy input dictionary
    observation = {
        "observation.state": state,
        "observation.image": image,
    }

    # Predict the next action with respect to the current observation
    with torch.inference_mode():
        action = policy.select_action(observation)

    # Prepare the action for the environment
    numpy_action = action.squeeze(0).to("cpu").numpy()
    print(f"numpy_action {numpy_action}")
    # Step through the environment and receive a new observation
    numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)
    print(f"{step=} {reward=} {terminated=}")

    # Keep track of all the rewards and frames
    rewards.append(reward)
    frames.append(env.render())

    # The rollout is considered done when the success state is reached (i.e. terminated is True),
    # or the maximum number of iterations is reached (i.e. truncated is True)
    done = terminated | truncated | done
    step += 1

if terminated:
    print("Success!")
else:
    print("Failure!")

# Get the speed of environment (i.e. its number of frames per second).
fps = env.metadata["render_fps"]

# Encode all frames into a mp4 video.
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)

print(f"Video of the evaluation is available in '{video_path}'.")

  from pkg_resources import resource_stream, resource_exists


{'observation.image': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 96, 96)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,))}
Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'pixels': Box(0, 255, (96, 96, 3), uint8))
{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(2,))}
Box(0.0, 512.0, (2,), float32)
numpy_action [169.76578 411.27423]
step=0 reward=np.float64(0.0) terminated=False
numpy_action [193.74083 408.32437]
step=1 reward=np.float64(0.0) terminated=False
numpy_action [193.11534 431.39883]
step=2 reward=np.float64(0.0) terminated=False
numpy_action [234.46233 423.3732 ]
step=3 reward=np.float64(0.0) terminated=False
numpy_action [204.08383 419.43323]
step=4 reward=np.float64(0.0) terminated=False
numpy_action [252.39441 434.70892]
step=5 reward=np.float64(0.0) terminated=False
numpy_action [256.12677 430.1405 ]
step=6 reward=np.float64(0.0) terminated=False
numpy_action [315.72968 440.87143]
step=7 reward=np



Video of the evaluation is available in 'outputs/train/test_bet_30k_off1000/rollout.mp4'.
