Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch"
)
parser.add_argument(
"--use_config_server_to_init_nccl",
action="store_true",
help="""use tcp store server started by config_server to init nccl, default is False, when set to True,
the --nccl_host must equal to the config_server_host, and the --nccl_port must be unique for a config_server,
dont use same nccl_port for different inference node, it will be critical error""",
)

parser.add_argument(
"--mode",
type=str,
Expand Down
12 changes: 8 additions & 4 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import signal
from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker
from lightllm.utils.start_utils import process_manager
from lightllm.utils.start_utils import process_manager, kill_recursive
from .metrics.manager import start_metric_manager
from .embed_cache.manager import start_cache_manager
from .visualserver.manager import start_visual_process
Expand All @@ -25,8 +25,8 @@ def setup_signal_handlers(http_server_process, process_manager):
def signal_handler(sig, frame):
if sig == signal.SIGINT:
logger.info("Received SIGINT (Ctrl+C), forcing immediate exit...")
if http_server_process and http_server_process.poll() is None:
http_server_process.kill()
if http_server_process:
kill_recursive(http_server_process)

process_manager.terminate_all_processes()
logger.info("All processes have been forcefully terminated.")
Expand All @@ -47,7 +47,7 @@ def signal_handler(sig, frame):
logger.info("HTTP server has exited gracefully")
else:
logger.warning("HTTP server did not exit in time, killing it...")
http_server_process.kill()
kill_recursive(http_server_process)

process_manager.terminate_all_processes()
logger.info("All processes have been terminated gracefully.")
Expand Down Expand Up @@ -82,6 +82,10 @@ def normal_or_p_d_start(args):

logger.info(f"use tgi api: {args.use_tgi_api}")

# 当使用config_server来初始化nccl时,nccl_host和config_server_host必须一致
if args.use_config_server_to_init_nccl:
assert args.config_server_host == args.nccl_host

assert (
args.mem_fraction > 0 and args.mem_fraction < 1
), f"Invalid mem_fraction {args.mem_fraction}, The expected value is between 0 and 1."
Expand Down
102 changes: 97 additions & 5 deletions lightllm/server/config_server/api_http.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import time
import asyncio
import base64
import pickle
import multiprocessing as mp
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Query
from threading import Lock
from typing import Dict
from typing import Dict, List
from fastapi.responses import JSONResponse
from lightllm.utils.log_utils import init_logger
from ..pd_io_struct import PD_Master_Obj
import base64
import pickle
import os
import requests
from .nccl_tcp_store import start_tcp_store_server
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.process_check import start_parent_check_thread


logger = init_logger(__name__)
app = FastAPI()
Expand Down Expand Up @@ -112,3 +117,90 @@ async def allocate_global_unique_multimodal_id_range():
end_id = global_multimodal_embedding_id

return {"start_id": start_id, "end_id": end_id}


global_store_port_to_process: Dict[int, mp.Process] = {}
global_store_port_to_client_states: Dict[int, List[bool]] = {}
global_store_port_lock = asyncio.Lock()


@app.get("/start_tcp_store_server")
async def http_start_tcp_store_server(
tcp_store_port: int = Query(...), rank_id: int = Query(...), world_size: int = Query(...)
):
"""
Start a TCP store server for NCCL communication.

Args:
tcp_store_port (int): The port number for the TCP store server.
rank_id (int): The rank ID of inference process.
world_size (int): The world size of nccl group.

Returns:
dict: A dictionary containing the status of the server.
"""
global global_store_port_to_process
global global_store_port_to_client_states
global global_store_port_lock

args = get_env_start_args()

if rank_id == 0:
async with global_store_port_lock:
if tcp_store_port in global_store_port_to_client_states:
logger.error(f"tcp store server {tcp_store_port} already started, rank_id 0 find client state exists")
assert False, f"tcp store server {tcp_store_port} already started, rank_id 0 find client state exists"

if tcp_store_port in global_store_port_to_process:
logger.warning(f"tcp store server {tcp_store_port} already started, kill and restart it")
process = global_store_port_to_process[tcp_store_port]
process.kill()
process.join()

global_store_port_to_process[tcp_store_port] = start_tcp_store_server(
args.config_server_host, tcp_store_port
)

world_size_state = [True for _ in range(world_size)]
global_store_port_to_client_states[tcp_store_port] = world_size_state

world_size_state[rank_id] = False

start_time = time.time()
while any(world_size_state):
await asyncio.sleep(1)
if time.time() - start_time > 60 * 3:
logger.error(
f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} wait all quit timeout"
)
async with global_store_port_lock:
global_store_port_to_client_states.pop(tcp_store_port, None)
raise Exception(
f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} wait timeout"
)

async with global_store_port_lock:
global_store_port_to_client_states.pop(tcp_store_port, None)

return {"status": "ok"}
else:
start_time = time.time()
while tcp_store_port not in global_store_port_to_client_states:
await asyncio.sleep(1)
if time.time() - start_time > 60 * 3:
logger.error(f"tcp store port {tcp_store_port} rank_id {rank_id} world_size {world_size} state timeout")
raise Exception(
f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} state timeout"
)

world_size_state = global_store_port_to_client_states[tcp_store_port]

assert (
world_size_state[rank_id] is True
), f"tcp store server {tcp_store_port} rank_id {rank_id} world_size {world_size} world_size_state error"
world_size_state[rank_id] = False
return {"status": "ok"}


