In [1]:
import sys
import numpy as np
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import ArmActionMode, JointVelocity, JointPosition, EndEffectorPoseViaPlanning, EndEffectorPoseViaIK
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.environment import Environment
from rlbench.observation_config import ObservationConfig, CameraConfig
from rlbench.tasks import ReachTarget, PickAndLift, StackBlocks, PushButton, StackBlocks, PickUpCup, PlaceHangerOnRack, PickDescribedObject
import matplotlib.pyplot as plt
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from PIL import Image
import torch
from transformers import BitsAndBytesConfig
from pyquaternion import Quaternion
from rlbench.backend.robot import Robot
from scipy.spatial.transform import Rotation
from rlbench.backend.scene import Scene
from pathlib import Path
import os, json

from transformers import AutoModelForVision2Seq
from peft import PeftModel
import argparse
import torch
from vla.action_tokenizer import RLbenchPoseTokenizer
from vla.dataset import RLbenchDataset
from peft import LoraConfig, PeftModel, get_peft_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_model_path = "/media/lawrence/Work/checkpoints/ecot-openvla-7b-bridge"
adapter_path = "adapter-tmp/Weighted_loss_4+pick_described_object+e1+b8+lr-1e-05+lora-r16+dropout-0.0+q-4bit"
save_path = "/media/lawrence/Work/checkpoints/vla-rl-ecot"

In [3]:
dataset_statistics: tuple = (np.array([-0.20173775, -0.36754665,  0.81396234, -3.14153998, -0.38798628, -3.14158631,  0. ]), np.array([0.41802976, 0.45118147, 1.47966564, 3.14159215, 0.30391057, 3.14157801, 1.])) # Min-Max normalization statistics

In [4]:
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
action_tokenizer = RLbenchPoseTokenizer(processor.tokenizer,dataset_statistics)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# base_model = AutoModelForVision2Seq.from_pretrained(
#         base_model_path,
#         torch_dtype=torch.bfloat16,
#         attn_implementation="sdpa",
#         # quantization_config=quantization_config,
#         low_cpu_mem_usage=True,
#         trust_remote_code=True,
#         device_map = "cpu"
#     )
# vla = PeftModel.from_pretrained(base_model, adapter_path)
# vla = vla.merge_and_unload()
# vla.save_pretrained(save_path)

In [6]:
quantization_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", #llm_int8_skip_modules = ['projector'],
        )
vla = AutoModelForVision2Seq.from_pretrained(
        save_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",
        quantization_config=quantization_config,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        device_map = "cuda"
    )
# vla.load_adapter(adapter_path)
# vla.enable_adapters()



Loading checkpoint shards: 100%|██████████| 3/3 [00:30<00:00, 10.22s/it]


In [19]:
class Agent(object):
    def __init__(self, vla, processor, action_tokenizer):
        self.vla = vla
        self.processor = processor
        self.action_tokenizer = action_tokenizer
        
    def act(self, obs, instr):
        prompt = "In: What action should the robot take to <INSTRUCTION>?\nOut:"
        prompt = prompt.replace("<INSTRUCTION>", instr.lower())
        image = Image.fromarray(obs.front_rgb)
        inputs = self.processor(prompt, image).to(self.vla.device, dtype=torch.bfloat16)
        # inputs['input_ids'] = torch.cat((inputs['input_ids'], torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(inputs['input_ids'].device)), dim=1)        
        action_dict = self.vla.generate(**inputs, max_new_tokens = 1024, output_scores = True, return_dict_in_generate=True)
        ation_score = torch.stack(action_dict['scores']).squeeze(1)[:,self.action_tokenizer.action_token_begin_idx:self.processor.tokenizer.vocab_size]
        pred_action = self.action_tokenizer.decode_token_score_to_actions(ation_score, soft = True)
        return pred_action



In [20]:
agent = Agent(vla, processor, action_tokenizer)

In [7]:
camera = CameraConfig(image_size=(224, 224), depth=False, point_cloud=False, mask=False)
obs_config = ObservationConfig(left_shoulder_camera=camera, right_shoulder_camera=camera, front_camera=camera, overhead_camera=camera)

env = Environment(
    action_mode=MoveArmThenGripper(
        arm_action_mode=EndEffectorPoseViaPlanning(absolute_mode=True, collision_checking=False), gripper_action_mode=Discrete()),
    obs_config=obs_config,
    headless=True)
env.launch()



In [9]:
task = env.get_task(PickDescribedObject)

In [14]:
descriptions, obs = task.reset()

In [33]:
SYSTEM_PROMPT = (
    "A chat between a curious user and an artificial intelligence assistant. "
    "The assistant gives helpful, detailed, and polite answers to the user's questions."
)

def get_openvla_prompt(instruction: str) -> str:
    return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT: "

INSTRUCTION = "place the watermelon on the towel"

def get_instruction_prompt(instruction:str) -> str:
    return f"In: "


In [34]:
instr

'pick up the chocolate jello and place in the basket'

In [35]:
instr = descriptions[1]
prompt = get_openvla_prompt(instr)

image = Image.fromarray(obs.front_rgb)

