# Implements OpenVLA

In [14]:
# Install minimal dependencies (`torch`, `transformers`, `timm`, `tokenizers`, ...)
# > pip install -r https://raw.githubusercontent.com/openvla/openvla/main/requirements-min.txt
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

import torch

# Load Processor & VLA
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b", 
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True, 
    trust_remote_code=True
).to("cuda:0")

# Grab image input & format prompt
image: Image.Image = get_from_camera(...)
prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"

# Predict Action (7-DoF; un-normalize for BridgeData V2)
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)

# Execute...
robot.act(action, ...)

A new version of the following files was downloaded from https://huggingface.co/openvla/openvla-7b:
- processing_prismatic.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/openvla/openvla-7b:
- configuration_prismatic.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/openvla/openvla-7b:
- modeling_prismatic.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Downloading shards: 100%|██████████| 3/3 [18:43<00:00, 374.42s/it]


ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.

# Explore Dataset

In [None]:
import tensorflow_datasets as tfds

# 예시: 가장 유명한 데이터셋 중 하나인 'fractal20220825_data' (RT-1 데이터)
ds = tfds.load('fractal20220825_data', split='train', data_dir='YOUR_DATA_PATH')
for episode in ds.take(1):
    for step in episode['steps'].take(1):
        print(step['observation']['image'])  # 이미지 데이터 확인
        print(step['action'])               # 액션 값 확인
        print(step['observation']['natural_language_instruction']) # 언어 명령 확인

# RLDS (Robot Learning Dataset)

In [3]:
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image

# 1. 데이터셋 경로 설정 (tfrecord 파일이 들어있는 '상위 폴더' 경로를 넣으세요)
DATA_PATH = "./data/" 

# 2. RLDS 데이터셋 로드
# 해당 경로에 metadata가 포함되어 있어야 작동합니다.
builder = tfds.builder_from_directory(DATA_PATH)
ds = builder.as_dataset(split='train')

# 3. 데이터 확인하기 (첫 번째 에피소드의 첫 번째 스텝)
for episode in ds.take(1):
    steps = list(episode['steps'])
    first_step = steps[0]
    
    # 이미지 데이터 추출
    image = first_step['observation']['image'].numpy()
    instruction = first_step['observation']['natural_language_instruction'].numpy().decode('utf-8')
    action = first_step['action'].numpy()
    
    print(f"Task: {instruction}")
    print(f"Action (Joint velocities/Pose): {action}")
    
    # 이미지 시각화
    display(Image.fromarray(image))

2026-01-07 10:26:38.421542: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-07 10:26:38.556332: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-01-07 10:26:38.556407: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-01-07 10:26:38.577467: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-01-07 10:26:38.633800: I tensorflow/core/platform/cpu_feature_guar

NotFoundError: {{function_node __wrapped__IteratorGetNext_output_types_8_device_/job:localhost/replica:0/task:0/device:CPU:0}} data/bridge_dataset-train.tfrecord-00000-of-01024; No such file or directory [Op:IteratorGetNext] name: 

# Make Dataloader

For training, the raw data needs to be converted into a format that is compatible with a data loader. 

In [3]:
import pickle

with open('./data/bridge_dataset_scripted_6_18/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0/obs_dict.pkl', 'rb') as f:
    obs = pickle.load(f)

FileNotFoundError: [Errno 2] No such file or directory: './data/bridge_dataset_scripted_6_18/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0/obs_dict.pkl'

In [2]:
with open('./data/bridge_dataset_scripted_6_18/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0/policy_out.pkl', 'rb') as f:
    policy = pickle.load(f)

In [3]:
import os
img_pth = './data/bridge_dataset_scripted_6_18/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0/images0'

len(os.listdir(img_pth))

50

# Explore processor

In [None]:
# Install minimal dependencies (`torch`, `transformers`, `timm`, `tokenizers`, ...)
# > pip install -r https://raw.githubusercontent.com/openvla/openvla/main/requirements-min.txt
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

