# Batch Inference Example

This notebook demonstrates how to use OpenPI for batch inference, processing multiple observations simultaneously for improved efficiency.


In [5]:
import importlib

import dataclasses
import time
import numpy as np
import jax.numpy as jnp
from openpi.models import model as _model
importlib.reload(_model)

from openpi.policies import libero_policy
importlib.reload(libero_policy)

from openpi.policies import policy_config as _policy_config
importlib.reload(_policy_config)

from openpi.shared import download
importlib.reload(download)

from openpi.training import config as _config
importlib.reload(_config)


<module 'openpi.training.config' from '/research/data/zhenyang/openpi/src/openpi/training/config.py'>

## Setup Policy


In [2]:
# Load configuration and checkpoint
config = _config.get_config("pi05_libero")
checkpoint_dir = download.maybe_download("gs://openpi-assets/checkpoints/pi05_libero")

# Create policy
policy = _policy_config.create_trained_policy(config, checkpoint_dir)
print("Policy loaded successfully!")


Policy loaded successfully!


## Create Batch of Examples


In [3]:
def create_batch_examples(batch_size: int = 3) -> list[dict]:
    """Create a batch of example observations."""
    examples = []
    
    for i in range(batch_size):
        example = {
            "observation/state": np.random.rand(8),
            "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
            "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
            "prompt": f"Task {i+1}: Pick up the object",
        }
        examples.append(example)
    
    return examples

# Create batch of 3 examples
batch_examples = create_batch_examples(4)
print(f"Created {len(batch_examples)} examples")
print(f"Example keys: {batch_examples[0].keys()}")


Created 4 examples
Example keys: dict_keys(['observation/state', 'observation/image', 'observation/wrist_image', 'prompt'])


In [7]:
del policy

## Single vs Batch Inference Comparison


In [None]:
# Single inference (original method)
print("=== Single Inference ===")
single_start = time.time()
single_results = []
for i, example in enumerate(batch_examples):
    result = policy.infer(example)
    single_results.append(result)
    print(f"Single inference {i+1}: Actions shape = {result['actions'].shape}")
single_time = time.time() - single_start
print(f"Total single inference time: {single_time:.3f}s")


In [6]:
# Batch inference (new method)
print("\n=== Batch Inference ===")
batch_start = time.time()
batch_results = policy.infer_batch(batch_examples)
batch_time = time.time() - batch_start

for i, result in enumerate(batch_results):
    print(f"Batch inference {i+1}: Actions shape = {result['actions'].shape}")
print(f"Total batch inference time: {batch_time:.3f}s")
print(f"Speedup: {single_time/batch_time:.2f}x")



=== Batch Inference ===


AttributeError: 'list' object has no attribute 'item'