In [1]:
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
from transformers import AutoConfig, AutoImageProcessor
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.vla.action_tokenizer import ActionTokenizer
from typing import Type, Any
from prismatic.extern.hf.configuration_prismatic import OpenVLAConfig
from prismatic.extern.hf.modeling_prismatic import OpenVLAForActionPrediction
from prismatic.extern.hf.processing_prismatic import PrismaticImageProcessor, PrismaticProcessor
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
# 注册 OpenVLA 模型到 HF Auto Classes
AutoConfig.register("openvla", OpenVLAConfig)
AutoImageProcessor.register(OpenVLAConfig, PrismaticImageProcessor)
AutoProcessor.register(OpenVLAConfig, PrismaticProcessor)
AutoModelForVision2Seq.register(OpenVLAConfig, OpenVLAForActionPrediction)


  from .autonotebook import tqdm as notebook_tqdm
2025-05-22 22:24:07.322359: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-22 22:24:07.322423: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-22 22:24:07.323649: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-22 22:24:07.329503: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from typing import Dict, List, Optional, Union, Any
import json
import os

class OpenVLAInference:
    def __init__(self, model_path, device="cuda"):
        """
        初始化 OpenVLA 推理器
        
        参数:
            model_path: 模型路径 (可以是 HuggingFace Hub 路径或本地路径)
            device: 推理设备 ('cuda' 或 'cpu')
        """
        self.device = device if torch.cuda.is_available() and device.startswith("cuda") else "cpu"
        dataset_statistics_path = os.path.join(model_path, "dataset_statistics.json")
        if os.path.isfile(dataset_statistics_path):
            with open(dataset_statistics_path, "r") as f:
                norm_stats = json.load(f)
            self.norm_stats = norm_stats
        self.unnorm_key = "example_dataset"
        assert self.unnorm_key in self.norm_stats, f"Action un-norm key {self.unnorm_key} not found in VLA `norm_stats`!"
        # 加载处理器和模型
        self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
        self.model = AutoModelForVision2Seq.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True
        ).to(self.device)
        self.model.norm_stats = self.norm_stats
        # 创建 action tokenizer
        self.action_tokenizer = ActionTokenizer(self.processor.tokenizer)
        
        # 根据模型版本选择 prompt builder
        self.prompt_builder_cls = (
            PurePromptBuilder if "v01" not in model_path 
            else VicunaV15ChatPromptBuilder
        )
        
        # 设置模型为评估模式
        self.model.eval()
    
    def _build_prompt(self, text_instruction: str) -> str:
        """
        构建推理时使用的 prompt (直接实现原RLDSBatchTransform的逻辑)
        
        参数:
            text_instruction: 文本指令 (如 "wipe the table")
            
        返回:
            格式化后的 prompt 文本
        """
        # 初始化 prompt builder
        self.prompt_builder = self.prompt_builder_cls("openvla")
        # print("self.prompt_builder.turn_count",self.prompt_builder.turn_count)
        conversation = [
            {"from": "human", "value":f"What action should the robot take to {text_instruction.lower()}?"},
            # {"from": "gpt", "value":""},
        ]
        # 添加对话轮次 (只包含人类指令部分)
        for turn in conversation:
            self.prompt_builder.add_turn(turn["from"], turn["value"])
        # print("self.prompt_builder.turn_count",self.prompt_builder.turn_count)
        self.turn_count = self.prompt_builder.turn_count
        return self.prompt_builder.get_prompt()
    
    def preprocess_inputs(self, text_prompt: str, image: Any) -> Dict[str, torch.Tensor]:
        """
        预处理输入文本和图像
        
        参数:
            text_prompt: 文本指令
            image: PIL.Image 对象或图像路径
            
        返回:
            处理后的模型输入字典
        """
        # 如果 image 是路径，则加载图像
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
        
        # 构建 prompt
        prompt_text = self._build_prompt(text_prompt)
        self.input_ids = self.processor.tokenizer(prompt_text, add_special_tokens=True).input_ids
        self.input_ids = torch.tensor(self.input_ids)
        # 使用处理器处理输入
        inputs = self.processor(
            text=prompt_text,
            images=image,
            return_tensors="pt",
            truncation=True
        )
        
        # 将输入移动到设备
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
        
        return inputs
    
    def generate_actions(self, text_prompt: str, image: Any, 
                       max_new_tokens: int = 512, 
                       temperature: float = 0) -> tuple:
        """
        生成动作序列
        
        参数:
            text_prompt: 文本指令
            image: PIL.Image 对象或图像路径
            max_new_tokens: 最大生成 token 数
            temperature: 采样温度
            
        返回:
            tuple: (action_sequence, decoded_actions)
                - action_sequence: 动作序列 (numpy 数组)
                - decoded_actions: 解码后的动作 (人类可读格式)
        """
        # 预处理输入
        inputs = self.preprocess_inputs(text_prompt, image)
        
        # 生成动作 token
        with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
            self.output = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True if temperature > 0 else False,
                pad_token_id=self.processor.tokenizer.pad_token_id,
                eos_token_id=self.processor.tokenizer.eos_token_id,
            )
            self.action = self.model.predict_action(**inputs,
                                               unnorm_key=self.unnorm_key,
                                               temperature=temperature,
                                               do_sample=True if temperature > 0 else False,
                                               )
        mask = self.output > self.action_tokenizer.action_token_begin_idx
        # 解码动作 token
        action_tokens = self.output[mask].cpu().numpy()
        action_sequence = self.action_tokenizer.decode_token_ids_to_actions(action_tokens)
        
        # 获取人类可读的动作描述
        # decoded_actions = self.action_tokenizer.decode_actions_to_readable(action_sequence)
        
        return action_sequence
    
    def __call__(self, text_prompt: str, image: Any, 
                max_new_tokens: int = 512, 
                temperature: float = 0) -> tuple:
        """便捷调用方法"""
        return self.generate_actions(text_prompt, image, max_new_tokens, temperature)



