In [1]:
import dataclasses

import jax
import torch

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import ur3_robotiq_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [2]:
config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"][0])

Policy infer inputs: [0.84179922 0.90663587 0.75344159 0.66017157 0.41391253 0.97284278
 0.86954556] [0.00413281]
Policy infer transformed inputs: [ 0.95836034  0.55242179  0.96191687  1.96338149  0.30179398 -1.22520143
  0.35176651 -0.99165932]
Policy infer outputs: 4022.0
Policy infer transformed outputs: [ 0.17652283 -0.23692529 -0.30517823 -0.01090471 -0.17994839  0.24778192
  0.24548847 -0.00596324]
Actions shape: [ 0.17652283 -0.23692529 -0.30517823 -0.01090471 -0.17994839  0.24778192
  0.24548847 -0.00596324]


# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [None]:
config = _config.get_config("pi0_ur3_robotiq")

checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_base")
key = jax.random.key(0)

# # Create a model from the checkpoint.
# model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# # We can create fake observations and actions to test the model.
# obs, act = config.model.fake_obs(), config.model.fake_act()

# # Sample actions from the model.
# loss = model.compute_loss(key, obs, act)
# print("Loss shape:", loss.shape)

Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [None]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)

In [3]:
from pathlib import Path

assets_root = download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/")
print(f"Assets directory: {assets_root}")
for entry in sorted(Path(assets_root).iterdir()):
    if entry.is_dir():
        print(f"[DIR] {entry.name}")
        for child in sorted(entry.iterdir()):
            print(f"  - {child.name}")
    else:
        print(f"[FILE] {entry.name}")

Assets directory: /home/user/.cache/openpi/openpi-assets/checkpoints/pi0_base
[DIR] assets
  - aloha.lock
  - arx
  - arx_mobile
  - droid
  - fibocom_mobile
  - franka
  - trossen
  - trossen_mobile
  - ur5e
  - ur5e_dual
[DIR] params
  - _CHECKPOINT_METADATA
  - _METADATA
  - _sharding
  - d
  - manifest.ocdbt
  - ocdbt.process_0


In [5]:
from pathlib import Path

assets_root = download.maybe_download("gs://openpi-assets/checkpoints/pi0_libero/assets/")
print(f"Assets directory: {assets_root}")
for entry in sorted(Path(assets_root).iterdir()):
    if entry.is_dir():
        print(f"[DIR] {entry.name}")
        for child in sorted(entry.iterdir()):
            print(f"  - {child.name}")
    else:
        print(f"[FILE] {entry.name}")
        print(f"[FILE] {entry}")

  0%|          | 0.00/4.40k [00:00<?, ?iB/s]

Assets directory: /home/user/.cache/openpi/openpi-assets/checkpoints/pi0_libero/assets
[DIR] physical-intelligence
  - libero


In [3]:
import time
print("start: ", time.strftime("%Y-%m-%d %H:%M:%S"))
config = _config.get_config("pi0_ur3_robotiq")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_base_pytorch")
torch.manual_seed(42)
# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)
print("Policy created: ", time.strftime("%Y-%m-%d %H:%M:%S"))

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = ur3_robotiq_policy.make_ur3_example()
print("Example state:", example["state"])
print("Example images:", {k: v.shape for k, v in example["images"].items()})
result = policy.infer(example)
print("Inference completed: ", time.strftime("%Y-%m-%d %H:%M:%S"))

# Delete the policy to free up memory.
del policy

start:  2025-12-07 17:05:27
Using PyTorch device: cuda
Policy created:  2025-12-07 17:06:05
Example state: [ 2.74 -1.65  0.49 -1.55 -0.45  0.18  0.5 ]
Example images: {'cam_high': (2048, 2048, 3), 'cam_left_wrist': (2048, 2048, 3)}
input before transform: [ 2.74 -1.65  0.49 -1.55 -0.45  0.18  0.5 ]
input before transform (degrees): [156.99043  -94.53803   28.074932 -88.80846  -25.7831    10.31324
  28.647888]
output after transform: [-0.03538096  0.1144994  -0.02814207  0.00599358 -0.09668985  0.09598636
  0.8825611 ]
output after transform (degrees): [-2.0271795   6.5603323  -1.612422    0.34340695 -5.5399203   5.499613
 50.567024  ]
Inference completed:  2025-12-07 17:06:18
