In [1]:
import os
from dataclasses import dataclass
from pathlib import Path

import draccus
import torch
import torch.distributed as dist
import tqdm
import wandb
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"


2024-07-08 22:36:48.032423: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-08 22:36:48.032540: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-08 22:36:48.090694: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-08 22:36:48.199996: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-08 22:36:52.604417: I external/local_xla/xla/

In [2]:
@dataclass
class FinetuneConfig:
    # fmt: off
    vla_path: str = "/media/lawrence/Work/checkpoints/openvla-7b"   # Path to OpenVLA model 

    # Directory Paths
    data_root_dir: Path = Path("./datasets")        # Path to Open-X dataset directory
    dataset_name: str = "imperialcollege_sawyer_wrist_cam"                                # Name of fine-tuning dataset (e.g., `droid_wipe`)
    run_root_dir: Path = Path("./runs")                               # Path to directory to store logs & checkpoints
    adapter_tmp_dir: Path = Path("./adapter-tmp")                     # Temporary directory for LoRA weights before fusing

    # Fine-tuning Parameters
    batch_size: int = 2#16                                            # Fine-tuning batch size
    max_steps: int = 200#200_000                                        # Max number of fine-tuning steps
    save_steps: int = 5#5000                                          # Interval for checkpoint saving
    learning_rate: float = 2e-5                                     # Fine-tuning learning rate
    grad_accumulation_steps: int = 1                                # Gradient accumulation steps
    image_aug: bool = True                                          # Whether to train with image augmentations
    shuffle_buffer_size: int = 100#100_000                              # Dataloader shuffle buffer size (can reduce if OOM)

    # LoRA Arguments
    use_lora: bool = True                                           # Whether to use LoRA fine-tuning
    lora_rank: int = 32                                             # Rank of LoRA weight matrix
    lora_dropout: float = 0.0                                       # Dropout applied to LoRA weights
    use_quantization: bool = True                                  # Whether to 4-bit quantize VLA for LoRA fine-tuning
                                                                    #   => CAUTION: Reduces memory but hurts performance

    # Tracking Parameters
    wandb_project: str = "openvla"                                  # Name of W&B project to log to (use default!)
    wandb_entity: str = "lawrence-rs-lin"                          # Name of entity to log under

    # fmt: on


In [3]:
cfg = FinetuneConfig
print(f"Fine-tuning OpenVLA Model `{cfg.vla_path}` on `{cfg.dataset_name}`")

# [Validate] Ensure GPU Available & Set Device / Distributed Context
assert torch.cuda.is_available(), "Fine-tuning assumes at least one GPU is available!"
distributed_state = PartialState()
torch.cuda.set_device(device_id := distributed_state.local_process_index)
torch.cuda.empty_cache()

Fine-tuning OpenVLA Model `/media/lawrence/Work/checkpoints/openvla-7b` on `imperialcollege_sawyer_wrist_cam`


In [4]:
# Configure Unique Experiment ID & Log Directory
exp_id = (
    f"{cfg.vla_path.split('/')[-1]}+{cfg.dataset_name}"
    f"+b{cfg.batch_size * cfg.grad_accumulation_steps}"
    f"+lr-{cfg.learning_rate}"
)
if cfg.use_lora:
    exp_id += f"+lora-r{cfg.lora_rank}+dropout-{cfg.lora_dropout}"
if cfg.use_quantization:
    exp_id += "+q-4bit"

In [5]:
# Start =>> Build Directories
run_dir, adapter_dir = cfg.run_root_dir / exp_id, cfg.adapter_tmp_dir / exp_id
os.makedirs(run_dir, exist_ok=True)

In [6]:
# Quantization Config =>> only if LoRA fine-tuning
quantization_config = None
if cfg.use_quantization:
    assert cfg.use_lora, "Quantized training only supported for LoRA fine-tuning!"
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
    )

In [7]:
# Load OpenVLA Processor and Model using HF AutoClasses
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    cfg.vla_path,
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
# Device Placement =>> note that BitsAndBytes automatically handles for quantized training
if cfg.use_quantization:
    vla = prepare_model_for_kbit_training(vla)
else:
    vla = vla.to(device_id)

In [9]:
# [LoRA] Wrap Model w/ PEFT `LoraConfig` =>> by default we set `target_modules=all-linear`
if cfg.use_lora:
    lora_config = LoraConfig(
        r=cfg.lora_rank,
        lora_alpha=min(cfg.lora_rank, 16),
        lora_dropout=cfg.lora_dropout,
        target_modules="all-linear",
        init_lora_weights="gaussian",
    )
    vla = get_peft_model(vla, lora_config)
    vla.print_trainable_parameters()

trainable params: 110,828,288 || all params: 7,652,065,472 || trainable%: 1.4483


In [10]:
# Create Optimizer =>> note that we default to a simple constant learning rate!
trainable_params = [param for param in vla.parameters() if param.requires_grad]
optimizer = AdamW(trainable_params, lr=cfg.learning_rate)


In [11]:
# Create Action Tokenizer
action_tokenizer = ActionTokenizer(processor.tokenizer)


