In [1]:

import torch
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training, PeftConfig
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.optim.lr_scheduler import StepLR
from transformers import get_linear_schedule_with_warmup

from vla.base_prompter import PurePromptBuilder
from vla.utils import PaddedCollatorForPosePrediction, runningLoss
from vla.action_tokenizer import RLbenchPoseTokenizer
from vla.dataset import RLbenchCotDataset
import numpy as np
import torch.nn.functional as F

import numpy as np
from PIL import Image

from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import ArmActionMode, JointVelocity, JointPosition, EndEffectorPoseViaPlanning, EndEffectorPoseViaIK


from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.environment import Environment
from rlbench.observation_config import ObservationConfig, CameraConfig
# from rlbench.tasks.pick_described_object import PickDescribedObject
from rlbench.tasks import PutGroceriesInCupboard, PickAndLift, StackBlocks, PlaceHangerOnRack, PickDescribedObject, TakeLidOffSaucepan, SetTheTable, PutGroceriesInCupboard
from scipy.spatial.transform import Rotation as R
from matplotlib import pyplot as plt
from PIL import Image
from pyrep.const import RenderMode

2024-08-07 11:34:14.290491: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-07 11:34:14.290535: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-07 11:34:14.293438: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-07 11:34:14.301269: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
base_model_path = "/media/lawrence/Work/checkpoints/ecot-openvla-7b-bridge"
adapter_path = "adapter-tmp/dynamic_weighted_loss+weighted+pick_described_object_replay+e1+b8+lr-0.0005+lora-r16+dropout-0.0+q-4bit"
data_path = "datasets/pick_described_object_replay/data.pt"
# test_adapter = "adapter-tmp/nosie_data_dynamic_weight+weighted+pick_described_object_replay1+e1+b8+lr-0.0005+lora-r16+dropout-0.0+q-4bit"
# adapter_path1 = "adapter-tmp/weighted_loss_cot_1+nll+pick_described_object2+e1+b8+lr-0.0001+lora-r16+dropout-0.0+q-4bit"

In [4]:
quantization_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", #llm_int8_skip_modules = ['projector'],
        )
basemodel = AutoModelForVision2Seq.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",
        quantization_config=quantization_config,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        device_map = "cuda"
    )
basemodel = prepare_model_for_kbit_training(basemodel)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
item_num = 5
stage_num = 2 
add_tokens = ['<g>', '</g>'] + [f'<item_{i}>' for i in np.arange(item_num)] + ['<o>', '</o>', '<t>', '</t>'] + [f'<stage_{i}>' for i in np.arange(stage_num)] + ['<a>', '</a>', '<q>', '<cot>']

processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
processor.tokenizer.add_tokens(add_tokens)
dataset_statistics: tuple = (np.array([-0.2, -0.35,  0.75199986, -np.pi/2, -np.pi/2, -np.pi/2,  0. ]), np.array([0.5, 0.35, 1.3, np.pi/2, 0, np.pi/2, 1.])) # Min-Max normalization statistics

action_tokenizer = RLbenchPoseTokenizer(processor.tokenizer, dataset_statistics)
trainset = RLbenchCotDataset(
    data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
)

collator = PaddedCollatorForPosePrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)

