In [2]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [3]:
config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_fast_droid")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
print(example)
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"].shape)

Some kwargs in processor config are unused and will not have any effect: vocab_size, time_horizon, scale, min_token, action_dim. 
Some kwargs in processor config are unused and will not have any effect: vocab_size, time_horizon, scale, min_token, action_dim. 


{'observation/exterior_image_1_left': array([[[ 43, 156,  92],
        [180, 190,  43],
        [231,  25, 180],
        ...,
        [255,  90,  53],
        [210, 110,   3],
        [132, 204, 206]],

       [[161, 212, 112],
        [207, 202, 152],
        [ 77,  27,  14],
        ...,
        [100, 120, 230],
        [155,  94, 139],
        [ 18, 183, 223]],

       [[237, 201,  74],
        [253,  11, 249],
        [ 52,  70, 153],
        ...,
        [125, 185,  41],
        [185, 167,  69],
        [ 42, 246,  88]],

       ...,

       [[234, 203, 175],
        [168,  34,  48],
        [124, 127, 170],
        ...,
        [ 59, 100, 141],
        [172, 162, 170],
        [169, 224, 165]],

       [[  7,  25, 135],
        [101, 193, 134],
        [ 57, 193, 241],
        ...,
        [ 23, 182, 226],
        [185, 181, 134],
        [211, 195,  86]],

       [[237, 126,  62],
        [119, 119, 108],
        [121,  50,  80],
        ...,
        [121,  12,  95],
        [13

# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [3]:
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)

  0%|          | 0.00/11.2G [00:00<?, ?iB/s]



Loss shape: (1, 50)


Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [4]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

meta%2Ftasks.jsonl:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

meta%2Fstats.json:   0%|          | 0.00/6.40k [00:00<?, ?B/s]

meta%2Finfo.json:   0%|          | 0.00/3.37k [00:00<?, ?B/s]

meta%2Fepisodes.jsonl:   0%|          | 0.00/5.99k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 106 files:   0%|          | 0/106 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/4.14k [00:00<?, ?B/s]

episode_000003.parquet:   0%|          | 0.00/53.8k [00:00<?, ?B/s]

episode_000002.parquet:   0%|          | 0.00/54.6k [00:00<?, ?B/s]

episode_000005.parquet:   0%|          | 0.00/53.7k [00:00<?, ?B/s]

episode_000001.parquet:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

episode_000004.parquet:   0%|          | 0.00/52.5k [00:00<?, ?B/s]

episode_000000.parquet:   0%|          | 0.00/52.2k [00:00<?, ?B/s]

episode_000006.parquet:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/2.42k [00:00<?, ?B/s]

episode_000008.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

episode_000010.parquet:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

episode_000007.parquet:   0%|          | 0.00/54.0k [00:00<?, ?B/s]

episode_000011.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

episode_000009.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

episode_000012.parquet:   0%|          | 0.00/54.2k [00:00<?, ?B/s]

episode_000013.parquet:   0%|          | 0.00/52.3k [00:00<?, ?B/s]

episode_000014.parquet:   0%|          | 0.00/53.8k [00:00<?, ?B/s]

episode_000015.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

episode_000018.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

episode_000017.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

episode_000016.parquet:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

episode_000020.parquet:   0%|          | 0.00/52.4k [00:00<?, ?B/s]

episode_000021.parquet:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

episode_000022.parquet:   0%|          | 0.00/54.4k [00:00<?, ?B/s]

episode_000024.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

episode_000026.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

episode_000025.parquet:   0%|          | 0.00/53.8k [00:00<?, ?B/s]

episode_000028.parquet:   0%|          | 0.00/53.4k [00:00<?, ?B/s]

episode_000027.parquet:   0%|          | 0.00/52.4k [00:00<?, ?B/s]

episode_000029.parquet:   0%|          | 0.00/54.2k [00:00<?, ?B/s]

episode_000030.parquet:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

episode_000032.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

episode_000031.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

episode_000034.parquet:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

episode_000033.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

episode_000019.parquet:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

episode_000036.parquet:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

episode_000037.parquet:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

episode_000038.parquet:   0%|          | 0.00/54.4k [00:00<?, ?B/s]

episode_000035.parquet:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

episode_000040.parquet:   0%|          | 0.00/53.4k [00:00<?, ?B/s]

episode_000039.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

episode_000041.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

episode_000042.parquet:   0%|          | 0.00/54.0k [00:00<?, ?B/s]

episode_000043.parquet:   0%|          | 0.00/53.7k [00:00<?, ?B/s]

episode_000045.parquet:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

episode_000046.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

episode_000048.parquet:   0%|          | 0.00/52.5k [00:00<?, ?B/s]

episode_000044.parquet:   0%|          | 0.00/52.2k [00:00<?, ?B/s]

episode_000047.parquet:   0%|          | 0.00/51.7k [00:00<?, ?B/s]

episode_000049.parquet:   0%|          | 0.00/53.5k [00:00<?, ?B/s]

episode_000000.mp4:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

episode_000001.mp4:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

episode_000002.mp4:   0%|          | 0.00/1.40M [00:00<?, ?B/s]

episode_000004.mp4:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

episode_000003.mp4:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

episode_000005.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000006.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000007.mp4:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

episode_000008.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000009.mp4:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

episode_000010.mp4:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

episode_000011.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000012.mp4:   0%|          | 0.00/1.40M [00:00<?, ?B/s]

episode_000015.mp4:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

episode_000013.mp4:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

episode_000016.mp4:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

episode_000014.mp4:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

episode_000018.mp4:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

episode_000019.mp4:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

episode_000017.mp4:   0%|          | 0.00/1.33M [00:00<?, ?B/s]

episode_000020.mp4:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

episode_000022.mp4:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

episode_000021.mp4:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

episode_000023.mp4:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

episode_000024.mp4:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

episode_000025.mp4:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

episode_000026.mp4:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

episode_000027.mp4:   0%|          | 0.00/1.30M [00:00<?, ?B/s]

episode_000028.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000029.mp4:   0%|          | 0.00/1.40M [00:00<?, ?B/s]

episode_000030.mp4:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

episode_000031.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000033.mp4:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

episode_000034.mp4:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

episode_000035.mp4:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

episode_000032.mp4:   0%|          | 0.00/1.35M [00:00<?, ?B/s]

episode_000036.mp4:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

episode_000037.mp4:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

episode_000038.mp4:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

episode_000039.mp4:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

episode_000040.mp4:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

episode_000041.mp4:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

episode_000042.mp4:   0%|          | 0.00/1.38M [00:00<?, ?B/s]

episode_000043.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000044.mp4:   0%|          | 0.00/1.33M [00:00<?, ?B/s]

episode_000045.mp4:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

episode_000046.mp4:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

episode_000047.mp4:   0%|          | 0.00/1.33M [00:00<?, ?B/s]

episode_000048.mp4:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

episode_000049.mp4:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

episode_000023.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/50 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Loss shape: (2, 50)