In [12]:
# Load Fine-tuning Dataset =>> note that we use an RLDS-formatted dataset following Open X-Embodiment by default.
#   =>> If you want to use a non-RLDS dataset (e.g., a standard PyTorch Dataset) see the following commented block.
#   =>> Note that our training code does not loop over epochs because the RLDS loader does this implicitly; if using
#       your own Dataset, make sure to add the appropriate logic to the training loop!
#
# ---
# from prismatic.vla.datasets import DummyDataset
#
# vla_dataset = DummyDataset(
#     action_tokenizer,
#     processor.tokenizer,
#     image_transform=processor.image_processor.apply_transform,
#     prompt_builder_fn=PurePromptBuilder if "v01" not in cfg.vla_path else VicunaV15ChatPromptBuilder,
# )
# ---
batch_transform = RLDSBatchTransform(
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder if "v01" not in cfg.vla_path else VicunaV15ChatPromptBuilder,
)

vla_dataset = RLDSDataset(
    cfg.data_root_dir,
    cfg.dataset_name,
    batch_transform,
    # resize_resolution=(vla.module.config.image_size, vla.module.config.image_size),
    resize_resolution = (224, 224),
    shuffle_buffer_size=cfg.shuffle_buffer_size,
    image_aug=cfg.image_aug,
)

2024-07-08 22:37:44.401186: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
2024-07-08 22:37:46.057879: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


2024-07-08 22:37:46.533527: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization



######################################################################################
# Loading the following 1 datasets (incl. sampling weight):                         #
######################################################################################



2024-07-08 22:37:47.967546: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


In [13]:
# [Important] Save Dataset Statistics =>> used to de-normalize actions for inference!
if distributed_state.is_main_process:
    save_dataset_statistics(vla_dataset.dataset_statistics, run_dir)


In [14]:
# Create Collator and DataLoader
collator = PaddedCollatorForActionPrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
dataloader = DataLoader(
    vla_dataset,
    batch_size=cfg.batch_size,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)

In [15]:
vla.train()
optimizer.zero_grad()

In [16]:
step_idx, batch = next(enumerate(dataloader))

W0000 00:00:1720492670.370368   50945 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 224 } dim { size: 224 } dim { size: -7 } } } inputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -2 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } } device { type: "CPU" vendor: "AuthenticAMD" model: "241" frequency: 3592 num_cores: 16 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 524288 l3_cache_size: 16777216 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: -8 } dim { size: -9 } dim { size: -7 } } }
W0000 00:00:1720492670.370

In [30]:
batch["labels"].shape

torch.Size([2, 27])

In [17]:
with torch.autocast("cuda", dtype=torch.bfloat16):
    output: CausalLMOutputWithPast = vla(
        input_ids=batch["input_ids"].to(device_id),
        attention_mask=batch["attention_mask"].to(device_id),
        pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
        labels=batch["labels"],
    )
    loss = output.loss

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


In [18]:
# Backward!
loss.backward()

In [31]:
vla.vision_backbone.featurizer.patch_embed.num_patches

256

In [32]:
batch["labels"][:, 1:]

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 31958, 31898,
         31819, 31857, 31852, 31876, 31744,     2],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100, 31869, 31895, 31840,
         31855, 31916, 31821, 31872,     2,  -100]])

In [19]:
# Compute Accuracy and L1 Loss for Logging
# action_logits = output.logits[:, vla.module.vision_backbone.featurizer.patch_embed.num_patches : -1]
action_logits = output.logits[:, vla.vision_backbone.featurizer.patch_embed.num_patches : -1]
action_preds = action_logits.argmax(dim=2)
action_gt = batch["labels"][:, 1:].to(action_preds.device)
mask = action_gt > action_tokenizer.action_token_begin_idx


In [34]:
action_gt

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 31958, 31898,
         31819, 31857, 31852, 31876, 31744,     2],
        [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100, 31869, 31895, 31840,
         31855, 31916, 31821, 31872,     2,  -100]], device='cuda:0')

In [33]:
mask

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False,  True,  True,
          True,  True,  True,  True,  True, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False,  True,  True,  True,
          True,  True,  True,  True, False, False]], device='cuda:0')

In [26]:
action_logits.shape

torch.Size([2, 26, 32064])

In [20]:
# Compute Accuracy
correct_preds = (action_preds == action_gt) & mask
action_accuracy = correct_preds.sum().float() / mask.sum().float()


In [21]:
# Compute L1 Loss on Predicted (Continuous) Actions
continuous_actions_pred = torch.tensor(
    action_tokenizer.decode_token_ids_to_actions(action_preds[mask].cpu().numpy())
)
continuous_actions_gt = torch.tensor(
    action_tokenizer.decode_token_ids_to_actions(action_gt[mask].cpu().numpy())
)
action_l1_loss = torch.nn.functional.l1_loss(continuous_actions_pred, continuous_actions_gt)


In [22]:
# Optimizer Step
if (step_idx + 1) % cfg.grad_accumulation_steps == 0:
    optimizer.step()
    # progress.update()
