In [1]:
import os, sys
sys.path.append('.')

from dataclasses import dataclass
from pathlib import Path

import draccus
import torch
import torch.distributed as dist
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.optim.lr_scheduler import StepLR
from transformers import get_linear_schedule_with_warmup

from vla.base_prompter import PurePromptBuilder
from vla.utils import PaddedCollatorForPosePrediction, runningLoss
from vla.action_tokenizer import RLbenchPoseTokenizer
from vla.dataset import RLbenchCotDataset
import numpy as np
import torch.nn.functional as F

import numpy as np
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import ArmActionMode, JointVelocity, JointPosition, EndEffectorPoseViaPlanning, EndEffectorPoseViaIK


from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.environment import Environment
from rlbench.observation_config import ObservationConfig, CameraConfig
# from rlbench.tasks.pick_described_object import PickDescribedObject
from rlbench.tasks import PutGroceriesInCupboard, PickAndLift, StackBlocks, PlaceHangerOnRack, PickDescribedObject, TakeLidOffSaucepan, SetTheTable
from scipy.spatial.transform import Rotation as R
from matplotlib import pyplot as plt
from PIL import Image


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_model_path = "/media/lawrence/Work/checkpoints/ecot-openvla-7b-bridge"
adapter_path = "adapter-tmp/nll_loss_cot_chat_random_init_angle+pick_described_object+e2+b8+lr-5e-05+lora-r16+dropout-0.0+q-4bit"
# save_path = "/media/lawrence/Work/checkpoints/vla-rl-ecot"

In [3]:
dataset_statistics: tuple = (np.array([-0.27499999, -0.65500004,  0.75199986, -np.pi, -np.pi, -np.pi,  0. ]), np.array([0.77499999, 0.65500004, 1.75199986, np.pi, np.pi, np.pi, 1.])) # Min-Max normalization statistics


In [4]:
processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
action_tokenizer = RLbenchPoseTokenizer(processor.tokenizer,dataset_statistics)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
quantization_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", #llm_int8_skip_modules = ['projector'],
        )
vla = AutoModelForVision2Seq.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",
        quantization_config=quantization_config,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        device_map = "cuda"
    )
vla.load_adapter(adapter_path)

Loading checkpoint shards: 100%|██████████| 3/3 [00:33<00:00, 11.08s/it]


In [6]:
class Agent(object):
    def __init__(self, vla, processor, action_tokenizer):
        self.vla = vla
        self.processor = processor
        self.action_tokenizer = action_tokenizer

    def get_openvla_prompt(self, instruction: str):
        SYSTEM_PROMPT = (
        "A chat between a curious user and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
        )
        return f"{SYSTEM_PROMPT} USER: What action should the robot take to {instruction.lower()}? ASSISTANT: Let's think step by step,"

    def act(self, obs, instr):
        prompt = self.get_openvla_prompt(instr)
        image = Image.fromarray(obs.front_rgb)
        inputs = self.processor(prompt, image).to(self.vla.device, dtype=torch.bfloat16)
        output_dict = self.vla.generate(**inputs, max_new_tokens = 1024, output_scores = True, return_dict_in_generate=True)
        action_mask, gripper_mask, object_mask, target_mask = self.action_tokenizer.get_mask(output_dict.sequences)
        print(self.processor.tokenizer.decode(output_dict.sequences[0]))
        output_logits = torch.stack(output_dict.scores, dim = 1)
        action_logits = output_logits[action_mask[:,-output_logits.size(1):]][:,self.action_tokenizer.action_token_begin_idx:self.processor.tokenizer.vocab_size].view(1,-1,self.action_tokenizer.n_bins)
        gripper_logits = output_logits[gripper_mask[:,-output_logits.size(1):]][:,self.action_tokenizer.action_token_begin_idx:self.processor.tokenizer.vocab_size].view(1,-1,self.action_tokenizer.n_bins)
        object_logits = output_logits[object_mask[:,-output_logits.size(1):]][:,self.action_tokenizer.action_token_begin_idx:self.processor.tokenizer.vocab_size].view(1,-1,self.action_tokenizer.n_bins)
        target_logits = output_logits[target_mask[:,-output_logits.size(1):]][:,self.action_tokenizer.action_token_begin_idx:self.processor.tokenizer.vocab_size].view(1,-1,self.action_tokenizer.n_bins)
        action = self.action_tokenizer.decode(action_logits).cpu().numpy().squeeze(0)
        return action
        

In [7]:
agent = Agent(vla, processor, action_tokenizer)

