In [1]:
"""
finetune.py

Simple script for parameter-efficient fine-tuning of OpenVLA models loaded through the HuggingFace AutoClasses, using
HuggingFace PEFT library for low-rank adaptation (LoRA).

Notes & Benchmarks:
    - Requires PEFT (`pip install peft==0.11.1`)
    - LoRA fine-tuning (see parameters below -- no quantization, LoRA rank = 32, target_modules = all-linear):
        + One 48 GB GPU can fit a Batch Size of 12
        + One 80 GB GPU can fit a Batch Size of 24

Run with:
    - [Single Node Multi-GPU (= $K) ]: torchrun --standalone --nnodes 1 --nproc-per-node $K vla-scripts/finetune.py
    - [Override Config Values]: torchrun --standalone --nnodes 1 --nproc-per-node $K vla-scripts/finetune.py \
                                    --data_root_dir <PATH/TO/RLDS/DATASETS/DIRECTORY> \
                                    --dataset_name <DATASET_NAME> \
                                    --run_root_dir <PATH/TO/LOGS/DIR> \
                                    ...
"""

import os
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import draccus
import torch
import torch.distributed as dist
import tqdm
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers import AutoConfig, AutoImageProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast

import wandb
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics

from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"



  from .autonotebook import tqdm as notebook_tqdm
2025-05-20 15:18:50.320479: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-20 15:18:50.320573: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-20 15:18:50.321782: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-20 15:18:50.328471: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:



@dataclass
class FinetuneConfig:
    # fmt: off
    vla_path: str = "/home/chuangzhi/zhq/yjc/runs/3-2+b2+lr-4e-06+lora-r32-steps10000"                            # Path to OpenVLA model (on HuggingFace Hub)

    # Directory Paths
    data_root_dir: Path = Path("/home/chuangzhi/zhq/yjc/mydata_tensorflow_datasets")        # Path to Open-X dataset directory
    dataset_name: str = "example_dataset"           # Name of fine-tuning dataset (e.g., `droid_wipe`)
    # data_root_dir = '/home/chuangzhi/zhq/yjc/modified_libero_rlds'
    # dataset_name = 'libero_spatial_no_noops'
    run_root_dir: Path = Path("runs")                               # Path to directory to store logs & checkpoints
    adapter_tmp_dir: Path = Path("adapter-tmp")                     # Temporary directory for LoRA weights before fusing

    # Fine-tuning Parameters
    batch_size: int = 1                                            # Fine-tuning batch size
    max_steps: int = 200_000                                        # Max number of fine-tuning steps
    save_steps: int = 5000                                          # Interval for checkpoint saving
    learning_rate: float = 5e-4                                     # Fine-tuning learning rate
    grad_accumulation_steps: int = 2                                # Gradient accumulation steps
    image_aug: bool = True                                          # Whether to train with image augmentations
    shuffle_buffer_size: int = 100_000                              # Dataloader shuffle buffer size (can reduce if OOM)
    save_latest_checkpoint_only: bool = True                        # Whether to save only one checkpoint per run and
                                                                    #   continually overwrite the latest checkpoint
                                                                    #   (If False, saves all checkpoints)

    # LoRA Arguments
    use_lora: bool = True                                           # Whether to use LoRA fine-tuning
    lora_rank: int = 32                                             # Rank of LoRA weight matrix
    lora_dropout: float = 0.0                                       # Dropout applied to LoRA weights
    use_quantization: bool = False                                  # Whether to 4-bit quantize VLA for LoRA fine-tuning
                                                                    #   => CAUTION: Reduces memory but hurts performance

    # Tracking Parameters
    wandb_project: str = "123"                                  # Name of W&B project to log to (use default!)
    wandb_entity: str = "1275259847-tianjin-university"                          # Name of entity to log under
    run_id_note: Optional[str] = None                               # Extra note for logging, Weights & Biases

    # fmt: on


In [3]:
cfg = FinetuneConfig()
cfg

