From 65747abcde1283356465cfc9836bd600ff354535 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Tue, 28 Jun 2022 21:50:35 +0200 Subject: [PATCH] feat(#1602): new rubrix dataset listeners (#1507, #1586, #1583, #1596) (cherry picked from commit b1658da3bb00232a78814167e0c539222767b802) - refactor(listener): support dynamic query parameters (cherry picked from commit 2f54f5bb67ce0d8e1c395a46c337fbbaf63ce8e0) - docs(listeners): add to python reference (cherry picked from commit 3e225fe3f508c1d1fd95a4f546143470032db9b0) - fix: include missing packages for rb.listener reference (cherry picked from commit 870e71df17b8c9937ba7211d0aaa4bf0e009b16e) --- docs/reference/python/index.rst | 2 + docs/reference/python/python_listeners.rst | 12 + environment_dev.yml | 2 +- environment_docs.yml | 2 +- pyproject.toml | 4 + src/rubrix/__init__.py | 2 + src/rubrix/client/api.py | 10 + src/rubrix/client/apis/metrics.py | 26 +++ src/rubrix/client/apis/searches.py | 70 ++++++ src/rubrix/client/sdk/client.py | 4 +- src/rubrix/listeners/__init__.py | 2 + src/rubrix/listeners/listener.py | 242 +++++++++++++++++++++ src/rubrix/listeners/models.py | 70 ++++++ tests/listeners/__init__.py | 0 tests/listeners/test_listener.py | 84 +++++++ 15 files changed, 528 insertions(+), 4 deletions(-) create mode 100644 docs/reference/python/python_listeners.rst create mode 100644 src/rubrix/client/apis/metrics.py create mode 100644 src/rubrix/client/apis/searches.py create mode 100644 src/rubrix/listeners/__init__.py create mode 100644 src/rubrix/listeners/listener.py create mode 100644 src/rubrix/listeners/models.py create mode 100644 tests/listeners/__init__.py create mode 100644 tests/listeners/test_listener.py diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst index 031b3e497b..7944ae5fdf 100644 --- a/docs/reference/python/index.rst +++ b/docs/reference/python/index.rst @@ -8,6 +8,7 @@ The python reference guide for Rubrix. This section contains: * :ref:`python_client`: The base client module * :ref:`python_metrics`: The module for dataset metrics * :ref:`python_labeling`: A toolbox to enhance your labeling workflow (weak labels, noisy labels, etc.) +* :ref:`python_listeners`: This module contains all you need to define and configure dataset rubrix listeners .. toctree:: :maxdepth: 2 @@ -17,3 +18,4 @@ The python reference guide for Rubrix. This section contains: python_client python_metrics python_labeling + python_listeners diff --git a/docs/reference/python/python_listeners.rst b/docs/reference/python/python_listeners.rst new file mode 100644 index 0000000000..e475ebc7ac --- /dev/null +++ b/docs/reference/python/python_listeners.rst @@ -0,0 +1,12 @@ +.. _python_listeners: + +Listeners +========= + +Here we describe the Rubrix listeners capabilities + +.. automodule:: rubrix.listeners + :members: listener, RBDatasetListener, Metrics, RBListenerContext, Search + + + diff --git a/environment_dev.yml b/environment_dev.yml index ee9f39278d..d254312228 100644 --- a/environment_dev.yml +++ b/environment_dev.yml @@ -45,4 +45,4 @@ dependencies: - transformers[torch]~=4.18.0 - loguru # install Rubrix in editable mode - - -e .[server] + - -e .[server,listeners] diff --git a/environment_docs.yml b/environment_docs.yml index cb81c27ded..ff97bcbcfc 100644 --- a/environment_docs.yml +++ b/environment_docs.yml @@ -105,4 +105,4 @@ dependencies: - webencodings==0.5.1 - wrapt==1.13.3 - zipp==3.7.0 - - -e . + - -e ".[listeners]" diff --git a/pyproject.toml b/pyproject.toml index ea188309a0..da1f7409a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,10 @@ server = [ "hurry.filesize", # TODO: remove "psutil ~= 5.8.0", ] +listeners = [ + "schedule ~= 1.1.0", + "prodict ~= 0.8.0" +] [project.urls] homepage = "https://www.rubrix.ml" diff --git a/src/rubrix/__init__.py b/src/rubrix/__init__.py index e9592fd6c6..2651bca4e4 100644 --- a/src/rubrix/__init__.py +++ b/src/rubrix/__init__.py @@ -57,6 +57,7 @@ TokenClassificationSettings, configure_dataset, ) + from rubrix.listeners import Metrics, RBListenerContext, Search, listener from rubrix.monitoring.model_monitor import monitor from rubrix.server.server import app @@ -85,6 +86,7 @@ "read_pandas", ], "monitoring.model_monitor": ["monitor"], + "listeners.listener": ["listener", "RBListenerContext", "Search", "Metrics"], "datasets": [ "configure_dataset", "TextClassificationSettings", diff --git a/src/rubrix/client/api.py b/src/rubrix/client/api.py index 9fd0a45b83..c3e95f34db 100644 --- a/src/rubrix/client/api.py +++ b/src/rubrix/client/api.py @@ -30,6 +30,8 @@ RUBRIX_WORKSPACE_HEADER_NAME, ) from rubrix.client.apis.datasets import Datasets +from rubrix.client.apis.metrics import MetricsAPI +from rubrix.client.apis.searches import Searches from rubrix.client.datasets import ( Dataset, DatasetForText2Text, @@ -161,6 +163,14 @@ def client(self): def datasets(self) -> Datasets: return Datasets(client=self._client) + @property + def searches(self): + return Searches(client=self._client) + + @property + def metrics(self): + return MetricsAPI(client=self.client) + def set_workspace(self, workspace: str): """Sets the active workspace. diff --git a/src/rubrix/client/apis/metrics.py b/src/rubrix/client/apis/metrics.py new file mode 100644 index 0000000000..af33013d44 --- /dev/null +++ b/src/rubrix/client/apis/metrics.py @@ -0,0 +1,26 @@ +from typing import Optional + +from rubrix.client.apis import AbstractApi +from rubrix.client.sdk.datasets.models import TaskType + + +class MetricsAPI(AbstractApi): + + _API_URL_PATTERN = "/api/datasets/{task}/{name}/metrics/{metric}:summary" + + def metric_summary( + self, + name: str, + task: TaskType, + metric: str, + query: Optional[str] = None, + **metric_params, + ): + url = self._API_URL_PATTERN.format(task=task, name=name, metric=metric) + metric_params = metric_params or {} + query_params = {k: v for k, v in metric_params.items() if v is not None} + if query_params: + url += "?" + "&".join([f"{k}={v}" for k, v in query_params.items()]) + + metric_summary = self.__client__.post(url, json={"query_text": query}) + return metric_summary diff --git a/src/rubrix/client/apis/searches.py b/src/rubrix/client/apis/searches.py new file mode 100644 index 0000000000..594877350e --- /dev/null +++ b/src/rubrix/client/apis/searches.py @@ -0,0 +1,70 @@ +import dataclasses +from typing import List, Optional + +from rubrix.client.apis import AbstractApi +from rubrix.client.models import Record +from rubrix.client.sdk.datasets.models import TaskType +from rubrix.client.sdk.text2text.models import Text2TextRecord +from rubrix.client.sdk.text_classification.models import TextClassificationRecord +from rubrix.client.sdk.token_classification.models import TokenClassificationRecord + + +@dataclasses.dataclass +class SearchResults: + total: int + + records: List[Record] + + +class Searches(AbstractApi): + + _API_URL_PATTERN = "/api/datasets/{name}/{task}:search" + + def search_records( + self, + name: str, + task: TaskType, + query: Optional[str], + size: Optional[int] = None, + ): + """ + Searches records over a dataset + + Args: + name: The dataset name + task: The dataset task type + query: The query string + size: If provided, only the provided number of records will be fetched + + Returns: + An instance of ``SearchResults`` class containing the search results + """ + + if task == TaskType.text_classification: + record_class = TextClassificationRecord + elif task == TaskType.token_classification: + record_class = TokenClassificationRecord + elif task == TaskType.text2text: + record_class = Text2TextRecord + else: + raise ValueError(f"Task {task} not supported") + + url = self._API_URL_PATTERN.format(name=name, task=task) + if size: + url += f"{url}?size={size}" + + query_request = {} + if query: + query_request["query_text"] = query + + response = self.__client__.post( + path=url, + json={"query": query_request}, + ) + + return SearchResults( + total=response["total"], + records=[ + record_class.parse_obj(r).to_client() for r in response["records"] + ], + ) diff --git a/src/rubrix/client/sdk/client.py b/src/rubrix/client/sdk/client.py index 4c02bf4b00..487d65c686 100644 --- a/src/rubrix/client/sdk/client.py +++ b/src/rubrix/client/sdk/client.py @@ -85,7 +85,7 @@ def post(self, path: str, *args, **kwargs): *args, **kwargs, ) - return build_raw_response(response) + return build_raw_response(response).parsed def put(self, path: str, *args, **kwargs): path = self._normalize_path(path) @@ -99,7 +99,7 @@ def put(self, path: str, *args, **kwargs): *args, **kwargs, ) - return build_raw_response(response) + return build_raw_response(response).parsed @staticmethod def _normalize_path(path: str) -> str: diff --git a/src/rubrix/listeners/__init__.py b/src/rubrix/listeners/__init__.py new file mode 100644 index 0000000000..fb0de89b42 --- /dev/null +++ b/src/rubrix/listeners/__init__.py @@ -0,0 +1,2 @@ +from .listener import RBDatasetListener, listener +from .models import Metrics, RBListenerContext, Search diff --git a/src/rubrix/listeners/listener.py b/src/rubrix/listeners/listener.py new file mode 100644 index 0000000000..deea227e6a --- /dev/null +++ b/src/rubrix/listeners/listener.py @@ -0,0 +1,242 @@ +import dataclasses +import logging +import threading +import time +from typing import Any, Dict, List, Optional + +import schedule + +import rubrix +from rubrix.client import api +from rubrix.client.sdk.commons.errors import NotFoundApiError +from rubrix.listeners.models import ( + ListenerAction, + ListenerCondition, + Metrics, + RBListenerContext, + Search, +) + + +@dataclasses.dataclass +class RBDatasetListener: + """ + The Rubrix dataset listener class + + Args: + dataset: The dataset over which listener is created + action: The action to execute when condition is satisfied + metrics: A list of metrics ids that will be required in condition + query: The query string to apply + query_params: Defined parameters used dynamically in the provided query + condition: The condition to satisfy to execute the action + query_records: If ``False``, the records won't be passed as argument to the action. + Default: ``True`` + interval_in_seconds: How often the listener is executed. Default to 30 seconds + """ + + _LOGGER = logging.getLogger(__name__) + + dataset: str + action: ListenerAction + metrics: Optional[List[str]] = None + query: Optional[str] = None + query_params: Optional[Dict[str, Any]] = None + condition: Optional[ListenerCondition] = None + query_records: bool = True + interval_in_seconds: int = 30 + + @property + def formatted_query(self) -> Optional[str]: + """Formatted query using defined query params, if any""" + if self.query is None: + return None + return self.query.format(**(self.query_params or {})) + + __listener_job__: Optional[schedule.Job] = dataclasses.field( + init=False, default=None + ) + __stop_schedule_event__ = None + __current_thread__ = None + __scheduler__ = schedule.Scheduler() + + def __post_init__(self): + self.metrics = self.metrics or [] + self._validate() + + def _validate(self): + try: + query = self.formatted_query + if query: + self._LOGGER.debug(f"Initial listener query {query}") + except KeyError as kex: + raise KeyError("Missing query parameter:", kex) + + def is_running(self): + """True if listener is running""" + return self.__listener_job__ is not None + + def start(self, *action_args, **action_kwargs): + """ + Start listen to changes in the dataset. Additionally, args and kwargs can be passed to action + by using the `action_*` arguments + + If the listener is already started, a ``ValueError`` will be raised + + """ + if self.is_running(): + raise ValueError("Listener is already running") + + self.__listener_job__ = self.__scheduler__.every( + self.interval_in_seconds + ).seconds.do(self.__listener_iteration_job__, *action_args, **action_kwargs) + + class _ScheduleThread(threading.Thread): + _WAIT_EVENT = threading.Event() + + _THREAD_LOGGER = logging.getLogger(__name__) + + @classmethod + def run(cls): + cls._THREAD_LOGGER.debug("Running listener thread...") + while not cls._WAIT_EVENT.is_set(): + self.__scheduler__.run_pending() + time.sleep(self.interval_in_seconds - 1) + cls._THREAD_LOGGER.debug("Stopping listener thread...") + + @classmethod + def stop(cls): + cls._WAIT_EVENT.set() + + self.__current_thread__ = _ScheduleThread() + self.__current_thread__.start() + + def stop(self): + """ + Stops listener if it's still running. + + If listener is already stopped, a ``ValueError`` will be raised + + """ + if not self.is_running(): + raise ValueError("Listener is not running") + + self.__scheduler__.cancel_job(self.__listener_job__) + self.__listener_job__ = None + self.__current_thread__.stop() + self.__current_thread__.join() # TODO: improve it! + + def __listener_iteration_job__(self, *args, **kwargs): + """ + Execute a complete listener iteration. The iteration consists on: + + 1. Query data and fetch configured metrics + 2. Check search results and metrics with provided condition + 3. Execute the action if condition is satisfied + + """ + current_api = api.active_api() + try: + dataset = current_api.datasets.find_by_name(self.dataset) + self._LOGGER.debug(f"Found listener dataset {dataset.name}") + except NotFoundApiError: + self._LOGGER.warning(f"Not found dataset <{self.dataset}>") + return + + ctx = RBListenerContext( + listener=self, + query_params=self.query_params, + metrics=self.__compute_metrics__( + current_api, dataset, query=self.formatted_query + ), + ) + if self.condition is None: + self._LOGGER.debug("No condition found! Running action...") + return self.__run_action__(ctx, *args, **kwargs) + + search_results = current_api.searches.search_records( + name=self.dataset, task=dataset.task, query=self.formatted_query, size=0 + ) + + ctx.search = Search(total=search_results.total) + condition_args = [ctx.search] + if self.metrics: + condition_args.append(ctx.metrics) + + self._LOGGER.debug(f"Evaluate condition with arguments: {condition_args}") + if self.condition(*condition_args): + self._LOGGER.debug(f"Condition passed! Running action...") + return self.__run_action__(ctx, *args, **kwargs) + + def __compute_metrics__(self, current_api, dataset, query: str) -> Metrics: + metrics = {} + for metric in self.metrics: + metrics.update( + { + metric: current_api.metrics.metric_summary( + name=self.dataset, + task=dataset.task, + metric=metric, + query=query, + ) + } + ) + return Metrics.from_dict(metrics) + + def __run_action__(self, ctx: Optional[RBListenerContext] = None, *args, **kwargs): + try: + action_args = [ctx] if ctx else [] + if self.query_records: + action_args.insert( + 0, + rubrix.load( + name=self.dataset, query=self.formatted_query, as_pandas=False + ), + ) + self._LOGGER.debug(f"Running action with arguments: {action_args}") + return self.action(*args, *action_args, **kwargs) + except: + import traceback + + print(traceback.format_exc()) + return schedule.CancelJob + + +def listener( + dataset: str, + query: Optional[str] = None, + metrics: Optional[List[str]] = None, + condition: Optional[ListenerCondition] = None, + with_records: bool = True, + execution_interval_in_seconds: int = 30, + **query_params, +): + """ + Configures the decorated function as a Rubrix listener. + + Args: + dataset: The dataset name. + query: The query string. + metrics: Required metrics for listener condition. + condition: Defines condition over search and metrics that launch action when is satisfied. + with_records: Include records as part or action arguments. If ``False``, + only the listener context ``RBListenerContext`` will be passed. Default: ``True``. + execution_interval_in_seconds: Define the execution interval in seconds when listener + iteration will be executed. + **query_params: Dynamic parameters used in the query. These parameters will be available + via the listener context and can be updated for subsequent queries. + """ + + def inner_decorator(func): + return RBDatasetListener( + dataset=dataset, + action=func, + condition=condition, + query=query, + query_params=query_params, + metrics=metrics, + query_records=with_records, + interval_in_seconds=execution_interval_in_seconds, + ) + + return inner_decorator diff --git a/src/rubrix/listeners/models.py b/src/rubrix/listeners/models.py new file mode 100644 index 0000000000..a2da985283 --- /dev/null +++ b/src/rubrix/listeners/models.py @@ -0,0 +1,70 @@ +import dataclasses +from typing import Any, Callable, Dict, List, Optional, Union + +from prodict import Prodict + +from rubrix.client.models import Record + + +@dataclasses.dataclass +class Search: + """ + Search results for a single listener execution + + Args: + total: The total number of records affected by the listener query + """ + + total: int + + +class Metrics(Prodict): + """ + Metrics results for a single listener execution. + + The metrics object exposes the metrics configured for the listener as property values. + For example, if you define a listener including the metric "F1", the results will be + accessible as ``metrics.F1`` + """ + + pass + + +@dataclasses.dataclass +class RBListenerContext: + """ + The Rubrix listener execution context. This class keeps the context components related to a listener + + Args: + + listener: The rubrix listener instance + search: Search results for current execution + metrics: Metrics results for current execution + query_params: Dynamic parameters used in the listener query + """ + + listener: "RBDatasetListener" = dataclasses.field(repr=False, hash=False) + search: Optional[Search] = None + metrics: Optional[Metrics] = None + query_params: Optional[Dict[str, Any]] = None + + def __post_init__(self): + self.__listener__ = self.listener + del self.listener + + @property + def dataset(self) -> str: + """Computed property that returns the configured listener dataset name""" + return self.__listener__.dataset + + @property + def query(self) -> Optional[str]: + """Computed property that returns the configured listener query string""" + return self.__listener__.formatted_query + + +ListenerCondition = Callable[[Search, Metrics], bool] +ListenerAction = Union[ + Callable[[List[Record], RBListenerContext], bool], + Callable[[RBListenerContext], bool], +] diff --git a/tests/listeners/__init__.py b/tests/listeners/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/listeners/test_listener.py b/tests/listeners/test_listener.py new file mode 100644 index 0000000000..210b7b386f --- /dev/null +++ b/tests/listeners/test_listener.py @@ -0,0 +1,84 @@ +import time +from typing import List + +import pytest + +import rubrix as rb +from rubrix import RBListenerContext, listener +from rubrix.client.models import Record + + +@pytest.mark.parametrize( + argnames=["dataset", "query", "metrics", "condition", "query_params"], + argvalues=[ + ("dataset", None, ["F1"], None, None), + ("dataset", "val", None, lambda s: True, None), + ("dataset", None, ["F1"], lambda s, m: True, None), + ("dataset", "val", None, None, None), + ("dataset", None, ["F1"], lambda search, metrics: False, None), + ("dataset", "val", None, lambda q: False, None), + ("dataset", "val + {param}", None, lambda q: True, {"param": 100}), + ], +) +def test_listener_with_parameters( + mocked_client, dataset, query, metrics, condition, query_params +): + rb.delete(dataset) + + class TestListener: + executed = False + error = None + + @listener( + dataset=dataset, + query=query, + metrics=metrics, + condition=condition, + execution_interval_in_seconds=1, + **(query_params or {}), + ) + def action(self, records: List[Record], ctx: RBListenerContext): + try: + assert ctx.dataset == dataset + if ctx.query_params: + assert ctx.query == query.format(**ctx.query_params) + if self.executed: + assert ctx.query_params != query_params + ctx.query_params["params"] += 1 + else: + assert ctx.query == query + + self.executed = True + + if metrics: + for metric in metrics: + assert metric in ctx.metrics + except Exception as error: + self.error = error + + test = TestListener() + test.action.start(test) + + time.sleep(1.5) + assert test.action.is_running() + rb.log(rb.TextClassificationRecord(text="This is a text"), name=dataset) + + with pytest.raises(ValueError): + test.action.start() + + time.sleep(1.5) + test.action.stop() + assert not test.action.is_running() + + with pytest.raises(ValueError): + test.action.stop() + + if condition: + res = condition(None, None) if metrics else condition(None) + if not res: + assert not test.executed, "Condition is False but action was executed" + + else: + assert test.executed + if test.error: + raise test.error