In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["HF_HUB_CACHE"] = "/mnt/sda/home/zijianwang/HF_CACHE"
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import draccus
import torch
import torch.distributed as dist
import tqdm
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers import AutoConfig, AutoImageProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast

import wandb
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset, EpisodicRLDSDataset
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics

from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm
2025-07-13 20:19:38.567015: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-13 20:19:38.567060: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-13 20:19:38.568839: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-13 20:19:38.577800: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler fl

In [3]:
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)

In [4]:
vla_model_config = OpenVLAConfig.from_pretrained("openvla/openvla-7b")
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
action_tokenizer = ActionTokenizer(processor.tokenizer)

# Add action unit separate token

In [5]:
print("len(tokenizer):", len(processor.tokenizer))
DEFAULT_ACT_TOKEN = "<A>"
num_added_toks = processor.tokenizer.add_tokens(DEFAULT_ACT_TOKEN)
print("len(tokenizer):", len(processor.tokenizer))
print("id of <A>:", processor.tokenizer.convert_tokens_to_ids("<A>"))

len(tokenizer): 32001
len(tokenizer): 32002
id of <A>: 32001


In [12]:
print("End of sequence token:", processor.tokenizer.eos_token)
print("End of sequence token id:", processor.tokenizer.eos_token_id)
print("All special tokens:", processor.tokenizer.special_tokens_map)

End of sequence token: </s>
End of sequence token id: 2
All special tokens: {'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<PAD>'}


In [6]:
vla = AutoModelForVision2Seq.from_pretrained(
    "/mnt/sda/home/zijianwang/openvla/FT_res/openvla-7b+libero_goal_no_noops+b4+lr-0.0005+lora-r24+dropout-0.0--image_aug--2025-07-08_15-39-30", 
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True, 
    trust_remote_code=True
).to("cuda:3")

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  4.01it/s]


In [7]:
print(vla.norm_stats.keys())

dict_keys(['austin_buds_dataset_converted_externally_to_rlds', 'austin_sailor_dataset_converted_externally_to_rlds', 'austin_sirius_dataset_converted_externally_to_rlds', 'bc_z', 'berkeley_autolab_ur5', 'berkeley_cable_routing', 'berkeley_fanuc_manipulation', 'bridge_orig', 'cmu_stretch', 'dlr_edan_shared_control_converted_externally_to_rlds', 'dobbe', 'fmb_dataset', 'fractal20220817_data', 'furniture_bench_dataset_converted_externally_to_rlds', 'iamlab_cmu_pickup_insert_converted_externally_to_rlds', 'jaco_play', 'kuka', 'nyu_franka_play_dataset_converted_externally_to_rlds', 'roboturk', 'stanford_hydra_dataset_converted_externally_to_rlds', 'taco_play', 'toto', 'ucsd_kitchen_dataset_converted_externally_to_rlds', 'utaustin_mutex', 'viola'])


In [None]:
vla.resize_token_embeddings()

In [8]:
batch_transform = RLDSBatchTransform(
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder if "v01" not in "openvla/openvla-7b" else VicunaV15ChatPromptBuilder,
)


task_name = "LIBERO-Long" #LIBERO-Object, LIBERO-Goal, LIBERO-Long, LIBERO-Spatial"

episodic_vla_dataset = EpisodicRLDSDataset(
    "/mnt/sda/home/zijianwang/openvla/modified_libero_rlds",
    "libero_goal_no_noops",
    batch_transform,
    resize_resolution=tuple(vla_model_config.image_sizes),
    shuffle_buffer_size=100_000,
    image_aug=True,
)

2025-07-13 19:48:53.136434: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


2025-07-13 19:48:53.682092: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


In [None]:
sum_length = 0
for episode in episodic_vla_dataset:
    length = episode['length']
    print(length)
    sum_length += length
print(sum_length)

In [9]:
from typing import Dict, Optional, Sequence, List
from torch.nn.utils.rnn import pad_sequence
IGNORE_INDEX = -100

