In [1]:
import os, sys
sys.path.append('.')

from dataclasses import dataclass
from pathlib import Path

import draccus
import torch
import torch.distributed as dist
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

from vla.base_prompter import PurePromptBuilder
from vla.utils import PaddedCollatorForPosePrediction, runningLoss
from vla.action_tokenizer import RLbenchPoseTokenizer
from vla.dataset import save_dataset_statistics, RLbenchCotDataset
import numpy as np
import torch.nn.functional as F


2024-07-29 17:47:32.041431: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-29 17:47:32.041471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-29 17:47:32.042769: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-29 17:47:32.050326: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
@dataclass
class FinetuneConfig:
    # fmt: off
    vla_path: str = "/media/lawrence/Work/checkpoints/ecot-openvla-7b-bridge"   # Path to OpenVLA model 
    vla_path_q: str = "/media/lawrence/Work/checkpoints/openvla-cot-4b"   # Path to OpenVLA model

    experiment_name: str = "Weighted_loss_2"
    dataset_name: str = "pick_described_object"                                # Name of fine-tuning dataset (e.g., `droid_wipe`)
    # data_path: Path = Path(f"./datasets/{dataset_name}/data.pt")
    train_data_path: Path = Path(f"./datasets/{dataset_name}/train_data2.pt")
    test_data_path: Path = Path(f"./datasets/{dataset_name}/test_data2.pt")
    run_root_dir: Path = Path("./runs")                               # Path to directory to store logs & checkpoints
    adapter_dir: Path = Path("./adapter-tmp")                     # Temporary directory for LoRA weights before fusing

    # Fine-tuning Parameters
    seed: int = 42                                                  # Random seed
    episode: int = 1
    batch_size: int = 2#16                                            # Fine-tuning batch size
    test_batch_size: int = 2
    test_limit_length: int = 30
    save_steps: int = 20#5000                                          # Interval for checkpoint saving
    learning_rate: float = 5e-5                                     # Fine-tuning learning rate
    weight_decay: float = 0.01                                       # Fine-tuning weight decay
    grad_accumulation_steps: int = 4                                # Gradient accumulation steps

    # LoRA Arguments
    use_lora: bool = True                                           # Whether to use LoRA fine-tuning
    lora_rank: int = 8#32                                             # Rank of LoRA weight matrix
    lora_dropout: float = 0.0                                       # Dropout applied to LoRA weights
    use_quantization: bool = True                                  # Whether to 4-bit quantize VLA for LoRA fine-tuning
                                                                    #   => CAUTION: Reduces memory but hurts performance
    dataset_statistics: tuple = (np.array([-0.20173775, -0.36754665,  0.81396234, -3.14153998, -0.38798628, -3.14158631,  0. ]), np.array([0.41802976, 0.45118147, 1.47966564, 3.14159215, 0.30391057, 3.14157801, 1.])) # Min-Max normalization statistics

    # Tracking Parameters
    wandb_project: str = "vla-rl"                                  # Name of W&B project to log to (use default!)
    wandb_entity: str = "lawrence-rs-lin-university-of-toronto"                           # Name of entity to log under

    # fmt: on
cfg = FinetuneConfig()

In [3]:
print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")

# [Validate] Ensure GPU Available & Set Device / Distributed Context
assert torch.cuda.is_available(), "Fine-tuning assumes at least one GPU is available!"
distributed_state = PartialState()
torch.cuda.set_device(device_id := distributed_state.local_process_index)
torch.cuda.empty_cache()

# Configure Unique Experiment ID & Log Directory
exp_id = (
    f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
    f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
    f"+lr-{cfg.learning_rate}"
)
if cfg.use_lora:
    exp_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
if cfg.use_quantization:
    exp_id += "+q-4bit"

# Start =>> Build Directories
run_dir, adapter_dir = cfg.run_root_dir / exp_id, cfg.adapter_dir / exp_id
os.makedirs(run_dir, exist_ok=True)

# Quantization Config =>> only if LoRA fine-tuning
quantization_config = None
if cfg.use_quantization:
    assert cfg.use_lora, "Quantized training only supported for LoRA fine-tuning!"
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", #llm_int8_skip_modules = ['projector'],
    )

# Load OpenVLA Processor and Model using HF AutoClasses
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    cfg.vla_path_q,
    torch_dtype=torch.bfloat16,
    # attn_implementation="sdpa",
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map = "auto"
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Fine-tuning OpenVLA Model `/media/lawrence/Work/checkpoints/ecot-openvla-7b-bridge` on `pick_described_object`


The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


In [4]:
# Device Placement =>> note that BitsAndBytes automatically handles for quantized training
if cfg.use_quantization:
    vla = prepare_model_for_kbit_training(vla)
else:
    vla = vla.to(device_id)

# [LoRA] Wrap Model w/ PEFT `LoraConfig` =>> by default we set `target_modules=all-linear`
if cfg.use_lora:
    lora_config = LoraConfig(
        r=cfg.lora_rank,
        lora_alpha=min(cfg.lora_rank, 16),
        lora_dropout=cfg.lora_dropout,
        target_modules="all-linear",
        init_lora_weights="gaussian",
    )
    vla = get_peft_model(vla, lora_config)
    vla.print_trainable_parameters()