import torch

# Load Processor & VLA
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b",
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to("cuda:0")

# 2. 입력 데이터 준비 (이미지 + 명령어)
image = Image.open("data/bridge_dataset_scripted_6_18/sweep_12-03/2022-12-04_14-56-20/raw/traj_group0/traj0/images0/im_0.jpg") 
# 로봇 카메라 이미지
prompt = "In order to pick up the can, the robot should" # Bridge 데이터셋 스타일 프롬프트

# 3. 추론 실행
inputs = processor(prompt, image, return_tensors="pt").to("cuda", dtype=torch.bfloat16)
# action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)

# print(f"Predicted Action: {action}") # [x, y, z, roll, pitch, yaw, gripper]

  from .autonotebook import tqdm as notebook_tqdm
2026-01-09 14:48:59.862285: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-09 14:48:59.887553: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-01-09 14:48:59.887585: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-01-09 14:48:59.888267: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-01-09 14:48:59.8

In [9]:
import os
import pickle

print(f'image length:{len(os.listdir("./data/bridge_dataset_scripted_6_18/sweep_12-03/2022-12-04_14-56-20/raw/traj_group0/traj0/images0"))}')

with open('./data/bridge_dataset_scripted_6_18/sweep_12-03/2022-12-04_14-56-20/raw/traj_group0/traj0/obs_dict.pkl', 'rb') as f:
    obs = pickle.load(f)

with open('./data/bridge_dataset_scripted_6_18/sweep_12-03/2022-12-04_14-56-20/raw/traj_group0/traj0/policy_out.pkl', 'rb') as f:
    policy = pickle.load(f)

image length:30


In [12]:
len(obs['joint_effort']), obs.keys()

(30,
 dict_keys(['joint_effort', 'qpos', 'qvel', 'full_state', 'state', 'desired_state', 'time_stamp', 'eef_transform', 'high_bound', 'low_bound', 'env_done', 't_get_obs']))

In [16]:
len(policy[-1]['actions']), policy[-1]

(7,
 {'actions': array([-0.00148566, -0.00214222, -0.00269785, -0.00634548,  0.01319404,
          0.02281726, -0.0005469 ])})

In [18]:
inputs.keys()

dict_keys(['input_ids', 'attention_mask', 'pixel_values'])

# Build Dataloader

In [8]:
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

import torch

processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
image = Image.open("./data/scripted_raw/sweep_12-03/2022-12-04_14-56-20/raw/traj_group0/traj0/images0/im_0.jpg") 
prompt = "In order to pick up the can, the robot should" # Bridge 데이터셋 스타일 프롬프트
inputs = processor(prompt, image, return_tensors="pt").to("cuda", dtype=torch.bfloat16)
print(*inputs)



input_ids attention_mask pixel_values


In [10]:
inputs['input_ids'], inputs['attention_mask'], inputs['pixel_values'].shape, inputs.keys()

(tensor([[    1,   512,  1797,   304,  5839,   701,   278,   508, 29892,   278,
          19964,   881]], device='cuda:0'),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'),
 torch.Size([1, 6, 224, 224]),
 dict_keys(['input_ids', 'attention_mask', 'pixel_values']))

In [7]:
import pandas as pd

# df = pd.read_csv('instruction.csv', encoding='utf-8')
# df = pd.DataFrame({
#     'path' : [],
#     'instruction' : []
# }).to_csv('instruction.csv', encoding='utf-8', index=False)

df

Unnamed: 0,path,instruction


In [24]:
import os
import pandas as pd

root_dir = './data/scripted_raw'

category = []
path = []

actions = 0
for root, dirs, files in os.walk(root_dir):
    if 'obs_dict.pkl' in files:
        if len(dirs) > 1:
            continue
        else:
            if 'images0' not in dirs:
                raise Exception
            actions += 1
            path.append(root)
print(actions)

9701


In [21]:
import re

sample = './data/scripted_raw/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0'

