Skip to content

Commit

Permalink
fix(api): correct the way the task completion is notified to the even…
Browse files Browse the repository at this point in the history
…t bus (#1301)
  • Loading branch information
laurent-laporte-pro committed Feb 9, 2023
1 parent 12f4e92 commit b9cea1e
Show file tree
Hide file tree
Showing 17 changed files with 152 additions and 79 deletions.
6 changes: 4 additions & 2 deletions antarest/core/cache/business/local_chache.py
Expand Up @@ -20,11 +20,13 @@ class LocalCacheElement(BaseModel):

class LocalCache(ICache):
def __init__(self, config: CacheConfig = CacheConfig()):
self.cache: Dict[str, LocalCacheElement] = dict()
self.cache: Dict[str, LocalCacheElement] = {}
self.lock = threading.Lock()
self.checker_delay = config.checker_delay
self.checker_thread = threading.Thread(
target=self.checker, daemon=True
target=self.checker,
name=self.__class__.__name__,
daemon=True,
)

def start(self) -> None:
Expand Down
9 changes: 5 additions & 4 deletions antarest/core/config.py
Expand Up @@ -73,9 +73,9 @@ class WorkspaceConfig:
def from_dict(data: JSON) -> "WorkspaceConfig":
return WorkspaceConfig(
path=Path(data["path"]),
groups=data.get("groups", list()),
groups=data.get("groups", []),
filter_in=data.get("filter_in", [".*"]),
filter_out=data.get("filter_out", list()),
filter_out=data.get("filter_out", []),
)


Expand Down Expand Up @@ -254,7 +254,7 @@ class LoggingConfig:
@staticmethod
def from_dict(data: JSON) -> "LoggingConfig":
logging_config: Dict[str, Any] = data or {}
logfile: Optional[str] = logging_config.get("logfile", None)
logfile: Optional[str] = logging_config.get("logfile")
return LoggingConfig(
logfile=Path(logfile) if logfile is not None else None,
json=logging_config.get("json", False),
Expand Down Expand Up @@ -287,6 +287,7 @@ class EventBusConfig:
Sub config object dedicated to eventbus module
"""

# noinspection PyUnusedLocal
@staticmethod
def from_dict(data: JSON) -> "EventBusConfig":
return EventBusConfig()
Expand All @@ -298,7 +299,7 @@ class CacheConfig:
Sub config object dedicated to cache module
"""

checker_delay: float = 0.2 # in ms
checker_delay: float = 0.2 # in seconds

@staticmethod
def from_dict(data: JSON) -> "CacheConfig":
Expand Down
6 changes: 5 additions & 1 deletion antarest/core/interfaces/service.py
Expand Up @@ -4,7 +4,11 @@

class IService(ABC):
def __init__(self) -> None:
self.thread = threading.Thread(target=self._loop, daemon=True)
self.thread = threading.Thread(
target=self._loop,
name=self.__class__.__name__,
daemon=True,
)

def start(self, threaded: bool = True) -> None:
if threaded:
Expand Down
6 changes: 5 additions & 1 deletion antarest/core/maintenance/service.py
Expand Up @@ -33,7 +33,11 @@ def __init__(
self._init()

def _init(self) -> None:
self.thread = Thread(target=self.check_disk_usage, daemon=True)
self.thread = Thread(
target=self.check_disk_usage,
name=self.__class__.__name__,
daemon=True,
)
self.thread.start()

def check_disk_usage(self) -> None:
Expand Down
7 changes: 5 additions & 2 deletions antarest/eventbus/service.py
Expand Up @@ -124,8 +124,11 @@ def _async_loop(self, new_loop: bool = True) -> None:

def start(self, threaded: bool = True) -> None:
if threaded:
t = threading.Thread(target=self._async_loop)
t.setDaemon(True)
t = threading.Thread(
target=self._async_loop,
name=self.__class__.__name__,
daemon=True,
)
logger.info("Starting event bus")
t.start()
else:
Expand Down
21 changes: 10 additions & 11 deletions antarest/launcher/adapters/local_launcher/local_launcher.py
Expand Up @@ -84,6 +84,7 @@ def run_study(
job_id,
launcher_parameters,
),
name=f"{self.__class__.__name__}-JobRunner",
)
job.start()

Expand Down Expand Up @@ -144,23 +145,21 @@ def stop_reading_output() -> bool:
stop_reading_output,
None,
),
name=f"{self.__class__.__name__}-LogsWatcher",
daemon=True,
)
thread.start()

while True:
if process.poll() is not None:
break
while process.poll() is None:
time.sleep(1)