In [3]:
# # 使用示例
# !CUDA_VISIBLE_DEVICES=1
vla_path = r"/home/chuangzhi/zhq/yjc/runs/openvla7b_huggingfacemodel+libero_spatial_no_noops+b2+lr-0.0005+lora-r32+dropout-0.0+example_dataset+b1+lr-0.0005+lora-r32+dropout-0.0"
# vla_path = r"/home/chuangzhi/zhq/yjc/runs/openvla7b_huggingfacemodel+libero_spatial_no_noops+b2+lr-0.0005+lora-r32+dropout-0.0"
# 初始化推理器 - 替换为你的模型路径  # 可以是本地路径或 HuggingFace Hub 路径
vla_inference = OpenVLAInference(vla_path,device="cuda:2")
vla_inference.device, vla_inference.norm_stats.keys()

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00,  4.99it/s]


('cuda:2', dict_keys(['example_dataset']))

In [4]:
# 示例输入
text_instruction = "Move the drinking glass to the basket."  # 注意使用小写，与训练时一致
image_path = "/home/chuangzhi/zhq/yjc/myScripts/9.png"  # 你的图像路径

# 生成动作
action_sequence = vla_inference(text_instruction, image_path)

# 打印结果
print("use .generate(), Generated Action Sequence(norm):", action_sequence)
print(f'"action(去归一化前)" 序列长度为:{len(action_sequence)}')
print("use .predict_action(), Generated Action Sequence(unnorm):", vla_inference.action)



use .generate(), Generated Action Sequence(norm): [0.         0.23529412 0.90980392 0.03137255 0.02352941 0.21176471
 0.99607843]
"action(去归一化前)" 序列长度为:7
use .predict_action(), Generated Action Sequence(unnorm): [-5.52274287e-06 -9.16986781e-05  3.00139446e-02  7.77396050e-04
 -8.98995658e-06  8.26287587e-04  9.96078431e-01]


use .generate(), Generated Action Sequence(norm): [0.90980392 0.24313725 0.         0.09411765 0.10980392 0.19607843
 0.99607843]

"action(去归一化前)" 序列长度为:7

[3.00022413e-02 1.17139460e-04 2.68705189e-05 1.69970930e-03 2.35824911e-06 6.13276555e-04 9.96078431e-01]

In [None]:
# 示例输入
text_instruction = "Move the drinking glass to the basket."  # 注意使用小写，与训练时一致
image_path = "/home/chuangzhi/zhq/yjc/myScripts/0.png"  # 你的图像路径

# 生成动作
action_sequence = vla_inference(text_instruction, image_path)

# 打印结果
print("Generated Action Sequence:", action_sequence)
print(f'"action" 序列长度为:{len(action_sequence)}')

In [4]:
from fastapi import FastAPI, UploadFile, File
import uvicorn

app = FastAPI()

@app.get("/healthcheck")
def healthcheck():
    return {"status": "OK"}

@app.post("/generate_action")
async def generate_action(instruction: str, image: UploadFile = File(...)):
    # TODO: 调用你的模型
    return {"action": [0.1, 0.2, 0.3]}  # 示例返回

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)

Form data requires "python-multipart" to be installed. 
You can install "python-multipart" with: 

pip install python-multipart



RuntimeError: Form data requires "python-multipart" to be installed. 
You can install "python-multipart" with: 

pip install python-multipart


In [4]:
import pika
import json

connection = pika.BlockingConnection(pika.ConnectionParameters('localhost', port=9025))
channel = connection.channel()

channel.queue_declare(queue='action_requests')
channel.queue_declare(queue='action_responses')

def callback(ch, method, properties, body):
    request = json.loads(body)
    action = vla_inference(request['instruction'], request['image_path'])
    
    response = {
        'action': vla_inference.action,
        'request_id': request['request_id']
    }
    
    channel.basic_publish(
        exchange='',
        routing_key='action_responses',
        body=json.dumps(response)
    )

channel.basic_consume(queue='action_requests', on_message_callback=callback, auto_ack=True)
channel.start_consuming()

IncompatibleProtocolError: StreamLostError: ("Stream connection lost: ConnectionResetError(104, 'Connection reset by peer')",)

In [8]:
# 远程服务器代码
import socket

server_ip = "0.0.0.0"  # 监听所有IP
server_port = 9025

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
    sock.bind((server_ip, server_port))
    sock.listen(1)
    print("Server is listening on port 9025")
    conn, addr = sock.accept()
    with conn:
        print(f"Connected by {addr}")
        data = conn.recv(1024)
        print(f"Received from client: {data.decode()}")
        conn.sendall(b"Hello from server")

OSError: [Errno 98] Address already in use

In [4]:
import socket

host = '0.0.0.0'  # 监听所有网络接口
port = 6789  # 确保端口号与客户端一致

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind((host, port))
    s.listen()
    print(f"Server listening on port {port}")
    conn, addr = s.accept()
    with conn:
        print(f"Connected by {addr}")
        while True:
            data = conn.recv(1024)
            if not data:
                break
            print(f"Received data: {data}")
            conn.sendall(b"Hello, client!")  # 发送固定响应

Server listening on port 6789


KeyboardInterrupt: 

In [4]:
import socket

host = '0.0.0.0'  # 监听所有网络接口
port = 8000  # 确保端口号与客户端一致

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind((host, port))
    s.listen()
    s.settimeout(10)  # 设置超时时间为10秒
    print(f"Server listening on port {port}")
    
    while True:
        conn, addr = s.accept()
        with conn:
            print(f"Connected by {addr}")
            while True:
                data = conn.recv(1024)
                if not data:
                    print("No data received, closing connection.")
                    break
                print(f"Received data: {data.decode()}")
                response = f"Server received: {data.decode()}"
                print(f"Sending response: {response}")
                conn.sendall(response.encode())

Server listening on port 8000


TimeoutError: timed out

In [None]:
import socket

host = '0.0.0.0'
port = 8000

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
    s.bind((host, port))
    s.listen()
    print(f"Server listening on port {port}")
    
    conn, addr = s.accept()
    with conn:
        print(f"Connected by {addr}")
        data = conn.recv(1024)
        print(f"Received data: {data.decode()}")
        response = "Hello from server"
        conn.sendall(response.encode())

In [1]:
import socket
import pickle

# 创建TCP socket
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind(('0.0.0.0', 12345))  # 监听所有接口，端口为12345
server_socket.listen(1)
print("Server is listening for incoming connections...")

# 接受客户端连接
client_socket, client_address = server_socket.accept()
print(f"Connection from {client_address}")

# 接收数据
data = client_socket.recv(4096)  # 根据字典大小调整接收缓冲区大小
if data:
    # 反序列化字典
    received_dict = pickle.loads(data)
    print("Received dictionary:", received_dict)

client_socket.close()
server_socket.close()

Server is listening for incoming connections...


KeyboardInterrupt: 

In [None]:
img = Image.open(image_path).convert("RGB")

In [None]:
inputs = vla_inference.preprocess_inputs(text_instruction, img)
inputs

In [None]:
vla_inference.input_ids

In [None]:
inputs.keys(),inputs['attention_mask']

In [None]:
vla_inference._build_prompt(text_instruction)

In [None]:
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    output = vla_inference.model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.1,
        do_sample=True,
        pad_token_id=vla_inference.processor.tokenizer.pad_token_id,
        eos_token_id=vla_inference.processor.tokenizer.eos_token_id,
    )
output

In [None]:
mask = output > vla_inference.action_tokenizer.action_token_begin_idx

output[mask].cpu().numpy()