train_dataloader = DataLoader(
    trainset,
    batch_size=2,
    shuffle=False,
    sampler=None,
    collate_fn=collator,
    num_workers=1,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

In [6]:
batch = next(iter(train_dataloader))

In [7]:
torch.cuda.empty_cache()
vla = PeftModel.from_pretrained(basemodel, adapter_path, adapter_name="online", is_trainable=True)
vla.load_adapter(adapter_path, adapter_name="target", is_trainable=True)
vla.set_adapter("online")
vla.add_weighted_adapter(["online", "target"], [0.05,0.95], combination_type="linear", adapter_name="target1")

In [7]:
# device_id = 'cuda'
# with torch.autocast("cuda", dtype=torch.bfloat16):
#     output = vla.generate(
#         input_ids=batch["input_ids"][0:1,:-28].to(device_id),
#         # attention_mask=batch["attention_mask"].to(device_id),
#         pixel_values=batch["pixel_values"][0:1].to(torch.bfloat16).to(device_id),
#         # labels=batch["labels"],
#         # use_cache=False
#         max_new_tokens = 28,
#     )
# processor.tokenizer.decode(output[0])

In [14]:
class Agent(object):
    def __init__(self, vla, processor, action_tokenizer):
        self.vla = vla
        self.processor = processor
        self.action_tokenizer = action_tokenizer

    def get_actor_prompt(self, gripper, instruction: str):
        gripper = self.action_tokenizer(gripper)
        prompt = (
            f"In: What should be the next key pose of the gripper to {instruction}? The current gripper pose is <g>{gripper} </g>.\n "
            "Out: <cot>"
        )
        return prompt

    def get_critic_prompt(self, gripper, instruction: str, item: int, object_: str, target: str, stage: int, action: str):
        prompt = (
            f"In: What should be the next key pose of the gripper to {instruction}? The current gripper pose is <g>{gripper} </g>.\n "
            f"Out: <cot> <item_{item}>, <o>{object_} </o>, <t>{target} </t>, <stage_{stage}>, <a>{action} </a>, <q>"
        )
        return prompt

    def act(self, obs, instr,temperature = 1, deterministic = False):
        gripper = np.concatenate([obs.gripper_pose,[obs.gripper_open]])
        prompt = self.get_actor_prompt(gripper,instr)
        image = Image.fromarray(obs.front_rgb)
        inputs = self.processor(prompt, image).to(self.vla.device, dtype=torch.bfloat16)
        while True:
            with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
                output_dict = self.vla.generate(
                    **inputs,
                    max_new_tokens = 28,
                    do_sample=False,
                    temperature=1,
                    return_dict_in_generate=True,
                    output_scores = True,
                )
            # output_dict = vla.generate(**inputs, max_new_tokens = 50, output_scores = True, return_dict_in_generate=True, do_sample = True, temperature = 0.5)
            gripper_mask, item_mask, object_mask, target_mask, stage_mask, action_mask = self.action_tokenizer.get_mask(output_dict.sequences)
            print(processor.tokenizer.decode(output_dict.sequences[0]))
            if action_mask.sum().item() != 7:
                print("Action mask is not correct")
                continue
            break
        
        output_logits = torch.stack(output_dict.scores, dim = 1)
        action_logits = output_logits[action_mask[:,-output_logits.size(1):]][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(1,-1,action_tokenizer.n_bins)
        action = self.action_tokenizer.get_action(action_logits,temperature = temperature, deterministic = deterministic)
        action_rotation = R.from_euler('xyz', action[3:6])
        action_quaternion = action_rotation.as_quat()
        action = np.concatenate([action[0:3], action_quaternion, action[-1:]])
        # get_action(self, logits: torch.tensor, temperature: float = 1, deterministic: bool = False)
        # q = output_logits[:,-1, 32016]

        return action#, q
    
    # def get_q(self, batch):
        
agent = Agent(vla,processor,action_tokenizer)

In [15]:
def get_data(task, variation_num):
    obs = task._scene.get_observation()
    img = Image.fromarray(obs.front_rgb,'RGB')
    gripper_pose = obs.gripper_pose
    gripper_open = obs.gripper_open
    object_pos = task._task.get_graspable_objects()[variation_num].get_position()
    target_pos = task._task.dropin_box.get_position()
    return img, gripper_pose, gripper_open, object_pos, target_pos


In [10]:
# class ReplayBuffer:
#     """
#     A simple FIFO experience replay buffer for DDPG agents.
#     """

#     def __init__(self, obs_dim, gripper_dim, act_dim, size):
#         self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
#         self.gripper_buf = np.zeros(size,)
#         self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
#         self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
#         self.rew_buf = np.zeros(size, dtype=np.float32)
#         self.done_buf = np.zeros(size, dtype=np.float32)
#         self.ptr, self.size, self.max_size = 0, 0, size

#     def store(self, obs, act, rew, next_obs, done):
#         self.obs_buf[self.ptr] = obs
#         self.obs2_buf[self.ptr] = next_obs
#         self.act_buf[self.ptr] = act
#         self.rew_buf[self.ptr] = rew
#         self.done_buf[self.ptr] = done
#         self.ptr = (self.ptr+1) % self.max_size
#         self.size = min(self.size+1, self.max_size)

#     def sample_batch(self, batch_size=32):
#         idxs = np.random.randint(0, self.size, size=batch_size)
#         batch = dict(obs=self.obs_buf[idxs],
#                      obs2=self.obs2_buf[idxs],
#                      act=self.act_buf[idxs],
#                      rew=self.rew_buf[idxs],
#                      done=self.done_buf[idxs])
#         return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}


# RLbench Env

In [8]:
camera = CameraConfig(image_size=(224, 224), depth=False, point_cloud=False, mask=False)
obs_config = ObservationConfig(left_shoulder_camera=camera, right_shoulder_camera=camera, front_camera=camera, overhead_camera=camera)
obs_config.front_camera.render_mode = RenderMode.OPENGL

env = Environment(
    action_mode=MoveArmThenGripper(
        arm_action_mode=EndEffectorPoseViaPlanning(absolute_mode=True, collision_checking=False), gripper_action_mode=Discrete()),
    obs_config=obs_config,
    headless=False, shaped_rewards = True)
env.launch()
task = env.get_task(PickDescribedObject)

# Train DDPG

In [10]:
import time

In [11]:
max_ep_len = 30
steps_per_epoch = 30
epochs = 100
start_steps = 1000
update_after = 1000
update_every = 50

In [31]:
num_test_episodes = 10
TestEpRet = []
TestEpLen = []
def test_agent():
    for j in range(num_test_episodes):
        task.sample_variation()
        (ins, o), d, ep_ret, ep_len = task.reset(), False, 0, 0
        while not(d or (ep_len == max_ep_len)):
            # Take deterministic actions at test time (noise_scale=0)
            o, r, d = task.step(agent.act(o, ins[0], temperature= 1))
            ep_ret += r
            ep_len += 1
        TestEpRet.append(ep_ret)
        TestEpLen.append(ep_len)
        

In [35]:
torch.cuda.empty_cache()
test_agent()

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 

: 

In [None]:
trainset.data.keys()

dict_keys(['images', 'instructions', 'grippers', 'items', 'objects', 'targets', 'stages', 'actions', 'rewards', 'dones', 'next_images', 'next_grippers'])

In [21]:
# imgs = 
data = torch.load(data_path)

In [22]:
data['grippers'][0]

array([-0.08452094,  0.02624014,  1.22730494, -2.66500001, -0.74425952,
       -1.40327554,  1.        ])

In [23]:
img = Image.open(data['images'][0])
instruction = data['instructions'][0]
gripper = action_tokenizer(data['grippers'][0])
item = data['items'][0]
object_ = action_tokenizer(data['objects'][0])
target = action_tokenizer(data['targets'][0])
action = action_tokenizer(data['actions'][0])

stage = data['stages'][0]
critic_prompt = agent.get_critic_prompt(gripper,instruction,item,object_,target,stage,action)
# input_ids = action_tokenizer.tokenizer(critic_prompt)
inputs = processor(critic_prompt, img).to(agent.vla.device, dtype=torch.bfloat16)

In [43]:
img = Image.open(data['images'][0])
instruction = data['instructions'][0]
gripper = action_tokenizer(data['grippers'][0])
item = data['items'][0]
object_ = action_tokenizer(data['objects'][0])
target = action_tokenizer(data['targets'][0])
action = action_tokenizer(data['actions'][0])

stage = data['stages'][0]
critic_prompt = agent.get_critic_prompt(gripper,instruction,item,object_,target,stage,action)
# input_ids = action_tokenizer.tokenizer(critic_prompt)
inputs = processor(critic_prompt, img).to(agent.vla.device, dtype=torch.bfloat16)
agent.vla.train()
# agent.vla.gradient_checkpointing_enable()
vla.set_adapter("online")
# vla.requires_grad_ = True
with torch.autocast("cuda", dtype=torch.bfloat16):
    output_dict = basemodel.generate_with_grad(
                    **inputs,
                    max_new_tokens = 1,
                    do_sample=False,
                    temperature=1,
                    return_dict_in_generate=True,
                    output_scores = True,
                    # output_logits = True
                )
q = output_dict.scores[0][0,32016]

TypeError: _DecoratorContextManager.clone() got an unexpected keyword argument 'input_ids'

In [45]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'pixel_values'])

