In [1]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.90' 

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [None]:
config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_droid")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"].shape)

  0%|          | 0.00/10.1G [00:00<?, ?iB/s]

ERROR:asyncio:Task exception was never retrieved
future: <Task finished name='Task-26' coro=<_run_coros_in_chunks.<locals>._run_coro() done, defined at /home/yahuanshi/OpenPI/openpi/.venv/lib/python3.11/site-packages/fsspec/asyn.py:243> exception=OSError(28, 'No space left on device')>
Traceback (most recent call last):
  File "/home/yahuanshi/OpenPI/openpi/.venv/lib/python3.11/site-packages/fsspec/asyn.py", line 245, in _run_coro
    return await asyncio.wait_for(coro, timeout=timeout), i
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/asyncio/tasks.py", line 452, in wait_for
    return await fut
           ^^^^^^^^^
  File "/home/yahuanshi/OpenPI/openpi/.venv/lib/python3.11/site-packages/fsspec/callbacks.py", line 81, in func
    return await fn(path1, path2, callback=child, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yahuanshi/OpenPI/openpi/.venv/lib/python3.11/site-packages/gcsfs/core.py", line 1613, in _g

Exception: Encountered error while reading array index: (slice(0, 18, 1), slice(0, 2, 1), slice(0, 2048, 1), slice(0, 16384, 1)). See full TensorStore details: <bound method PyCapsule.spec of TensorStore({
  'base': {
    'driver': 'zarr',
    'dtype': 'float32',
    'fill_missing_data_reads': False,
    'kvstore': {
      'base': {
        'driver': 'file',
        'path': '/home/yahuanshi/.cache/openpi/openpi-assets/checkpoints/pi0_fast_droid/params/',
      },
      'cache_pool': 'cache_pool#ocdbt',
      'config': {
        'compression': {'id': 'zstd'},
        'max_decoded_node_bytes': 100000000,
        'max_inline_value_bytes': 1024,
        'uuid': 'ca888d2fcf3b0d15d240930fdc9d4ce1',
        'version_tree_arity_log2': 4,
      },
      'driver': 'ocdbt',
      'path': 'params.PaliGemma.llm.layers.mlp.gating_einsum/',
    },
    'metadata': {
      'chunks': [18, 2, 2048, 4096],
      'compressor': {'id': 'zstd', 'level': 1},
      'dimension_separator': '.',
      'dtype': '<f4',
      'fill_value': None,
      'filters': None,
      'order': 'C',
      'shape': [18, 2, 2048, 16384],
      'zarr_format': 2,
    },
    'recheck_cached_data': False,
    'recheck_cached_metadata': False,
  },
  'context': {
    'cache_pool': {},
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_locking': {},
    'file_io_memmap': False,
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'cast',
  'dtype': 'bfloat16',
  'transform': {
    'input_exclusive_max': [[18], [2], [2048], [16384]],
    'input_inclusive_min': [0, 0, 0, 0],
  },
})>.

# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [None]:
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)

Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [None]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)