From 98a09a08776f4edd653fa99f9769e37c74a604db Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 7 Jul 2025 20:32:18 +0800 Subject: [PATCH 1/5] cuda graph pool with LRU --- lightllm/common/basemodel/basemodel.py | 8 +++-- lightllm/common/basemodel/cuda_graph.py | 34 ++++++++++++++++---- lightllm/server/api_cli.py | 7 ++++ lightllm/server/core/objs/start_args_type.py | 1 + 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index c7760e995..433ed2d63 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -346,6 +346,8 @@ def _decode( ) -> ModelOutput: if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) + assert find_graph_batch_size is not None + padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) copy_kv_index_to_req( @@ -356,7 +358,7 @@ def _decode( ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) - if self.graph.need_capture(find_graph_batch_size): + if self.graph.get_graph(find_graph_batch_size) is None: infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode( self._token_forward, padded_model_input.input_ids, infer_state @@ -497,6 +499,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) + assert find_graph_batch_size is not None + padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) infer_state0 = self._create_inferstate(padded_model_input0, 0) @@ -516,7 +520,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) - if self.graph.need_capture(find_graph_batch_size): + if self.graph.get_graph(find_graph_batch_size) is None: infer_state0.is_cuda_graph = True infer_state1.is_cuda_graph = True diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index dc615eb46..7805ae6a9 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -2,6 +2,7 @@ import torch import copy import bisect +from collections import OrderedDict from typing import Optional from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args @@ -9,7 +10,6 @@ from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput from .infer_struct import InferStateInfo - logger = init_logger(__name__) @@ -17,12 +17,14 @@ class CudaGraph: # CudaGraph forward pass for the decoding stage. def __init__(self, max_batch_size=8, max_len_in_batch=8192): - self.graph = {} + self.graph = OrderedDict() # for LRU + self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None self.max_batch_size = max_batch_size self.graph_max_len_in_batch = max_len_in_batch self.args = get_env_start_args() self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap + self.max_graph_pool_size = self.args.max_graph_pool_size # gen cuda graph batch_sizes # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size] @@ -47,12 +49,22 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch - def need_capture(self, batch_size): - find_batch_size = self.find_closest_graph_batch_size(batch_size) - if find_batch_size is not None: - return find_batch_size not in self.graph + def get_graph(self, batch_size): + # we assume batch_size is already dealed with find_closest_graph_batch_size outside + # If the graph already exists, dequeue it and then enqueue it, + # thus move it to the most recently used position. + if batch_size in self.graph: + find_graph = self.graph.pop(batch_size) + self.graph[batch_size] = find_graph + return find_graph else: - assert False, "dead code" + return None + + def evict_oldest_graph(self): + if self.graph: + oldest_batch_size, oldest_graph = self.graph.popitem(last=False) + del oldest_graph + torch.cuda.empty_cache() def find_closest_graph_batch_size(self, batch_size): index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) @@ -64,6 +76,9 @@ def find_closest_graph_batch_size(self, batch_size): def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo): dist_group: CustomProcessGroup = infer_state.dist_group + if len(self.graph) >= self.max_graph_pool_size: + self.evict_oldest_graph() + graph_obj = torch.cuda.CUDAGraph() batch_size = input_ids.shape[0] infer_state.max_len_in_batch = self.graph_max_len_in_batch @@ -84,6 +99,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(input_ids, infer_state) + # we assume batch_size is already dealed with find_closest_graph_batch_size outside self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) graph_obj.replay() return model_output @@ -97,6 +113,9 @@ def _capture_decode_overlap( infer_state1: InferStateInfo, ): dist_group: CustomProcessGroup = infer_state.dist_group + if len(self.graph) >= self.max_graph_pool_size: + self.evict_oldest_graph() + dist_group1 = infer_state1.dist_group graph_obj = torch.cuda.CUDAGraph() batch_size = input_ids.shape[0] @@ -113,6 +132,7 @@ def _capture_decode_overlap( with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) + # we assume batch_size is already dealed with find_closest_graph_batch_size outside self.graph[batch_size] = ( graph_obj, input_ids, diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 3f3eaf96f..60bfc3cc6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -335,6 +335,13 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage") + parser.add_argument( + "--max_graph_pool_size", + type=int, + default=16, + help="""Maximum cuda graph pool size for decoding stage.""", + ) + parser.add_argument( "--graph_max_batch_size", type=int, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ec1eb427e..eb2464f3f 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -77,6 +77,7 @@ class StartArgs: visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) enable_monitor_auth: bool = field(default=False) disable_cudagraph: bool = field(default=False) + max_graph_pool_size: int = field(default=16) graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) From 07ad3bca27474e06d78feacf2f537a1c7386ac29 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 9 Jul 2025 11:41:47 +0800 Subject: [PATCH 2/5] fix gemini review comments --- lightllm/common/basemodel/basemodel.py | 12 ++++++++++-- lightllm/common/basemodel/cuda_graph.py | 15 +++++++-------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 433ed2d63..dbb687bb7 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -346,7 +346,9 @@ def _decode( ) -> ModelOutput: if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) - assert find_graph_batch_size is not None + if find_graph_batch_size is None: + logger.error("No suitable graph batch size found for batch_size={model_input.batch_size}, return None.") + return None padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) infer_state = self._create_inferstate(padded_model_input) @@ -358,6 +360,8 @@ def _decode( ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) + # Check if a graph needs to be captured. + # get_graph returns None if a graph for the batch_size doesn't exist. if self.graph.get_graph(find_graph_batch_size) is None: infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode( @@ -499,7 +503,9 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) - assert find_graph_batch_size is not None + if find_graph_batch_size is None: + logger.error("No suitable graph batch size found for batch_size={origin_batch_size}, return None.") + return None padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size) @@ -520,6 +526,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) + # Check if a graph needs to be captured. + # get_graph returns None if a graph for the batch_size doesn't exist. if self.graph.get_graph(find_graph_batch_size) is None: infer_state0.is_cuda_graph = True infer_state1.is_cuda_graph = True diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 7805ae6a9..74730f6f0 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -50,12 +50,11 @@ def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch def get_graph(self, batch_size): - # we assume batch_size is already dealed with find_closest_graph_batch_size outside - # If the graph already exists, dequeue it and then enqueue it, - # thus move it to the most recently used position. + # We assume batch_size has already been adjusted to the closest supported graph batch size + # If the graph already exists, get it and move it to the most recently used position. if batch_size in self.graph: - find_graph = self.graph.pop(batch_size) - self.graph[batch_size] = find_graph + find_graph = self.graph.pop(batch_size) # Dequeue the graph + self.graph[batch_size] = find_graph # Enqueue the graph for LRU return find_graph else: return None @@ -64,7 +63,7 @@ def evict_oldest_graph(self): if self.graph: oldest_batch_size, oldest_graph = self.graph.popitem(last=False) del oldest_graph - torch.cuda.empty_cache() + logger.info(f"Evicted CUDA graph for batch size: {oldest_batch_size}") def find_closest_graph_batch_size(self, batch_size): index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size) @@ -99,7 +98,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output = decode_func(input_ids, infer_state) - # we assume batch_size is already dealed with find_closest_graph_batch_size outside + # We assume batch_size has already been adjusted to the closest supported graph batch size self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output) graph_obj.replay() return model_output @@ -132,7 +131,7 @@ def _capture_decode_overlap( with lightllm_capture_graph(dist_group): with torch.cuda.graph(graph_obj, pool=self.mempool): model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1) - # we assume batch_size is already dealed with find_closest_graph_batch_size outside + # We assume batch_size has already been adjusted to the closest supported graph batch size self.graph[batch_size] = ( graph_obj, input_ids, From 55fbf5f2424091f1ccc7fe6d245211102e9f607c Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 23 Jul 2025 13:49:37 +0800 Subject: [PATCH 3/5] use global max batch_size for cuda graph --- lightllm/common/basemodel/basemodel.py | 31 ++++++++++++++++++++------ 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index dbb687bb7..67840b927 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from typing import final +import torch.distributed as dist from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.common.mem_manager import MemoryManager @@ -18,7 +19,7 @@ from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.quantization import Quantcfg from lightllm.utils.log_utils import init_logger -from lightllm.utils.dist_utils import get_dp_world_size +from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size, get_global_rank from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput @@ -344,10 +345,18 @@ def _decode( self, model_input: ModelInput, ) -> ModelOutput: - if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch): - find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size) + # collect global max batch_size + world_size = get_global_world_size() + rank = get_global_rank() + all_batch_sizes = [None] * world_size + all_batch_sizes[rank] = model_input.batch_size + dist.all_gather_object(all_batch_sizes, model_input.batch_size) + global_max_batch_size = max(all_batch_sizes) + + if self.graph is not None and self.graph.can_run(global_max_batch_size, model_input.max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(global_max_batch_size) if find_graph_batch_size is None: - logger.error("No suitable graph batch size found for batch_size={model_input.batch_size}, return None.") + logger.error("No suitable graph batch size found for batch_size={global_max_batch_size}, return None.") return None padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size) @@ -501,10 +510,18 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode origin_batch_size = model_input0.batch_size max_len_in_batch = max(model_input0.max_len_in_batch, model_input1.max_len_in_batch) - if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch): - find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size) + # collect global max batch_size + world_size = get_global_world_size() + rank = get_global_rank() + all_batch_sizes = [None] * world_size + all_batch_sizes[rank] = origin_batch_size + dist.all_gather_object(all_batch_sizes, origin_batch_size) + global_max_batch_size = max(all_batch_sizes) + + if self.graph is not None and self.graph.can_run(global_max_batch_size, max_len_in_batch): + find_graph_batch_size = self.graph.find_closest_graph_batch_size(global_max_batch_size) if find_graph_batch_size is None: - logger.error("No suitable graph batch size found for batch_size={origin_batch_size}, return None.") + logger.error("No suitable graph batch size found for batch_size={global_max_batch_size}, return None.") return None padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size) From fe6b1677a101c0160189b3add3e1672369b2d338 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 23 Jul 2025 13:51:39 +0800 Subject: [PATCH 4/5] use cuda graph pool for lazy graph capture --- lightllm/common/basemodel/cuda_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 74730f6f0..ab594a5eb 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -210,7 +210,7 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + for batch_size in (self.cuda_graph_batch_sizes[-1],): seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -261,7 +261,7 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + for batch_size in (self.cuda_graph_batch_sizes[-1],): decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph From e63cae7d0591f1025f5618222a9cd529050199fc Mon Sep 17 00:00:00 2001 From: STwangyingrui Date: Wed, 23 Jul 2025 14:33:57 +0800 Subject: [PATCH 5/5] recover need_capture --- lightllm/common/basemodel/basemodel.py | 8 ++------ lightllm/common/basemodel/cuda_graph.py | 6 +++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 007d0f4b8..80b36295e 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -381,9 +381,7 @@ def _decode( ) infer_state.init_some_extra_state(self, padded_model_input.input_ids) - # Check if a graph needs to be captured. - # get_graph returns None if a graph for the batch_size doesn't exist. - if self.graph.get_graph(find_graph_batch_size) is None: + if self.graph.need_capture(find_graph_batch_size): infer_state.is_cuda_graph = True model_output: ModelOutput = self.graph.capture_decode( self._token_forward, padded_model_input.input_ids, infer_state @@ -574,9 +572,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode ) infer_state1.init_some_extra_state(self, padded_model_input1.input_ids) - # Check if a graph needs to be captured. - # get_graph returns None if a graph for the batch_size doesn't exist. - if self.graph.get_graph(find_graph_batch_size) is None: + if self.graph.need_capture(find_graph_batch_size): infer_state0.is_cuda_graph = True infer_state1.is_cuda_graph = True diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 43cccd452..8d436e62f 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -49,15 +49,15 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): def can_run(self, batch_size, max_len_in_batch): return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch - def get_graph(self, batch_size): + def need_capture(self, batch_size): # We assume batch_size has already been adjusted to the closest supported graph batch size # If the graph already exists, get it and move it to the most recently used position. if batch_size in self.graph: find_graph = self.graph.pop(batch_size) # Dequeue the graph self.graph[batch_size] = find_graph # Enqueue the graph for LRU - return find_graph + return False else: - return None + return True def evict_oldest_graph(self): if self.graph: