diff --git a/benchmarks/paddleocr_vl/PaddleOCR-VL.yaml b/benchmarks/paddleocr_vl/PaddleOCR-VL.yaml
new file mode 100644
index 00000000000..f22fb98cc3c
--- /dev/null
+++ b/benchmarks/paddleocr_vl/PaddleOCR-VL.yaml
@@ -0,0 +1,97 @@
+
+pipeline_name: PaddleOCR-VL
+
+batch_size: 64
+
+use_queues: True
+
+use_doc_preprocessor: False
+use_layout_detection: True
+use_chart_recognition: False
+format_block_content: False
+
+SubModules:
+ LayoutDetection:
+ module_name: layout_detection
+ model_name: PP-DocLayoutV2
+ model_dir: null
+ batch_size: 8
+ threshold:
+ 0: 0.5 # abstract
+ 1: 0.5 # algorithm
+ 2: 0.5 # aside_text
+ 3: 0.5 # chart
+ 4: 0.5 # content
+ 5: 0.4 # formula
+ 6: 0.4 # doc_title
+ 7: 0.5 # figure_title
+ 8: 0.5 # footer
+ 9: 0.5 # footer
+ 10: 0.5 # footnote
+ 11: 0.5 # formula_number
+ 12: 0.5 # header
+ 13: 0.5 # header
+ 14: 0.5 # image
+ 15: 0.4 # formula
+ 16: 0.5 # number
+ 17: 0.4 # paragraph_title
+ 18: 0.5 # reference
+ 19: 0.5 # reference_content
+ 20: 0.45 # seal
+ 21: 0.5 # table
+ 22: 0.4 # text
+ 23: 0.4 # text
+ 24: 0.5 # vision_footnote
+ layout_nms: True
+ layout_unclip_ratio: [1.0, 1.0]
+ layout_merge_bboxes_mode:
+ 0: "union" # abstract
+ 1: "union" # algorithm
+ 2: "union" # aside_text
+ 3: "large" # chart
+ 4: "union" # content
+ 5: "large" # display_formula
+ 6: "large" # doc_title
+ 7: "union" # figure_title
+ 8: "union" # footer
+ 9: "union" # footer
+ 10: "union" # footnote
+ 11: "union" # formula_number
+ 12: "union" # header
+ 13: "union" # header
+ 14: "union" # image
+ 15: "large" # inline_formula
+ 16: "union" # number
+ 17: "large" # paragraph_title
+ 18: "union" # reference
+ 19: "union" # reference_content
+ 20: "union" # seal
+ 21: "union" # table
+ 22: "union" # text
+ 23: "union" # text
+ 24: "union" # vision_footnote
+ VLRecognition:
+ module_name: vl_recognition
+ model_name: PaddleOCR-VL-0.9B
+ model_dir: null
+ batch_size: 4096
+ genai_config:
+ backend: fastdeploy-server
+ server_url: http://127.0.0.1:8118/v1
+
+SubPipelines:
+ DocPreprocessor:
+ pipeline_name: doc_preprocessor
+ batch_size: 8
+ use_doc_orientation_classify: True
+ use_doc_unwarping: True
+ SubModules:
+ DocOrientationClassify:
+ module_name: doc_text_orientation
+ model_name: PP-LCNet_x1_0_doc_ori
+ model_dir: null
+ batch_size: 8
+ DocUnwarping:
+ module_name: image_unwarping
+ model_name: UVDoc
+ model_dir: null
diff --git a/benchmarks/paddleocr_vl/README.md b/benchmarks/paddleocr_vl/README.md
new file mode 100644
index 00000000000..3dbe96c3898
--- /dev/null
+++ b/benchmarks/paddleocr_vl/README.md
@@ -0,0 +1,139 @@
+## FastDeploy 服务化性能压测工具(PaddleOCR-VL)
+
+本文档主要介绍如何对 [PaddleOCR-VL](https://www.paddleocr.ai/latest/version3.x/pipeline_usage/PaddleOCR-VL.html) 进行性能测试。
+
+### 数据集:
+
+下载数据集到本地用于性能测试:
+
+
+
+
+ | 数据集 |
+ 获取地址 |
+
+
+
+
+ | OmniDocBench v1 数据集,共 981 个 pdf 文件 |
+ https://github.com/opendatalab/OmniDocBench |
+
+
+
+
+### 使用方式
+
+1. 启动 FastDeploy 服务,下面为 A100-80G 测试时使用的参数,可以根据实际情况进行调整:
+
+ ```shell
+ python -m fastdeploy.entrypoints.openai.api_server \
+ --model PaddlePaddle/PaddleOCR-VL \
+ --port 8118 \
+ --metrics-port 8471 \
+ --engine-worker-queue-port 8472 \
+ --cache-queue-port 55660 \
+ --max-model-len 16384 \
+ --max-num-batched-tokens 16384 \
+ --gpu-memory-utilization 0.7 \
+ --max-num-seqs 256 \
+ --workers 2 \
+ --graph-optimization-config '{"graph_opt_level":0, "use_cudagraph":true}'
+ ```
+
+2. 在同一环境安装依赖后启动测试脚本:
+
+ ```shell
+ # 安装依赖
+ pip install -U paddlex
+ # 启动测试脚本
+ python benchmark.py ./test_data -b 512 -o ./benchmark.json --paddlex_config_path ./PaddleOCR-VL.yaml --gpu_ids 0
+ ```
+
+ 测试脚本参数说明:
+
+
+
+
+ | 参数 |
+ 说明 |
+
+
+
+
+ input_dirs |
+ 输入的目录路径,会自动识别到目录下的 pdf 或图片。可以提供一个或多个。 |
+
+
+ -b, --batch_size |
+ 推理时使用的批处理大小。 |
+
+
+ -o, --output_path |
+ 输出结果文件的路径。 |
+
+
+ --paddlex_config_path |
+ PaddleX 的 YAML 配置文件路径。 |
+
+
+ --gpu_ids |
+ 指定要使用的 GPU 设备 ID,可提供一个或多个。 |
+
+
+
+
+3. 测试结束后,会输出类似于下面的结果:
+
+ ```text
+ Throughput (file): 1.3961 files per second
+ Average latency (batch): 351.0812 seconds
+ Processed pages: 981
+ Throughput (page): 1.3961 pages per second
+ Generated tokens: 1510337
+ Throughput (token): 2149.5 tokens per second
+ GPU utilization (%): 100.0, 0.0, 68.1
+ GPU memory usage (MB): 77664.8, 58802.8, 74402.7
+ ```
+
+ 输出结果说明:
+
+
+
+
+ | 参数 |
+ 说明 |
+
+
+
+ | Throughput (file) |
+ 每秒处理的文件数量 |
+
+
+ | Average latency (batch) |
+ 每批次处理的平均延迟时间,单位为秒 |
+
+
+ | Processed pages |
+ 已处理的页面总数 |
+
+
+ | Throughput (page) |
+ 每秒处理的页面数量 |
+
+
+ | Generated tokens |
+ 生成的token总数 |
+
+
+ | Throughput (token) |
+ 每秒生成的token数量 |
+
+
+ | GPU utilization (%) |
+ GPU 的最大、最小、平均利用率 |
+
+
+ | GPU memory usage (MB) |
+ GPU 的最大、最小、平均显存占用,单位为 MB |
+
+
diff --git a/benchmarks/paddleocr_vl/benchmark.py b/benchmarks/paddleocr_vl/benchmark.py
new file mode 100644
index 00000000000..c09d91c5360
--- /dev/null
+++ b/benchmarks/paddleocr_vl/benchmark.py
@@ -0,0 +1,226 @@
+#!/usr/bin/env python
+
+import argparse
+import glob
+import json
+import os
+import sys
+import time
+import uuid
+from operator import itemgetter
+from threading import Thread
+
+import pynvml
+import tiktoken
+from tqdm import tqdm
+
+shutdown = False
+
+encoding = tiktoken.get_encoding("cl100k_base")
+
+
+class Predictor(object):
+ def predict(self, task_info, batch_data):
+ task_info["start_time"] = get_curr_time()
+ try:
+ markdown, num_pages = self._predict(batch_data)
+ except Exception as e:
+ task_info["successful"] = False
+ print(e)
+ raise
+ finally:
+ task_info["end_time"] = get_curr_time()
+ task_info["successful"] = True
+ task_info["processed_pages"] = num_pages
+ task_info["generated_tokens"] = len(encoding.encode(markdown))
+ return markdown
+
+ def _predict(self, batch_data):
+ raise NotImplementedError
+
+ def close(self):
+ pass
+
+
+class PaddleXPredictor(Predictor):
+ def __init__(self, config_path):
+ from paddlex import create_pipeline
+
+ super().__init__()
+ self.pipeline = create_pipeline(config_path)
+
+ def _predict(self, batch_data):
+ results = list(self.pipeline.predict(batch_data))
+ return "\n\n".join(res._to_markdown(pretty=False)["markdown_texts"] for res in results), len(results)
+
+ def close(self):
+ self.pipeline.close()
+
+
+def monitor_device(gpu_ids, gpu_metrics_list):
+ try:
+ pynvml.nvmlInit()
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(gpu_id) for gpu_id in gpu_ids]
+
+ time.sleep(5)
+ while not shutdown:
+ try:
+ gpu_util = 0
+ mem_bytes = 0
+
+ for handle in handles:
+ gpu_util += pynvml.nvmlDeviceGetUtilizationRates(handle).gpu
+ mem_bytes += pynvml.nvmlDeviceGetMemoryInfo(handle).used
+
+ gpu_metrics_list.append(
+ {
+ "utilization": gpu_util,
+ "memory": mem_bytes,
+ }
+ )
+ except Exception as e:
+ print(f"Error monitoring GPUs: {e}")
+
+ time.sleep(0.5)
+
+ except Exception as e:
+ print(f"Error initializing the GPU monitor: {e}")
+ finally:
+ try:
+ pynvml.nvmlShutdown()
+ except:
+ pass
+
+
+def get_curr_time():
+ return time.perf_counter()
+
+
+def new_task_info():
+ task_info = {}
+ task_info["id"] = uuid.uuid4().hex
+ return task_info
+
+
+def create_and_submit_new_task(executor, requestor, task_info_dict, input_path):
+ task_info = new_task_info()
+ task = executor.submit(
+ requestor.make_request,
+ task_info,
+ input_path,
+ )
+ task_info_dict[task] = task_info
+
+ return task
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input_dirs", type=str, nargs="+", metavar="INPUT_DIR")
+ parser.add_argument("-b", "--batch_size", type=int, default=1)
+ parser.add_argument("-o", "--output_path", type=str, default="benchmark.json")
+ parser.add_argument("--paddlex_config_path", type=str, default="PaddleOCR-VL.yaml")
+ parser.add_argument("--gpu_ids", type=int, nargs="+", default=[0])
+ args = parser.parse_args()
+
+ task_info_list = []
+
+ all_input_paths = []
+ for input_dir in args.input_dirs:
+ all_input_paths += glob.glob(os.path.join(input_dir, "*"))
+ all_input_paths.sort()
+ if len(all_input_paths) == 0:
+ print("No valid data")
+ sys.exit(1)
+
+ predictor = PaddleXPredictor(args.paddlex_config_path)
+
+ if args.batch_size < 1:
+ print("Invalid batch size")
+ sys.exit(2)
+
+ gpu_metrics_list = []
+ thread_device_monitor = Thread(
+ target=monitor_device,
+ args=(args.gpu_ids, gpu_metrics_list),
+ )
+ thread_device_monitor.start()
+
+ try:
+ start_time = get_curr_time()
+ batch_data = []
+ with open("generated_markdown.md", "w", encoding="utf-8") as f:
+ for i, input_path in tqdm(enumerate(all_input_paths), total=len(all_input_paths)):
+ batch_data.append(input_path)
+ if len(batch_data) == args.batch_size or i == len(all_input_paths) - 1:
+ task_info = new_task_info()
+ try:
+ markdown = predictor.predict(task_info, batch_data)
+ f.write(markdown)
+ f.write("\n\n")
+ except Exception as e:
+ print(e)
+ continue
+ task_info_list.append(task_info)
+ batch_data.clear()
+ end_time = get_curr_time()
+ finally:
+ shutdown = True
+ thread_device_monitor.join()
+ predictor.close()
+
+ total_files = len(all_input_paths)
+ throughput_file = total_files / (end_time - start_time)
+ print(f"Throughput (file): {throughput_file:.4f} files per second")
+ duration_list_batch = [info["end_time"] - info["start_time"] for info in task_info_list]
+ avg_latency_batch = sum(duration_list_batch) / len(duration_list_batch)
+ print(f"Average latency (batch): {avg_latency_batch:.4f} seconds")
+
+ successful_files = sum(map(lambda x: x["successful"], task_info_list))
+ if successful_files:
+ processed_pages = sum(info.get("processed_pages", 0) for info in task_info_list)
+ throughput_page = processed_pages / (end_time - start_time)
+ print(f"Processed pages: {processed_pages}")
+ print(f"Throughput (page): {throughput_page:.4f} pages per second")
+ generated_tokens = sum(info.get("generated_tokens", 0) for info in task_info_list)
+ throughput_token = generated_tokens / (end_time - start_time)
+ print(f"Generated tokens: {generated_tokens}")
+ print(f"Throughput (token): {throughput_token:.1f} tokens per second")
+ else:
+ processed_pages = None
+ throughput_page = None
+ generated_tokens = None
+ throughput_token = None
+
+ if gpu_metrics_list:
+ gpu_util_list = list(map(itemgetter("utilization"), gpu_metrics_list))
+ print(
+ f"GPU utilization (%): {max(gpu_util_list):.1f}, {min(gpu_util_list):.1f}, {sum(gpu_util_list) / len(gpu_util_list):.1f}"
+ )
+ gpu_mem_list = list(map(itemgetter("memory"), gpu_metrics_list))
+ print(
+ f"GPU memory usage (MB): {max(gpu_mem_list) / 1024**2:.1f}, {min(gpu_mem_list) / 1024**2:.1f}, {sum(gpu_mem_list) / len(gpu_mem_list) / 1024**2:.1f}"
+ )
+
+ dic = {
+ "input_dirs": args.input_dirs,
+ "batch_size": args.batch_size,
+ "total_files": total_files,
+ "throughput_file": throughput_file,
+ "avg_latency_batch": avg_latency_batch,
+ "duration_list": duration_list_batch,
+ "successful_files": successful_files,
+ "processed_pages": processed_pages,
+ "throughput_page": throughput_page,
+ "generated_tokens": generated_tokens,
+ "throughput_token": throughput_token,
+ "gpu_metrics_list": gpu_metrics_list,
+ }
+ with open(args.output_path, "w", encoding="utf-8") as f:
+ json.dump(
+ dic,
+ f,
+ ensure_ascii=False,
+ indent=2,
+ )
+ print(f"Config and results saved to {args.output_path}")