-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[TRTC-1921][feat] Add trtllm-configure CLI tool and constraints/profiles schemas #9160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
anish-shanbhag
wants to merge
8
commits into
NVIDIA:main
Choose a base branch
from
anish-shanbhag:ashanbhag/trtllm-configure-cli
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
e3fe37e
[TRTC-1921][feat] Add trtllm-configure CLI tool and scenario/profile …
anish-shanbhag 4788dab
Add max_batch_size to mock config
anish-shanbhag 478ffb2
Address review comments
anish-shanbhag d7cb3ff
Address review comments and refactor to use a profile registry
anish-shanbhag 01fa195
Update test
anish-shanbhag 841807a
Remove debug line
anish-shanbhag 8527251
Change to absolute imports
anish-shanbhag 906c665
Fix docstring
anish-shanbhag File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| from tensorrt_llm.configure.cli import TRTLLMConfigure | ||
|
|
||
|
|
||
| def main(): | ||
| TRTLLMConfigure().run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| from enum import StrEnum | ||
| from pathlib import Path | ||
| from typing import Optional | ||
|
|
||
| import yaml | ||
| from pydantic import AliasChoices, Field, model_validator | ||
| from pydantic_settings import BaseSettings, CliSubCommand, SettingsConfigDict, get_subcommand | ||
|
|
||
| from tensorrt_llm.configure.constraints import BaseConstraints, BenchmarkConstraints | ||
| from tensorrt_llm.configure.profile import PROFILE_REGISTRY | ||
|
|
||
|
|
||
| def generate_subcommand_description(constraints_cls: type[BaseConstraints]) -> str: | ||
| """Generate a description of the subcommand for the given constraints class.""" | ||
| profiles = PROFILE_REGISTRY[constraints_cls] | ||
| description = constraints_cls._get_cli_description() + "\n\n" | ||
|
|
||
| description += ( | ||
| "The --profile flag can be used to specify which profile to use. A profile defines the strategy used to " | ||
| "generate the optimized config. The available profiles are:\n\n" | ||
| ) | ||
|
|
||
| for profile in profiles: | ||
| metadata = profile._get_metadata() | ||
| description += f"- {metadata.cli_name}: {metadata.description}\n" | ||
|
|
||
| return description | ||
|
|
||
|
|
||
| def create_subcommand(constraints_cls: type[BaseConstraints]) -> CliSubCommand: | ||
| """Create a Pydantic CLI subcommand for the given constraints class.""" | ||
| profiles = PROFILE_REGISTRY[constraints_cls] | ||
| default_profile = next(profile for profile in profiles if profile._get_metadata().is_default) | ||
| ProfileEnum = StrEnum( | ||
| "ProfileEnum", | ||
| [profile._get_metadata().cli_name for profile in profiles], | ||
| ) | ||
|
|
||
| class SubCommand(constraints_cls): | ||
| # The docstring is shown as the help message for the subcommand | ||
| __doc__ = generate_subcommand_description(constraints_cls) | ||
|
|
||
| # Common options for all subcommands | ||
| output: Optional[Path] = Field( | ||
| default=None, | ||
| description="YAML file path where the optimized config will be written.", | ||
| validation_alias=AliasChoices("output", "o"), | ||
| ) | ||
|
|
||
| profile: ProfileEnum = Field( | ||
| default=default_profile._get_metadata().cli_name, | ||
| description="Name of the profile to use, which defines the strategy used to generate the optimized config. " | ||
| "See above for a description of the available profiles.", | ||
| ) | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_output(self) -> "SubCommand": | ||
| """Verify that output file is a valid YAML file path and does not already exist.""" | ||
| if self.output is not None: | ||
| if self.output.suffix != ".yaml": | ||
| raise ValueError(f"Output file must be a YAML file. Got '{self.output}'.") | ||
| if self.output.exists(): | ||
| print(f"Output file '{self.output}' already exists, will overwrite it.") | ||
| return self | ||
|
|
||
| def run(self) -> None: | ||
| # Dispatch to the appropriate profile | ||
| profiles = PROFILE_REGISTRY[constraints_cls] | ||
| profile_cls = next( | ||
| profile for profile in profiles if profile._get_metadata().cli_name == self.profile | ||
| ) | ||
| config = profile_cls().get_config(self) | ||
| print(f"Found optimized config: \n\n{yaml.safe_dump(config)}") | ||
|
|
||
| if self.output is None: | ||
| print( | ||
| "No output file specified. To write the optimized config to a file, use the --output / -o flag." | ||
| ) | ||
| else: | ||
| with open(self.output, "w") as f: | ||
| f.write(yaml.safe_dump(config)) | ||
| print(f"Optimized config written to {self.output}") | ||
| print("To serve the model with optimized settings, run the following command:\n") | ||
| print(f"trtllm-serve {self.model} --config {self.output}") | ||
|
|
||
| return CliSubCommand[SubCommand] | ||
|
|
||
|
|
||
| BenchmarkSubCommand = create_subcommand(BenchmarkConstraints) | ||
| # TODO: add support for throughput/latency subcommand | ||
| # ThroughputLatencySubCommand = create_subcommand(ThroughputLatencyConstraints) | ||
|
|
||
|
|
||
| class TRTLLMConfigure(BaseSettings): | ||
| # The docstring below is used to generate the CLI help message | ||
| """The trtllm-configure CLI tool allows you to optimize the configuration of TensorRT LLM for your specific | ||
| inference scenario. | ||
| """ # noqa: D205 | ||
|
|
||
| model_config = SettingsConfigDict( | ||
| cli_parse_args=True, | ||
| cli_prog_name="trtllm-configure", | ||
| cli_enforce_required=True, # Make required fields enforced at CLI level | ||
| cli_implicit_flags=True, # Boolean fields will be exposed as e.g. --flag and --no-flag | ||
| cli_avoid_json=True, # Do not expose JSON string options for nested models | ||
| ) | ||
|
|
||
| benchmark: BenchmarkSubCommand = Field(description=BenchmarkConstraints._get_cli_description()) | ||
| # TODO: add support for throughput/latency SLA subcommand | ||
| # throughput_latency: CliSubCommand[ThroughputLatencySubCommand] = Field( | ||
| # description=ThroughputLatencySubCommand.__doc__ | ||
| # ) | ||
|
|
||
| def run(self) -> None: | ||
| """Main entrypoint for the trtllm-configure CLI tool.""" | ||
| subcommand = get_subcommand(self) | ||
| subcommand.run() | ||
|
|
||
|
|
||
| def main(): | ||
| TRTLLMConfigure().run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| from abc import ABC, abstractmethod | ||
| from enum import StrEnum | ||
| from typing import Optional | ||
|
|
||
| from pydantic import BaseModel, Field, NonNegativeInt, PositiveFloat, PositiveInt, model_validator | ||
|
|
||
|
|
||
| class GPU(StrEnum): | ||
| GB200 = "GB200" | ||
| H200_SXM = "H200_SXM" | ||
|
|
||
|
|
||
| class BaseConstraints(BaseModel, ABC): | ||
| """Base class for all constraints containing common fields. | ||
|
|
||
| A set of constraints fully defines the requirements for a specific inference workload | ||
| and the goals for optimization (e.g. SLA targets). | ||
| """ | ||
|
|
||
| model: str = Field(description="HuggingFace ID of the model being deployed") | ||
| gpu: GPU = Field(description="GPU SKU used in the deployment") | ||
| num_gpus: PositiveInt = Field(description="Number of GPUs available in the deployment") | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def _get_cli_description(cls) -> str: | ||
| """Get a description of the constraints which will be shown in CLI help messages.""" | ||
|
|
||
|
|
||
| class BenchmarkConstraints(BaseConstraints): | ||
| isl: NonNegativeInt = Field(description="Target input sequence length") | ||
| osl: NonNegativeInt = Field(description="Target output sequence length") | ||
| concurrency: PositiveInt = Field(description="Target number of concurrent requests") | ||
| # TODO: make this optional and add logic to choose best parallelization mapping automatically | ||
| tp_size: PositiveInt = Field(description="Specific tensor parallel size that should be used") | ||
|
|
||
| @classmethod | ||
| def _get_cli_description(cls) -> str: | ||
| return "Optimize TensorRT LLM for a benchmark workload with a specific number of concurrent requests." | ||
|
|
||
|
|
||
| class ThroughputLatencyConstraints(BaseConstraints): | ||
| tps_per_gpu: Optional[PositiveFloat] = Field( | ||
| default=None, | ||
| description="Target minimum throughput per GPU in tokens per second", | ||
| ) | ||
| tps_per_user: Optional[PositiveFloat] = Field( | ||
| default=None, description="Target minimum throughput per user in tokens per second." | ||
| ) | ||
| ttft: Optional[PositiveFloat] = Field( | ||
| default=None, | ||
| description="Target maximum time to first token in seconds.", | ||
| ) | ||
|
|
||
| @classmethod | ||
| def _get_cli_description(cls) -> str: | ||
| return "Optimize TensorRT LLM to meet a throughput and/or latency SLA." | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_has_at_least_one_constraint(self) -> "ThroughputLatencyConstraints": | ||
| if not any([self.tps_per_gpu, self.tps_per_user, self.ttft]): | ||
| raise ValueError( | ||
| "At least one of target throughput per GPU, target throughput per user, or target time to first token " | ||
| "must be specified." | ||
| ) | ||
| return self |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| from abc import ABC, abstractmethod | ||
| from collections import defaultdict | ||
| from typing import Any | ||
|
|
||
| from pydantic import BaseModel, Field | ||
|
|
||
| from tensorrt_llm.configure.constraints import BaseConstraints, BenchmarkConstraints | ||
|
|
||
|
|
||
| class ProfileMetadata(BaseModel): | ||
| """Metadata about a profile.""" | ||
|
|
||
| constraints_cls: type[BaseConstraints] = Field( | ||
| description="The constraints class that this profile is compatible with." | ||
| ) | ||
| cli_name: str = Field(description="The name of the profile as it will be exposed in the CLI.") | ||
| description: str = Field( | ||
| description="A description of the profile which will be shown in the CLI help message." | ||
| ) | ||
| is_default: bool = Field( | ||
| default=False, | ||
| description="Whether this profile is the default profile for the constraints class. There can only be one " | ||
| "default profile per constraints class.", | ||
| ) | ||
|
|
||
|
|
||
| class BaseProfile(ABC): | ||
| """Base class for all profiles. | ||
|
|
||
| A profile defines a particular strategy used to find an optimized config for a given set of constraints | ||
| (e.g. database lookup, heuristics, etc.) | ||
|
|
||
| Each profile is compatible with a specific type of constraints. | ||
| """ | ||
|
|
||
| @classmethod | ||
| @abstractmethod | ||
| def _get_metadata(cls) -> ProfileMetadata: | ||
| """Get the metadata associated with this profile.""" | ||
|
|
||
| @abstractmethod | ||
| def get_config(self, constraints: BaseConstraints) -> dict[str, Any]: | ||
| """Retrieve or generate the optimal config for the given constraints.""" | ||
|
|
||
| def __init_subclass__(cls, **kwargs): | ||
| super().__init_subclass__(**kwargs) | ||
|
|
||
| # Validate that there is only one default profile per constraints class | ||
| metadata = cls._get_metadata() | ||
| if metadata.is_default: | ||
| for other_profile in PROFILE_REGISTRY[metadata.constraints_cls]: | ||
| other_metadata = other_profile._get_metadata() | ||
| if other_metadata.is_default: | ||
| raise ValueError( | ||
| f"Multiple default profiles found for constraints class {metadata.constraints_cls.__name__}: " | ||
| f"{other_profile.__name__} and {cls.__name__}" | ||
| ) | ||
|
|
||
| # Add this class to the profile registry | ||
| PROFILE_REGISTRY[metadata.constraints_cls].append(cls) | ||
|
|
||
|
|
||
| # Maps constraints classes to the list of profiles compatible with those constraints | ||
| PROFILE_REGISTRY: defaultdict[type[BaseConstraints], list[type[BaseProfile]]] = defaultdict(list) | ||
|
|
||
|
|
||
| class InferenceMaxProfile(BaseProfile): | ||
| @classmethod | ||
| def _get_metadata(cls) -> ProfileMetadata: | ||
| return ProfileMetadata( | ||
| constraints_cls=BenchmarkConstraints, | ||
| cli_name="inferencemax", | ||
| description=( | ||
| "Retrieve optimized settings from a database of configs used for SemiAnalysis InferenceMax " | ||
| "benchmarks." | ||
| ), | ||
| is_default=True, | ||
| ) | ||
|
|
||
| def get_config(self, constraints: BenchmarkConstraints) -> dict[str, Any]: | ||
| # TODO: add logic to retrieve optimal config from database | ||
| return {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| from pathlib import Path | ||
| from unittest.mock import patch | ||
|
|
||
| import yaml | ||
|
|
||
| from tensorrt_llm.configure.cli import BenchmarkSubCommand, TRTLLMConfigure | ||
| from tensorrt_llm.configure.profile import InferenceMaxProfile | ||
|
|
||
|
|
||
| def test_trtllm_configure_subcommand_basic(tmp_path: Path): | ||
| output_path = tmp_path / "test_config.yaml" | ||
|
|
||
| mock_config = { | ||
| "max_batch_size": 1, | ||
| "kv_cache_free_gpu_memory_fraction": 0.9, | ||
| } | ||
|
|
||
| cmd = BenchmarkSubCommand( | ||
| model="meta-llama/Llama-3.1-8B", | ||
| gpu="H200_SXM", | ||
| num_gpus=1, | ||
| isl=1000, | ||
| osl=2000, | ||
| concurrency=64, | ||
| tp_size=1, | ||
| output=output_path, | ||
| profile="inferencemax", | ||
| ) | ||
|
|
||
| trtllm_configure = TRTLLMConfigure(benchmark=cmd) | ||
|
|
||
| # Mock get_config to return our mock config | ||
| with patch.object(InferenceMaxProfile, "get_config", return_value=mock_config): | ||
| trtllm_configure.run() | ||
|
|
||
| assert output_path.exists() | ||
| with open(output_path, "r") as f: | ||
| loaded_config = yaml.safe_load(f) | ||
|
|
||
| assert "max_batch_size" in loaded_config | ||
| assert loaded_config["max_batch_size"] == 1 | ||
| assert "kv_cache_free_gpu_memory_fraction" in loaded_config | ||
| assert loaded_config["kv_cache_free_gpu_memory_fraction"] == 0.9 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.