In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["HF_HUB_CACHE"] = "/mnt/sda/home/zijianwang/HF_CACHE"
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import draccus
import torch
import torch.distributed as dist
import tqdm
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers import AutoConfig, AutoImageProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast

import wandb
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset, EpisodicRLDSDataset
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics

from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm
2025-06-26 19:41:01.890857: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-26 19:41:01.891072: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-26 19:41:02.088490: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-26 19:41:02.611248: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler fl

In [3]:
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)

In [4]:
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)

In [6]:
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b", 
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True, 
    trust_remote_code=True
).to("cuda:0")

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  3.83it/s]


In [5]:
action_tokenizer = ActionTokenizer(processor.tokenizer)
vocab_size = action_tokenizer.vocab_size
print("词表大小:", vocab_size)

词表大小: 256


In [6]:
vla_model_config = OpenVLAConfig.from_pretrained("openvla/openvla-7b")

In [7]:
print(vla_model_config.image_sizes)

[224, 224]


In [14]:
batch_transform = RLDSBatchTransform(
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder if "v01" not in "openvla/openvla-7b" else VicunaV15ChatPromptBuilder,
)

vla_dataset = RLDSDataset(
    "/mnt/sda/home/zijianwang/openvla/modified_libero_rlds",
    "libero_goal_no_noops",
    batch_transform,
    resize_resolution=tuple(vla_model_config.image_sizes),
    shuffle_buffer_size=100_000,
    image_aug=True,
)

episodic_vla_dataset = EpisodicRLDSDataset(
    "/mnt/sda/home/zijianwang/openvla/modified_libero_rlds",
    "libero_goal_no_noops",
    batch_transform,
    resize_resolution=tuple(vla_model_config.image_sizes),
    shuffle_buffer_size=100_000,
    image_aug=True,
)

2025-06-26 20:05:33.916763: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


2025-06-26 20:05:34.091089: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization



######################################################################################
# Loading the following 1 datasets (incl. sampling weight):                         #
######################################################################################



2025-06-26 20:05:34.315863: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


2025-06-26 20:05:34.921486: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


2025-06-26 20:05:35.101424: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization


In [18]:
print(vla_dataset.dataset_statistics['libero_goal_no_noops'].keys())
print(episodic_vla_dataset.dataset_statistics['libero_goal_no_noops'].keys())

dict_keys(['action', 'proprio', 'num_transitions', 'num_trajectories'])
dict_keys(['action', 'proprio', 'num_transitions', 'num_trajectories'])


In [11]:
collator = PaddedCollatorForActionPrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
dataloader = DataLoader(
    vla_dataset,
    batch_size=100,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)
print(len(dataloader))

521


In [12]:
with tqdm.tqdm(total=20000, leave=False) as progress:
    for batch_idx, batch in enumerate(dataloader):
        progress.update()

  0%|          | 0/20000 [00:00<?, ?it/s]W0000 00:00:1750820233.214303 2045336 op_level_cost_estimator.cc:699] Error in PredictCost() for the op: op: "CropAndResize" attr { key: "T" value { type: DT_FLOAT } } attr { key: "extrapolation_value" value { f: 0 } } attr { key: "method" value { s: "bilinear" } } inputs { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 224 } dim { size: 224 } dim { size: -7 } } } inputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: 4 } } } inputs { dtype: DT_INT32 shape { dim { size: -2 } } } inputs { dtype: DT_INT32 shape { dim { size: 2 } } } device { type: "CPU" vendor: "GenuineIntel" model: "101" frequency: 2500 num_cores: 80 environment { key: "cpu_instruction_set" value: "AVX SSE, SSE2, SSE3, SSSE3, SSE4.1, SSE4.2" } environment { key: "eigen" value: "3.4.90" } l1_cache_size: 32768 l2_cache_size: 1048576 l3_cache_size: 28835840 memory_size: 268435456 } outputs { dtype: DT_FLOAT shape { dim { size: -2 } dim { size: -8 } dim { size: -9 } dim {

KeyboardInterrupt: 

In [26]:
batch = data
device_id = vla.device
output: CausalLMOutputWithPast = vla(
    input_ids=batch["input_ids"].to(device_id),
    # attention_mask=batch["attention_mask"].to(device_id),
    pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
    labels=batch["labels"],
)

inputs = {
    "input_ids": batch["input_ids"].to(device_id),
    "pixel_values": batch["pixel_values"].to(torch.bfloat16).to(device_id),
    "labels": batch["labels"],
}
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)

# action_ids = vla.predict_action_ids(**inputs, unnorm_key="bridge_orig", do_sample=False)

In [None]:
print(data["input_ids"][0].shape)
print(processor.tokenizer.decode(data["input_ids"][0]))

In [None]:
from PIL import Image
import numpy as np
import torch
import matplotlib.pyplot as plt

img = data["pixel_values"][0,0:3]  # Shape: [3, 224, 224]
print(img.shape)

# Convert from CHW to HWC format and scale to 0-255 range
img = img.permute(1, 2, 0)  # Shape: [224, 224, 3]
img = (img * 255).clamp(0, 255).to(torch.uint8)


# 将tensor转换为numpy数组并调整通道顺序
img_np = img.numpy()

# 创建新的图形
plt.figure(figsize=(8, 8))
plt.imshow(img_np)
plt.axis('off')  # 不显示坐标轴
plt.show()


In [None]:
img = data["pixel_values"][0,3:6]  # Shape: [3, 224, 224]
print(img.shape)

# Convert from CHW to HWC format and scale to 0-255 range
img = img.permute(1, 2, 0)  # Shape: [224, 224, 3]
img = (img * 255).clamp(0, 255).to(torch.uint8)


# 将tensor转换为numpy数组并调整通道顺序
img_np = img.numpy()

# 创建新的图形
plt.figure(figsize=(8, 8))
plt.imshow(img_np)
plt.axis('off')  # 不显示坐标轴
plt.show()