In [1]:
import os
import torch
from PIL import Image
import numpy as np
import cv2

In [2]:
from transformers import AutoModelForVision2Seq

In [3]:
import robosuite as suite
from robosuite.controllers import load_controller_config
from robosuite.wrappers import Wrapper





In [4]:
from transformers import AutoProcessor

2025-11-18 11:34:27.177009: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
device = torch.device("cuda")

In [6]:
processor = AutoProcessor.from_pretrained(
    "openvla/openvla-7b",
    trust_remote_code=True,
    cache_dir = "./vla_cache"
)

In [7]:
# vla = AutoModelForVision2Seq.from_pretrained(
#     "openvla/openvla-7b",
#     torch_dtype=torch.bfloat16,
#     trust_remote_code=True,
#     low_cpu_mem_usage = True,
#     cache_dir = "./vla_cache"
# )

In [8]:
vla = AutoModelForVision2Seq.from_pretrained(
    "./vla_cache",
    local_files_only = True,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage = True,
    trust_remote_code=True,
)

Expected `transformers==4.40.1` and `tokenizers==0.19.1` but got `transformers==4.46.3` and `tokenizers==0.20.3`; there might be inference-time regressions due to dependency changes. If in doubt, pleaseuse the above versions.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
vla = vla.to(device)

In [59]:
env = suite.make(
    "PickPlace",                     # simple manipulation task
    robots="Panda",             # Franka Panda 7-DoF arm
    has_renderer=True,          # <-- GUI window ON
    has_offscreen_renderer=True,
    use_camera_obs=True,        # return camera obs
    camera_names=["frontview", "birdview", "robot0_robotview"], # one camera is enough
    camera_heights=224,
    camera_widths=224,
    render_camera="frontview",  # show this in GUI
    control_freq=20,            # 20 Hz control
)
env = Wrapper(env)
obs = env.reset()

In [50]:
# from IPython.display import display
# frame = obs["frontview_image"]
# prompt = "take the cube"
# img = Image.fromarray(frame).transpose(method=Image.FLIP_TOP_BOTTOM)
# display(img)
# inps = processor(text = [prompt], images = [img]).to(device=device, dtype=torch.bfloat16)
# with torch.no_grad():
#     action = vla.predict_action(
#         **inps,
#         unnorm_key = "bridge_orig",
#         do_sample = False
#     )
# action = action.tolist()
# print(action)

In [60]:
env.render()
prompt = "Pick up the objects and place them in correct locations"
for step in range(300):
    frame1 = obs["robot0_robotview_image"]
    # image1 = Image.fromarray(frame1)
    image1 = Image.fromarray(frame1).transpose(method=Image.FLIP_TOP_BOTTOM)
    # frame2 = obs["birdview_image"]
    # image2 = Image.fromarray(frame2)
    env.render()
    inputs = processor(text = [prompt], images = [image1]).to(device=device, dtype=torch.bfloat16)
    with torch.no_grad():
        action = vla.predict_action(
            **inputs,
            unnorm_key = 'viola',
            do_sample=False
        )
    action = action.tolist()
    action.append(action[-1])
    # print(action)
    obs, reward, done, info = env.step(action)

In [55]:
env.close()

In [53]:
obs

OrderedDict([('robot0_joint_pos_cos',
              array([ 0.99999995,  0.97629465,  0.99998839, -0.88489527,  0.99928289,
                     -0.97731779,  0.71354295])),
             ('robot0_joint_pos_sin',
              array([ 3.19153199e-04,  2.16445742e-01, -4.81855473e-03, -4.65790030e-01,
                     -3.78643331e-02,  2.11778045e-01,  7.00611492e-01])),
             ('robot0_joint_vel', array([0., 0., 0., 0., 0., 0., 0.])),
             ('robot0_eef_pos',
              array([-0.07177103, -0.1005209 ,  0.98834387])),
             ('robot0_eef_quat',
              array([0.99941589, 0.02103147, 0.02642837, 0.00520536])),
             ('robot0_gripper_qpos', array([ 0.020833, -0.020833])),
             ('robot0_gripper_qvel', array([0., 0.])),
             ('frontview_image',
              array([[[143, 125, 113],
                      [146, 126, 115],
                      [147, 127, 116],
                      ...,
                      [123, 110,  97],
            