In [37]:
from types import MethodType

In [42]:
generate_with_grad = undecorated(basemodel.generate)
basemodel.generate_with_grad = MethodType(generate_with_grad, basemodel)

In [27]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'pixel_values'])

In [28]:
inputs.pixel_values.requires_grad = True

In [29]:
next_img = Image.open(data['next_images'][0])
next_gripper = (data['actions'][0])
actor_prompt = agent.get_actor_prompt(next_gripper,instruction)
inputs = processor(actor_prompt,next_img).to(agent.vla.device, dtype=torch.bfloat16)
vla.eval()
vla.set_adapter("target")
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    output_dict = agent.vla.generate(
                    **inputs,
                    max_new_tokens = 28,
                    do_sample=False,
                    temperature=1,
                    return_dict_in_generate=True,
                    output_scores = True,
                )
q_pi_targ = output_dict.scores[-1][0,32016]

In [3]:
import torch
dt = torch.load("datasets/pick_described_object_replay1/data.pt")
dt.keys()

dict_keys(['images', 'instructions', 'grippers', 'items', 'objects', 'targets', 'stages', 'actions', 'rewards', 'dones', 'next_images', 'next_grippers'])

In [9]:
len(dt['items'])

1750

In [1]:
q.requires_grad

NameError: name 'q' is not defined

In [31]:
with torch.autocast("cuda", dtype=torch.bfloat16):
    output_dict = agent.vla.generate(
                    **inputs,
                    max_new_tokens = 28,
                    do_sample=False,
                    temperature=1,
                    return_dict_in_generate=True,
                    output_scores = True,
                )

In [32]:
gamma = 0.99
d = data['dones'][0]
r = data['rewards'][0]
backup = r + gamma * (1 - d) * q_pi_targ