In [8]:
camera = CameraConfig(image_size=(224, 224), depth=False, point_cloud=False, mask=False)
obs_config = ObservationConfig(left_shoulder_camera=camera, right_shoulder_camera=camera, front_camera=camera, overhead_camera=camera)

env = Environment(
    action_mode=MoveArmThenGripper(
        arm_action_mode=EndEffectorPoseViaPlanning(absolute_mode=True, collision_checking=False), gripper_action_mode=Discrete()),
    obs_config=obs_config,
    headless=False)
env.launch()

In [9]:
task = env.get_task(PickDescribedObject)
training_steps = 1000
episode_length = 100
obs = None
for i in range(training_steps):
    if i % episode_length == 0:
        print('Reset Episode')
        descriptions, obs = task.reset()
        print(descriptions[1])
    try:
        action = agent.act(obs,descriptions[1])
        action_rotation = R.from_euler('xyz', action[3:6])
        action_quaternion = action_rotation.as_quat()
        # print(delta_quaternion)  # returns (qx, qy, qz, qw)
        action = np.concatenate([action[0:3], action_quaternion, action[-1:]])
        print(action)
        obs, reward, terminate = task.step(action)
        print(reward)
    except Exception as e:
        print(e)
        continue

Reset Episode
pick up the chocolate jello and place in the basket
<s> A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: What action should the robot take to pick up the chocolate jello and place in the basket? ASSISTANT: Let's think step by step, we need to first move to the chocolate jello and pick it up. The chocolate jello is now in front of the gripper, we need to close the gripper. The chocolate jello is now below the gripper, so we need to move down and close the gripper. The gripper is now closed, and the chocolate jello is below the gripper, so we need to lift the chocolate jello. The basket is now empty, so we need to open the gripper. The chocolate jello is now inside the basket, so we need to stop. The task is now complete, so we need to stop. The robot task [96, 1, 155, 146], and place it [27, 90, 90, 161] ACTION: ந鳥才Ἐ백索Ÿ</s>
Expected `angles` to be at most 2-dimen

KeyboardInterrupt: 

In [None]:
env.shutdown()

[CoppeliaSim:loadinfo]   done.


In [26]:
test_data_path: Path = Path(f"./datasets/pick_described_object/test_data1.pt")
testset = RLbenchCotDataset(
    test_data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder,
)
collator = PaddedCollatorForPosePrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
test_dataloader = DataLoader(
    testset,
    batch_size=2,
    sampler=None,
    collate_fn=collator,
    num_workers=1,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)


In [29]:
test = next(iter(test_dataloader))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [46]:
processor.tokenizer.decode([ 1, 29871,    13,  4706,512,4706])

'<s> \n        In       '

In [35]:
test['input_ids'][0]

tensor([    1, 29871,    13,  4706,   512, 29901,  1724,  3158,   881,   278,
        19964,  2125,   304,  5839,   701,   278,  1818,   538,   322,  2058,
          297,   278, 25972, 29973,    13,  4706,  4451, 29901,  2803, 29915,
        29879,  1348,  4331,   491,  4331, 29892, 14893,   491,  8401,   278,
          330,   374,  2496,   304,   278,  1818,   538,   322,  5839,   292,
          372,   701, 29889,  2860,   393, 29892,  4337,   278,   330,   374,
         2496,   975,   278, 25972,   322,  2058,   278,  1818,   538,  2768,
        29889,   450,  1818,   538,   338,  5982,   472, 32005, 31458, 31532,
        31603, 32006, 29892,   278, 25972, 29915, 29879,  2602,   338, 32007,
        31398, 31503, 31625, 32008, 29892,   322,   278,   330,   374,  2496,
        29915, 29879, 18593,   338, 32003, 31461, 31539, 31601, 31703, 31856,
        31965, 31999, 32004, 29889,   450,   330,   374,  2496, 22602, 29915,
        29873, 25274,   287,   278,  1818,   538, 29889,   450, 

In [34]:
processor.tokenizer.decode(test['input_ids'][0])

"<s> \n        In: What action should the robot take to pick up the mustard and place in the basket?\n        Out: Let's think step by step, Begin by moving the gripper to the mustard and picking it up. After that, move the gripper over the basket and place the mustard inside. The mustard is located at<object>산线호</object>, the basket's position is<target>交宇ൾ</target>, and the gripper's pose is<gripper>密球ུ守터進给</gripper>. The gripper hasn't grasped the mustard. The current task step is Move the gripper to the mustard and pick it up. and the next key pose of the gripper to perform is<action>န洋그군ھ双给</action>.</s>"

In [None]:
instr

'pick up the chocolate jello and place in the basket'