Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions docs/source/developer-guide/api-change.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ TensorRT LLM classifies APIs into two categories:
All API schemas are:
- Stored as YAML files in the codebase
- Protected by unit tests in `tests/unittest/api_stability/`
- Automatically validated to ensure consistency
- Automatically validated to ensure consistency

## API Change Principles

Expand All @@ -44,22 +44,26 @@ All API schemas are:

Argument names should describe what the argument represents, not how it is used internally.

✅ **Good**: `max_new_tokens` (clear meaning)
✅ **Good**: `max_new_tokens` (clear meaning)

❌ **Bad**: `num` (ambiguous)

**Reflect Argument Type and Granularity**

- For **boolean** knobs, prefix with verbs like `enable_` and so on.

Examples: `enable_cache`, `enable_flash_attention`

- For **numerical threshold** knobs, suffix with `_limit`, `_size`, `_count`, `_len_` or `_ratio`
- For **numerical threshold** knobs, suffix with `_limit`, `_size`, `_count`, `_len_` or `_ratio`

Examples: `max_seq_len`, `prefill_batch_size`

**Avoid Redundant Prefixes**

Example (in `MoeConfig`):

✅ **Good**: `backend`
✅ **Good**: `backend`

❌ **Bad**: `moe_backend` (redundant since it's already in `MoeConfig`)

**Use Specific Names for Narrow Scenarios**
Expand All @@ -68,7 +72,8 @@ When adding knobs for specific use cases, make the name convey the restriction c

Example (argument to the LLM class):

✅ **Good**: `rope_scaling_factor` → clearly indicates it's for RoPE
✅ **Good**: `rope_scaling_factor` → clearly indicates it's for RoPE

❌ **Bad**: `scaling_factor` → too generic and prone to misuse

### 2. Hierarchical Configuration
Expand All @@ -77,13 +82,16 @@ Organize complex or hierarchical arguments into **dedicated configuration datacl

**Guidelines**

- Use the `XxxConfig` suffix consistently
- Use the `XxxConfig` suffix consistently

Examples: `ModelConfig`, `ParallelConfig`, `MoeConfig`

- **Reflect conceptual hierarchy**

- **Reflect conceptual hierarchy**

The dataclass name should represent a coherent functional unit, not an arbitrary grouping

- **Avoid over-nesting**

- **Avoid over-nesting**

Use only one level of configuration hierarchy whenever possible (e.g., `LlmArgs → ParallelConfig`) to balance readability and modularity

### 3. Prefer `LlmArgs` Over Environment Variables
Expand Down Expand Up @@ -154,15 +162,15 @@ garbage_collection_gen0_threshold: int = Field(

Add the field to the appropriate schema file:

- **Non-committed arguments**: `tests/unittest/api_stability/references/llm_args.yaml`
- **Non-committed arguments**: `tests/unittest/api_stability/references/llm.yaml`
```yaml
garbage_collection_gen0_threshold:
type: int
default: 20000
status: beta # Must match the status in code
```

- **Committed arguments**: `tests/unittest/api_stability/references_committed/llm_args.yaml`
- **Committed arguments**: `tests/unittest/api_stability/references_committed/llm.yaml`
```yaml
garbage_collection_gen0_threshold:
type: int
Expand Down Expand Up @@ -196,16 +204,16 @@ For non-committed APIs, use the `@set_api_status` decorator:
```python
@set_api_status("beta")
def generate_with_streaming(
self,
prompts: List[str],
self,
prompts: List[str],
**kwargs
) -> Iterator[GenerationOutput]:
"""Generate text with streaming output.

Args:
prompts: Input prompts for generation
**kwargs: Additional generation parameters

Returns:
Iterator of generation outputs
"""
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ class _ParallelConfig(StrictBaseModel):
moe_tp_size: int = -1
moe_ep_size: int = -1
cp_config: dict = Field(default_factory=dict)
pp_partition: Optional[List[int]] = Field(default=None)
enable_attention_dp: bool = False
enable_lm_head_tp_in_adp: bool = False

Expand Down Expand Up @@ -372,6 +373,7 @@ def to_mapping(self) -> Mapping:
gpus_per_node=self.gpus_per_node,
tp_size=self.tp_size,
pp_size=self.pp_size,
pp_partition=self.pp_partition,
cp_size=self.cp_size,
cp_config=self.cp_config,
enable_attention_dp=self.enable_attention_dp,
Expand Down Expand Up @@ -1587,6 +1589,12 @@ class BaseLlmArgs(StrictBaseModel):
description="Enable LM head TP in attention dp.",
status="prototype")

pp_partition: Optional[List[int]] = Field(
default=None,
description=
"Pipeline parallel partition, a list of each rank's layer number.",
status="prototype")

cp_config: Optional[dict] = Field(default_factory=dict,
description="Context parallel config.",
status="prototype")
Expand Down Expand Up @@ -1843,6 +1851,7 @@ def validate_parallel_config(self):
moe_ep_size=self.moe_expert_parallel_size,
enable_attention_dp=self.enable_attention_dp,
enable_lm_head_tp_in_adp=self.enable_lm_head_tp_in_adp,
pp_partition=self.pp_partition,
cp_config=self.cp_config)
return self

