-
Notifications
You must be signed in to change notification settings - Fork 75
Add Triton + TensorRT-LLM inference example #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
39 changes: 39 additions & 0 deletions
39
5-large-language-models/8-faster-inference-with-triton-tensorrt/Dockerfile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| FROM nvcr.io/nvidia/tritonserver:25.10-trtllm-python-py3 | ||
|
|
||
| # Environment variables | ||
| ENV PYTHONPATH=/usr/local/lib/python3.12/dist-packages:$PYTHONPATH | ||
| ENV PYTHONDONTWRITEBYTECODE=1 | ||
| ENV DEBIAN_FRONTEND=noninteractive | ||
| ENV HF_HOME=/persistent-storage/models | ||
| ENV TORCH_CUDA_ARCH_LIST=8.6 | ||
|
|
||
| # Install system dependencies | ||
| RUN apt-get update && apt-get install -y \ | ||
| git \ | ||
| git-lfs \ | ||
| && rm -rf /var/lib/apt/lists/* | ||
|
|
||
| WORKDIR /app | ||
|
|
||
| # Install Python dependencies | ||
| RUN pip install --break-system-packages \ | ||
| huggingface_hub \ | ||
| transformers \ | ||
| || true | ||
|
|
||
| # Create required directories | ||
| RUN mkdir -p \ | ||
| /app/model_repository/llama3_2/1 \ | ||
| /persistent-storage/models \ | ||
| /persistent-storage/engines | ||
|
|
||
| # Copy application files | ||
| COPY --chmod=755 download_model.py /app/ | ||
| COPY model.py /app/model_repository/llama3_2/1/ | ||
| COPY config.pbtxt /app/model_repository/llama3_2/ | ||
|
|
||
| # Expose Triton ports | ||
| EXPOSE 8000 8001 8002 | ||
|
|
||
| # Start Triton server directly | ||
| CMD ["tritonserver", "--model-repository=/app/model_repository", "--http-port=8000", "--grpc-port=8001", "--metrics-port=8002"] |
32 changes: 32 additions & 0 deletions
32
5-large-language-models/8-faster-inference-with-triton-tensorrt/cerebrium.toml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| [cerebrium.deployment] | ||
| name = "tensorrt-triton-demo" | ||
| python_version = "3.12" | ||
| disable_auth = true | ||
| include = ['./*', 'cerebrium.toml'] | ||
| exclude = ['.*'] | ||
| deployment_initialization_timeout = 830 | ||
|
|
||
| [cerebrium.hardware] | ||
| cpu = 4.0 | ||
| memory = 40.0 | ||
| compute = "AMPERE_A10" | ||
| gpu_count = 1 | ||
| provider = "aws" | ||
| region = "us-east-1" | ||
|
|
||
| [cerebrium.scaling] | ||
| min_replicas = 1 | ||
| max_replicas = 5 | ||
| cooldown = 300 | ||
| replica_concurrency = 128 | ||
| scaling_metric = "concurrency_utilization" | ||
|
|
||
| [cerebrium.dependencies.pip] | ||
| huggingface_hub = "latest" | ||
| transformers = "latest" | ||
|
|
||
| [cerebrium.runtime.custom] | ||
| port = 8000 | ||
| healthcheck_endpoint = "/v2/health/live" | ||
| readycheck_endpoint = "/v2/health/ready" | ||
| dockerfile_path = "./Dockerfile" |
48 changes: 48 additions & 0 deletions
48
5-large-language-models/8-faster-inference-with-triton-tensorrt/config.pbtxt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| name: "llama3_2" | ||
| backend: "python" | ||
| max_batch_size: 128 | ||
|
|
||
| dynamic_batching { | ||
| max_queue_delay_microseconds: 800 | ||
| } | ||
|
|
||
| instance_group [ | ||
| { | ||
| count: 1 | ||
| kind: KIND_GPU | ||
| } | ||
| ] | ||
|
|
||
| input [ | ||
| { | ||
| name: "text_input" | ||
| data_type: TYPE_STRING | ||
| dims: [ 1 ] | ||
| }, | ||
| { | ||
| name: "max_tokens" | ||
| data_type: TYPE_INT32 | ||
| dims: [ 1 ] | ||
| optional: true | ||
| }, | ||
| { | ||
| name: "temperature" | ||
| data_type: TYPE_FP32 | ||
| dims: [ 1 ] | ||
| optional: true | ||
| }, | ||
| { | ||
| name: "top_p" | ||
| data_type: TYPE_FP32 | ||
| dims: [ 1 ] | ||
| optional: true | ||
| } | ||
| ] | ||
|
|
||
| output [ | ||
| { | ||
| name: "text_output" | ||
| data_type: TYPE_STRING | ||
| dims: [ 1 ] | ||
| } | ||
| ] | ||
38 changes: 38 additions & 0 deletions
38
5-large-language-models/8-faster-inference-with-triton-tensorrt/download_model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| Download HuggingFace model to persistent storage. | ||
| Only downloads if model doesn't already exist. | ||
ajaykrishnan23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| import os | ||
| from pathlib import Path | ||
| from huggingface_hub import snapshot_download, login | ||
|
|
||
| MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" | ||
| MODEL_DIR = Path("/persistent-storage/models") / MODEL_ID | ||
|
|
||
|
|
||
| def download_model(): | ||
| """Download model from HuggingFace if not already present.""" | ||
| hf_token = os.environ.get("HF_AUTH_TOKEN") | ||
|
|
||
| if not hf_token: | ||
| print("WARNING: HF_AUTH_TOKEN not set, model download may fail") | ||
| return | ||
|
|
||
| if MODEL_DIR.exists() and any(MODEL_DIR.iterdir()): | ||
| print("✓ Model already exists") | ||
| return | ||
|
|
||
| print("Downloading model from HuggingFace...") | ||
| login(token=hf_token) | ||
| snapshot_download( | ||
| MODEL_ID, | ||
| local_dir=str(MODEL_DIR), | ||
| token=hf_token | ||
| ) | ||
| print("✓ Model downloaded successfully") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| download_model() | ||
209 changes: 209 additions & 0 deletions
209
5-large-language-models/8-faster-inference-with-triton-tensorrt/model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| """ | ||
| Triton Python Backend for TensorRT-LLM. | ||
|
|
||
| This module implements a Triton Inference Server Python backend that uses | ||
| TensorRT-LLM's PyTorch backend for optimized LLM inference. | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import triton_python_backend_utils as pb_utils | ||
| import torch | ||
| from tensorrt_llm import LLM, SamplingParams, BuildConfig | ||
| from tensorrt_llm.plugin.plugin import PluginConfig | ||
| from transformers import AutoTokenizer | ||
| from pathlib import Path | ||
|
|
||
| # Model configuration | ||
| MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" | ||
| MODEL_DIR = f"/persistent-storage/models/{MODEL_ID}" | ||
|
|
||
|
|
||
| def ensure_model_downloaded(): | ||
| """Check if model exists, download if not available.""" | ||
| model_path = Path(MODEL_DIR) | ||
|
|
||
| # Check if model directory exists and has content | ||
| if not model_path.exists() or not any(model_path.iterdir()): | ||
| print("Model not found, downloading...") | ||
| try: | ||
| # Import download function from download_model | ||
| from download_model import download_model | ||
| download_model() | ||
| except Exception as e: | ||
| print(f"Error downloading model: {e}") | ||
| raise | ||
| else: | ||
| print("✓ Model already exists") | ||
|
|
||
|
|
||
| class TritonPythonModel: | ||
| """ | ||
| Triton Python Backend model for TensorRT-LLM inference. | ||
|
|
||
| This class handles model initialization, inference requests, and cleanup. | ||
| """ | ||
|
|
||
| def initialize(self, args): | ||
| """ | ||
| Initialize the model - called once when Triton loads the model. | ||
|
|
||
| Loads tokenizer and initializes TensorRT-LLM with PyTorch backend. | ||
| """ | ||
| # Ensure model is downloaded before loading | ||
| ensure_model_downloaded() | ||
|
|
||
| print("Loading tokenizer...") | ||
| self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | ||
|
|
||
| print("Initializing TensorRT-LLM...") | ||
|
|
||
| plugin_config = PluginConfig.from_dict({ | ||
| "paged_kv_cache": True, # Efficient memory usage for KV cache | ||
| }) | ||
|
|
||
| build_config = BuildConfig( | ||
| plugin_config=plugin_config, | ||
| max_input_len=4096, | ||
| max_batch_size=128, # Matches Triton max_batch_size in config.pbtxt | ||
| ) | ||
|
|
||
| self.llm = LLM( | ||
| model=MODEL_DIR, | ||
| build_config=build_config, | ||
| tensor_parallel_size=torch.cuda.device_count(), | ||
| ) | ||
| print("✓ Model ready") | ||
|
|
||
| def execute(self, requests): | ||
| """ | ||
| Execute inference on batched requests. | ||
|
|
||
| Triton automatically batches requests (up to max_batch_size: 32). | ||
| This function processes the batch that Triton provides. | ||
| """ | ||
| try: | ||
| prompts = [] | ||
| sampling_params_list = [] | ||
| original_prompts = [] # Store original prompts to strip from output if needed | ||
|
|
||
| # Extract data from each request in the batch | ||
| for request in requests: | ||
| try: | ||
| # Get input text - handle batched tensor structures | ||
| input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input") | ||
| text_array = input_tensor.as_numpy() | ||
|
|
||
| # Extract text handling different array structures (batched vs non-batched) | ||
| if text_array.ndim == 0: | ||
| # Scalar | ||
| text = text_array.item() | ||
| elif text_array.dtype == object: | ||
| # Object dtype array (common for BYTES/STRING with batching) | ||
| text = text_array.flat[0] if text_array.size > 0 else text_array.item() | ||
| else: | ||
| # Regular array - get first element | ||
| text = text_array.flat[0] if text_array.size > 0 else text_array.item() | ||
|
|
||
| # Decode if bytes, otherwise use as string | ||
| if isinstance(text, bytes): | ||
| text = text.decode('utf-8') | ||
| elif isinstance(text, np.str_): | ||
| text = str(text) | ||
|
|
||
| # Get optional parameters with defaults | ||
| max_tokens = 1024 | ||
| if pb_utils.get_input_tensor_by_name(request, "max_tokens") is not None: | ||
| max_tokens_array = pb_utils.get_input_tensor_by_name(request, "max_tokens").as_numpy() | ||
| max_tokens = int(max_tokens_array.item() if max_tokens_array.ndim == 0 else max_tokens_array.flat[0]) | ||
|
|
||
| temperature = 0.8 | ||
| if pb_utils.get_input_tensor_by_name(request, "temperature") is not None: | ||
| temp_array = pb_utils.get_input_tensor_by_name(request, "temperature").as_numpy() | ||
| temperature = float(temp_array.item() if temp_array.ndim == 0 else temp_array.flat[0]) | ||
|
|
||
| top_p = 0.95 | ||
| if pb_utils.get_input_tensor_by_name(request, "top_p") is not None: | ||
| top_p_array = pb_utils.get_input_tensor_by_name(request, "top_p").as_numpy() | ||
| top_p = float(top_p_array.item() if top_p_array.ndim == 0 else top_p_array.flat[0]) | ||
|
|
||
| # Format prompt using chat template | ||
| prompt = self.tokenizer.apply_chat_template( | ||
| [{"role": "user", "content": text}], | ||
| tokenize=False, | ||
| add_generation_prompt=True | ||
| ) | ||
|
|
||
| prompts.append(prompt) | ||
| original_prompts.append(prompt) # Store for potential stripping | ||
| sampling_params_list.append(SamplingParams( | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| max_tokens=max_tokens, | ||
| )) | ||
| except Exception as e: | ||
| print(f"Error processing request: {e}", flush=True) | ||
| import traceback | ||
| traceback.print_exc() | ||
| # Use default max_tokens instead of 1 to avoid single token output | ||
| prompts.append("") | ||
| original_prompts.append("") | ||
| sampling_params_list.append(SamplingParams(max_tokens=1024)) | ||
|
|
||
| # Batch inference | ||
| if not prompts: | ||
| return [] | ||
|
|
||
| outputs = self.llm.generate(prompts, sampling_params_list) | ||
|
|
||
| # Create responses | ||
| responses = [] | ||
| for i, output in enumerate(outputs): | ||
| try: | ||
| # Extract generated text | ||
| generated_text = output.outputs[0].text | ||
|
|
||
| # Remove the prompt from generated text if it's included | ||
| if original_prompts[i] and original_prompts[i] in generated_text: | ||
| generated_text = generated_text.replace(original_prompts[i], "").strip() | ||
|
|
||
| responses.append(pb_utils.InferenceResponse( | ||
| output_tensors=[pb_utils.Tensor( | ||
| "text_output", | ||
| np.array([generated_text.encode('utf-8')], dtype=object) | ||
| )] | ||
| )) | ||
| except Exception as e: | ||
| print(f"Error creating response {i}: {e}", flush=True) | ||
| responses.append(pb_utils.InferenceResponse( | ||
| output_tensors=[pb_utils.Tensor( | ||
| "text_output", | ||
| np.array([f"Error: {str(e)}".encode('utf-8')], dtype=object) | ||
| )] | ||
| )) | ||
|
|
||
| return responses | ||
|
|
||
| except Exception as e: | ||
| print(f"Error in execute: {e}", flush=True) | ||
| import traceback | ||
| traceback.print_exc() | ||
| # Return error responses | ||
| return [ | ||
| pb_utils.InferenceResponse( | ||
| output_tensors=[pb_utils.Tensor( | ||
| "text_output", | ||
| np.array([f"Batch error: {str(e)}".encode('utf-8')], dtype=object) | ||
| )] | ||
| ) | ||
| for _ in requests | ||
| ] | ||
|
|
||
| def finalize(self): | ||
| """ | ||
| Cleanup when model is being unloaded. | ||
|
|
||
| Shuts down the TensorRT-LLM engine and clears GPU memory. | ||
| """ | ||
| if hasattr(self, 'llm'): | ||
| self.llm.shutdown() | ||
| torch.cuda.empty_cache() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.