@dataclass
class DataCollatorForCoASupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""
    model_max_length: int
    pad_token_id: int
    padding_side: str = "right"
    pixel_values_dtype: torch.dtype = torch.float32
    
    # episode is a list, each element is a dict. keys are 'pixel_values', 'input_ids', 'labels', 'dataset_name'. 
    # length of input_ids and labels are the same, which is len(text) + 7 + eos token.

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        pixel_values = [instance["pixel_values"] for instance in instances]
        lengths = [instance["length"] for instance in instances]
        if "dataset_name" in instances[0]:
            dataset_names = [instance["dataset_name"] for instance in instances]
        else:
            dataset_names = None

        # For now, we only support Tokenizers with `padding_side = "right"` during training
        #   => Handle padding via RNN Utils => `pad_sequence`
        assert self.padding_side == "right", f"Invalid Tokenizer `{self.padding_side = }`"
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)

        # Truncate (if necessary)
        input_ids, labels = input_ids[:, : self.model_max_length], labels[:, : self.model_max_length]

        # Get `attention_mask` by checking for `pad_token_id`
        attention_mask = input_ids.ne(self.pad_token_id)

        # [Contract] For VLA Training =>> No "Unimodal" Data!
        assert all([pv is not None for pv in pixel_values]), "Invalid VLA Example with `pixel_values = None`!"

        # Stack all `pixel_values` --> depending on type is torch.Tensor or Dict[str, torch.Tensor]
        if isinstance(pixel_values[0], torch.Tensor):
            pixel_values = torch.stack(pixel_values)
        elif isinstance(pixel_values[0], dict):
            pixel_values = {
                k: torch.stack([pixel_values[idx][k] for idx in range(len(input_ids))]) for k in pixel_values[0]
            }
        else:
            raise ValueError(f"Unsupported `pixel_values` type = {type(pixel_values)}")

        output = dict(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            lengths = lengths
        )
        if dataset_names is not None:
            output["dataset_names"] = dataset_names
        return output

In [10]:
collator = DataCollatorForCoASupervisedDataset(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
dataloader = DataLoader(
    episodic_vla_dataset,
    collate_fn=collator,
    sampler=None,
    batch_size=2,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)
print(f"Length of dataloader: {len(dataloader)}")

Length of dataloader: 214


In [None]:
with tqdm.tqdm(total=20000, leave=False) as progress:
    for batch_idx, batch in enumerate(dataloader):
        progress.update()

In [11]:
while True:  # Infinite loop to keep reading data
    for index, batch in tqdm.tqdm(enumerate(dataloader)):
        lengths = batch['lengths']
        print(index)

W0000 00:00:1752400158.194683 2253589 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 224 } dim { size: 224 } dim { size: -7 } } } inputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -2 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } } device { type: "CPU" vendor: "GenuineIntel" model: "101" frequency: 2500 num_cores: 80 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 1048576 l3_cache_size: 28835840 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: -8 } dim { size: -9 } dim { size: -7 } } }
W0000 00:00:1752400158.19

cutting episode at index half 39


1it [00:03,  3.55s/it]

cutting episode at index half 37
0
cutting episode at index half 101


2it [00:04,  1.78s/it]

cutting episode at index half 81
1
cutting episode at index half 35


3it [00:04,  1.30s/it]

cutting episode at index half 128
2
cutting episode at index half 1


4it [00:05,  1.03it/s]

cutting episode at index half 2
3
cutting episode at index half 11


5it [00:05,  1.22it/s]

cutting episode at index half 77
4
cutting episode at index half 32


6it [00:06,  1.40it/s]

cutting episode at index half 58
5
cutting episode at index half 80


7it [00:07,  1.43it/s]

cutting episode at index half 58
6
cutting episode at index half 140


8it [00:07,  1.41it/s]

cutting episode at index half 90
7
cutting episode at index half 23


9it [00:08,  1.57it/s]

cutting episode at index half 31
8
cutting episode at index half 162


10it [00:08,  1.54it/s]

cutting episode at index half 35
9
cutting episode at index half 26


10it [00:09,  1.04it/s]


KeyboardInterrupt: 

In [None]:
print(batch['pixel_values'].shape)
print(batch['input_ids'].shape)
print(batch['labels'].shape)
print(batch['dataset_names'])
print(batch['lengths'])

In [None]:
print(batch['input_ids'][0])

print(processor.tokenizer.batch_decode(batch['input_ids']))

In [None]:
import random
# random.seed(42)  # Fix random seed
start_index = random.randint(0, 110)
print(start_index)