In [18]:

import torch
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training, PeftConfig
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.optim.lr_scheduler import StepLR
from transformers import get_linear_schedule_with_warmup

from vla.base_prompter import PurePromptBuilder
from vla.utils import PaddedCollatorForPosePrediction, runningLoss
from vla.action_tokenizer import RLbenchPoseTokenizer
from vla.dataset import RLbenchCotDataset
import numpy as np
import torch.nn.functional as F

import numpy as np
from PIL import Image

In [3]:
base_model_path = "/media/lawrence/Work/checkpoints/ecot-openvla-7b-bridge"
adapter_path = "adapter-tmp/1_sample+nll+pick_described_object+e1+b8+lr-0.0001+lora-r16+dropout-0.0+q-4bit"
adapter_path1 = "adapter-tmp/weighted_loss_cot_1+nll+pick_described_object2+e1+b8+lr-0.0001+lora-r16+dropout-0.0+q-4bit"
data_path = "datasets/pick_described_object/train_data.pt"

In [4]:
quantization_config = BitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", #llm_int8_skip_modules = ['projector'],
        )
base_model = AutoModelForVision2Seq.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",
        quantization_config=quantization_config,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        device_map = "cuda"
    )
base_model = prepare_model_for_kbit_training(base_model)

Loading checkpoint shards: 100%|██████████| 3/3 [00:37<00:00, 12.41s/it]


In [5]:
item_num = 5
stage_num = 2 
add_tokens = ['<g>', '</g>'] + [f'<item_{i}>' for i in np.arange(item_num)] + ['<o>', '</o>', '<t>', '</t>'] + [f'<stage_{i}>' for i in np.arange(stage_num)] + ['<a>', '</a>']

processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True)
processor.tokenizer.add_tokens(add_tokens)
dataset_statistics: tuple = (np.array([-0.2, -0.35,  0.75199986, -np.pi/2, -np.pi/2, -np.pi/2,  0. ]), np.array([0.5, 0.35, 1.3, np.pi/2, 0, np.pi/2, 1.])) # Min-Max normalization statistics

action_tokenizer = RLbenchPoseTokenizer(processor.tokenizer, dataset_statistics)
trainset = RLbenchCotDataset(
    data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
)

collator = PaddedCollatorForPosePrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)

train_dataloader = DataLoader(
    trainset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    collate_fn=collator,
    num_workers=1,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

In [6]:
adapter_config = LoraConfig.from_pretrained(adapter_path)
adapter_config.inference_mode = False

In [7]:
vla = get_peft_model(base_model, adapter_config, adapter_name = "actor_critic")

In [96]:
"<q><PAD></s>"

'<q><PAD></s>'

In [98]:
processor.tokenizer.add_tokens(["<q>"])

1

In [101]:
processor.tokenizer(", <q><PAD></s>")

{'input_ids': [1, 1919, 32016, 32000, 2], 'attention_mask': [1, 1, 1, 1, 1]}

In [79]:
vla.print_trainable_parameters()

trainable params: 48,801,984 || all params: 7,236,926,592 || trainable%: 0.6743


In [65]:
def get_instruct_prompt(action_tokenizer, gripper, instruction: str):
    # gripper = action_tokenizer(gripper)
    prompt = (
        "In: What the next key pose of gripper should the robot take to {instruction}? Current pose is <g>{gripper} </g>, let's think step by step.\n "
        "Out: "
    )
    return prompt

In [56]:
batch = next(iter(train_dataloader))

In [57]:
batch.keys()

dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'labels', 'grippers', 'items', 'objects', 'targets', 'stages', 'actions'])

In [58]:
# Image.fromarray(batch['pixel_values'].numpy())

In [68]:
torch.tensor(processor.tokenizer(get_instruct_prompt(None,1,1)).input_ids)

tensor([    1,   512, 29901,  1724,   278,  2446,  1820, 18593,   310,   330,
          374,  2496,   881,   278, 19964,  2125,   304,   426,  2611,  4080,
        29913, 29973,  9626, 18593,   338, 32001, 29912, 29887,   374,  2496,
        29913, 32002, 29892,  1235, 29915, 29879,  1348,  4331,   491,  4331,
        29889,    13,  4451, 29901, 29871])

In [69]:
batch['input_ids'][:,]

tensor([[    1,   512, 29901,  1724,   881,   367,   278,  2446,  1820, 18593,
           310,   278,   330,   374,  2496,   304,  4337,   278, 26438,  3800,
           304,   278, 25972, 29973,   450,  1857,   330,   374,  2496, 18593,
           338, 32001, 31417, 31501, 31668, 31747, 31888, 31995, 31998,   829,
         29887, 15513,  2803, 29915, 29879,  1348,  4331,   491,  4331, 29889,
            13,  4451, 29901, 32007, 29892, 32008, 31437, 31588, 31613, 32009,
         29892, 32010, 31398, 31501, 31643, 32011, 29892, 32012, 29892, 32014,
         31417, 31501, 31668, 31747, 31888, 31995, 31999, 32015,     2]])

In [None]:
prompt = get_instruct_prompt(gripper,instr)
image = Image.fromarray(obs.front_rgb)
inputs = processor(prompt, image).to(vla.device, dtype=torch.bfloat16)


In [73]:
vla.set_adapter('actor_critic')

In [85]:
base_model.load_adapter(adapter_path1, adapter_name= "test")

In [88]:
base_model.active_adapters()

['test']

In [92]:
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    output = base_model.generate(
        input_ids = batch['input_ids'][:,:-26].to(vla.device),
        pixel_values=batch["pixel_values"].to(torch.bfloat16).to(vla.device),
        max_new_tokens = 40
    )

In [93]:
output.cpu().numpy()

array([[    1,   512, 29901,  1724,   881,   367,   278,  2446,  1820,
        18593,   310,   278,   330,   374,  2496,   304,  4337,   278,
        26438,  3800,   304,   278, 25972, 29973,   450,  1857,   330,
          374,  2496, 18593,   338, 32001, 31417, 31501, 31668, 31747,
        31888, 31995, 31998,   829, 29887, 15513,  2803, 29915, 29879,
         1348,  4331,   491,  4331, 29889,    13,  4451, 29901, 32001,
        31530, 31501, 31742, 31772, 31876, 31958, 31999, 32002, 29892,
        32007, 29892, 32008, 31474, 31613, 32009, 29892, 32010, 31398,
        31501, 31643, 32011, 29892, 32012, 29892, 32014, 31474, 31613,
        31625, 31748, 31897, 31997, 31998, 32015,     2]])

In [94]:
processor.tokenizer.decode(output[0].cpu().numpy())

"<s> In: What should be the next key pose of the gripper to move the sugar box to the basket? The current gripper pose is<g>마ペή면项黃弘</g>. Let's think step by step.\n Out:<g>马ペĖₗ頭ヨ给</g>,<item_4>,<o>意态</o>,<t>交ペམ</t>,<stage_0>,<a>意态ൾķḳ收弘</a></s>"

In [22]:
# vla = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=True, config=adapter_config, device_map="cuda")