logger.info("config server start_parent_check_thread...")
start_parent_check_thread()
60 changes: 60 additions & 0 deletions lightllm/server/config_server/nccl_tcp_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import psutil
import time
import torch.distributed as dist
import torch.multiprocessing as mp
from lightllm.utils.log_utils import init_logger
from lightllm.utils.process_check import start_parent_check_thread

logger = init_logger(__name__)


def start_tcp_store_server(nccl_store_host, nccl_store_port):
"""
start a process to run a TCPStore server.
"""
process = mp.Process(
target=_start_tcp_store_server,
args=(nccl_store_host, nccl_store_port),
daemon=True,
)
process.start()
return process


def _start_tcp_store_server(nccl_store_host, nccl_store_port):
"""
start a TCPStore server.
"""
start_parent_check_thread()

try:
from torch._C._distributed_c10d import _DEFAULT_PG_NCCL_TIMEOUT

default_pg_nccl_timeout = _DEFAULT_PG_NCCL_TIMEOUT
except ImportError:
# if C++ NCCL support is not compiled, we don't have access to the default nccl value.
# if anyone is actually trying to use nccl in this state, it should error.
default_pg_nccl_timeout = None

logger.info(f"default_pg_nccl_timeout: {default_pg_nccl_timeout}")
logger.info(f"[Server] TCPStore start: {nccl_store_host}:{nccl_store_port}")
try:
store = dist.TCPStore(
host_name=nccl_store_host,
port=nccl_store_port,
world_size=None,
is_master=True,
wait_for_workers=False,
timeout=default_pg_nccl_timeout,
multi_tenant=True,
use_libuv=True,
)

while True:
keys_num = store.num_keys()
logger.info(f"[Server] TCPStore start: {nccl_store_host}:{nccl_store_port} keys num: {keys_num}")
time.sleep(20)

except Exception as e:
logger.warning(str(e))
logger.info(f"TCPStore server {nccl_store_host}:{nccl_store_port} start failed, retrying ...")
4 changes: 3 additions & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Optional, Union
from typing import List, Optional, Tuple

# 只是为了更好的编程提示

Expand Down Expand Up @@ -31,7 +31,9 @@ class StartArgs:
tp: int = field(default=1)
dp: int = field(default=1)
max_req_total_len: int = field(default=2048 + 1024)
nccl_host: str = field(default="127.0.0.1")
nccl_port: int = field(default=28765)
use_config_server_to_init_nccl: bool = field(default=False)
mode: List[str] = field(default_factory=list)
trust_remote_code: bool = field(default=False)
disable_log_stats: bool = field(default=False)
Expand Down
32 changes: 32 additions & 0 deletions lightllm/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.distributed as dist
import os
import torch
import requests

# 规范 rank 的含义,在 llm 推理的相关代码中下述的 rank 的含义如下:
# global_rank 全局 rank 序列id, 如两节点 8卡,会存在 0 - 15 16个global_rank
Expand Down Expand Up @@ -93,6 +94,7 @@ def init_distributed_env(kvargs):
dp_size_in_node = max(1, get_dp_size() // nnodes)
set_dp_rank_in_node(get_global_dp_rank() % dp_size_in_node)

_init_nccl_env()
device_id = kvargs["rank_id"] % get_node_world_size()
set_current_device_id(device_id)
torch.cuda.set_device(device_id)
Expand Down Expand Up @@ -199,3 +201,33 @@ def create_new_group_for_current_dp(backend):
if get_global_dp_rank() == iter_dp_rank:
ans_group = device_group
return ans_group


def _init_nccl_env():
from lightllm.utils.envs_utils import get_env_start_args

args = get_env_start_args()

# 配置使用外部的 tcp store server 来创建 nccl 连接
if args.use_config_server_to_init_nccl:
os.environ["TORCHELASTIC_USE_AGENT_STORE"] = "True"
rank_id = get_global_rank()
world_size = get_global_world_size()
ip_port = f"{args.config_server_host}:{args.config_server_port}"
params = f"tcp_store_port={args.nccl_port}&&rank_id={rank_id}&&world_size={world_size}"

if rank_id == 0:
# 当使用外部config server 启动的tcpStore来初始化nccl时,需要保证配置了config_server_host.
# 同时也需要保证config_server_host和nccl_host是同一个ip, 这个时候 rank 0 推理进程会先调用
# config server的http接口来启动tcp store server, 然后再调用nccl init方法来初始化nccl.
assert args.config_server_host == args.nccl_host
url = f"http://{ip_port}/start_tcp_store_server?{params}"
response = requests.get(url, timeout=60 * 3)
assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}"
else:
assert args.config_server_host == args.nccl_host
url = f"http://{ip_port}/start_tcp_store_server?{params}"
response = requests.get(url, timeout=60 * 3)
assert response.status_code == 200, f"Failed to init config server nccl tcp store: {response.status_code}"

return
13 changes: 13 additions & 0 deletions lightllm/utils/start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,17 @@ def start_submodule_processes(start_funcs=[], start_args=[]):
return


def kill_recursive(proc):
try:
parent = psutil.Process(proc.pid)
children = parent.children(recursive=True)
for child in children:
logger.info(f"Killing child process {child.pid}")
child.kill()
logger.info(f"Killing parent process {proc.pid}")
parent.kill()
except psutil.NoSuchProcess:
logger.warning(f"Process {proc.pid} does not exist.")


process_manager = SubmoduleManager()