In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["HF_HUB_CACHE"] = "/mnt/sda/home/zijianwang/HF_CACHE"
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import draccus
import torch
import torch.distributed as dist
import tqdm
from accelerate import PartialState
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from transformers import AutoConfig, AutoImageProcessor
from transformers.modeling_outputs import CausalLMOutputWithPast

import wandb
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset, EpisodicRLDSDataset
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics

from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)

In [None]:
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)

In [None]:
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b", 
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True, 
    trust_remote_code=True
).to("cuda:0")

In [None]:
action_tokenizer = ActionTokenizer(processor.tokenizer)
vocab_size = action_tokenizer.vocab_size
print("词表大小:", vocab_size)

In [None]:
vla_model_config = OpenVLAConfig.from_pretrained("openvla/openvla-7b")

In [None]:
print(vla_model_config.image_sizes)

In [None]:
batch_transform = RLDSBatchTransform(
    action_tokenizer,
    processor.tokenizer,
    image_transform=processor.image_processor.apply_transform,
    prompt_builder_fn=PurePromptBuilder if "v01" not in "openvla/openvla-7b" else VicunaV15ChatPromptBuilder,
)

vla_dataset = RLDSDataset(
    "/mnt/sda/home/zijianwang/openvla/modified_libero_rlds",
    "libero_goal_no_noops",
    batch_transform,
    resize_resolution=tuple(vla_model_config.image_sizes),
    shuffle_buffer_size=100_000,
    image_aug=True,
)

episodic_vla_dataset = EpisodicRLDSDataset(
    "/mnt/sda/home/zijianwang/openvla/modified_libero_rlds",
    "libero_goal_no_noops",
    batch_transform,
    resize_resolution=tuple(vla_model_config.image_sizes),
    shuffle_buffer_size=100_000,
    image_aug=False,
    if_random_start=False
)

In [None]:
import imageio, os
import numpy as np
imgs = data["replay_images"]
mp4_path = "test3.mp4"
# os.makedirs(os.path.dirname(mp4_path), exist_ok=True)
video_writer = imageio.get_writer(mp4_path, fps=30)
for img in imgs[:]:    
    video_writer.append_data(img)
video_writer.close()

In [None]:
for data in vla_dataset:
    print(data["input_ids"].shape)
    print(data["labels"].shape)
    print(data["pixel_values"].shape)
    print(data["dataset_name"])
    print(data.keys())
    break

In [None]:
print(data["action"])
print(data["img"])

In [None]:
collator = PaddedCollatorForActionPrediction(
    processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
dataloader = DataLoader(
    vla_dataset,
    batch_size=100,
    sampler=None,
    collate_fn=collator,
    num_workers=0,  # Important =>> Set to 0 if using RLDS; TFDS rolls its own parallelism!
)
print(len(dataloader))

In [None]:
with tqdm.tqdm(total=20000, leave=False) as progress:
    for batch_idx, batch in enumerate(dataloader):
        progress.update()
        break

In [None]:
print(batch.keys())

In [None]:
# batch = data
device_id = vla.device
output: CausalLMOutputWithPast = vla(
    input_ids=batch["input_ids"].to(device_id),
    # attention_mask=batch["attention_mask"].to(device_id),
    pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device_id),
    labels=batch["labels"],
)

inputs = {
    "input_ids": batch["input_ids"].to(device_id),
    "pixel_values": batch["pixel_values"].to(torch.bfloat16).to(device_id),
    "labels": batch["labels"],
}
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)

# action_ids = vla.predict_action_ids(**inputs, unnorm_key="bridge_orig", do_sample=False)

In [None]:
data = batch
print(data["input_ids"][0].shape)
print(processor.tokenizer.decode(data["input_ids"][0]))

In [None]:
from PIL import Image
import numpy as np
import torch
import matplotlib.pyplot as plt

img = data["pixel_values"][0,0:3]  # Shape: [3, 224, 224]
print(img.shape)

# Convert from CHW to HWC format and scale to 0-255 range
img = img.permute(1, 2, 0)  # Shape: [224, 224, 3]
img = (img * 255).clamp(0, 255).to(torch.uint8)


# 将tensor转换为numpy数组并调整通道顺序
img_np = img.numpy()

# 创建新的图形
plt.figure(figsize=(8, 8))
plt.imshow(img_np)
plt.axis('off')  # 不显示坐标轴
plt.show()

In [None]:
img = data["pixel_values"][0,3:6]  # Shape: [3, 224, 224]
print(img.shape)

# Convert from CHW to HWC format and scale to 0-255 range
img = img.permute(1, 2, 0)  # Shape: [224, 224, 3]
img = (img * 255).clamp(0, 255).to(torch.uint8)


# 将tensor转换为numpy数组并调整通道顺序
img_np = img.numpy()

# 创建新的图形
plt.figure(figsize=(8, 8))
plt.imshow(img_np)
plt.axis('off')  # 不显示坐标轴
plt.show()

In [None]:
import re, pickle, os, random
from experiments.robot.libero.libero_utils import (
    get_libero_dummy_action,
    get_libero_env,
    get_libero_image)
import numpy as np

In [None]:
resize_size = 224
base_path = "/mnt/sda/home/zijianwang/openvla/vla-scripts/DPO/winner_trajectory/libero_10"

# Find all folders with positive episode numbers
trajectory_folders = []
for folder_name in os.listdir(base_path):
    match = re.search(r'task_(\d+)_episode_(\d+)_success', folder_name)
    if match:
        episode_num = int(match.group(2))
        if episode_num > 0:  # Only positive episode numbers
            trajectory_folders.append(folder_name)

print(f"Found {len(trajectory_folders)} trajectories with positive episode numbers")

# Process each trajectory folder
for folder_name in trajectory_folders:
    trajectory_folder_path = os.path.join(base_path, folder_name)
    print(f"Processing: {folder_name}")
    
    pkl_files = [f for f in os.listdir(trajectory_folder_path) if f.endswith(".pkl")]
    
    # Sort pkl files by step number
    pkl_files_sorted = []
    for pkl_file in pkl_files:
        match = re.search(r'step_(\d+)\.pkl', pkl_file)
        if match:
            pkl_files_sorted.append(pkl_file)
    
    pkl_files_sorted.sort(key=lambda x: int(re.search(r'step_(\d+)\.pkl', x).group(1)))
    
    start_idx = 0
    action_sperate_token_id = 32001
    imgs = []
    
    for i in range(start_idx, len(pkl_files_sorted)):
        with open(os.path.join(trajectory_folder_path, pkl_files_sorted[i]), "rb") as f:
            data = pickle.load(f)
            state = data["obs"]
            if type(state).__name__ == 'OrderedDict':
                img = get_libero_image(state, resize_size) 
            elif type(state) == np.ndarray:
                img = data["obs"]
            imgs.append(img)
    
    if len(imgs) > 0:
        mp4_path = os.path.join(trajectory_folder_path, f"Avideo_{len(imgs)}.mp4")
        video_writer = imageio.get_writer(mp4_path, fps=30)
        for step, img in enumerate(imgs):    
            video_writer.append_data(img)
        video_writer.close()
        print(f"Saved video: {mp4_path}")
    else:
        print(f"No images found for {folder_name}")

In [None]:
print(type(state))
print(img.shape)