Expand Down
23 changes: 20 additions & 3 deletions tensorrt_llm/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
cp_config=None,
tp_size=1,
pp_size=1,
pp_partition=None,
moe_cluster_size=-1, # -1 means no moe
moe_tp_size=-1, # -1 means no moe
moe_ep_size=-1, # -1 means no moe
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(
self.cp_size = cp_size
self.cp_config = cp_config if cp_config is not None else {}
self.pp_size = pp_size
self.pp_partition = pp_partition
self.moe_tp_size = moe_tp_size
self.moe_ep_size = moe_ep_size
self.moe_cluster_size = moe_cluster_size
Expand Down Expand Up @@ -156,6 +158,7 @@ def __eq__(self, other):
and self.tp_size == other.tp_size
and self.moe_cluster_size == other.moe_cluster_size
and self.pp_size == other.pp_size
and self.pp_partition == other.pp_partition
and self.moe_tp_size == other.moe_tp_size
and self.moe_ep_size == other.moe_ep_size
and self.attn_tp_size == other.attn_tp_size
Expand All @@ -177,6 +180,7 @@ def __hash__(self):
self.attn_cp_size,
# note: we do not allow updating cp_config after initialization
tuple(sorted(self.cp_config.items())),
tuple(self.pp_partition) if self.pp_partition is not None else (),
))

@property
Expand Down Expand Up @@ -299,9 +303,20 @@ def has_moe_ep(self):
return self.moe_ep_size > 1

def pp_layers(self, num_layers: int) -> List[int]:
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
return torch.tensor_split(torch.arange(num_layers),
self.pp_size)[self.pp_rank].tolist()
if self.pp_partition is not None:
if len(self.pp_partition) != self.pp_size:
raise ValueError(
f"{len(self.pp_partition)=} does not match {self.pp_size=}."
)
if sum(self.pp_partition) != num_layers:
raise ValueError(
f"{sum(self.pp_partition)=} does not match {num_layers=}.")
return torch.arange(num_layers).split(
self.pp_partition)[self.pp_rank].tolist()
else:
# If num_layers % pp_size = n != 0, first n ranks get one extra layer
return torch.tensor_split(torch.arange(num_layers),
self.pp_size)[self.pp_rank].tolist()

def ep_experts(self, num_experts: int) -> List[int]:
assert self.cp_size == 1
Expand Down Expand Up @@ -446,6 +461,7 @@ def __init__(
cp_config=None,
tp_size=1,
pp_size=1,
pp_partition=None,
moe_cluster_size=-1, # -1 means no moe
moe_tp_size=-1, # -1 means no moe
moe_ep_size=-1, # -1 means no moe
Expand All @@ -460,6 +476,7 @@ def __init__(
cp_config=cp_config,
tp_size=tp_size,
pp_size=pp_size,
pp_partition=pp_partition,
moe_cluster_size=moe_cluster_size,
moe_tp_size=moe_tp_size,
moe_ep_size=moe_ep_size,
Expand Down
8 changes: 8 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,13 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
torch_compile):
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")

if pp_size > 1 and mtp_nextn > 0:
num_hidden_layers = 30
pp_partition = [num_hidden_layers // pp_size + 1] * pp_size
pp_partition[-1] = num_hidden_layers - sum(pp_partition[:-1])
else:
pp_partition = None
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
Expand All @@ -1307,6 +1314,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
with LLM(self.MODEL_PATH,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
pp_partition=pp_partition,
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
Expand Down
4 changes: 4 additions & 0 deletions tests/unittest/api_stability/references/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ methods:
annotation: Optional[dict]
default: null
status: prototype
pp_partition:
annotation: Optional[List[int]]
default: null
status: prototype
# Stats
iter_stats_max_iterations:
annotation: Optional[int]
Expand Down