diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index 6139131e478..67d83c3c43f 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -13,7 +13,8 @@ from ..utils import (get_model_extra_attrs, get_per_request_piecewise_cuda_graph_flag, - get_piecewise_cuda_graph_flag, make_weak_ref) + get_piecewise_cuda_graph_flag, make_weak_ref, + set_piecewise_running) from .multi_stream.auto_multi_stream import multi_stream_schedule from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function @@ -27,6 +28,7 @@ def __init__( compile_time_num_tokens: Union[int | torch.SymInt], capture_num_tokens: list[int], exclude_modules_id: list[int], + piecewise_runner_num: int, graph_pool_handle: tuple[int, int], garbage_collect_values: bool = True, graph=None, @@ -38,6 +40,8 @@ def __init__( self.compile_time_num_tokens = compile_time_num_tokens self.capture_num_tokens = capture_num_tokens + self.piecewise_runner_num = piecewise_runner_num + self.piecewise_runner_idx = 0 self.exclude_modules = [f"submod_{i}" for i in exclude_modules_id] self.graph_pool_handle = graph_pool_handle self.enable_inductor = enable_inductor @@ -90,8 +94,10 @@ def call_module(self, target, args, kwargs): self.graph_pool_handle, compile_fx(submod, args) if self.enable_inductor else submod, self.enable_inductor, + self.piecewise_runner_idx == 0, + self.piecewise_runner_idx == self.piecewise_runner_num - 1, ) - + self.piecewise_runner_idx += 1 return output @@ -124,6 +130,8 @@ def __init__( graph_pool_handle, default_callable: Callable, enable_inductor: bool, + is_first_runner: bool, + is_last_runner: bool, ): if runtime_num_tokens_idx != None: assert isinstance(compile_time_num_tokens, torch.SymInt) @@ -138,6 +146,8 @@ def __init__( self.enable_inductor = enable_inductor self.entries: dict[int, Entry] = {} + self.is_first_runner = is_first_runner + self.is_last_runner = is_last_runner for num_tokens in capture_num_tokens: self.entries[num_tokens] = Entry( @@ -161,6 +171,12 @@ def __call__(self, *args): or not get_per_request_piecewise_cuda_graph_flag()): return self.default_callable(*args) + if self.is_first_runner or self.is_last_runner: + if self.is_first_runner == self.is_last_runner: + set_piecewise_running(False) + else: + set_piecewise_running(self.is_first_runner) + entry = self.entries[runtime_num_of_token] if entry.enable_inductor and not entry.compiled: @@ -267,6 +283,7 @@ def piecewise_optimizer( input_num_tokens, capture_num_tokens, exclude_modules_id, + len(set(node_to_graph_id.values())) - len(exclude_modules_id), graph_pool_handle, max_num_streams=max_num_streams, ) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index a7f0a82eaf1..e9181b422f3 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -23,7 +23,7 @@ from ..model_config import ModelConfig from ..peft.lora.layer import LoraLayer, LoraModuleType from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs, - is_torch_compiling) + is_piecewise_running, is_torch_compiling) from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig from .multi_stream_utils import maybe_execute_in_parallel from .rms_norm import RMSNorm @@ -76,13 +76,24 @@ def extract_extra_attrs(layer_idx: str, attn_type: str): return metadata, attn_layer -@torch.compile -def compiled_copy_(dst, src): +def maybe_compile(func): + + def wrapper(*args, **kwargs): + if is_piecewise_running(): + # When piecewise running, we don't need to compile the function to avoid host overhead in attention op. + return func(*args, **kwargs) + return torch.compile(func)(*args, **kwargs) + + return wrapper + + +@maybe_compile +def maybe_compiled_copy_(dst, src): dst.copy_(src) -@torch.compile -def compiled_cat(tensors, dim): +@maybe_compile +def maybe_compiled_cat(tensors, dim): return torch.cat(tensors, dim) @@ -1222,8 +1233,9 @@ def forward_context_default( ) k = torch.empty_like(q).view(-1, self.num_heads, self.qk_head_dim) - compiled_copy_(k[..., :self.qk_nope_head_dim], - k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)) + maybe_compiled_copy_( + k[..., :self.qk_nope_head_dim], + k_nope.view(-1, self.num_heads, self.qk_nope_head_dim)) if self.apply_rotary_emb: k[..., self.qk_nope_head_dim:] = k_pe.view(-1, 1, self.qk_rope_head_dim) @@ -1317,7 +1329,7 @@ def forward_context_with_cached_kv( full_k_nope = full_k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) full_k_pe = full_k_pe.view(-1, 1, self.qk_rope_head_dim) - full_k = compiled_cat( + full_k = maybe_compiled_cat( (full_k_nope, full_k_pe.expand(-1, self.num_heads, -1)), dim=-1) full_k = full_k.view(-1, self.num_heads * self.qk_head_dim) @@ -1412,7 +1424,7 @@ def forward_context_with_chunked_prefill( chunked_k_nope = chunked_k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) chunked_k_pe = chunked_k_pe.view(-1, 1, self.qk_rope_head_dim) - chunked_k = compiled_cat( + chunked_k = maybe_compiled_cat( (chunked_k_nope, chunked_k_pe.expand(-1, self.num_heads, -1)), dim=-1) chunked_k = chunked_k.view(-1, self.num_heads * self.qk_head_dim) @@ -1470,7 +1482,8 @@ def forward_context_with_chunked_prefill( k_nope = k_nope.view(-1, self.num_heads, self.qk_nope_head_dim) k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim) - k = compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), dim=-1) + k = maybe_compiled_cat((k_nope, k_pe.expand(-1, self.num_heads, -1)), + dim=-1) k = k.view(-1, self.num_heads * self.qk_head_dim) # copy q_lens to replace kv_lens_runtime diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index de586e3adba..525f9f86f97 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -12,6 +12,7 @@ from tensorrt_llm.quantization.utils import fp4_utils is_torch_compiling_flag = False +is_piecewise_running_flag = False aux_stream_name_list = [ 'Attention', @@ -40,6 +41,16 @@ def is_torch_compiling() -> bool: return is_torch_compiling_flag +def set_piecewise_running(enable: bool): + global is_piecewise_running_flag + is_piecewise_running_flag = enable + + +def is_piecewise_running() -> bool: + global is_piecewise_running_flag + return is_piecewise_running_flag + + _global_attrs = threading.local()