Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions flocks/cli/service_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,30 @@ def _tracked_processes_stopped(
return not any(pid_is_running(pid) for pid in tracked_pids)


def _runtime_record_pids(record: RuntimeRecord | None) -> list[int]:
"""Collect the latest pids implied by a runtime record."""
if record is None:
return []

result: list[int] = []
if record.pid > 0:
result = append_unique_pids(result, collect_process_tree_pids(record.pid))
if record.pgid is not None and sys.platform != "win32":
result = append_unique_pids(result, _process_group_member_pids(record.pgid))
return result


def _current_stop_targets(
port: int,
record: RuntimeRecord | None,
tracked_pids: Iterable[int],
) -> list[int]:
"""Refresh the pid list that stop_one() should verify or force kill."""
result = append_unique_pids([], tracked_pids)
result = append_unique_pids(result, _runtime_record_pids(record))
return append_unique_pids(result, port_owner_pids(port))


def signal_process_group(sig: signal.Signals, pgid: int | None) -> None:
"""Signal an entire Unix process group when it exists."""
if sys.platform == "win32" or pgid is None or pgid <= 0:
Expand Down Expand Up @@ -888,25 +912,34 @@ def stop_one(port: int, pid_file: Path, name: str, console) -> None:
else:
signal_pid_list(signal.SIGTERM, target_pids)
for _ in range(10):
if _tracked_processes_stopped(port, runtime_record, target_pids):
current_targets = _current_stop_targets(port, runtime_record, target_pids)
if _tracked_processes_stopped(port, runtime_record, current_targets):
pid_file.unlink(missing_ok=True)
console.print(f"[flocks] {name} 已停止。")
return
time.sleep(1)

console.print(f"[flocks] {name} 未在预期时间内退出,强制终止...")
force_targets = _current_stop_targets(port, runtime_record, target_pids)
if runtime_record and runtime_record.pgid is not None:
signal_process_group(signal.SIGKILL, runtime_record.pgid)
signal_pid_list(signal.SIGKILL, append_unique_pids(target_pids, port_owner_pids(port)))
signal_pid_list(signal.SIGKILL, force_targets)

for _ in range(10):
if _tracked_processes_stopped(port, runtime_record, append_unique_pids(target_pids, port_owner_pids(port))):
force_targets = _current_stop_targets(port, runtime_record, target_pids)
if _tracked_processes_stopped(port, runtime_record, force_targets):
pid_file.unlink(missing_ok=True)
console.print(f"[flocks] {name} 已停止。")
return
if sys.platform == "win32":
for pid in force_targets:
subprocess.run(["taskkill", "/PID", str(pid), "/T", "/F"], check=False, capture_output=True)
else:
if runtime_record and runtime_record.pgid is not None:
signal_process_group(signal.SIGKILL, runtime_record.pgid)
signal_pid_list(signal.SIGKILL, force_targets)
time.sleep(1)

pid_file.unlink(missing_ok=True)
raise ServiceError(f"{name} 未在预期时间内退出,请手动检查端口 {port}。")


Expand Down
63 changes: 63 additions & 0 deletions tests/cli/test_service_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,69 @@ def fake_run(args, **kwargs):
]


def test_stop_one_force_kill_refreshes_process_group_members(monkeypatch, tmp_path: Path) -> None:
pid_file = tmp_path / "backend.pid"
service_manager.write_runtime_record(
pid_file,
service_manager.RuntimeRecord(pid=111, pgid=222, port=8000),
)
console = DummyConsole()
pid_signals: list[tuple[signal.Signals, list[int]]] = []
group_signals: list[tuple[signal.Signals, int | None]] = []
alive_group_members = {333}

monkeypatch.setattr(service_manager.sys, "platform", "darwin")
monkeypatch.setattr(service_manager, "collect_process_tree_pids", lambda _pid: [111])
monkeypatch.setattr(service_manager, "_process_group_member_pids", lambda pgid: [333] if pgid == 222 and alive_group_members else [])
monkeypatch.setattr(service_manager, "port_owner_pids", lambda _port: [])
monkeypatch.setattr(service_manager, "pid_is_running", lambda pid: pid in alive_group_members)
monkeypatch.setattr(service_manager, "process_group_is_running", lambda pgid: bool(pgid == 222 and alive_group_members))
monkeypatch.setattr(service_manager.time, "sleep", lambda _delay: None)

def fake_signal_group(sig, pgid):
group_signals.append((sig, pgid))

def fake_signal_pid_list(sig, pids):
pid_list = list(pids)
pid_signals.append((sig, pid_list))
if sig == signal.SIGKILL and 333 in pid_list:
alive_group_members.clear()

monkeypatch.setattr(service_manager, "signal_process_group", fake_signal_group)
monkeypatch.setattr(service_manager, "signal_pid_list", fake_signal_pid_list)

service_manager.stop_one(8000, pid_file, "后端", console)

assert (signal.SIGTERM, 222) in group_signals
assert any(sig == signal.SIGKILL and 333 in pids for sig, pids in pid_signals)
assert not pid_file.exists()
assert console.messages[-1] == "[flocks] 后端 已停止。"


def test_stop_one_keeps_runtime_record_when_force_kill_still_times_out(monkeypatch, tmp_path: Path) -> None:
pid_file = tmp_path / "backend.pid"
service_manager.write_runtime_record(
pid_file,
service_manager.RuntimeRecord(pid=111, pgid=222, port=8000),
)
console = DummyConsole()

monkeypatch.setattr(service_manager.sys, "platform", "darwin")
monkeypatch.setattr(service_manager, "collect_process_tree_pids", lambda _pid: [111])
monkeypatch.setattr(service_manager, "_process_group_member_pids", lambda pgid: [333] if pgid == 222 else [])
monkeypatch.setattr(service_manager, "port_owner_pids", lambda _port: [])
monkeypatch.setattr(service_manager, "pid_is_running", lambda _pid: False)
monkeypatch.setattr(service_manager, "process_group_is_running", lambda pgid: pgid == 222)
monkeypatch.setattr(service_manager, "signal_process_group", lambda *_args: None)
monkeypatch.setattr(service_manager, "signal_pid_list", lambda *_args: None)
monkeypatch.setattr(service_manager.time, "sleep", lambda _delay: None)

with pytest.raises(service_manager.ServiceError, match="未在预期时间内退出"):
service_manager.stop_one(8000, pid_file, "后端", console)

assert pid_file.exists()


@contextlib.contextmanager
def _record_call(call_order: list[str], name: str):
call_order.append(name)
Expand Down
Loading