diff --git a/antarest/core/cache/business/local_chache.py b/antarest/core/cache/business/local_chache.py index a3275b5d7f..f6b544260a 100644 --- a/antarest/core/cache/business/local_chache.py +++ b/antarest/core/cache/business/local_chache.py @@ -20,11 +20,13 @@ class LocalCacheElement(BaseModel): class LocalCache(ICache): def __init__(self, config: CacheConfig = CacheConfig()): - self.cache: Dict[str, LocalCacheElement] = dict() + self.cache: Dict[str, LocalCacheElement] = {} self.lock = threading.Lock() self.checker_delay = config.checker_delay self.checker_thread = threading.Thread( - target=self.checker, daemon=True + target=self.checker, + name=self.__class__.__name__, + daemon=True, ) def start(self) -> None: diff --git a/antarest/core/config.py b/antarest/core/config.py index 7bf772ef19..86acdd8bfd 100644 --- a/antarest/core/config.py +++ b/antarest/core/config.py @@ -73,9 +73,9 @@ class WorkspaceConfig: def from_dict(data: JSON) -> "WorkspaceConfig": return WorkspaceConfig( path=Path(data["path"]), - groups=data.get("groups", list()), + groups=data.get("groups", []), filter_in=data.get("filter_in", [".*"]), - filter_out=data.get("filter_out", list()), + filter_out=data.get("filter_out", []), ) @@ -254,7 +254,7 @@ class LoggingConfig: @staticmethod def from_dict(data: JSON) -> "LoggingConfig": logging_config: Dict[str, Any] = data or {} - logfile: Optional[str] = logging_config.get("logfile", None) + logfile: Optional[str] = logging_config.get("logfile") return LoggingConfig( logfile=Path(logfile) if logfile is not None else None, json=logging_config.get("json", False), @@ -287,6 +287,7 @@ class EventBusConfig: Sub config object dedicated to eventbus module """ + # noinspection PyUnusedLocal @staticmethod def from_dict(data: JSON) -> "EventBusConfig": return EventBusConfig() @@ -298,7 +299,7 @@ class CacheConfig: Sub config object dedicated to cache module """ - checker_delay: float = 0.2 # in ms + checker_delay: float = 0.2 # in seconds @staticmethod def from_dict(data: JSON) -> "CacheConfig": diff --git a/antarest/core/interfaces/service.py b/antarest/core/interfaces/service.py index 3696e71e5f..7adacc6182 100644 --- a/antarest/core/interfaces/service.py +++ b/antarest/core/interfaces/service.py @@ -4,7 +4,11 @@ class IService(ABC): def __init__(self) -> None: - self.thread = threading.Thread(target=self._loop, daemon=True) + self.thread = threading.Thread( + target=self._loop, + name=self.__class__.__name__, + daemon=True, + ) def start(self, threaded: bool = True) -> None: if threaded: diff --git a/antarest/core/maintenance/service.py b/antarest/core/maintenance/service.py index ca1ad5a704..f252591013 100644 --- a/antarest/core/maintenance/service.py +++ b/antarest/core/maintenance/service.py @@ -33,7 +33,11 @@ def __init__( self._init() def _init(self) -> None: - self.thread = Thread(target=self.check_disk_usage, daemon=True) + self.thread = Thread( + target=self.check_disk_usage, + name=self.__class__.__name__, + daemon=True, + ) self.thread.start() def check_disk_usage(self) -> None: diff --git a/antarest/eventbus/service.py b/antarest/eventbus/service.py index 870859dc8e..8411547f8d 100644 --- a/antarest/eventbus/service.py +++ b/antarest/eventbus/service.py @@ -124,8 +124,11 @@ def _async_loop(self, new_loop: bool = True) -> None: def start(self, threaded: bool = True) -> None: if threaded: - t = threading.Thread(target=self._async_loop) - t.setDaemon(True) + t = threading.Thread( + target=self._async_loop, + name=self.__class__.__name__, + daemon=True, + ) logger.info("Starting event bus") t.start() else: diff --git a/antarest/launcher/adapters/local_launcher/local_launcher.py b/antarest/launcher/adapters/local_launcher/local_launcher.py index 4b1aa6349c..84a9b6fff3 100644 --- a/antarest/launcher/adapters/local_launcher/local_launcher.py +++ b/antarest/launcher/adapters/local_launcher/local_launcher.py @@ -84,6 +84,7 @@ def run_study( job_id, launcher_parameters, ), + name=f"{self.__class__.__name__}-JobRunner", ) job.start() @@ -144,23 +145,21 @@ def stop_reading_output() -> bool: stop_reading_output, None, ), + name=f"{self.__class__.__name__}-LogsWatcher", daemon=True, ) thread.start() - while True: - if process.poll() is not None: - break + while process.poll() is None: time.sleep(1) - if launcher_parameters is not None: - if ( - launcher_parameters.post_processing - or launcher_parameters.adequacy_patch is not None - ): - subprocess.run( - ["Rscript", "post-processing.R"], cwd=export_path - ) + if launcher_parameters is not None and ( + launcher_parameters.post_processing + or launcher_parameters.adequacy_patch is not None + ): + subprocess.run( + ["Rscript", "post-processing.R"], cwd=export_path + ) output_id: Optional[str] = None try: diff --git a/antarest/launcher/adapters/log_manager.py b/antarest/launcher/adapters/log_manager.py index a7d53b16f8..afac2a2da1 100644 --- a/antarest/launcher/adapters/log_manager.py +++ b/antarest/launcher/adapters/log_manager.py @@ -32,6 +32,7 @@ def track( target=lambda: self._follow( log_path, handler, self._stop_tracking(str(log_path)) ), + name=f"{self.__class__.__name__}-LogsWatcher", daemon=True, ) self.tracked_logs[str(log_path)] = thread diff --git a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py index 0b2233e62a..8cc2dc80bd 100644 --- a/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py +++ b/antarest/launcher/adapters/slurm_launcher/slurm_launcher.py @@ -155,7 +155,11 @@ def _loop(self) -> None: def start(self) -> None: logger.info("Starting slurm_launcher loop") self.check_state = True - self.thread = threading.Thread(target=self._loop, daemon=True) + self.thread = threading.Thread( + target=self._loop, + name=self.__class__.__name__, + daemon=True, + ) self.thread.start() def stop(self) -> None: @@ -610,6 +614,7 @@ def run_study( thread = threading.Thread( target=self._run_study, args=(study_uuid, job_id, launcher_parameters, version), + name=f"{self.__class__.__name__}-JobRunner", ) thread.start() diff --git a/antarest/study/storage/rawstudy/model/filesystem/root/filestudytree.py b/antarest/study/storage/rawstudy/model/filesystem/root/filestudytree.py index 39b38c0ea7..2c1ed018ce 100644 --- a/antarest/study/storage/rawstudy/model/filesystem/root/filestudytree.py +++ b/antarest/study/storage/rawstudy/model/filesystem/root/filestudytree.py @@ -66,7 +66,10 @@ def async_denormalize(self) -> Thread: logger.info( f"Denormalizing (async) study data for study {self.config.study_id}" ) - thread = Thread(target=self._threaded_denormalize) + thread = Thread( + target=self._threaded_denormalize, + name=f"{self.__class__.__name__}-Denormalizer", + ) thread.start() return thread diff --git a/antarest/study/storage/rawstudy/raw_study_service.py b/antarest/study/storage/rawstudy/raw_study_service.py index a5cbd0394a..c56ace5dc5 100644 --- a/antarest/study/storage/rawstudy/raw_study_service.py +++ b/antarest/study/storage/rawstudy/raw_study_service.py @@ -73,6 +73,7 @@ def __init__( self.path_resources: Path = path_resources self.cleanup_thread = Thread( target=RawStudyService.cleanup_lazynode_zipfilelist_cache, + name=f"{self.__class__.__name__}-Cleaner", daemon=True, ) self.cleanup_thread.start() diff --git a/antarest/worker/archive_worker.py b/antarest/worker/archive_worker.py index 9a09e15821..83d56d9c7d 100644 --- a/antarest/worker/archive_worker.py +++ b/antarest/worker/archive_worker.py @@ -41,6 +41,7 @@ def __init__( def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: logger.info(f"Executing task {task_info.json()}") try: + # sourcery skip: extract-method archive_args = ArchiveTaskArgs.parse_obj(task_info.task_args) dest = self.translate_path(Path(archive_args.dest)) src = self.translate_path(Path(archive_args.src)) diff --git a/antarest/worker/simulator_worker.py b/antarest/worker/simulator_worker.py index 5dc9a6ee2d..e419a54178 100644 --- a/antarest/worker/simulator_worker.py +++ b/antarest/worker/simulator_worker.py @@ -118,13 +118,12 @@ def stop_reading() -> bool: stop_reading, None, ), + name=f"{self.__class__.__name__}-TS-Generator", daemon=True, ) thread.start() - while True: - if process.poll() is not None: - break + while process.poll() is None: time.sleep(1) result.success = process.returncode == 0 diff --git a/antarest/worker/worker.py b/antarest/worker/worker.py index 701ee39b61..ced8b7d834 100644 --- a/antarest/worker/worker.py +++ b/antarest/worker/worker.py @@ -1,10 +1,8 @@ import logging import threading import time -from abc import abstractmethod -from concurrent.futures import Future, ThreadPoolExecutor -from threading import Thread -from typing import Dict, List, Union +from concurrent.futures import ThreadPoolExecutor, Future +from typing import Dict, List, Union, Any from antarest.core.interfaces.eventbus import Event, EventType, IEventBus from antarest.core.interfaces.service import IService @@ -28,23 +26,70 @@ class WorkerTaskCommand(BaseModel): task_args: Dict[str, Union[int, float, bool, str]] +class _WorkerTaskEndedCallback: + """ + Callback function which uses the event bus to notify + that the worker task is completed (or cancelled). + """ + + def __init__( + self, + event_bus: IEventBus, + task_id: str, + ) -> None: + self._event_bus = event_bus + self._task_id = task_id + + # NOTE: it seems that mypy has an issue with `concurrent.futures.Future`, + # for this reason we have annotated the `future` parameter with a string. + def __call__(self, future: "Future[Any]") -> None: + result = future.result() + event = Event( + type=EventType.WORKER_TASK_ENDED, + payload=WorkerTaskResult( + task_id=self._task_id, task_result=result + ), + # Use `NONE` for internal events + permissions=PermissionInfo(public_mode=PublicMode.NONE), + ) + self._event_bus.push(event) + + +# fixme: `AbstractWorker` should not inherit from `IService` class AbstractWorker(IService): def __init__( - self, name: str, event_bus: IEventBus, accept: List[str] + self, + name: str, + event_bus: IEventBus, + accept: List[str], ) -> None: - super(AbstractWorker, self).__init__() + super().__init__() + # fixme: `AbstractWorker` should not have any `thread` attribute + del self.thread self.name = name self.event_bus = event_bus - for task_type in accept: - self.event_bus.add_queue_consumer(self.listen_for_tasks, task_type) + self.accept = accept self.threadpool = ThreadPoolExecutor( - max_workers=MAX_WORKERS, thread_name_prefix="workertask_" + max_workers=MAX_WORKERS, + thread_name_prefix="worker_task_", ) - self.task_watcher = Thread(target=self._loop, daemon=True) self.lock = threading.Lock() - self.futures: Dict[str, Future[TaskResult]] = {} - async def listen_for_tasks(self, event: Event) -> None: + # fixme: `AbstractWorker.start` should not have any `threaded` parameter + def start(self, threaded: bool = True) -> None: + for task_type in self.accept: + self.event_bus.add_queue_consumer( + self._listen_for_tasks, task_type + ) + # Wait a short time to allow the event bus to have the opportunity + # to process the tasks as soon as possible + time.sleep(0.01) + + # fixme: `AbstractWorker` should not have any `_loop` function + def _loop(self) -> None: + pass + + async def _listen_for_tasks(self, event: Event) -> None: logger.info(f"Accepting new task {event.json()}") task_info = WorkerTaskCommand.parse_obj(event.payload) self.event_bus.push( @@ -56,11 +101,13 @@ async def listen_for_tasks(self, event: Event) -> None: ) ) with self.lock: - self.futures[task_info.task_id] = self.threadpool.submit( - self.safe_execute_task, task_info - ) + # fmt: off + future = self.threadpool.submit(self._safe_execute_task, task_info) + callback = _WorkerTaskEndedCallback(self.event_bus, task_info.task_id) + future.add_done_callback(callback) + # fmt: on - def safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + def _safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: try: return self.execute_task(task_info) except Exception as e: @@ -70,27 +117,5 @@ def safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: ) return TaskResult(success=False, message=repr(e)) - @abstractmethod def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: raise NotImplementedError() - - def _loop(self) -> None: - while True: - with self.lock: - for task_id, future in list(self.futures.items()): - if future.done(): - self.event_bus.push( - Event( - type=EventType.WORKER_TASK_ENDED, - payload=WorkerTaskResult( - task_id=task_id, - task_result=future.result(), - ), - # Use `NONE` for internal events - permissions=PermissionInfo( - public_mode=PublicMode.NONE - ), - ) - ) - del self.futures[task_id] - time.sleep(2) diff --git a/tests/conftest.py b/tests/conftest.py index f8384538de..7eff60d0a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import sys import time -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from functools import wraps from pathlib import Path from typing import Any, Callable @@ -24,16 +24,17 @@ def project_path() -> Path: def with_db_context(f): @wraps(f) - def wrapper(*args, **kwds): + def wrapper(*args, **kwargs): engine = create_engine("sqlite:///:memory:", echo=True) Base.metadata.create_all(engine) + # noinspection PyTypeChecker DBSessionMiddleware( Mock(), custom_engine=engine, session_args={"autocommit": False, "autoflush": False}, ) with db(): - return f(*args, **kwds) + return f(*args, **kwargs) return wrapper @@ -80,10 +81,12 @@ def assert_study(a: SUB_JSON, b: SUB_JSON) -> None: _assert_others(a, b) -def autoretry_assert(func: Callable[..., bool], timeout: int) -> None: - threshold = datetime.utcnow() + timedelta(seconds=timeout) - while datetime.utcnow() < threshold: - if func(): +def auto_retry_assert( + predicate: Callable[..., bool], timeout: int = 2 +) -> None: + threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout) + while datetime.now(timezone.utc) < threshold: + if predicate(): return time.sleep(0.2) raise AssertionError() diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 415fd016f9..31ef310a1c 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -1,5 +1,6 @@ import datetime from pathlib import Path +import time from typing import Callable, List from unittest.mock import ANY, Mock, call @@ -271,6 +272,8 @@ def __init__( self.tmp_path = tmp_path def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + # simulate a "long" task ;-) + time.sleep(0.01) relative_path = task_info.task_args["file"] (self.tmp_path / relative_path).touch() return TaskResult(success=True, message="") diff --git a/tests/eventbus/test_service.py b/tests/eventbus/test_service.py index 39c18a9e47..542daa5bd3 100644 --- a/tests/eventbus/test_service.py +++ b/tests/eventbus/test_service.py @@ -5,7 +5,7 @@ from antarest.core.interfaces.eventbus import Event, EventType from antarest.core.model import PermissionInfo, PublicMode from antarest.eventbus.main import build_eventbus -from tests.conftest import autoretry_assert +from tests.conftest import auto_retry_assert def test_service_factory(): @@ -53,7 +53,7 @@ async def _append_to_bucket(event: Event): permissions=PermissionInfo(public_mode=PublicMode.READ), ) ) - autoretry_assert(lambda: len(test_bucket) == 3, 2) + auto_retry_assert(lambda: len(test_bucket) == 3, timeout=2) event_bus.remove_listener(lid1) event_bus.remove_listener(lid2) @@ -65,7 +65,7 @@ async def _append_to_bucket(event: Event): permissions=PermissionInfo(public_mode=PublicMode.READ), ) ) - autoretry_assert(lambda: len(test_bucket) == 0, 2) + auto_retry_assert(lambda: len(test_bucket) == 0, timeout=2) queue_name = "some work job" event_bus.add_queue_consumer(append_to_bucket(test_bucket), queue_name) @@ -80,4 +80,4 @@ async def _append_to_bucket(event: Event): ), queue_name, ) - autoretry_assert(lambda: len(test_bucket) == 1, 2) + auto_retry_assert(lambda: len(test_bucket) == 1, timeout=2) diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py index ef9d468261..12053b225a 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_worker.py @@ -1,3 +1,4 @@ +import time from pathlib import Path from typing import List from unittest.mock import MagicMock @@ -8,7 +9,7 @@ from antarest.core.tasks.model import TaskResult from antarest.eventbus.main import build_eventbus from antarest.worker.worker import AbstractWorker, WorkerTaskCommand -from tests.conftest import autoretry_assert +from tests.conftest import auto_retry_assert class DummyWorker(AbstractWorker): @@ -19,6 +20,8 @@ def __init__( self.tmp_path = tmp_path def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult: + # simulate a "long" task ;-) + time.sleep(0.01) relative_path = task_info.task_args["file"] (self.tmp_path / relative_path).touch() return TaskResult(success=True, message="") @@ -40,9 +43,25 @@ def test_simple_task(tmp_path: Path): task_queue, ) - assert not (tmp_path / "foo").exists() + # Add some listeners to debug the event bus notifications + msg = [] + async def notify(event: Event): + msg.append(event.type.value) + + event_bus.add_listener(notify, [EventType.WORKER_TASK_STARTED]) + event_bus.add_listener(notify, [EventType.WORKER_TASK_ENDED]) + + # Initialize and start a worker worker = DummyWorker(event_bus, [task_queue], tmp_path) - worker.start(threaded=True) + worker.start() + + # Wait for the end of the processing + # Set a big value to `timeout` if you want to debug the worker + auto_retry_assert(lambda: (tmp_path / "foo").exists(), timeout=60) + + # Wait a short time to allow the event bus to have the opportunity + # to process the notification of the end event. + time.sleep(0.01) - autoretry_assert(lambda: (tmp_path / "foo").exists(), 2) + assert msg == ["WORKER_TASK_STARTED", "WORKER_TASK_ENDED"]