Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 76 additions & 25 deletions flocks/channel/builtin/feishu/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import importlib
import json
import threading
import time
import uuid
from urllib.parse import urlsplit
from typing import Any, Awaitable, Callable, Optional
Expand Down Expand Up @@ -107,31 +108,50 @@ def do_without_validation(self, payload: bytes) -> None:

class _CompatWSClient:
def __init__(self) -> None:
self._client = native_client_cls(
app_id=app_id,
app_secret=app_secret,
log_level=lark.LogLevel.WARNING,
event_handler=_Dispatcher(),
domain=domain,
auto_reconnect=False,
)
self._client: Any | None = None
self._thread: Optional[threading.Thread] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._receive_task: Optional[asyncio.Task] = None
self._ping_task: Optional[asyncio.Task] = None
self._start_error: Optional[BaseException] = None
self._stop_requested = False
self._finished = threading.Event()

def start(self) -> None:
self._finished.clear()
self._start_error = None
self._stop_requested = False

def _run() -> None:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
ws_module.loop = self._loop
self._client = native_client_cls(
app_id=app_id,
app_secret=app_secret,
log_level=lark.LogLevel.WARNING,
event_handler=_Dispatcher(),
domain=domain,
auto_reconnect=False,
)

original_ping_loop = getattr(self._client, "_ping_loop", None)
if callable(original_ping_loop):
async def _tracked_ping_loop() -> None:
self._ping_task = asyncio.current_task()
try:
await original_ping_loop()
finally:
self._ping_task = None

self._client._ping_loop = _tracked_ping_loop

async def _receive_message_loop() -> None:
self._receive_task = asyncio.current_task()
try:
while True:
if self._client is None:
return
if self._stop_requested and self._client._conn is None:
return
if self._client._conn is None:
Expand Down Expand Up @@ -160,6 +180,8 @@ async def _receive_message_loop() -> None:

self._client._receive_message_loop = _receive_message_loop
try:
if self._stop_requested:
return
self._client.start()
except RuntimeError as e:
if "Event loop stopped before Future completed" not in str(e):
Expand All @@ -175,42 +197,71 @@ async def _receive_message_loop() -> None:
daemon=True,
)
self._thread.start()
self._finished.wait(timeout=0.2)
deadline = time.monotonic() + 0.2
while self._client is None and not self._finished.is_set():
remaining = deadline - time.monotonic()
if remaining <= 0:
break
self._finished.wait(timeout=min(remaining, 0.01))
if self._start_error:
raise RuntimeError(str(self._start_error)) from self._start_error

def stop(self) -> None:
self._stop_requested = True
if self._client is None and self._thread and self._thread.is_alive():
deadline = time.monotonic() + 0.5
while self._client is None and not self._finished.is_set():
remaining = deadline - time.monotonic()
if remaining <= 0:
break
self._finished.wait(timeout=min(remaining, 0.01))
if self._loop is None:
if self._thread:
self._thread.join(timeout=5)
self._thread = None
return
self._stop_requested = True
loop_running = self._loop.is_running()

async def _drain_receive_task() -> None:
task = self._receive_task
async def _drain_task(task: Optional[asyncio.Task], timeout: float) -> None:
if task is None or task.done():
return
try:
await asyncio.wait_for(asyncio.shield(task), timeout=1.0)
await asyncio.wait_for(asyncio.shield(task), timeout=timeout)
except asyncio.TimeoutError:
task.cancel()
with contextlib.suppress(asyncio.CancelledError, Exception):
await task

with contextlib.suppress(Exception):
future = asyncio.run_coroutine_threadsafe(
self._client._disconnect(),
self._loop,
)
future.result(timeout=5)
with contextlib.suppress(Exception):
future = asyncio.run_coroutine_threadsafe(
_drain_receive_task(),
self._loop,
)
future.result(timeout=2)
if loop_running and self._client is not None:
with contextlib.suppress(Exception):
future = asyncio.run_coroutine_threadsafe(
self._client._disconnect(),
self._loop,
)
future.result(timeout=5)
if loop_running:
with contextlib.suppress(Exception):
future = asyncio.run_coroutine_threadsafe(
_drain_task(self._receive_task, timeout=1.0),
self._loop,
)
future.result(timeout=2)
if loop_running:
with contextlib.suppress(Exception):
future = asyncio.run_coroutine_threadsafe(
_drain_task(self._ping_task, timeout=1.0),
self._loop,
)
future.result(timeout=2)
with contextlib.suppress(Exception):
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=5)
self._thread = None
self._loop = None
self._client = None
self._receive_task = None
self._ping_task = None

