From 2cde8e6d8f1823e25bc9113e9cac8bd3a3351dcd Mon Sep 17 00:00:00 2001 From: SiYu Wu Date: Wed, 20 Nov 2024 23:03:55 +0800 Subject: [PATCH] feat(misc): Profiler support use --enable_profiling=MODE to enable, currently support torch_profile and nvtx (use with NVIDIA Nsight system) mode --- lightllm/server/api_cli.py | 15 ++ lightllm/server/api_http.py | 18 +++ lightllm/server/httpserver/manager.py | 13 +- lightllm/server/router/manager.py | 19 ++- .../model_infer/mode_backend/base_backend.py | 28 +++- lightllm/server/visualserver/manager.py | 19 ++- .../visualserver/model_infer/model_rpc.py | 22 +++ lightllm/utils/profiler.py | 134 ++++++++++++++++++ 8 files changed, 261 insertions(+), 7 deletions(-) create mode 100644 lightllm/utils/profiler.py diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index ec11f8f1d..a8303d875 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -537,4 +537,19 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used.""" ) + parser.add_argument( + "--enable_profiling", + type=str, + choices=["torch_profiler", "nvtx"], + default=None, + help="""Enable profiler support. + This will expose '/profiler_start' and '/profiler_stop' API, + below profiling features will only been enabled in this range. + Options: + 'torch_profiler': will setup torch.profiler.profile(), traces file will been saved to './trace', + or set by 'LIGHTLLM_TRACE_DIR' env; + 'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System + (you should setup it by youself). + A NVTX named 'LIGHTLLM_PROFILE' will been added within the profiling range.""", + ) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8bda50fb7..eddce01a9 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -335,6 +335,24 @@ async def kv_move_status(websocket: WebSocket): return +@app.get("/profiler_start") +async def profiler_start() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("start") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + +@app.get("/profiler_stop") +async def profiler_stop() -> Response: + if g_objs.args.enable_profiling: + await g_objs.httpserver_manager.profiler_cmd("stop") + return JSONResponse({"status": "ok"}) + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=400) + + @app.on_event("shutdown") async def shutdown(): logger.info("Received signal to shutdown. Performing graceful shutdown...") diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 11919398e..a65dd15c5 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,7 +13,7 @@ from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional +from typing import Literal, Union, List, Tuple, Dict, Optional from websockets import ClientConnection from fastapi import Request from ..tokenizer import get_tokenizer @@ -35,6 +35,7 @@ from lightllm.utils.config_utils import get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken +from lightllm.utils.profiler import ProfilerCmd from rpyc.utils.classic import obtain logger = init_logger(__name__) @@ -642,6 +643,16 @@ async def abort(self, group_req_id: int) -> bool: logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") return True + async def profiler_cmd(self, cmd: Literal["start", "stop"]): + receivers = [self.send_to_router] + if self.pd_mode.is_P_or_NORMAL() and self.enable_multimodal: + receivers.append(self.send_to_visual) + for receiver in receivers: + receiver.send_pyobj( + ProfilerCmd(cmd), + protocol=pickle.HIGHEST_PROTOCOL, + ) + async def recycle_resource_loop(self): pre_time_mark = time.time() diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 3c8ca2399..c0f78d468 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -26,6 +26,7 @@ from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.utils.log_utils import init_logger, log_time_ready +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd from lightllm.server.router.token_load import TokenLoad from lightllm.server.metrics.manager import MetricClient from lightllm.common.basemodel.infer_lock import g_router_lock @@ -106,6 +107,10 @@ def __init__(self, args: StartArgs): if not self.args.enable_cpu_cache else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) + + self.profiler = ( + ProcessProfiler(mode=args.enable_profiling, name="lightllm-router") if args.enable_profiling else None + ) return async def wait_to_model_ready(self): @@ -508,6 +513,16 @@ def _multinode_tp_generate_new_batch(self): raise e return + async def _profiler_cmd(self, cmd_obj: ProfilerCmd): + self.profiler.cmd(cmd_obj) + + cmd = ProfilerCmd(cmd=cmd_obj.cmd) + while not self.shm_reqs_io_buffer.is_empty(): + await asyncio.sleep(0.02) + + self.shm_reqs_io_buffer.write_obj([cmd]) + self.shm_reqs_io_buffer.set_ready() + async def _recv_new_reqs_and_schedule(self): if not hasattr(self, "recv_max_count"): self.recv_max_count = 64 @@ -515,9 +530,11 @@ async def _recv_new_reqs_and_schedule(self): try: # 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。 for _ in range(self.recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self._add_req(recv_req) + elif isinstance(recv_req, ProfilerCmd): + await self._profiler_cmd(recv_req) else: assert False, f"Error Req Inf {recv_req}" diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 95f0c9951..f7e67a164 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,7 +4,7 @@ import time import threading import torch.distributed as dist -from typing import List, Tuple, Callable, Optional +from typing import Dict, List, Literal, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger @@ -39,6 +39,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd class ModeBackend: @@ -218,11 +219,19 @@ def init_model(self, kvargs): if self.args.mtp_mode: self.init_mtp_draft_model(kvargs) + self.profiler: Optional[ProcessProfiler] = None + if self.args.enable_profiling: + self.profiler = ProcessProfiler( + mode=self.args.enable_profiling, + name=f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}", + ) + self.profiling_active = False + # 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景 # 可以降低 cpu overhead,大幅提升gpu得使用率。 - self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True) + self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True, name="loop0") self.infer_loop_thread.start() - self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True) + self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True, name="loop1") self.infer_loop_thread1.start() return @@ -308,6 +317,14 @@ def _try_read_new_reqs(self): self._try_read_new_reqs_multinode_tp() else: self._try_read_new_reqs_normal() + + # on each loop thread + if self.profiler is not None: + if self.profiler.is_active != self.profiling_active: + if self.profiling_active: + self.profiler.start() + else: + self.profiler.stop() return def _try_read_new_reqs_normal(self): @@ -373,6 +390,11 @@ def _read_reqs_buffer_and_init_reqs(self): if obj.req_id in g_infer_context.requests_mapping: req: InferReq = g_infer_context.requests_mapping[obj.req_id] req.infer_aborted = True + elif isinstance(obj, ProfilerCmd): + if obj.cmd == "start": + self.profiling_active = True + elif obj.cmd == "stop": + self.profiling_active = False else: assert False, f"error type {type(obj)}" if init_reqs: diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b7e1ac10c..adb971ea7 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -7,7 +7,7 @@ import pickle import inspect import setproctitle -from typing import List +from typing import List, Union from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -18,6 +18,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd from rpyc.utils.classic import obtain @@ -58,6 +59,9 @@ def __init__( self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.shm_req_manager = ShmReqManager() + self.profiler: "ProcessProfiler|None" = ( + ProcessProfiler(args.enable_profiling, name="lightllm-visual_server") if args.enable_profiling else None + ) async def wait_to_model_ready(self): @@ -90,6 +94,7 @@ async def wait_to_model_ready(self): "quant_type": self.args.vit_quant_type, "quant_cfg": self.args.vit_quant_cfg, "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), + "profiler": self.args.enable_profiling, } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) @@ -171,9 +176,19 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) + elif isinstance(recv_req, ProfilerCmd): + self.profiler.cmd(recv_req) + tasks = [] + for vit_dp_rank in range(self.vit_dp): + for vit_tp_rank in range(self.vit_tp): + task = asyncio.create_task( + self.model_rpcs[vit_dp_rank][vit_tp_rank].profiler_cmd(recv_req) + ) + tasks.append(task) + await asyncio.gather(*tasks) else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a9409ceb9..dab3dde85 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -24,6 +24,7 @@ from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.profiler import ProcessProfiler class VisualModelRpcServer(rpyc.Service): @@ -42,6 +43,13 @@ def exposed_init_model(self, kvargs): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] + self.profiler = ( + ProcessProfiler( + mode=kvargs["profiler"], name=f"lightllm-visual-vit_dp{self.dp_rank_id}_tp{self.tp_rank_id}" + ) + if kvargs["profiler"] + else None + ) init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -116,6 +124,10 @@ def exposed_encode(self, images: List[ImageItem]): self.cache_client.root.set_items_embed(ids_to_set) return + def exposed_profiler_cmd(self, cmd_obj): + cmd_obj = obtain(cmd_obj) + self.profiler.cmd(cmd_obj) + class VisualModelRpcClient: def __init__(self, model_rpc, vit_tp, rpc_server_process=None): @@ -138,9 +150,11 @@ async def _func(*args, **kwargs): self._init_model = async_wrap(self.model.init_model) self._encode = async_wrap(self.model.encode) + self._profiler_cmd = async_wrap(self.model.profiler_cmd) else: self._init_model = self.model.exposed_init_model self._encode = self.model.exposed_encode + self._profiler_cmd = self.model.exposed_profiler_cmd return async def init_model(self, kvargs): @@ -158,6 +172,14 @@ async def encode(self, images: List[ImageItem]): else: return ans + async def profiler_cmd(self, cmd_obj): + ans: rpyc.AsyncResult = self._profiler_cmd(cmd_obj) + if self.use_rpc: + await ans + return + else: + return + def _init_env(port, device_id): # 注册graceful 退出的处理 diff --git a/lightllm/utils/profiler.py b/lightllm/utils/profiler.py new file mode 100644 index 000000000..0ca2ac630 --- /dev/null +++ b/lightllm/utils/profiler.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass +import os +from typing import Any, Literal, Optional +import threading +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class ProfilerCmd: + cmd: Literal["start", "stop"] + + +class ProcessProfiler: + def __init__(self, mode: Literal["torch_profiler", "nvtx"], name: Optional[str] = None): + self.mode: Literal["torch_profiler", "nvtx"] = mode + self.name: Optional[str] = name + self.is_active: bool = False + self.lock = threading.Lock() + self.tid = threading.get_native_id() if hasattr(threading, "get_native_id") else threading.get_ident() + + logger.warning("-" * 50) + logger.warning(f"[tgid={os.getpid()} pid={self.tid}] Profiler <{self.name}> initialized with mode: {self.mode}") + if self.mode == "torch_profiler": + trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace") + self._torch_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, # additional overhead + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir, worker_name=name, use_gzip=True), + ) + logger.warning( + "Profiler support for torch.profiler enabled (--enable_profiling=torch_profiler), " + "trace files will be saved to %s (change it with LIGHTLLM_TRACE_DIR env var)", + trace_dir, + ) + elif self.mode == "nvtx": + self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE" + logger.warning( + "Profiler support for NVTX enabled (--enable_profiling=nvtx), toplevel NVTX mark is '%s'\n" + "you can use it with external profiling tools like NVIDIA Nsight Systems.", + self._nvtx_toplevel_mark, + ) + logger.warning( + "e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx " + "-e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [other nsys options] " + "python -m lightllm.server.api_server --enable_profiling=nvtx [other lightllm options]", + self._nvtx_toplevel_mark, + ) + elif self.mode is not None: + raise ValueError("invalid profiler mode") + logger.warning("Use /profiler_start and /profiler_stop HTTP GET APIs to start/stop profiling") + logger.warning("DO NOT enable this feature in production environment") + logger.warning("-" * 50) + + def _torch_profiler_start(self) -> None: + torch.cuda.synchronize() + with self.lock: + if not hasattr(self, "_torch_profiler_start_tid"): + # torch profiler only needs to start once per process + self._torch_profiler_start_tid = self.tid + self._torch_profiler.start() + torch.cuda.synchronize() + + def _nvtx_start(self) -> None: + torch.cuda.synchronize() + with self.lock: + if not hasattr(self, "_nvtx_toplevel_ids"): + self._nvtx_toplevel_ids = {} + self._nvtx_toplevel_ids[self.tid] = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark) + torch.cuda.synchronize() + + def start(self) -> None: + if self.is_active: + logger.error("profiler already started, ignore") + return + logger.warning(f"[tgid={os.getpid()} pid={self.tid}] Profiler <{self.name}>: profiling start") + self.is_active = True + if self.mode == "torch_profiler": + self._torch_profiler_start() + elif self.mode == "nvtx": + self._nvtx_start() + + def _torch_profiler_stop(self) -> None: + torch.cuda.synchronize() + with self.lock: + if hasattr(self, "_torch_profiler_start_tid") and self._torch_profiler_start_tid == self.tid: + # torch profiler only needs to stop once per process, in the same thread that started it + del self._torch_profiler_start_tid + logger.warning(f"Profiler <{self.name}>: torch profiler stopping and saving trace, please wait...") + try: + self._torch_profiler.stop() + except RuntimeError as e: + logger.error(f"Profiler <{self.name}>: torch profiler stop failed: {e}, maybe too short") + import traceback + + traceback.print_exc() + return + logger.warning(f"Profiler <{self.name}>: torch profiler trace saved.") + torch.cuda.synchronize() + + def _nvtx_stop(self) -> None: + torch.cuda.synchronize() + with self.lock: + if hasattr(self, "_nvtx_toplevel_ids") and self.tid in self._nvtx_toplevel_ids: + torch.cuda.nvtx.range_end(self._nvtx_toplevel_ids[self.tid]) + del self._nvtx_toplevel_ids[self.tid] + else: + logger.error("nvtx profiler stop called without matching start for tid %s", self.tid) + torch.cuda.synchronize() + + def stop(self) -> None: + if not self.is_active: + logger.error("profiler not started, ignore") + return + logger.warning(f"[tgid={os.getpid()} pid={self.tid}] Profiler <{self.name}>: profiling stop") + self.is_active = False + if self.mode == "torch_profiler": + self._torch_profiler_stop() + elif self.mode == "nvtx": + self._nvtx_stop() + + def cmd(self, cmd_obj: ProfilerCmd) -> None: + if cmd_obj.cmd == "start": + self.start() + elif cmd_obj.cmd == "stop": + self.stop() + else: + raise ValueError(f"invalid profiler ops: {cmd_obj.cmd}")