FinetuneConfig(vla_path='/home/chuangzhi/zhq/yjc/runs/3-2+b2+lr-4e-06+lora-r32-steps10000', data_root_dir=PosixPath('/home/chuangzhi/zhq/yjc/mydata_tensorflow_datasets'), dataset_name='example_dataset', run_root_dir=PosixPath('runs'), adapter_tmp_dir=PosixPath('adapter-tmp'), batch_size=1, max_steps=200000, save_steps=5000, learning_rate=0.0005, grad_accumulation_steps=2, image_aug=True, shuffle_buffer_size=100000, save_latest_checkpoint_only=True, use_lora=True, lora_rank=32, lora_dropout=0.0, use_quantization=False, wandb_project='123', wandb_entity='1275259847-tianjin-university', run_id_note=None)

In [4]:

distributed_state = PartialState()
torch.cuda.set_device(device_id := distributed_state.local_process_index)
torch.cuda.empty_cache()

# Quantization Config =>> only if LoRA fine-tuning
quantization_config = None

# Register OpenVLA model to HF Auto Classes (not needed if the model is on HF Hub)
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)

# Load OpenVLA Processor and Model using HF AutoClasses
processor = AutoProcessor.from_pretrained(cfg.vla_path, trust_remote_code=True)


# Create Action Tokenizer
action_tokenizer = ActionTokenizer(processor.tokenizer)

batch_transform = RLDSBatchTransform(
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder if "v01" not in cfg.vla_path else VicunaV15ChatPromptBuilder,
)
vla = AutoModelForVision2Seq.from_pretrained(
    cfg.vla_path,
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
)

vla = vla.to(device_id)
vla_dataset = RLDSDataset(
    cfg.data_root_dir,
    cfg.dataset_name,
    batch_transform,
    resize_resolution=tuple(vla.config.image_sizes), #tuple(vla.module.config.image_sizes)
    shuffle_buffer_size=cfg.shuffle_buffer_size,
    image_aug=cfg.image_aug,
)







Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  4.90it/s]
2025-05-20 12:35:00.480938: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


2025-05-20 12:35:00.920937: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization



######################################################################################
# Loading the following 1 datasets (incl. sampling weight):                         #
######################################################################################



2025-05-20 12:35:01.321489: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


In [5]:

# Create Collator and DataLoader
collator = PaddedCollatorForActionPrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
dataloader = DataLoader(
    vla_dataset,
    batch_size=cfg.batch_size,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)







In [6]:
first_batch = next(iter(dataloader))

W0000 00:00:1747715707.168457 2529976 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 224 } dim { size: 224 } dim { size: -7 } } } inputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -2 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } } device { type: "CPU" vendor: "AuthenticAMD" model: "241" frequency: 2800 num_cores: 128 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 524288 l3_cache_size: 268435456 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: -8 } dim { size: -9 } dim { size: -7 } } }
W0000 00:00:1747715707.1

In [7]:
first_batch.keys()

dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'labels', 'dataset_names'])

In [8]:
first_batch['input_ids']

tensor([[    1,   512, 29901,  1724,  3158,   881,   278, 19964,  2125,   304,
          4337,   278, 13748,   292, 12917,   304,   278, 25972, 29889, 29973,
            13,  3744, 29901, 31988, 31842, 31872, 31903, 31853, 31830, 31872,
             2]])

In [10]:
first_batch['input_ids']

tensor([[    1,   512, 29901,  1724,  3158,   881,   278, 19964,  2125,   304,
          4337,   278, 13748,   292, 12917,   304,   278, 25972, 29889, 29973,
            13,  3744, 29901, 29871, 31872, 31899, 31872, 31843, 31837, 31824,
         31744,     2]])

In [10]:
first_batch['input_ids'][0][:-8]

tensor([    1,   512, 29901,  1724,  3158,   881,   278, 19964,  2125,   304,
         4337,   278, 13748,   292, 12917,   304,   278, 25972, 29889, 29973,
           13,  3744, 29901])

In [30]:
processor.decode(first_batch['input_ids'][0][:-8].cpu().numpy())

'<s> In: What action should the robot take to move the drinking glass to the basket.?\nOut: '