# # Wrap VLA in PyTorch DDP Wrapper for Multi-GPU Training
# vla = DDP(vla, device_ids=[device_id], find_unused_parameters=True, gradient_as_bucket_view=True)

# Create Optimizer =>> note that we default to a simple constant learning rate!
trainable_params = [param for param in vla.parameters() if param.requires_grad]
optimizer = AdamW(trainable_params, lr=cfg.learning_rate)


trainable params: 24,400,992 || all params: 7,212,525,600 || trainable%: 0.3383


In [5]:
# Create Action Tokenizer
action_tokenizer = RLbenchPoseTokenizer(processor.tokenizer, cfg.dataset_statistics)

trainset = RLbenchCotDataset(
    cfg.train_data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder,
)

testset = RLbenchCotDataset(
    cfg.test_data_path,
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder,
)

# Create Collator and DataLoader
collator = PaddedCollatorForPosePrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)

train_dataloader = DataLoader(
    trainset,
    batch_size=cfg.batch_size,
    sampler=None,
    collate_fn=collator,
    num_workers=1,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)
test_dataloader = DataLoader(
    testset,
    batch_size=2,
    sampler=None,
    collate_fn=collator,
    num_workers=1,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)



In [6]:
vla.train()
vla.gradient_checkpointing_enable()
optimizer.zero_grad()

step_idx, batch = next(enumerate(train_dataloader))

with torch.autocast("cuda", dtype=torch.bfloat16):
    vla.train()
    output: CausalLMOutputWithPast = vla(
        input_ids=batch["input_ids"].to(device_id),
        attention_mask=batch["attention_mask"].to(device_id),
        pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
        labels=batch["labels"],
        use_cache=False
    )
    train_nll_loss = output.loss


In [7]:
processor.tokenizer.decode(batch["input_ids"][0])

"<s> You are an assistant helping to control a robotic manipulator. The robot performs tasks by following a series of steps to interact with objects in its environment. The environment includes items like soup cans and baskets, and the robot uses a gripper to pick up and move these items.\n\nInstructions format:\n- 'USER': Describes the task to be performed.\n- 'ASSISTANT': Provides a detailed step-by-step plan for the robot to execute the task.\n\nThe 'ASSISTANT' response includes:\n1. A logical step-by-step plan for the task.\n2. The current positions of relevant objects and the gripper.\n3. The current state of the gripper (whether it has grasped the object or not).\n4. The next key pose of the gripper to achieve the task.\n\nExample:\n\nUSER: What action should the robot take to pick up the soup and place it in the basket?\nASSISTANT: Let's think step by step. The plan is to move the gripper to the soup and pick it up, then move over the basket, and then place the soup in the baske

In [8]:
output_logits = output.logits[:, vla.vision_backbone.featurizer.patch_embed.num_patches:-1]
output_gt = batch["labels"][:, 1:].to(device_id)
action_mask, gripper_mask, object_mask, target_mask = action_tokenizer.get_mask(output_gt)

action_logits = output_logits[action_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)
gripper_logits = output_logits[gripper_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)
object_logits = output_logits[object_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)
target_logits = output_logits[target_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)

gt_object = batch['target_item_poses'].to(device_id)
gt_target = batch['basket_positions'].to(device_id)
gt_gripper = batch['gripper_poses'].to(device_id)
gt_action = batch['actions'].to(device_id)

loss_dict = action_tokenizer.get_loss(action_logits=action_logits, gripper_logits=gripper_logits, object_logits=object_logits, target_logits=target_logits, gt_action=gt_action, gt_gripper=gt_gripper, gt_object=gt_object, gt_target=gt_target)


In [9]:
output_gt[0]

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 

## Test

In [10]:
step_idx, batch = next(enumerate(test_dataloader))

test_nll_loss = []

