# GR00T Inference

This tutorial shows how to use the GR00T inference model to predict the actions from the observations, given a test dataset.

In [2]:
import os
import torch
import gr00t

from gr00t.data.dataset.lerobot_episode_loader import LeRobotEpisodeLoader
from gr00t.data.dataset.sharded_single_step_dataset import extract_step_data
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.policy.gr00t_policy import Gr00tPolicy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# change the following paths
MODEL_PATH = "nvidia/GR00T-N1.6-3B"

# REPO_PATH is the path of the pip install gr00t repo and one level up
REPO_PATH = os.path.dirname(os.path.dirname(gr00t.__file__))
DATASET_PATH = os.path.join(REPO_PATH, "demo_data/gr1.PickNPlace")
EMBODIMENT_TAG = "gr1"

device = "cuda" if torch.cuda.is_available() else "cpu"

## Loading Pretrained Policy

Policy Model is loaded just like any other huggingface model.

There are 2 new concepts here in the GR00T model:
 - modality config: This defines the keys in the dictionary used by the model. (e.g. `action`, `state`, `annotation`, `video`)
 - modality_transform: A sequence of transform which are used during dataloading

In [8]:
policy = Gr00tPolicy(
    model_path=MODEL_PATH,
    embodiment_tag=EmbodimentTag(EMBODIMENT_TAG),
    device=device,
    strict=True,
)

# print out the policy model architecture
print(policy.model)

Tune backbone llm: False
Tune backbone visual: False
Backbone trainable parameter: model.language_model.model.layers.12.self_attn.q_proj.weight
Backbone trainable parameter: model.language_model.model.layers.12.self_attn.k_proj.weight
Backbone trainable parameter: model.language_model.model.layers.12.self_attn.v_proj.weight
Backbone trainable parameter: model.language_model.model.layers.12.self_attn.o_proj.weight
Backbone trainable parameter: model.language_model.model.layers.12.self_attn.q_norm.weight
Backbone trainable parameter: model.language_model.model.layers.12.self_attn.k_norm.weight
Backbone trainable parameter: model.language_model.model.layers.12.mlp.gate_proj.weight
Backbone trainable parameter: model.language_model.model.layers.12.mlp.up_proj.weight
Backbone trainable parameter: model.language_model.model.layers.12.mlp.down_proj.weight
Backbone trainable parameter: model.language_model.model.layers.12.input_layernorm.weight
Backbone trainable parameter: model.language_mode

  embedding_dim=self.inner_dim, compute_dtype=self.compute_dtype
  self.proj_out_2 = nn.Linear(self.inner_dim, self.output_dim)


Tune action head projector: True
Tune action head diffusion model: True
Tune action head vlln: True


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.54it/s]


