- High-Performance Inference: Model inference optimized with NVIDIA TensorRT
- Dynamic Batching: Supports automatic batch aggregation to improve inference efficiency
- Multi-Model Support: Simultaneously supports Embedding, Reranker, and NLI models
- RESTful API: Provides a standard HTTP API interface
- OpenAI Compatible: Supports API calls in OpenAI SDK format
- GPU Memory Optimization: Efficient GPU memory management
Currently only supports Embedding, Reranker, and NLI
git clone https://github.com/FubonDS/TensorrtServer.git
cd TensorrtServerRefer to docker/docker-compose.yaml as the base image
pip install -r requirements.txtModel conversion consists of two steps: PyTorch → ONNX and ONNX → TensorRT.
The relevant scripts are located in the trt_convert/ directory and all support argparse arguments.
python trt_convert/embedding2onnx.py \
--model_path ./embedding_engine/model/embedding_model/bge-m3-model \
--tokenizer_path ./embedding_engine/model/embedding_model/bge-m3-tokenizer \
--output_path ./embedding_models/model_dynamic/bge_m3_embedding_dynamic.onnx \
--max_length 256Parameters & Script Details (click to expand)
| Parameter | Default | Description |
|---|---|---|
--model_path |
./embedding_engine/model/embedding_model/bge-m3-model |
Model path |
--tokenizer_path |
./embedding_engine/model/embedding_model/bge-m3-tokenizer |
Tokenizer path |
--output_path |
./embedding_models/model_dynamic/bge_m3_embedding_dynamic.onnx |
Output ONNX path |
--max_length |
256 |
Maximum sequence length |
Script notes:
- Loads the model with
AutoModeland usesEmbeddingWrapperto extract the CLS token (last_hidden_state[:, 0, :]) as output - Inputs:
input_ids,attention_mask(int32); Output:embeddings - Dynamic axis:
batch_size(dimension 0), ONNX opset: 17
python trt_convert/rerank2onnx.py \
--model_path ./reranking_model/bge-reranker-large-model \
--tokenizer_path ./reranking_model/bge-reranker-large-tokenizer \
--output_path ./model_dynamic/bge_reranker_large_dynamic.onnx \
--max_length 256Parameters & Script Details (click to expand)
| Parameter | Default | Description |
|---|---|---|
--model_path |
./reranking_model/bge-reranker-large-model |
Model path |
--tokenizer_path |
./reranking_model/bge-reranker-large-tokenizer |
Tokenizer path |
--output_path |
./model_dynamic/bge_reranker_large_dynamic.onnx |
Output ONNX path |
--max_length |
256 |
Maximum sequence length |
Script notes:
- Loads the model with
AutoModelForSequenceClassification;RerankerWrapperoutputslogits.squeeze(-1)as relevance scores - Inputs:
input_ids,attention_mask(int32); Output:scores - Dynamic axis:
batch_size(dimension 0), ONNX opset: 17
python trt_convert/nli2onnx.py \
--model_path joeddav/xlm-roberta-large-xnli \
--tokenizer_path joeddav/xlm-roberta-large-xnli \
--output_path ./model_dynamic_bs/nli_model_dynamic_bs.onnx \
--max_length 256Parameters & Script Details (click to expand)
| Parameter | Default | Description |
|---|---|---|
--model_path |
joeddav/xlm-roberta-large-xnli |
Model path (supports HuggingFace Hub) |
--tokenizer_path |
joeddav/xlm-roberta-large-xnli |
Tokenizer path |
--output_path |
./model_dynamic_bs/nli_model_dynamic_bs.onnx |
Output ONNX path |
--max_length |
256 |
Maximum sequence length |
Script notes:
- Loads the model with
AutoModelForSequenceClassificationand directly outputslogits(3 NLI class scores) - Inputs:
input_ids,attention_mask(int32); Output:logits - Dynamic axis:
batch_size(dimension 0), ONNX opset: 17
Must be run inside a TensorRT Docker container (see docker/docker-compose.yaml), using the trtexec tool.
trtexec \
--onnx=./model/nli_model_dynamic_bs.onnx \
--saveEngine=nli_model_bs8.trt \
--fp16Dynamic batch size allows the model to accept batches of varying sizes at inference time. Examples for each model:
Embedding Model
trtexec \
--onnx=./embedding_models/bge_m3_embedding_dynamic.onnx \
--saveEngine=bge_m3_model_dynamic_bs.trt \
--fp16 \
--minShapes=input_ids:1x256,attention_mask:1x256 \
--optShapes=input_ids:8x256,attention_mask:8x256 \
--maxShapes=input_ids:32x256,attention_mask:32x256Reranker Model
trtexec \
--onnx=./reranker_models/bge_reranker_large_dynamic.onnx \
--saveEngine=bge_reranker_large_dynamic_bs.trt \
--fp16 \
--minShapes=input_ids:1x256,attention_mask:1x256 \
--optShapes=input_ids:8x256,attention_mask:8x256 \
--maxShapes=input_ids:32x256,attention_mask:32x256NLI Model
trtexec \
--onnx=./nli_models/nli_model_dynamic_bs.onnx \
--saveEngine=nli_model_dynamic_bs.trt \
--fp16 \
--minShapes=input_ids:1x256,attention_mask:1x256 \
--optShapes=input_ids:8x256,attention_mask:8x256 \
--maxShapes=input_ids:32x256,attention_mask:32x256| Parameter | Description |
|---|---|
--onnx |
Input ONNX model path |
--saveEngine |
Output TensorRT engine path |
--fp16 |
Enable FP16 precision to accelerate inference and reduce memory usage |
--minShapes |
Minimum input size for dynamic shapes: tensor_name:dim0xdim1 |
--optShapes |
Optimal input size for dynamic shapes (affects TRT optimization focus) |
--maxShapes |
Maximum input size for dynamic shapes |
After conversion, set the
.trtfile path in themodel_pathfield ofconfigs/config.yaml.
Edit configs/config.yaml to configure your models:
nli_models:
xlm-roberta-large-xnli:
model_name: "xlm-roberta-large-xnli"
model_path: "./model/nlimodels/trtmodels/nli_model_dynamic_bs.trt"
tokenizer_path: "joeddav/xlm-roberta-large-xnli"
reuse_dynamic_buffer: true
cuda_graph_list:
- 1
- 3
- 5
embedding_models:
bge-m3:
model_name: "bge-m3"
model_path: "./model/embedding_models/trt_models/bge_m3_model_dynamic_bs.trt"
tokenizer_path: "./model/embedding_models/bge-m3-tokenizer"
reuse_dynamic_buffer: true
cuda_graph_list:
- 1
- 3
- 5
reranking_models:
bge-reranker-large:
model_name: "bge-reranker-large"
model_path: "./model/reranker_models/trt_models/bge_reranker_large_dynamic_bs.trt"
tokenizer_path: "./model/reranker_models/bge-reranker-large-tokenizer"
reuse_dynamic_buffer: true
cuda_graph_list:
- 1
- 3
- 5model_name: Model identifier namemodel_path: TensorRT model file pathtokenizer_path: Tokenizer path (can be a local path or a Hugging Face model name)reuse_dynamic_buffer: Whether to pre-allocate buffers during initialization for dynamic batch sizes, avoiding dynamic allocation on each inference callcuda_graph_list: Pre-generates a fixed sequence of GPU kernel calls for the specified batch sizes, reducing CPU→GPU kernel launch overhead
chmod +x start_tensorrt_server.sh
./start_tensorrt_server.shAfter startup, the API will be available at http://{ip}:{port}.
The service provides two API formats: the native API and the OpenAI-compatible API.
curl http://localhost:8887/modelsResponse:
{
"models": {
"embedding_models": ["bge-m3"],
"reranking_models": ["bge-reranker-large"],
"nli_models": ["xlm-roberta-large-xnl"]
}
}Usage Instructions (click to expand)
import requests
url = "http://localhost:8887/infer/bge-m3"
payload = {"documents": ["This is a test text", "Another test text"]}
response = requests.post(url, json=payload)
print(response.json())
# Output:
# {
# "embeddings": [[0.1, 0.2, ...], [0.3, 0.4, ...]],
# "elapsed_ms": 15.2
# }import requests
url = "http://localhost:8887/infer/bge-reranker-large"
payload = {
"query": "Theory in machine learning is important",
"documents": [
"Theory is very important for understanding machine learning",
"Practical experience is also crucial in machine learning"
]
}
response = requests.post(url, json=payload)
print(response.json())
# Output:
# {
# "scores": [9.5078125, 7.2421875],
# "elapsed_ms": 5.78
# }import requests
url = "http://localhost:8887/infer/xlm-roberta-large-xnli"
payload = {
"premises": ["The weather is nice today", "Cats are animals"],
"hypotheses": ["Today is sunny", "Dogs are animals"]
}
response = requests.post(url, json=payload)
print(response.json())
# Output:
# {
# "predictions": ["entailment", "neutral"],
# "logits": [[2.18, -1.38, -0.72], [1.02, -0.42, -0.65]],
# "elapsed_ms": 12.45
# }Note: Only Embedding and Reranker models are supported; NLI models do not support this format.
Usage Instructions (click to expand)
from openai import OpenAI
client = OpenAI(
api_key="EMPTY", # any value works
base_url="http://localhost:8887/v1"
)
text = "This is a test text"
response = client.embeddings.create(
input=[text],
model="bge-m3"
)
print(response.data[0].embedding)from openai import OpenAI
client = OpenAI(
api_key="EMPTY",
base_url="http://localhost:8887/v1"
)
documents = [
"Machine learning is best learned through projects",
"Theory is important for understanding machine learning"
]
response = client.embeddings.create(
model="bge-reranker-large",
input=documents,
extra_body={"query": "Theory is important for understanding machine learning"}
)
# Get reranking scores
scores = [data.embedding for data in response.data]
print(scores)documents(required): A string or list of strings to encodemodel(optional): Model name; required when using the OpenAI API
query(required): Query stringdocuments(required): List of candidate documentsmodel(optional): Model name; required when using the OpenAI API
premises(required): List of premise sentenceshypotheses(required): List of hypothesis sentences
{
"embeddings": [[0.1, 0.2, ...]], // list of embedding vectors
"elapsed_ms": 15.2 // inference time (milliseconds)
}{
"scores": [9.5078125], // list of relevance scores
"elapsed_ms": 5.78 // inference time (milliseconds)
}{
"predictions": ["entailment"], // list of predicted labels
"logits": [[2.18, -1.38, -0.72]], // raw scores
"elapsed_ms": 12.45 // inference time (milliseconds)
}entailment: The premise supports the hypothesisneutral: The premise and hypothesis are unrelatedcontradiction: The premise and hypothesis conflict
The following results were obtained on an NVIDIA A100 GPU:
Test Configuration:
- Model: BGE-M3
- Batch size: 1–64
- Inference comparison: Torch vs. TensorRT
- Server comparison: Sequential inference vs. Dynamic batching
Test Configuration:
- Model: BGE-Reranker-Large
- Batch size: 1–64
- Inference comparison: Torch vs. TensorRT
- Server comparison: Sequential inference vs. Dynamic batching
Test Configuration:
- Model: XLM-RoBERTa-Large-XNLI
- Batch size: 1–64
- Inference comparison: Torch vs. TensorRT
- Server comparison: Sequential inference vs. Dynamic batching
sequenceDiagram
participant Client as Client
participant API as FastAPI Service
participant Worker as Worker Handler
participant TRT as TensorRT Inferencer
participant GPU as GPU (CUDA + TensorRT)
Client->>API: HTTP Request (/infer, /v1/embeddings)
API->>Worker: Dynamic Queue (payload, Future)
Worker->>Worker: Collect Batch (max_batch=32, max_wait=10ms)
Worker->>TRT: Call model.infer(all_docs)
TRT->>TRT: Tokenizer → input_ids, attention_mask
TRT->>TRT: Convert dtype → engine dtype
TRT->>GPU: H2D (memcpy_htod_async)
Note right of GPU: GPU buffer ready
TRT->>GPU: set_input_shape / set_tensor_address
TRT->>GPU: execute_async_v3()
GPU-->>TRT: GPU inference complete
Note over GPU: Engine executes inference on GPU
TRT->>GPU: D2H (memcpy_dtoh_async)
TRT->>TRT: Reshape and truncate to original length
TRT-->>Worker: Predictions / logits
Worker->>Worker: Split batch results
Worker-->>API: Future.set_result()
API-->>Client: Return JSON result
This project is licensed under the MIT License.