with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    output: CausalLMOutputWithPast = vla(
        input_ids=batch["input_ids"].to(device_id),
        attention_mask=batch["attention_mask"].to(device_id),
        pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
        labels=batch["labels"],
    )
    test_nll_loss_ = output.loss
    test_nll_loss.append(test_nll_loss_)

    output_logits = output.logits[:, vla.vision_backbone.featurizer.patch_embed.num_patches:-1]
    output_gt = batch["labels"][:, 1:].to(device_id)
    action_mask, gripper_mask, object_mask, target_mask = action_tokenizer.get_mask(output_gt)

    action_logits = output_logits[action_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)
    gripper_logits = output_logits[gripper_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)
    object_logits = output_logits[object_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)
    target_logits = output_logits[target_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(cfg.batch_size,-1,action_tokenizer.n_bins)

    gt_object = batch['target_item_poses'].to(device_id)
    gt_target = batch['basket_positions'].to(device_id)
    gt_gripper = batch['gripper_poses'].to(device_id)
    gt_action = batch['actions'].to(device_id)
    
    loss_dict = action_tokenizer.get_loss(action_logits=action_logits, gripper_logits=gripper_logits, object_logits=object_logits, target_logits=target_logits, gt_action=gt_action, gt_gripper=gt_gripper, gt_object=gt_object, gt_target=gt_target)


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


In [12]:
output_gt[0]

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 

In [25]:
action_logits.shape

torch.Size([2, 7, 602])

In [None]:
# vla.eval()
# best_test_loss = float("inf")
# test_loss = 0
# test_nll_loss = []
# test_object_position_loss = []
# test_object_orientation_loss = []
# test_target_position_loss = []
# test_gripper_position_loss = []
# test_gripper_orientation_loss = []
# test_gripper_open_loss = []
# test_action_position_loss = []
# test_action_orientation_loss = []
# test_action_open_loss = []
# for batch in test_dataloader:
#     with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
#         output: CausalLMOutputWithPast = vla(
#             input_ids=batch["input_ids"].to(device_id),
#             attention_mask=batch["attention_mask"].to(device_id),
#             pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
#             labels=batch["labels"],
#         )
#         test_nll_loss_ = output.loss
#         test_nll_loss.append(test_nll_loss_)

#         output_logits = output.logits[:, vla.vision_backbone.featurizer.patch_embed.num_patches:-1]
#         output_gt = batch["labels"][:, 1:].to(device_id)
#         action_mask, gripper_mask, object_mask, target_mask = action_tokenizer.get_mask(output_gt)

#         action_logits = output_logits[action_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(4,-1,action_tokenizer.n_bins)
#         gripper_logits = output_logits[gripper_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(4,-1,action_tokenizer.n_bins)
#         object_logits = output_logits[object_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(4,-1,action_tokenizer.n_bins)
#         target_logits = output_logits[target_mask][:,action_tokenizer.action_token_begin_idx:processor.tokenizer.vocab_size].view(4,-1,action_tokenizer.n_bins)

#         # object
#         pred_object = action_tokenizer.decode(object_logits, soft = True)
#         gt_object = batch['target_item_poses'].to(device_id)
#         assert pred_object.shape == gt_object.shape, f"Object shape {pred_object.shape} != {gt_object.shape}"
#         object_position_loss = F.mse_loss(pred_object[:,:3], gt_object[:,:3])
#         object_orientation_loss = F.mse_loss(pred_object[:,3:], gt_object[:,3:])
#         test_object_position_loss.append(object_position_loss)
#         test_object_orientation_loss.append(object_orientation_loss)

#         # target
#         pred_target = action_tokenizer.decode(target_logits, soft = True)
#         gt_target = batch['basket_positions'].to(device_id)
#         assert pred_target.shape == gt_target.shape, f"Target shape {pred_target.shape} != {gt_target.shape}"
#         target_position_loss = F.mse_loss(pred_target[:,:3], gt_target[:,:3])
#         test_target_position_loss.append(target_position_loss)

#         # gripper
#         pred_gripper = action_tokenizer.decode(gripper_logits, soft = True)
#         gt_gripper = batch['gripper_poses'].to(device_id)
#         assert pred_gripper.shape == gt_gripper.shape, f"Gripper shape {pred_gripper.shape} != {gt_gripper.shape}"
#         gripper_position_loss = F.mse_loss(pred_gripper[:,:3], gt_gripper[:,:3])
#         gripper_orientation_loss = F.mse_loss(pred_gripper[:,3:6], gt_gripper[:,3:6])
#         gripper_open_loss = F.mse_loss(pred_gripper[:,6], gt_gripper[:,6])
#         test_gripper_position_loss.append(gripper_position_loss)
#         test_gripper_orientation_loss.append(gripper_orientation_loss)
#         test_gripper_open_loss.append(gripper_open_loss)

#         #action
#         pred_action = action_tokenizer.decode(action_logits, soft = True)
#         gt_action = batch['actions'].to(device_id)
#         assert pred_action.shape == gt_action.shape, f"Action shape {pred_action.shape} != {gt_action.shape}"
#         action_position_loss = F.mse_loss(pred_action[:,:3], gt_action[:,:3])
#         action_orientation_loss = F.mse_loss(pred_action[:,3:6], gt_action[:,3:6])
#         action_open_loss = F.mse_loss(pred_action[:,6], gt_action[:,6])
#         test_action_position_loss.append(action_position_loss)
#         test_action_orientation_loss.append(action_orientation_loss)
#         test_action_open_loss.append(action_open_loss)

#         test_nll_loss = torch.stack(test_nll_loss).mean()
#         test_object_position_loss = torch.stack(test_object_position_loss).mean()
#         test_object_orientation_loss = torch.stack(test_object_orientation_loss).mean()
#         test_target_position_loss = torch.stack(test_target_position_loss).mean()
#         test_gripper_position_loss = torch.stack(test_gripper_position_loss).mean()
#         test_gripper_orientation_loss = torch.stack(test_gripper_orientation_loss).mean()
#         test_gripper_open_loss = torch.stack(test_gripper_open_loss).mean()
#         test_action_position_loss = torch.stack(test_action_position_loss).mean()
#         test_action_orientation_loss = torch.stack(test_action_orientation_loss).mean()
#         test_action_open_loss = torch.stack(test_action_open_loss).mean()
        