In [33]:
loss_q = ((q - backup)**2).mean()

In [34]:
loss_q

tensor(0.0337, device='cuda:0')

In [35]:
# Prepare for interaction with environment
total_steps = steps_per_epoch * epochs
start_time = time.time()
task.sample_variation()
o, ep_ret, ep_len= task.reset(), 0, 0
item = task._variation_number

NameError: name 'task' is not defined

In [None]:
# Main loop: collect experience in env and update/log each epoch
for t in range(total_steps):
    
    # Until start_steps have elapsed, randomly sample actions
    # from a uniform distribution for better exploration. Afterwards, 
    # use the learned policy (with some noise, via act_noise). 
    if t > start_steps:
        a, q = agent.act(o, temperature=1, deterministic=False)
    else:
        a, q = agent.act(o, temperature=10, deterministic=False)

    # Step the env
    o2, r, d = task.step(a)
    ep_ret += r
    ep_len += 1

    # Ignore the "done" signal if it comes from hitting the time
    # horizon (that is, when it's an artificial terminal signal
    # that isn't based on the agent's state)
    d = False if ep_len==max_ep_len else d

    # Store experience to replay buffer
    replay_buffer.store(o, a, q, r, o2, d)

    # Super critical, easy to overlook step: make sure to update 
    # most recent observation!
    o = o2

    # End of trajectory handling
    if d or (ep_len == max_ep_len):
        logger.store(EpRet=ep_ret, EpLen=ep_len)
        o, ep_ret, ep_len = task.reset(), 0, 0

    # Update handling
    if t >= update_after and t % update_every == 0:
        for _ in range(update_every):
            batch = replay_buffer.sample_batch(batch_size)
            update(data=batch)

    



In [None]:
agent.vla.set_adapter("actor_critic")
agent.vla.active_adapter


'actor_critic'

In [None]:
import gc
gc.collect()

306

In [None]:
data = torch.load(data_path)

In [None]:
data.keys()

dict_keys(['images', 'instructions', 'grippers', 'items', 'objects', 'targets', 'stages', 'actions', 'rewards', 'dones', 'next_images', 'next_grippers'])

In [None]:
torch.cuda.empty_cache()

In [22]:
task.sample_variation()
descriptions, obs = task.reset()
agent.vla.set_adapter("online")
# agent.vla.active_adapter
torch.cuda.empty_cache()
# Image.fromarray(obs.front_rgb)
while True:
    try:
        action = agent.act(obs, descriptions[0], temperature=1, deterministic=False)
        print(action)
        obs, reward, terminate = task.step(action)
        print(reward, terminate)
    except Exception as e:
        # continue
        print(e)
        pass

<s> In: What should be the next key pose of the gripper to put the sugar packet in the basket? The current gripper pose is<g>ܐ宝親ḳ瀬</g>.
 Out:<cot><item_4>,<o>森ˆ态</o>,<t>交ペམ</t>,<stage_0>,<a>森ˆ್면ḳ彦弘</a>,<q>,
[ 4.05499991e-01 -3.32500000e-01  8.91739893e-01  8.64664612e-01
 -5.02272840e-01 -5.49101968e-04  8.76365217e-03  0.00000000e+00]
1.4000289033513216 False
<s> In: What should be the next key pose of the gripper to put the sugar packet in the basket? The current gripper pose is<g>森್̂ữḳ梅</g>.
 Out:<cot><item_4>,<o>森̂态</o>,<t>交ペམ</t>,<stage_1>,<a>交ペ电면ḳ彦弘</a>,<q>,
[-0.1965     -0.3255      1.19861997  0.96240976 -0.27145967  0.00164758
  0.00862488  0.        ]
-3.668792332815686 False
<s> In: What should be the next key pose of the gripper to put the sugar packet in the basket? The current gripper pose is<g>交ペ电ữḳ深</g>.
 Out:<cot><item_4>,<o>ホペ态</o>,<t>交ペམ</t>,<stage_0>,<a>ホペ್면ḳ彦弘</a>,<q>,
[ 0.15349999 -0.3255      0.89173989  0.88004402  0.47481094  0.00718534
  0.00504718  0.       

KeyboardInterrupt: 

In [17]:
agent.act(obs, descriptions[0], temperature=10, deterministic=False)

<s> In: What should be the next key pose of the gripper to put the soup container in the basket? The current gripper pose is<g>记－节ữ菜梅</g>.
 Out:<cot><item_1>,<o>意黒関</o>,<t>交ペམ</t>,<stage_0>,<a>意黒共면ḳ彦弘</a>,<q>,


array([ 0.0835    , -0.2625    ,  1.12737995,  0.95266526, -0.30145016,
       -0.00809692,  0.03861537,  0.        ])