# ARM4R Inference Notebook

This notebook runs inference using the ARM4R policy on a real-world Kinova robot demonstration. It loads model weights, visualizes the demo, and plots predicted versus ground truth actions.


## Step 1: Import Dependencies
We begin by importing all required libraries.

In [None]:
import json
import os
from arm4r.models.policy.arm4r_wrapper import ARM4RWrapper
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import trange
import zarr

## Step 2: Define Helper Function
This function loads image, proprioception, and action data from the demo path.

In [None]:
def get_data_from_demo_path(demo_path, length, image_resolution=[[-1,256,456,3], [-1, 256,456,3]], 
                             observation_key=['observation/exterior_image_1_left', 'observation/wrist_image_left'], 
                             points_key='wrist_points', return_PIL_images=False):
    camera_observations = {}
    for key in observation_key:
        camera_observations[key] = []
        for img_dix in range(length):
            try:
                if os.path.exists(os.path.join(demo_path, 'images')):
                    current_frame_path = os.path.join(demo_path, 'images', os.path.basename(key), '%05d.jpg' % img_dix)
                    current_frame = Image.open(current_frame_path)
                else:
                    image_json = json.load(open(os.path.join(demo_path, 'images.json')))
                    current_frame_path = image_json[key][img_dix]
                    current_frame = Image.open(current_frame_path)
            except:
                current_frame = np.random.randint(0, 255, image_resolution[0][1:], dtype="uint8")
            if not current_frame.size == (image_resolution[0][2], image_resolution[0][1]):
                current_frame = current_frame.resize((image_resolution[0][2], image_resolution[0][1]))
            camera_observations[key].append(np.asarray(current_frame))
        camera_observations[key] = np.stack(camera_observations[key])

    side_images_l = camera_observations[observation_key[0]]
    wrist_images_l = camera_observations[observation_key[1]]

    proprios = zarr.load(os.path.join(demo_path, 'proprio.zarr'))
    actions = zarr.load(os.path.join(demo_path, 'action.zarr'))
    instruction = zarr.load(os.path.join(demo_path, 'instruction.zarr'))

    if return_PIL_images:
        side_images_l = [Image.fromarray(side_images_l[i]) for i in range(side_images_l.shape[0])]
        wrist_images_l = [Image.fromarray(wrist_images_l[i]) for i in range(wrist_images_l.shape[0])]

    return {
        "side_images": side_images_l,
        "wrist_images": wrist_images_l,
        "actions": actions,
        "proprios": proprios,
        "instruction": instruction
    }

## Step 3: Initialize Model and Load Configurations

Your need to pass all your data and checkpoint path here.

In [None]:
demo_path = 'PREFIX/real_kinova_release_data/pick_cube/pick_yellow_cube/kinova_tasks/common_task/2025-01-06T18:49:39.032022' # choose a random demo from the dataset
checkpoint_path = '../arm4r-ckpts/model_ckpts/ft_kinova_pick_cube/ft_kinova_pick_cube.pth' # here you can use our realsed pick cube Kinova ckpt
train_yaml_path = os.path.join(os.path.dirname(checkpoint_path), "run.yaml")
vision_encoder_path = "../arm4r-ckpts/vision_encoder/cross-mae-rtx-vitb.pth"

image_resolution = [[-1,224,224,3], [-1, 224,224,3]]
observation_key = ['observation/exterior_image_1_left', 'observation/wrist_image_left']
points_key = 'action'

arm4r = ARM4RWrapper(train_yaml_path, checkpoint_path, vision_encoder_path)

## Step 4: Load Demonstration Data

In [None]:
length = zarr.load(os.path.join(demo_path, f'{points_key}.zarr')).shape[0] - 1

obs_dict = get_data_from_demo_path(
    demo_path=demo_path,
    length=length,
    image_resolution=image_resolution,
    observation_key=observation_key,
    points_key=points_key
)

side_images_l = obs_dict["side_images"]
wrist_images_l = obs_dict["wrist_images"]
proprios = obs_dict["proprios"]
actions = obs_dict["actions"]
instruction = obs_dict["instruction"]

## Step 5: Visualize the Demonstration
This animates the camera views from the side and wrist.

In [None]:
frames = np.concatenate([side_images_l, wrist_images_l], axis=1)
T = frames.shape[0]

fig, ax = plt.subplots()
img = ax.imshow(frames[0])

def update(frame):
    img.set_data(frame)
    return [img]

ani = animation.FuncAnimation(fig, update, frames=frames, interval=50, blit=True)
plt.show()

## Step 6: Run Policy Inference
Use the ARM4R model to predict actions step-by-step.

In [None]:
side_images = [Image.fromarray(side_images_l[i]) for i in range(side_images_l.shape[0])]
wrist_images = [Image.fromarray(wrist_images_l[i]) for i in range(wrist_images_l.shape[0])]

arm4r.reset()
pred_actions = []

for i in trange(len(side_images)):
    if i == 0:
        action = None
    else:
        action = actions[i - 1:i]
    action = arm4r(
        side_images[i], wrist_images[i],
        proprios[i:i + 1],
        instruction=instruction,
        action=action,
        use_temporal=True,
        teacher_forcing=False,
        binary_gripper=True,
    )
    pred_actions.append(action)

pred_actions = np.array(pred_actions)

## Step 7: Plot Predicted vs Ground Truth Actions

In [None]:
action_keys = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
T = pred_actions.shape[0]

plt.figure(figsize=(12, 6))
for i in range(7):
    plt.subplot(2, 4, i + 1)
    plt.plot(range(T), pred_actions[:, i], label='Predicted')
    plt.plot(range(T), actions[:-1, :][:, i], label='Ground Truth')
    plt.xlabel('Time Step')
    plt.ylabel('Action Value')
    plt.title(action_keys[i])
    plt.legend()

plt.tight_layout()
plt.show()