if launcher_parameters is not None:
if (
launcher_parameters.post_processing
or launcher_parameters.adequacy_patch is not None
):
subprocess.run(
["Rscript", "post-processing.R"], cwd=export_path
)
if launcher_parameters is not None and (
launcher_parameters.post_processing
or launcher_parameters.adequacy_patch is not None
):
subprocess.run(
["Rscript", "post-processing.R"], cwd=export_path
)

output_id: Optional[str] = None
try:
Expand Down
1 change: 1 addition & 0 deletions antarest/launcher/adapters/log_manager.py
Expand Up @@ -32,6 +32,7 @@ def track(
target=lambda: self._follow(
log_path, handler, self._stop_tracking(str(log_path))
),
name=f"{self.__class__.__name__}-LogsWatcher",
daemon=True,
)
self.tracked_logs[str(log_path)] = thread
Expand Down
7 changes: 6 additions & 1 deletion antarest/launcher/adapters/slurm_launcher/slurm_launcher.py
Expand Up @@ -155,7 +155,11 @@ def _loop(self) -> None:
def start(self) -> None:
logger.info("Starting slurm_launcher loop")
self.check_state = True
self.thread = threading.Thread(target=self._loop, daemon=True)
self.thread = threading.Thread(
target=self._loop,
name=self.__class__.__name__,
daemon=True,
)
self.thread.start()

def stop(self) -> None:
Expand Down Expand Up @@ -610,6 +614,7 @@ def run_study(
thread = threading.Thread(
target=self._run_study,
args=(study_uuid, job_id, launcher_parameters, version),
name=f"{self.__class__.__name__}-JobRunner",
)
thread.start()

Expand Down
Expand Up @@ -66,7 +66,10 @@ def async_denormalize(self) -> Thread:
logger.info(
f"Denormalizing (async) study data for study {self.config.study_id}"
)
thread = Thread(target=self._threaded_denormalize)
thread = Thread(
target=self._threaded_denormalize,
name=f"{self.__class__.__name__}-Denormalizer",
)
thread.start()
return thread

Expand Down
1 change: 1 addition & 0 deletions antarest/study/storage/rawstudy/raw_study_service.py
Expand Up @@ -73,6 +73,7 @@ def __init__(
self.path_resources: Path = path_resources
self.cleanup_thread = Thread(
target=RawStudyService.cleanup_lazynode_zipfilelist_cache,
name=f"{self.__class__.__name__}-Cleaner",
daemon=True,
)
self.cleanup_thread.start()
Expand Down
1 change: 1 addition & 0 deletions antarest/worker/archive_worker.py
Expand Up @@ -41,6 +41,7 @@ def __init__(
def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
logger.info(f"Executing task {task_info.json()}")
try:
# sourcery skip: extract-method
archive_args = ArchiveTaskArgs.parse_obj(task_info.task_args)
dest = self.translate_path(Path(archive_args.dest))
src = self.translate_path(Path(archive_args.src))
Expand Down
5 changes: 2 additions & 3 deletions antarest/worker/simulator_worker.py
Expand Up @@ -118,13 +118,12 @@ def stop_reading() -> bool:
stop_reading,
None,
),
name=f"{self.__class__.__name__}-TS-Generator",
daemon=True,
)
thread.start()

while True:
if process.poll() is not None:
break
while process.poll() is None:
time.sleep(1)

result.success = process.returncode == 0
Expand Down
101 changes: 63 additions & 38 deletions antarest/worker/worker.py
@@ -1,10 +1,8 @@
import logging
import threading
import time
from abc import abstractmethod
from concurrent.futures import Future, ThreadPoolExecutor
from threading import Thread
from typing import Dict, List, Union
from concurrent.futures import ThreadPoolExecutor, Future
from typing import Dict, List, Union, Any

from antarest.core.interfaces.eventbus import Event, EventType, IEventBus
from antarest.core.interfaces.service import IService
Expand All @@ -28,23 +26,70 @@ class WorkerTaskCommand(BaseModel):
task_args: Dict[str, Union[int, float, bool, str]]


class _WorkerTaskEndedCallback:
"""
Callback function which uses the event bus to notify
that the worker task is completed (or cancelled).
"""

def __init__(
self,
event_bus: IEventBus,
task_id: str,
) -> None:
self._event_bus = event_bus
self._task_id = task_id

# NOTE: it seems that mypy has an issue with `concurrent.futures.Future`,
# for this reason we have annotated the `future` parameter with a string.
def __call__(self, future: "Future[Any]") -> None:
result = future.result()
event = Event(
type=EventType.WORKER_TASK_ENDED,
payload=WorkerTaskResult(
task_id=self._task_id, task_result=result
),
# Use `NONE` for internal events
permissions=PermissionInfo(public_mode=PublicMode.NONE),
)
self._event_bus.push(event)


# fixme: `AbstractWorker` should not inherit from `IService`
class AbstractWorker(IService):
def __init__(
self, name: str, event_bus: IEventBus, accept: List[str]
self,
name: str,
event_bus: IEventBus,
accept: List[str],
) -> None:
super(AbstractWorker, self).__init__()
super().__init__()
# fixme: `AbstractWorker` should not have any `thread` attribute
del self.thread
self.name = name
self.event_bus = event_bus
for task_type in accept:
self.event_bus.add_queue_consumer(self.listen_for_tasks, task_type)
self.accept = accept
self.threadpool = ThreadPoolExecutor(
max_workers=MAX_WORKERS, thread_name_prefix="workertask_"
max_workers=MAX_WORKERS,
thread_name_prefix="worker_task_",
)
self.task_watcher = Thread(target=self._loop, daemon=True)
self.lock = threading.Lock()
self.futures: Dict[str, Future[TaskResult]] = {}

async def listen_for_tasks(self, event: Event) -> None:
# fixme: `AbstractWorker.start` should not have any `threaded` parameter
def start(self, threaded: bool = True) -> None:
for task_type in self.accept:
self.event_bus.add_queue_consumer(
self._listen_for_tasks, task_type
)
# Wait a short time to allow the event bus to have the opportunity
# to process the tasks as soon as possible
time.sleep(0.01)

# fixme: `AbstractWorker` should not have any `_loop` function
def _loop(self) -> None:
pass

async def _listen_for_tasks(self, event: Event) -> None:
logger.info(f"Accepting new task {event.json()}")
task_info = WorkerTaskCommand.parse_obj(event.payload)
self.event_bus.push(
Expand All @@ -56,11 +101,13 @@ async def listen_for_tasks(self, event: Event) -> None:
)
)
with self.lock:
self.futures[task_info.task_id] = self.threadpool.submit(
self.safe_execute_task, task_info
)
# fmt: off
future = self.threadpool.submit(self._safe_execute_task, task_info)
callback = _WorkerTaskEndedCallback(self.event_bus, task_info.task_id)
future.add_done_callback(callback)
# fmt: on

def safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
def _safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
try:
return self.execute_task(task_info)
except Exception as e:
Expand All @@ -70,27 +117,5 @@ def safe_execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
)
return TaskResult(success=False, message=repr(e))

@abstractmethod
def execute_task(self, task_info: WorkerTaskCommand) -> TaskResult:
raise NotImplementedError()

def _loop(self) -> None:
while True:
with self.lock:
for task_id, future in list(self.futures.items()):
if future.done():
self.event_bus.push(
Event(
type=EventType.WORKER_TASK_ENDED,
payload=WorkerTaskResult(
task_id=task_id,
task_result=future.result(),
),
# Use `NONE` for internal events
permissions=PermissionInfo(
public_mode=PublicMode.NONE
),
)
)
del self.futures[task_id]
time.sleep(2)
17 changes: 10 additions & 7 deletions tests/conftest.py
@@ -1,6 +1,6 @@
import sys
import time
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from functools import wraps
from pathlib import Path
from typing import Any, Callable
Expand All @@ -24,16 +24,17 @@ def project_path() -> Path:

def with_db_context(f):
@wraps(f)
def wrapper(*args, **kwds):
def wrapper(*args, **kwargs):
engine = create_engine("sqlite:///:memory:", echo=True)
Base.metadata.create_all(engine)
# noinspection PyTypeChecker
DBSessionMiddleware(
Mock(),
custom_engine=engine,
session_args={"autocommit": False, "autoflush": False},
)
with db():
return f(*args, **kwds)
return f(*args, **kwargs)

return wrapper

Expand Down Expand Up @@ -80,10 +81,12 @@ def assert_study(a: SUB_JSON, b: SUB_JSON) -> None:
_assert_others(a, b)


def autoretry_assert(func: Callable[..., bool], timeout: int) -> None:
threshold = datetime.utcnow() + timedelta(seconds=timeout)
while datetime.utcnow() < threshold:
if func():
def auto_retry_assert(
predicate: Callable[..., bool], timeout: int = 2
) -> None:
threshold = datetime.now(timezone.utc) + timedelta(seconds=timeout)
while datetime.now(timezone.utc) < threshold:
if predicate():
return
time.sleep(0.2)
raise AssertionError()

0 comments on commit b9cea1e

Please sign in to comment.