In [1]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

: 

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import jax
print(jax.devices())
import os
print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))

config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"].shape)

[CudaDevice(id=0)]
CUDA_VISIBLE_DEVICES = 2


  0%|          | 0.00/4.07M [00:00<?, ?iB/s]

processor_config.json:   0%|          | 0.00/253 [00:00<?, ?B/s]

processing_action_tokenizer.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/physical-intelligence/fast:
- processing_action_tokenizer.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer_config.json:   0%|          | 0.00/322 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/3.00 [00:00<?, ?B/s]

2025-09-17 21:31:11.474257: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng11{k2=3,k3=0} for conv %cudnn-conv-bias-activation.9 = (f32[1,1152,16,16]{3,2,1,0}, u8[0]{0}) custom-call(%bitcast.8894, %bitcast.9147, %bitcast.9149), window={size=14x14 stride=14x14}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", metadata={op_name="jit(fun)/jit(main)/_Module/embedding/conv_general_dilated" source_file="/data/home/zhangjing2/th/openpi/.venv/lib/python3.11/site-packages/flax/linen/linear.py" source_line=658}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-09-17 21:31:12.113195: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.639047881s
Trying algorithm eng11{k2=3,k3=0} for conv %cudnn-conv-bias-activatio

Actions shape: (10, 8)


# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [None]:
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)

Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [None]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)