In [10]:
first_batch['pixel_values'].shape

torch.Size([1, 6, 224, 224])

In [11]:
first_batch['attention_mask']

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True]])

In [12]:
first_batch["labels"]

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100, 31872, 31922, 31872, 31801, 31859, 31851,
         31872,     2]])

In [13]:
processor

PrismaticProcessor:
- image_processor: PrismaticImageProcessor {
  "auto_map": {
    "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
    "AutoProcessor": "processing_prismatic.PrismaticProcessor"
  },
  "image_processor_type": "PrismaticImageProcessor",
  "image_resize_strategy": "resize-naive",
  "input_sizes": [
    [
      3,
      224,
      224
    ],
    [
      3,
      224,
      224
    ]
  ],
  "interpolations": [
    "bicubic",
    "bicubic"
  ],
  "means": [
    [
      0.485,
      0.456,
      0.406
    ],
    [
      0.5,
      0.5,
      0.5
    ]
  ],
  "processor_class": "PrismaticProcessor",
  "stds": [
    [
      0.229,
      0.224,
      0.225
    ],
    [
      0.5,
      0.5,
      0.5
    ]
  ],
  "tvf_crop_params": [
    {
      "output_size": [
        224,
        224
      ]
    },
    {
      "output_size": [
        224,
        224
      ]
    }
  ],
  "tvf_do_letterbox": false,
  "tvf_letterbox_fill": null,
  "tvf_normalize_params"

In [14]:
processor.image_processor

PrismaticImageProcessor {
  "auto_map": {
    "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor",
    "AutoProcessor": "processing_prismatic.PrismaticProcessor"
  },
  "image_processor_type": "PrismaticImageProcessor",
  "image_resize_strategy": "resize-naive",
  "input_sizes": [
    [
      3,
      224,
      224
    ],
    [
      3,
      224,
      224
    ]
  ],
  "interpolations": [
    "bicubic",
    "bicubic"
  ],
  "means": [
    [
      0.485,
      0.456,
      0.406
    ],
    [
      0.5,
      0.5,
      0.5
    ]
  ],
  "processor_class": "PrismaticProcessor",
  "stds": [
    [
      0.229,
      0.224,
      0.225
    ],
    [
      0.5,
      0.5,
      0.5
    ]
  ],
  "tvf_crop_params": [
    {
      "output_size": [
        224,
        224
      ]
    },
    {
      "output_size": [
        224,
        224
      ]
    }
  ],
  "tvf_do_letterbox": false,
  "tvf_letterbox_fill": null,
  "tvf_normalize_params": [
    {
      "inplace": false,
     

In [15]:
action_tokenizer.decode_token_ids_to_actions

<bound method ActionTokenizer.decode_token_ids_to_actions of <prismatic.vla.action_tokenizer.ActionTokenizer object at 0x7f911a5da2c0>>

In [16]:
action_tokenizer.action_token_begin_idx

31743

In [17]:
vla.predict_action

<bound method OpenVLAForActionPrediction.predict_action of OpenVLAForActionPrediction(
  (vision_backbone): PrismaticVisionBackbone(
    (featurizer): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (patch_drop): Identity()
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((

In [18]:
# vla.train()
# output: CausalLMOutputWithPast = vla(
#     input_ids=first_batch["input_ids"].to(device_id),
#     attention_mask=first_batch["attention_mask"].to(device_id),
#     pixel_values=first_batch["pixel_values"].to(torch.bfloat16).to(device_id),
#     labels=first_batch["labels"],
# )
# output

In [19]:
first_batch.keys()

dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'labels', 'dataset_names'])

In [20]:
prompt_builder_fn=PurePromptBuilder if "v01" not in cfg.vla_path else VicunaV15ChatPromptBuilder



# Construct Chat-based Prompt =>> Input is default query + language instruction, output are the action tokens
prompt_builder = prompt_builder_fn("openvla")

input_ids = processor.tokenizer(prompt_builder.get_prompt(), add_special_tokens=True).input_ids
input_ids



[1]