In [1]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

In [2]:
import os

# Set OpenPI data home to use custom checkpoints location
os.environ['OPENPI_DATA_HOME'] = '/mnt/22TB_IndEgo/yahuan/openpi/checkpoints'

# GPU/JAX configuration
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.90' 

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [3]:
config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_fast_droid")
print(f"Checkpoint downloaded to: {os.path.abspath(checkpoint_dir)}")

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"].shape)

Checkpoint downloaded to: /mnt/22TB_IndEgo/yahuan/openpi/checkpoints/openpi-assets/checkpoints/pi0_fast_droid
Actions shape: (10, 8)


# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [4]:
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi0_aloha_sim")
key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)

Loss shape: (1, 50)


Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [5]:
# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

info.json: 0.00B [00:00, ?B/s]

tasks.jsonl:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

episodes_stats.jsonl: 0.00B [00:00, ?B/s]

episodes.jsonl: 0.00B [00:00, ?B/s]

Fetching 106 files:   0%|          | 0/106 [00:00<?, ?it/s]

.gitattributes: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

data/chunk-000/episode_000001.parquet:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

data/chunk-000/episode_000000.parquet:   0%|          | 0.00/52.2k [00:00<?, ?B/s]

data/chunk-000/episode_000003.parquet:   0%|          | 0.00/53.8k [00:00<?, ?B/s]

data/chunk-000/episode_000002.parquet:   0%|          | 0.00/54.6k [00:00<?, ?B/s]

data/chunk-000/episode_000005.parquet:   0%|          | 0.00/53.7k [00:00<?, ?B/s]

data/chunk-000/episode_000004.parquet:   0%|          | 0.00/52.5k [00:00<?, ?B/s]

data/chunk-000/episode_000006.parquet:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

data/chunk-000/episode_000007.parquet:   0%|          | 0.00/54.0k [00:00<?, ?B/s]

data/chunk-000/episode_000013.parquet:   0%|          | 0.00/52.3k [00:00<?, ?B/s]

data/chunk-000/episode_000008.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

data/chunk-000/episode_000010.parquet:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

data/chunk-000/episode_000012.parquet:   0%|          | 0.00/54.2k [00:00<?, ?B/s]

data/chunk-000/episode_000011.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

data/chunk-000/episode_000009.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

data/chunk-000/episode_000015.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

data/chunk-000/episode_000014.parquet:   0%|          | 0.00/53.8k [00:00<?, ?B/s]

data/chunk-000/episode_000016.parquet:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

data/chunk-000/episode_000018.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

data/chunk-000/episode_000020.parquet:   0%|          | 0.00/52.4k [00:00<?, ?B/s]

data/chunk-000/episode_000019.parquet:   0%|          | 0.00/53.0k [00:00<?, ?B/s]

data/chunk-000/episode_000017.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

data/chunk-000/episode_000021.parquet:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

data/chunk-000/episode_000022.parquet:   0%|          | 0.00/54.4k [00:00<?, ?B/s]

data/chunk-000/episode_000023.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

data/chunk-000/episode_000024.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

data/chunk-000/episode_000025.parquet:   0%|          | 0.00/53.8k [00:00<?, ?B/s]

data/chunk-000/episode_000026.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

data/chunk-000/episode_000027.parquet:   0%|          | 0.00/52.4k [00:00<?, ?B/s]

data/chunk-000/episode_000028.parquet:   0%|          | 0.00/53.4k [00:00<?, ?B/s]

data/chunk-000/episode_000030.parquet:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

data/chunk-000/episode_000029.parquet:   0%|          | 0.00/54.2k [00:00<?, ?B/s]

data/chunk-000/episode_000031.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

data/chunk-000/episode_000032.parquet:   0%|          | 0.00/52.8k [00:00<?, ?B/s]

data/chunk-000/episode_000033.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

data/chunk-000/episode_000035.parquet:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

data/chunk-000/episode_000034.parquet:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

data/chunk-000/episode_000036.parquet:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

data/chunk-000/episode_000037.parquet:   0%|          | 0.00/52.9k [00:00<?, ?B/s]

data/chunk-000/episode_000038.parquet:   0%|          | 0.00/54.4k [00:00<?, ?B/s]

data/chunk-000/episode_000039.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

data/chunk-000/episode_000040.parquet:   0%|          | 0.00/53.4k [00:00<?, ?B/s]

data/chunk-000/episode_000041.parquet:   0%|          | 0.00/52.6k [00:00<?, ?B/s]

data/chunk-000/episode_000042.parquet:   0%|          | 0.00/54.0k [00:00<?, ?B/s]

data/chunk-000/episode_000043.parquet:   0%|          | 0.00/53.7k [00:00<?, ?B/s]

data/chunk-000/episode_000045.parquet:   0%|          | 0.00/53.9k [00:00<?, ?B/s]

data/chunk-000/episode_000046.parquet:   0%|          | 0.00/54.1k [00:00<?, ?B/s]

data/chunk-000/episode_000044.parquet:   0%|          | 0.00/52.2k [00:00<?, ?B/s]

data/chunk-000/episode_000047.parquet:   0%|          | 0.00/51.7k [00:00<?, ?B/s]

data/chunk-000/episode_000048.parquet:   0%|          | 0.00/52.5k [00:00<?, ?B/s]

data/chunk-000/episode_000049.parquet:   0%|          | 0.00/53.5k [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.31M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.34M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.40M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.39M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.35M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.38M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.34M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.34M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.40M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.35M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.38M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.34M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.33M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.32M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.32M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.29M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.34M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.35M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.31M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.37M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.37M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.38M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.34M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.30M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.40M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.32M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.35M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.35M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.35M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.37M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.38M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.37M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.38M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.31M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.32M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.37M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.38M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.33M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.36M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.34M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.33M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.39M [00:00<?, ?B/s]

videos/chunk-000/observation.images.top/(…):   0%|          | 0.00/1.32M [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/50 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Loss shape: (2, 50)
