Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 59 additions & 38 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class TuningConfig:
constraint_specs: Tuple[ConstraintSpec, ...] = ()
tune_max_num_tokens: int = None
inputs_pre_hook: Callable = None
use_cuda_graph: bool = False
use_cuda_graph: bool = True


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -526,7 +526,7 @@ class AutoTuner:
_CUDA_GRAPH_DELAY_MICRO_SECS = 100
_instance = None

def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
self.repeat = repeat
self.warmup = warmup
self.stream_delay_micro_secs = stream_delay_micro_secs
Expand Down Expand Up @@ -698,23 +698,25 @@ def choose_one(
})

input_shapes = tuple(self._get_input_sizes(inputs))
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
custom_op, runners, input_shapes, tuning_config)

# Early return if it's not tuning, use cache found one or fallback one
if not self.is_tuning_mode:
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
custom_op, runners, input_shapes, tuning_config)
best_runner = runners[best_runner_id]
# TODO: check the stored runner and tactic can implement this shape here
# Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf.

# Record the cache miss config.
# Expect no cache miss in inference. Thus, any cache miss should be recorded.
# Log the cache miss. Expect no cache miss in inference.
if not is_cache_hit:
logger.warning_once(
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
key=(custom_op, "warning_autotuning_cache_miss_fallback"))

return (best_runner, best_tactic)

# If it's tuning mode and cache hit, return the best runner and tactic to avoid redundant profiling.
if self.is_tuning_mode and is_cache_hit:
return (runners[best_runner_id], best_tactic)

assert len(runners) > 0, "At least one runner is required"
assert all([isinstance(r, TunableRunner) for r in runners]), \
"All Given runners must be subclass of TunableRunner"
Expand Down Expand Up @@ -881,43 +883,62 @@ def _profile_single_kernel(
are used to ensure accurate timing.
"""
stream = torch.cuda.current_stream()
graph = torch.cuda.CUDAGraph()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

with torch.cuda.stream(stream):
# warm up, no timing
for _ in range(self.warmup):
runner(inputs, tactic=tactic, **kwargs)

if use_cuda_graph:
with torch.cuda.graph(graph):
for _ in range(self.repeat):
runner(inputs, tactic=tactic, **kwargs)
# If the warm up time is longer than 0.5ms, we will profile the kernel with fewer repeats.
profile_fewer_repeat = 2
short_profile_threshold_ms = 1

avg_time = float('inf')

def pure_profile(stream: torch.cuda.Stream, repeat: int):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
graph = torch.cuda.CUDAGraph()

with torch.cuda.stream(stream):
if use_cuda_graph:
with torch.cuda.graph(graph):
for _ in range(repeat):
runner(inputs, tactic=tactic, **kwargs)

stream.synchronize()

# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
if use_cuda_graph:
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
else:
delay_kernel(self.stream_delay_micro_secs, stream)

stream.synchronize()
start.record()

# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
if use_cuda_graph:
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
else:
delay_kernel(self.stream_delay_micro_secs, stream)
if use_cuda_graph:
graph.replay()
else:
for _ in range(repeat):
runner(inputs, tactic=tactic, **kwargs)

start.record()
end.record()
stream.synchronize()

if use_cuda_graph:
graph.replay()
else:
for _ in range(self.repeat):
runner(inputs, tactic=tactic, **kwargs)
return start.elapsed_time(end) / repeat

end.record()
for _ in range(self.warmup):
runner(inputs, tactic=tactic, **kwargs)

stream.synchronize()
fewer_repeat_avg_time = pure_profile(stream, profile_fewer_repeat)

avg_time = start.elapsed_time(end) / self.repeat
disable_short_profile = os.environ.get(
"TLLM_AUTOTUNER_DISABLE_SHORT_PROFILE", "0") == "1"
if fewer_repeat_avg_time > short_profile_threshold_ms and not disable_short_profile:
print(
f"[Autotuner] Few repeat estimated time is longer than {short_profile_threshold_ms}ms, directly use the few repeat estimated time to avoid redundant profiling."
)
# directly use the few repeat estimated time to avoid redundant profiling
avg_time = fewer_repeat_avg_time
else:
# profile the kernel with the full repeat to get precise time
avg_time = pure_profile(stream, self.repeat)

shapes = self._get_input_sizes(inputs)
logger.debug(
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner):
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
use_cuda_graph=True,
)

def __init__(self, alpha: float, output_dtype: torch.dtype):
Expand Down