# re.search(r'\d{4}-\d{2}-\d{2}_(.+?)/', sample).group(1)

# sample.split('/')[3]
re.search(r'\d{4}-\d{2}-\d{2}_(.+?)/', sample)

<re.Match object; span=(20, 49), match='2022-12-08_pnp_rigid_objects/'>

In [22]:
import re

data = []
for p in path:
    w = p.split('/')[3]
    if re.search(r'\d{4}-\d{2}-\d{2}_(.+?)/', w):
        w = re.search(r'\d{4}-\d{2}-\d{2}_(.+?)/', w).group(1)
    else:
        w = w
    data.append({
        "path" : p,
        "category" : w,
        "instruction" : "",
        "anno" : False
    })

In [23]:
pd.DataFrame(data).to_csv(
    "instruction.csv", index=False, encoding='utf-8'
)

# Build Dataloader

In [4]:
import pickle

print(f'image length:{len(os.listdir("/workspace/openvla-LoRA/data/scripted_raw/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0/images0"))}')

with open('/workspace/openvla-LoRA/data/scripted_raw/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0/obs_dict.pkl', 'rb') as f:
    obs = pickle.load(f)

with open('/workspace/openvla-LoRA/data/scripted_raw/2022-12-08_pnp_rigid_objects/2022-12-08_15-22-17/raw/traj_group0/traj0/policy_out.pkl', 'rb') as f:
    policy = pickle.load(f)

image length:50


In [7]:
policy[0], len(policy)

({'actions': array([-0.01376923, -0.03019569, -0.00570136, -0.0037535 ,  0.00319256,
         -0.06449221,  0.99812681])},
 49)

In [2]:
import os
import pandas as pd

from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

import torch

from peft import LoraConfig, get_peft_model
from transformers import AutoModelForVision2Seq
import torch

# 1. 모델 로드
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)

# 2. LoRA 설정
# OpenVLA의 언어 모델 파트(Llama)의 특정 레이어를 타겟팅합니다.
config = LoraConfig(
    r=32,                         # Rank
    lora_alpha=64,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention 레이어 타겟
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# 3. LoRA 모델로 변환
vla = get_peft_model(vla, config)
vla.print_trainable_parameters()

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.15s/it]


trainable params: 33,554,432 || all params: 7,574,791,616 || trainable%: 0.4430


In [16]:
processor.tokenizer

LlamaTokenizerFast(name_or_path='openvla/openvla-7b', vocab_size=32000, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<PAD>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	32000: AddedToken("<PAD>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [17]:
dir(vla)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_compiled_call_impl',
 '_create_repo',
 '_enable_peft_forward_hooks',
 '_forward_hooks',
 '_forward_hooks_always_called',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_base_model_class',
 '_get_files_timestamps',
 '_get_name',
 '_is_full_backward_hook',
 '_is_prompt_learning',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_loa

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class BridgeV2Dataset(Dataset):
    def __init__(self, data_list, processor, action_tokenizer):
        """
        Args:
            data_list: [{'image_path': '...', 'instruction': '...', 'action': [...]}] 형태의 리스트
            processor: OpenVLA Processor
            action_tokenizer: OpenVLA 내부의 action tokenizer (vla.action_tokenizer)
        """
        self.data_list = data_list
        self.processor = processor
        self.action_tokenizer = action_tokenizer

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        item = self.data_list[idx]
        image = Image.open(item['image_path']).convert("RGB")
        instruction = item['instruction']
        
        # 1. OpenVLA 전용 입력 생성 (Prompt + Image)
        # padding과 truncation을 적용하여 배치 크기를 맞춥니다.
        inputs = self.processor(instruction, image, return_tensors="pt")
        
        # 2. Action 레이블 처리
        # OpenVLA는 액션을 텍스트 토큰 뒤에 붙는 특수 토큰으로 처리합니다.
        # item['action']은 7차원 연속 수치형 리스트라고 가정합니다.
        raw_action = torch.tensor(item['action'], dtype=torch.float32)
        
        # Action을 토큰 ID로 변환 (분산화/양자화 과정)
        # 이 토큰들이 파인튜닝의 target(labels)이 됩니다.
        action_tokens = self.action_tokenizer.encode(raw_action.numpy())
        
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": torch.tensor(action_tokens, dtype=torch.long)
        }

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = [item["labels"] for item in batch]

    # 입력 텍스트 패딩
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    # 레이블 패딩 (보통 -100으로 설정하여 Loss 계산에서 제외)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)

    return {
        "input_ids": input_ids_padded,
        "pixel_values": pixel_values,
        "labels": labels_padded
    }

# 실행 예시
# vla 모델이 이미 로드되어 있다고 가정 (vla.action_tokenizer 사용을 위해)
train_dataset = BridgeV2Dataset(your_data_list, processor, vla.action_tokenizer)
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=4, 
    shuffle=True, 
    collate_fn=collate_fn
)

