Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
FROM nvcr.io/nvidia/tritonserver:25.10-trtllm-python-py3

# Environment variables
ENV PYTHONPATH=/usr/local/lib/python3.12/dist-packages:$PYTHONPATH
ENV PYTHONDONTWRITEBYTECODE=1
ENV DEBIAN_FRONTEND=noninteractive
ENV HF_HOME=/persistent-storage/models
ENV TORCH_CUDA_ARCH_LIST=8.6

# Install system dependencies
RUN apt-get update && apt-get install -y \
git \
git-lfs \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app

# Install Python dependencies
RUN pip install --break-system-packages \
huggingface_hub \
transformers \
|| true

# Create required directories
RUN mkdir -p \
/app/model_repository/llama3_2/1 \
/persistent-storage/models \
/persistent-storage/engines

# Copy application files
COPY --chmod=755 download_model.py /app/
COPY model.py /app/model_repository/llama3_2/1/
COPY config.pbtxt /app/model_repository/llama3_2/

# Expose Triton ports
EXPOSE 8000 8001 8002

# Start Triton server directly
CMD ["tritonserver", "--model-repository=/app/model_repository", "--http-port=8000", "--grpc-port=8001", "--metrics-port=8002"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
[cerebrium.deployment]
name = "tensorrt-triton-demo"
python_version = "3.12"
disable_auth = true
include = ['./*', 'cerebrium.toml']
exclude = ['.*']
deployment_initialization_timeout = 830

[cerebrium.hardware]
cpu = 4.0
memory = 40.0
compute = "AMPERE_A10"
gpu_count = 1
provider = "aws"
region = "us-east-1"

[cerebrium.scaling]
min_replicas = 1
max_replicas = 5
cooldown = 300
replica_concurrency = 128
scaling_metric = "concurrency_utilization"

[cerebrium.dependencies.pip]
huggingface_hub = "latest"
transformers = "latest"

[cerebrium.runtime.custom]
port = 8000
healthcheck_endpoint = "/v2/health/live"
readycheck_endpoint = "/v2/health/ready"
dockerfile_path = "./Dockerfile"
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: "llama3_2"
backend: "python"
max_batch_size: 128

dynamic_batching {
max_queue_delay_microseconds: 800
}

instance_group [
{
count: 1
kind: KIND_GPU
}
]

input [
{
name: "text_input"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "max_tokens"
data_type: TYPE_INT32
dims: [ 1 ]
optional: true
},
{
name: "temperature"
data_type: TYPE_FP32
dims: [ 1 ]
optional: true
},
{
name: "top_p"
data_type: TYPE_FP32
dims: [ 1 ]
optional: true
}
]

output [
{
name: "text_output"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
"""
Download HuggingFace model to persistent storage.
Only downloads if model doesn't already exist.
"""

import os
from pathlib import Path
from huggingface_hub import snapshot_download, login

MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_DIR = Path("/persistent-storage/models") / MODEL_ID


def download_model():
"""Download model from HuggingFace if not already present."""
hf_token = os.environ.get("HF_AUTH_TOKEN")

if not hf_token:
print("WARNING: HF_AUTH_TOKEN not set, model download may fail")
return

if MODEL_DIR.exists() and any(MODEL_DIR.iterdir()):
print("✓ Model already exists")
return

print("Downloading model from HuggingFace...")
login(token=hf_token)
snapshot_download(
MODEL_ID,
local_dir=str(MODEL_DIR),
token=hf_token
)
print("✓ Model downloaded successfully")


if __name__ == "__main__":
download_model()
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""
Triton Python Backend for TensorRT-LLM.

This module implements a Triton Inference Server Python backend that uses
TensorRT-LLM's PyTorch backend for optimized LLM inference.
"""

import numpy as np
import triton_python_backend_utils as pb_utils
import torch
from tensorrt_llm import LLM, SamplingParams, BuildConfig
from tensorrt_llm.plugin.plugin import PluginConfig
from transformers import AutoTokenizer
from pathlib import Path

# Model configuration
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_DIR = f"/persistent-storage/models/{MODEL_ID}"


def ensure_model_downloaded():
"""Check if model exists, download if not available."""
model_path = Path(MODEL_DIR)

# Check if model directory exists and has content
if not model_path.exists() or not any(model_path.iterdir()):
print("Model not found, downloading...")
try:
# Import download function from download_model
from download_model import download_model
download_model()
except Exception as e:
print(f"Error downloading model: {e}")
raise
else:
print("✓ Model already exists")


class TritonPythonModel:
"""
Triton Python Backend model for TensorRT-LLM inference.

This class handles model initialization, inference requests, and cleanup.
"""

def initialize(self, args):
"""
Initialize the model - called once when Triton loads the model.

Loads tokenizer and initializes TensorRT-LLM with PyTorch backend.
"""
# Ensure model is downloaded before loading
ensure_model_downloaded()

print("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

print("Initializing TensorRT-LLM...")

plugin_config = PluginConfig.from_dict({
"paged_kv_cache": True, # Efficient memory usage for KV cache
})

build_config = BuildConfig(
plugin_config=plugin_config,
max_input_len=4096,
max_batch_size=128, # Matches Triton max_batch_size in config.pbtxt
)

self.llm = LLM(
model=MODEL_DIR,
build_config=build_config,
tensor_parallel_size=torch.cuda.device_count(),
)
print("✓ Model ready")

def execute(self, requests):
"""
Execute inference on batched requests.

Triton automatically batches requests (up to max_batch_size: 32).
This function processes the batch that Triton provides.
"""
try:
prompts = []
sampling_params_list = []
original_prompts = [] # Store original prompts to strip from output if needed

# Extract data from each request in the batch
for request in requests:
try:
# Get input text - handle batched tensor structures
input_tensor = pb_utils.get_input_tensor_by_name(request, "text_input")
text_array = input_tensor.as_numpy()

# Extract text handling different array structures (batched vs non-batched)
if text_array.ndim == 0:
# Scalar
text = text_array.item()
elif text_array.dtype == object:
# Object dtype array (common for BYTES/STRING with batching)
text = text_array.flat[0] if text_array.size > 0 else text_array.item()
else:
# Regular array - get first element
text = text_array.flat[0] if text_array.size > 0 else text_array.item()

# Decode if bytes, otherwise use as string
if isinstance(text, bytes):
text = text.decode('utf-8')
elif isinstance(text, np.str_):
text = str(text)

# Get optional parameters with defaults
max_tokens = 1024
if pb_utils.get_input_tensor_by_name(request, "max_tokens") is not None:
max_tokens_array = pb_utils.get_input_tensor_by_name(request, "max_tokens").as_numpy()
max_tokens = int(max_tokens_array.item() if max_tokens_array.ndim == 0 else max_tokens_array.flat[0])

temperature = 0.8
if pb_utils.get_input_tensor_by_name(request, "temperature") is not None:
temp_array = pb_utils.get_input_tensor_by_name(request, "temperature").as_numpy()
temperature = float(temp_array.item() if temp_array.ndim == 0 else temp_array.flat[0])

top_p = 0.95
if pb_utils.get_input_tensor_by_name(request, "top_p") is not None:
top_p_array = pb_utils.get_input_tensor_by_name(request, "top_p").as_numpy()
top_p = float(top_p_array.item() if top_p_array.ndim == 0 else top_p_array.flat[0])

# Format prompt using chat template
prompt = self.tokenizer.apply_chat_template(
[{"role": "user", "content": text}],
tokenize=False,
add_generation_prompt=True
)

prompts.append(prompt)
original_prompts.append(prompt) # Store for potential stripping
sampling_params_list.append(SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
))
except Exception as e:
print(f"Error processing request: {e}", flush=True)
import traceback
traceback.print_exc()
# Use default max_tokens instead of 1 to avoid single token output
prompts.append("")
original_prompts.append("")
sampling_params_list.append(SamplingParams(max_tokens=1024))

# Batch inference
if not prompts:
return []

outputs = self.llm.generate(prompts, sampling_params_list)

# Create responses
responses = []
for i, output in enumerate(outputs):
try:
# Extract generated text
generated_text = output.outputs[0].text

# Remove the prompt from generated text if it's included
if original_prompts[i] and original_prompts[i] in generated_text:
generated_text = generated_text.replace(original_prompts[i], "").strip()

responses.append(pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor(
"text_output",
np.array([generated_text.encode('utf-8')], dtype=object)
)]
))
except Exception as e:
print(f"Error creating response {i}: {e}", flush=True)
responses.append(pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor(
"text_output",
np.array([f"Error: {str(e)}".encode('utf-8')], dtype=object)
)]
))

return responses

except Exception as e:
print(f"Error in execute: {e}", flush=True)
import traceback
traceback.print_exc()
# Return error responses
return [
pb_utils.InferenceResponse(
output_tensors=[pb_utils.Tensor(
"text_output",
np.array([f"Batch error: {str(e)}".encode('utf-8')], dtype=object)
)]
)
for _ in requests
]

def finalize(self):
"""
Cleanup when model is being unloaded.

Shuts down the TensorRT-LLM engine and clears GPU memory.
"""
if hasattr(self, 'llm'):
self.llm.shutdown()
torch.cuda.empty_cache()