inputs = processor(prompt, image).to(vla.device, dtype=torch.bfloat16)
# inputs['input_ids'] = torch.cat((inputs['input_ids'], torch.unsqueeze(torch.Tensor([29871]).long(), dim=0).to(inputs['input_ids'].device)), dim=1)        
action_dict = vla.generate(**inputs, max_new_tokens = 1024, output_scores = True, return_dict_in_generate=True)

# agent.act(obs, instr)


In [36]:
action_dict.sequences[0]

tensor([    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
        21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
          322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
        29889,  3148,  1001, 29901,  1724,  3158,   881,   278, 19964,  2125,
          304,  5839,   701,   278,   521,   542, 23167,   432,  3156,   322,
         2058,   297,   278, 25972, 29973,   319,  1799,  9047, 13566, 29901,
          521,   542, 23167,   432,  3156, 21521,  1660, 29901, 32005, 16999,
        12064,  5195, 29909,  3094,  4214, 29901,   521,   542, 23167,   432,
         3156,   338,   297,   278,   330,   374,  2496, 29892,   541,   451,
          297,   278, 25972, 29892,   577,  4337,   304,   278, 25972, 16999,
        12064, 29901,  4337,  1492,   402,  3960, 29925, 13171,   349,  3267,
        22122, 29901,   518, 29896, 29896, 29946, 29892, 29871, 29929, 29946,
        29962,   478,  3235,  8979,  1307,   438, 29933, 17637, 

In [37]:
processor.tokenizer.decode(action_dict.sequences[0])

"<s> A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: What action should the robot take to pick up the chocolate jello and place in the basket? ASSISTANT: chocolate jello POSE:<object> MOVE REASONING: chocolate jello is in the gripper, but not in the basket, so move to the basket MOVE: move right GRIPPER POSITION: [114, 94] VISIBLE OBJECTS: and a grey metal machine [99, 1, 155, 145], a yellow cube [52, 139, 70, 171], a white wooden table [15, 123, 239, 250], a robot [93, 0, 159, 147] ACTION: ھპ项ほ唐ċŸ</s>"

: 

In [33]:
(action_dict.sequences == 32002).sum()

tensor(0, device='cuda:0')

In [34]:
instr

'pick up the chocolate jello and place in the basket'

In [35]:
agent.act(obs,descriptions[0])

AttributeError: 'RLbenchPoseTokenizer' object has no attribute 'decode_token_score_to_actions'

In [22]:
task = env.get_task(PickUpCup)
training_steps = 1000
episode_length = 100
obs = None
for i in range(training_steps):
    if i % episode_length == 0:
        print('Reset Episode')
        descriptions, obs = task.reset()
        print(descriptions[1])
    try:
        action = agent.act(obs,descriptions[0]).cpu().numpy()
        action_rotation = Rotation.from_euler('xyz', action[3:6])
        action_quaternion = action_rotation.as_quat()
        # print(delta_quaternion)  # returns (qx, qy, qz, qw)
        action = np.concatenate([action[0:3], action_quaternion, action[-1:]])
        print(action)
        obs, reward, terminate = task.step(action)
        print(reward)
    except Exception as e:
        print(e)
        continue

Reset Episode
grasp the red cup and lift it


  x_pred = F.softmax(x_score) @ torch.tensor(self.x_bin_centers, dtype=torch.float32).to(device)
  y_pred = F.softmax(y_score) @ torch.tensor(self.y_bin_centers, dtype=torch.float32).to(device)
  z_pred = F.softmax(z_score) @ torch.tensor(self.z_bin_centers, dtype=torch.float32).to(device)
  grip_pred = F.softmax(grip_score) @ torch.tensor([0,1], dtype=torch.float32).to(device)


[ 0.27572709  0.24583329  1.46532702 -0.20408997 -0.25716859 -0.85771209
  0.39565334  0.99771273]
0.0
[ 0.20059341  0.05374869  1.33494139  0.23751288 -0.90787488 -0.24849725
  0.23999988  0.715424  ]
0.0
[ 0.20427898  0.33727607  1.2896384  -0.72007487 -0.00627378 -0.58384206
  0.37494169  0.99982733]
0.0
[ 0.22827733  0.29828966  1.36124277 -0.13839479 -0.44249046 -0.10972493
  0.8792096   0.99659038]
0.0
[ 0.22979419  0.16954969  1.41112947 -0.23194927  0.4435576  -0.86251087
  0.074372    0.92944044]
0.0
[ 0.18916911  0.2723802   1.37525153 -0.47196301 -0.3703273  -0.67230804
  0.43371709  0.88991213]
0.0
[ 0.21629931  0.12004524  1.46467423  0.70475247  0.54476549  0.44603796
 -0.08720467  0.99576843]
0.0
[ 0.25534695  0.23317787  1.34824216  0.01521532  0.74741209 -0.10307417
  0.65613975  0.99922061]
0.0
[ 0.19314648  0.23666552  1.42500782  0.28048186 -0.64344575  0.57425778
  0.4213496   0.96138906]
0.0
[ 0.18919125  0.29094267  1.24881077 -0.86361804 -0.13638627 -0.39551703


KeyboardInterrupt: 

In [28]:
env.shutdown()

[CoppeliaSim:loadinfo]   done.


In [None]:
action

array([-0.02001152,  0.02458013, -0.01649398,  0.00402673, -0.01346854,
        0.02819962,  0.99607843])