In [12]:
import os
import pickle
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image

class SingleTrajBridgeDataset(Dataset):
    def __init__(self, traj_dir, instruction, processor, action_tokenizer):
        """
        Args:
            traj_dir: 'traj0' 폴더의 절대 경로 (안에 images0 폴더와 policy_out.pkl 존재)
            instruction: 해당 trajectory의 작업 명령어 (e.g., "pick up the can")
            processor: OpenVLA Processor
            action_tokenizer: vla.action_tokenizer
        """
        self.processor = processor
        self.action_tokenizer = action_tokenizer
        self.instruction = instruction
        
        # 1. Action 데이터 로드 (policy_out.pkl)
        pkl_path = os.path.join(traj_dir, "policy_out.pkl")
        with open(pkl_path, "rb") as f:
            # Bridge 데이터셋은 보통 list of dicts 형태입니다.
            raw_data = pickle.load(f)
        
        # policy_out 내부에 'actions' 키가 있는지 확인하여 데이터 추출
        # 만약 raw_data 자체가 numpy array라면 바로 사용합니다.
        if isinstance(raw_data[0], dict):
            self.actions = [d['actions'] for d in raw_data] 
        else:
            self.actions = raw_data

        # 2. 이미지 경로 설정 (images0)
        self.img_dir = os.path.join(traj_dir, "images0")
        
        # 3. 유효한 샘플 개수 정의 (Action 개수 기준)
        self.num_samples = len(self.actions)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # t 시점의 이미지와 t 시점의 액션을 매칭
        img_path = os.path.join(self.img_dir, f"im_{idx}.jpg")
        image = Image.open(img_path).convert("RGB")
        
        # OpenVLA 입력 처리 (Image + Text)
        inputs = self.processor(self.instruction, image, return_tensors="pt")
        
        # Action 토큰화 (7차원 수치 -> 7개 토큰 ID)
        raw_action = np.array(self.actions[idx], dtype=np.float32)
        action_tokens = self.action_tokenizer.encode(raw_action)
        
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": torch.tensor(action_tokens, dtype=torch.long)
        }

In [13]:
# 1. 초기 설정
TRAJ_PATH = "./data/scripted_raw/sweep_12-03/2022-12-04_14-56-20/raw/traj_group0/traj0"
INST = "In order to pick up the can, the robot should"

# 2. 데이터셋 및 데이터로더 생성
train_dataset = SingleTrajBridgeDataset(
    traj_dir=TRAJ_PATH,
    instruction=INST,
    processor=processor,
    action_tokenizer=vla.action_tokenizer
)

def collate_fn(batch):
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collate_fn
)

# 3. 동작 확인
batch = next(iter(train_dataloader))
print(f"Input IDs shape: {batch['input_ids'].shape}")     # [BS, Seq_Len]
print(f"Pixel Values shape: {batch['pixel_values'].shape}") # [BS, 3, 224, 224]
print(f"Labels shape: {batch['labels'].shape}")           # [BS, 7] (7 action tokens)