@property
def start_error(self) -> Optional[BaseException]:
Expand Down
146 changes: 146 additions & 0 deletions tests/channel/test_feishu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import asyncio
import json
import sys
import threading
import time
import types
from pathlib import Path
from unittest.mock import AsyncMock
Expand Down Expand Up @@ -670,6 +672,150 @@ def fake_import_module(name, package=None):
assert ws_client.start_error is None


def test_build_ws_client_initializes_legacy_client_on_worker_loop(monkeypatch, caplog) -> None:
captured: dict[str, object] = {}

class _LoopBoundConnection:
def __init__(self, client) -> None:
self._client = client

async def recv(self):
if asyncio.get_running_loop() is not self._client.bound_loop:
future = self._client.bound_loop.create_future()
return await future
while not self._client.closed:
await asyncio.sleep(0.01)
return b"ignored"

async def close(self) -> None:
self._client.closed = True

class _FakeClient:
def __init__(self, **kwargs):
captured.update(kwargs)
captured["client"] = self
self.bound_loop = asyncio.get_event_loop()
self.constructed_thread = threading.get_ident()
self._conn = None
self._auto_reconnect = kwargs["auto_reconnect"]
self.closed = False
self.disconnect_calls = 0

def start(self):
loop = asyncio.get_event_loop()
self._conn = _LoopBoundConnection(self)
loop.create_task(self._receive_message_loop())
loop.run_forever()

async def _handle_message(self, _msg):
return None

async def _disconnect(self):
self.disconnect_calls += 1
self.closed = True
self._conn = None

fake_lark = types.ModuleType("lark_oapi")
fake_lark.LogLevel = types.SimpleNamespace(WARNING="warning")
fake_ws_client = types.ModuleType("lark_oapi.ws.client")
fake_ws_client.Client = _FakeClient
fake_ws_client.loop = None

real_import_module = __import__("importlib").import_module

def fake_import_module(name, package=None):
if name == "lark_oapi":
return fake_lark
if name == "lark_oapi.ws.client":
return fake_ws_client
if name == "lark_oapi.adapter.websocket":
raise ImportError("legacy websocket adapter missing")
return real_import_module(name, package)

monkeypatch.setattr(
"flocks.channel.builtin.feishu.monitor.importlib.import_module",
fake_import_module,
)

ws_client = _build_ws_client(
app_id="app-id",
app_secret="app-secret",
event_handler=lambda _data: None,
domain="https://open.feishu.cn",
)

ws_client.start()
time.sleep(0.1)
ws_client.stop()

fake_client = captured["client"]
assert fake_client.constructed_thread != threading.get_ident()
assert fake_client.disconnect_calls >= 1
assert "attached to a different loop" not in caplog.text


def test_build_ws_client_stop_waits_for_legacy_init_and_skips_start(monkeypatch) -> None:
captured: dict[str, object] = {"start_calls": 0}

class _SlowClient:
def __init__(self, **kwargs):
time.sleep(0.35)
captured.update(kwargs)
captured["client"] = self
self._conn = None
self._auto_reconnect = kwargs["auto_reconnect"]
self.disconnect_calls = 0

def start(self):
captured["start_calls"] += 1
asyncio.get_event_loop().run_forever()

async def _disconnect(self):
self.disconnect_calls += 1
self._conn = None

fake_lark = types.ModuleType("lark_oapi")
fake_lark.LogLevel = types.SimpleNamespace(WARNING="warning")
fake_ws_client = types.ModuleType("lark_oapi.ws.client")
fake_ws_client.Client = _SlowClient
fake_ws_client.loop = None

real_import_module = __import__("importlib").import_module

def fake_import_module(name, package=None):
if name == "lark_oapi":
return fake_lark
if name == "lark_oapi.ws.client":
return fake_ws_client
if name == "lark_oapi.adapter.websocket":
raise ImportError("legacy websocket adapter missing")
return real_import_module(name, package)

monkeypatch.setattr(
"flocks.channel.builtin.feishu.monitor.importlib.import_module",
fake_import_module,
)

ws_client = _build_ws_client(
app_id="app-id",
app_secret="app-secret",
event_handler=lambda _data: None,
domain="https://open.feishu.cn",
)

ws_client.start()
assert ws_client._client is None

ws_client.stop()

fake_client = captured["client"]
assert captured["start_calls"] == 0
assert fake_client.disconnect_calls == 0
assert ws_client._thread is None
assert ws_client._loop is None
assert ws_client._client is None


@pytest.mark.asyncio
async def test_parse_reaction_event_falls_back_to_user_id(monkeypatch) -> None:
from flocks.channel.builtin.feishu.monitor import _parse_reaction_event
Expand Down