Gr00tN1d6(
  (backbone): EagleBackbone(
    (model): Eagle3_VLForConditionalGeneration(
      (vision_model): Siglip2VisionModel(
        (vision_model): Siglip2VisionTransformer(
          (embeddings): Siglip2VisionEmbeddings(
            (patch_embedding): Linear(in_features=588, out_features=1152, bias=True)
            (position_embedding): Embedding(256, 1152)
          )
          (encoder): Siglip2Encoder(
            (rope_2d): Rope2DPosEmb(dim=72, max_height=512, max_width=512, theta_base=14)
            (layers): ModuleList(
              (0-26): 27 x Siglip2EncoderLayer(
                (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                (self_attn): Siglip2Attention(
                  (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
                  (out_proj)

## Loading dataset

First this requires user to check which embodiment tags are used to pretrained the `Gr00tPolicy` pretrained models.

In [5]:
import numpy as np

modality_config = policy.get_modality_config()

print(modality_config.keys())

for key, value in modality_config.items():
    if isinstance(value, np.ndarray):
        print(key, value.shape)
    else:
        print(key, value)


dict_keys(['video', 'state', 'action', 'language'])
video ModalityConfig(delta_indices=[0], modality_keys=['ego_view_bg_crop_pad_res256_freq20'], sin_cos_embedding_keys=None, mean_std_embedding_keys=None, action_configs=None)
state ModalityConfig(delta_indices=[0], modality_keys=['left_arm', 'right_arm', 'left_hand', 'right_hand', 'waist'], sin_cos_embedding_keys=['left_arm', 'right_arm', 'left_hand', 'right_hand', 'waist'], mean_std_embedding_keys=None, action_configs=None)
action ModalityConfig(delta_indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], modality_keys=['left_arm', 'right_arm', 'left_hand', 'right_hand', 'waist'], sin_cos_embedding_keys=None, mean_std_embedding_keys=None, action_configs=[ActionConfig(rep=<ActionRepresentation.RELATIVE: 'relative'>, type=<ActionType.NON_EEF: 'non_eef'>, format=<ActionFormat.DEFAULT: 'default'>, state_key=None), ActionConfig(rep=<ActionRepresentation.RELATIVE: 'relative'>, type=<ActionType.NON_EEF: 'non_eef'>, format=<ActionFor

In [6]:
# Create the dataset
dataset = LeRobotEpisodeLoader(
    dataset_path=DATASET_PATH,
    modality_configs=modality_config,
    video_backend="torchcodec",
    video_backend_kwargs=None,
)

Let's print out a single data and visualize it

In [7]:
import numpy as np

episode_data = dataset[0]
step_data = extract_step_data(
    episode_data, step_index=0, modality_configs=modality_config, embodiment_tag=EmbodimentTag(EMBODIMENT_TAG), allow_padding=False
)

print(step_data)

print("\n\n====================================")
print("Images:")
for img_key in step_data.images:
    print(" " * 4, img_key, f"{len(step_data.images[img_key])} x {step_data.images[img_key][0].shape}")

print("\nStates:")
for state_key in step_data.states:
    print(" " * 4, state_key, step_data.states[state_key].shape)

print("\nActions:")
for action_key in step_data.actions:
    print(" " * 4, action_key, step_data.actions[action_key].shape)

print("\nTask: ", step_data.text)


task
VLAStepData(images={'ego_view_bg_crop_pad_res256_freq20': [array([[[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       ...,

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]],

       [[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        ...,
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]]], dtype=uint8)]}, states={'left_arm': array([[-0.01147083,  0.12207967,  0.04229397, -2.1       , -0.01441445,
        -0.03013532, -0.00384387]], dtyp

Let's plot just the "right arm" state and action data and see how it looks like. Also show the images of the right hand state.

In [None]:
import matplotlib.pyplot as plt

episode_index = 0
max_steps = 400
joint_name = "right_arm"
image_key = "ego_view_bg_crop_pad_res256_freq20"

state_joints_across_time = []
gt_action_joints_across_time = []
images = []

sample_images = 6
episode_data = dataset[episode_index]
print(len(episode_data))

for step_count in range(max_steps):
    data_point = extract_step_data(
        episode_data, step_index=step_count, modality_configs=modality_config, embodiment_tag=EmbodimentTag(EMBODIMENT_TAG), allow_padding=False
    )
    state_joints = data_point.states[joint_name][0]
    gt_action_joints = data_point.actions[joint_name][0]

    state_joints_across_time.append(state_joints)
    gt_action_joints_across_time.append(gt_action_joints)

    # We can also get the image data
    if step_count % (max_steps // sample_images) == 0:
        image = data_point.images[image_key][0]
        images.append(image)

# Size is (max_steps, num_joints)
state_joints_across_time = np.array(state_joints_across_time)
gt_action_joints_across_time = np.array(gt_action_joints_across_time)


# Plot the joint angles across time
num_joints = state_joints_across_time.shape[1]
fig, axes = plt.subplots(nrows=num_joints, ncols=1, figsize=(8, 2*num_joints))

for i, ax in enumerate(axes):
    ax.plot(state_joints_across_time[:, i], label="state joints")
    ax.plot(gt_action_joints_across_time[:, i], label="gt action joints")
    ax.set_title(f"Joint {i}")
    ax.legend()

plt.tight_layout()
plt.show()


# Plot the images in a row
fig, axes = plt.subplots(nrows=1, ncols=sample_images, figsize=(16, 4))

for i, ax in enumerate(axes):
    ax.imshow(images[i])
    ax.axis("off")

Now we can run the policy from the pretrained checkpoint.

In [None]:
observation = {
    "video": {k: np.stack(step_data.images[k])[None] for k in step_data.images},  # stach images and add batch dimension
    "state": {k: step_data.states[k][None] for k in step_data.states},  # add batch dimension
    "action": {k: step_data.actions[k][None] for k in step_data.actions},  # add batch dimension
    "language": {
        modality_config["language"].modality_keys[0]: [[step_data.text]],  # add time and batch dimension
    }
}
predicted_action, info = policy.get_action(observation)
for key, value in predicted_action.items():
    print(key, value.shape)

For more details on the policy (e.g. expected input and output), please refer to the [policy documentation](policy.md).