AttributeError: 'OpenVLAForActionPrediction' object has no attribute 'action_tokenizer'

In [18]:
import numpy as np

class BridgeV2Dataset(Dataset):
    def __init__(self, traj_dir, instruction, processor, vla_config):
        self.processor = processor
        self.instruction = instruction
        self.vla_config = vla_config # vla.config 전달
        
        # 데이터 로드
        with open(os.path.join(traj_dir, "policy_out.pkl"), "rb") as f:
            raw_data = pickle.load(f)
        self.actions = [d['actions'] for d in raw_data] if isinstance(raw_data[0], dict) else raw_data
        self.img_dir = os.path.join(traj_dir, "images0")

    def __len__(self):
        return len(self.actions)

    def __getitem__(self, idx):
        # 1. 이미지 로드 및 전처리
        image = Image.open(os.path.join(self.img_dir, f"im_{idx}.jpg")).convert("RGB")
        inputs = self.processor(self.instruction, image, return_tensors="pt")
        
        # 2. 액션 양자화 (Action -> Token IDs)
        # OpenVLA 7B는 보통 31000번 이후의 256개 토큰을 사용합니다.
        raw_action = np.array(self.actions[idx], dtype=np.float32)
        
        # 공식 가이드에 따른 정규화 및 토큰화 로직
        # OpenVLA는 [-1, 1] 범위를 256개 bin으로 나눕니다.
        bin_indices = np.clip((raw_action + 1.0) / 2.0 * 255, 0, 255).astype(np.int32)
        
        # OpenVLA 7B의 액션 토큰 시작 인덱스는 보통 31000입니다.
        # 정확한 값은 vla.config.vocab_size - 256 근처입니다.
        action_token_ids = bin_indices + 31000 
        
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": torch.tensor(action_token_ids, dtype=torch.long)
        }

In [19]:
# 1. 초기 설정
TRAJ_PATH = "./data/scripted_raw/sweep_12-03/2022-12-04_14-56-20/raw/traj_group0/traj0"
INST = "In order to pick up the can, the robot should"

# 2. 데이터셋 및 데이터로더 생성
train_dataset = BridgeV2Dataset(
    traj_dir=TRAJ_PATH,
    instruction=INST,
    processor=processor,
    vla_config=vla.config
)

def collate_fn(batch):
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "pixel_values": torch.stack([item["pixel_values"] for item in batch]),
        "labels": torch.stack([item["labels"] for item in batch])
    }

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collate_fn
)

# 3. 동작 확인
batch = next(iter(train_dataloader))
print(f"Input IDs shape: {batch['input_ids'].shape}")     # [BS, Seq_Len]
print(f"Pixel Values shape: {batch['pixel_values'].shape}") # [BS, 3, 224, 224]
print(f"Labels shape: {batch['labels'].shape}")           # [BS, 7] (7 action tokens)

Input IDs shape: torch.Size([2, 12])
Pixel Values shape: torch.Size([2, 6, 224, 224])
Labels shape: torch.Size([2, 7])


In [None]:
from torch.optim import AdamW

optimizer = AdamW(vla.parameters(), lr=2e-5)
vla.to('cuda')
vla.train()
for batch in train_dataloader:
    # 데이터를 GPU로 이동
    input_ids = batch["input_ids"].to("cuda")
    pixel_values = batch["pixel_values"].to("cuda", dtype=torch.bfloat16)
    labels = batch["labels"].to("cuda")

    # Forward Pass
    # OpenVLA는 input_ids와 pixel_values를 받아 마지막에 액션 토큰을 예측하도록 설계됨
    outputs = vla(input_ids=input_ids, pixel_values=pixel_values, labels=labels)
    
    loss = outputs.loss
    loss.backward()
    
    optimizer.step()
    optimizer.zero_grad()
    
    print(f"Loss: {loss.item()}")

    break

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacity of 12.00 GiB of which 0 bytes is free. Of the allocated memory 11.25 GiB is allocated by PyTorch, and 60.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)