diff --git a/docs/guides/monitoring.ipynb b/docs/guides/monitoring.ipynb index a406b9b519..b273581b40 100644 --- a/docs/guides/monitoring.ipynb +++ b/docs/guides/monitoring.ipynb @@ -216,6 +216,120 @@ "dataset.map(make_prediction)" ] }, + { + "cell_type": "markdown", + "id": "6987c362-61b3-4682-aa7b-693bed30d3ae", + "metadata": {}, + "source": [ + "## Using `rb.log` in background mode\n", + "\n", + "You can monitor your own models without adding a response delay by using the `background` param in rb.log\n", + "\n", + "Let's see an example using [BentoML](https://www.bentoml.com/) with a spaCy NER pipeline:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a6e24b1-b665-4e44-be52-62e024927643", + "metadata": {}, + "outputs": [], + "source": [ + "import spacy\n", + "\n", + "nlp = spacy.load(\"en_core_web_sm\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45740aee-e022-452d-8f77-fb8872eb7d11", + "metadata": {}, + "outputs": [], + "source": [ + "%%writefile spacy_model.py\n", + "\n", + "\n", + "from bentoml import BentoService, api, artifacts, env\n", + "from bentoml.adapters import JsonInput\n", + "from bentoml.frameworks.spacy import SpacyModelArtifact\n", + "\n", + "import rubrix as rb\n", + "\n", + "\n", + "@env(infer_pip_packages=True)\n", + "@artifacts([SpacyModelArtifact(\"nlp\")])\n", + "class SpacyNERService(BentoService):\n", + "\n", + " @api(input=JsonInput(), batch=True)\n", + " def predict(self, parsed_json_list):\n", + " result, rb_records = ([], [])\n", + " for index, parsed_json in enumerate(parsed_json_list):\n", + " doc = self.artifacts.nlp(parsed_json[\"text\"])\n", + " prediction = [{\"entity\": ent.text, \"label\": ent.label_} for ent in doc.ents]\n", + " rb_records.append(\n", + " rb.TokenClassificationRecord(\n", + " text=doc.text,\n", + " tokens=[t.text for t in doc],\n", + " prediction=[\n", + " (ent.label_, ent.start_char, ent.end_char) for ent in doc.ents\n", + " ],\n", + " )\n", + " )\n", + " result.append(prediction)\n", + " \n", + " rb.log(\n", + " name=\"monitor-for-spacy-ner\",\n", + " records=rb_records,\n", + " tags={\"framework\": \"bentoml\"},\n", + " background=True, \n", + " verbose=False\n", + " ) # By using the background=True, the model latency won't be affected\n", + " \n", + " return result\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "126b826f-c708-4b1d-898d-960058a7635d", + "metadata": {}, + "outputs": [], + "source": [ + "from spacy_model import SpacyNERService\n", + "\n", + "svc = SpacyNERService()\n", + "svc.pack('nlp', nlp)\n", + "\n", + "saved_path = svc.save()" + ] + }, + { + "cell_type": "markdown", + "id": "b5fadd39-87ed-4b7c-9f62-7cfdba0e2813", + "metadata": {}, + "source": [ + "You can predict some data without serving the model. Just launch following command:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69736fa1-e263-4086-9a53-a0be399b22a1", + "metadata": {}, + "outputs": [], + "source": [ + "!bentoml run SpacyNERService:latest predict --input \"{\\\"text\\\":\\\"I am driving BMW\\\"}\"" + ] + }, + { + "cell_type": "markdown", + "id": "8a1fe7dd-3ef0-42c6-bb2b-25a52a95c582", + "metadata": {}, + "source": [ + "If you're running Rubrix in local, go to http://localhost:6900/datasets/rubrix/monitor-for-spacy-ner and see that the new dataset `monitor-for-spacy-ner` contains your data" + ] + }, { "cell_type": "markdown", "id": "c71b49ea-7384-423d-b2d8-ac6280b7a200", @@ -229,7 +343,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -243,9 +357,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/rublanding/rubrix-landing b/rublanding/rubrix-landing deleted file mode 160000 index 342bc092df..0000000000 --- a/rublanding/rubrix-landing +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 342bc092dfca362d1f1fb89521d1f1f868ada7b2 diff --git a/src/rubrix/__init__.py b/src/rubrix/__init__.py index 82865de196..761e429d3c 100644 --- a/src/rubrix/__init__.py +++ b/src/rubrix/__init__.py @@ -36,6 +36,7 @@ init, load, log, + log_async, set_workspace, ) from rubrix.client.datasets import ( @@ -62,6 +63,7 @@ "init", "load", "log", + "log_async", "set_workspace", ], "client.models": [ diff --git a/src/rubrix/client/api.py b/src/rubrix/client/api.py index 9fd10c2688..e1015b247d 100644 --- a/src/rubrix/client/api.py +++ b/src/rubrix/client/api.py @@ -12,12 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import os import re +from asyncio import Future from functools import wraps from inspect import signature -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import pandas from tqdm.auto import tqdm @@ -42,8 +44,8 @@ TokenClassificationRecord, ) from rubrix.client.sdk.client import AuthenticatedClient +from rubrix.client.sdk.commons.api import async_bulk, bulk from rubrix.client.sdk.commons.errors import RubrixClientError -from rubrix.client.sdk.commons.models import Response from rubrix.client.sdk.datasets import api as datasets_api from rubrix.client.sdk.datasets.models import CopyDatasetRequest, TaskType from rubrix.client.sdk.metrics import api as metrics_api @@ -70,15 +72,45 @@ ) from rubrix.client.sdk.users.api import whoami from rubrix.client.sdk.users.models import User +from rubrix.utils import setup_loop_in_thread _LOGGER = logging.getLogger(__name__) _WARNED_ABOUT_AS_PANDAS = False -# Larger sizes will trigger a warning -_MAX_CHUNK_SIZE = 5000 + +class _RubrixLogAgent: + def __init__(self, api: "Api"): + self.__api__ = api + self.__loop__, self.__thread__ = setup_loop_in_thread() + + @staticmethod + async def __log_internal__(api: "Api", *args, **kwargs): + + try: + return await api.log_async(*args, **kwargs) + except Exception as ex: + _LOGGER.error( + f"Cannot log data {args, kwargs}\n" + f"Error of type {type(ex)}\n: {ex}. ({ex.args})" + ) + raise ex + + def log(self, *args, **kwargs) -> Future: + return asyncio.run_coroutine_threadsafe( + self.__log_internal__(self.__api__, *args, **kwargs), self.__loop__ + ) + + def __del__(self): + self.__loop__.stop() + + del self.__loop__ + del self.__thread__ class Api: + # Larger sizes will trigger a warning + _MAX_CHUNK_SIZE = 5000 + def __init__( self, api_url: Optional[str] = None, @@ -118,6 +150,13 @@ def __init__( if workspace is not None: self.set_workspace(workspace) + self._agent = _RubrixLogAgent(self) + + @property + def client(self): + """The underlying authenticated client""" + return self._client + def set_workspace(self, workspace: str): """Sets the active workspace. @@ -179,8 +218,7 @@ def delete(self, name: str) -> None: >>> import rubrix as rb >>> rb.delete(name="example-dataset") """ - response = datasets_api.delete_dataset(client=self._client, name=name) - self.check_response_errors(response) + datasets_api.delete_dataset(client=self._client, name=name) def log( self, @@ -190,9 +228,12 @@ def log( metadata: Optional[Dict[str, Any]] = None, chunk_size: int = 500, verbose: bool = True, - ) -> BulkResponse: + background: bool = False, + ) -> Union[BulkResponse, Future]: """Logs Records to Rubrix. + The logging happens asynchronously in a background thread. + Args: records: The record, an iterable of records, or a dataset to log. name: The dataset name. @@ -200,9 +241,12 @@ def log( metadata: A dictionary of extra info for the dataset. chunk_size: The chunk size for a data bulk. verbose: If True, shows a progress bar and prints out a quick summary at the end. + background: If True, we will NOT wait for the logging process to finish and return an ``asyncio.Future`` + object. You probably want to set ``verbose`` to False in that case. Returns: - Summary of the response from the REST API + Summary of the response from the REST API. + If the ``background`` argument is set to True, an ``asyncio.Future`` will be returned instead. Examples: >>> import rubrix as rb @@ -213,7 +257,58 @@ def log( >>> rb.log(record, name="example-dataset") 1 records logged to http://localhost:6900/datasets/rubrix/example-dataset BulkResponse(dataset='example-dataset', processed=1, failed=0) + >>> + >>> # Logging records in the background + >>> rb.log(record, name="example-dataset", background=True, verbose=False) + + """ + future = self._agent.log( + records=records, + name=name, + tags=tags, + metadata=metadata, + chunk_size=chunk_size, + verbose=verbose, + ) + if background: + return future + return future.result() + + async def log_async( + self, + records: Union[Record, Iterable[Record], Dataset], + name: str, + tags: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, Any]] = None, + chunk_size: int = 500, + verbose: bool = True, + ) -> BulkResponse: + """Logs Records to Rubrix with asyncio. + + Args: + records: The record, an iterable of records, or a dataset to log. + name: The dataset name. + tags: A dictionary of tags related to the dataset. + metadata: A dictionary of extra info for the dataset. + chunk_size: The chunk size for a data bulk. + verbose: If True, shows a progress bar and prints out a quick summary at the end. + + Returns: + Summary of the response from the REST API + + Examples: + >>> # Log asynchronously from your notebook + >>> import asyncio + >>> import rubrix as rb + >>> from rubrix.utils import setup_loop_in_thread + >>> loop, _ = setup_loop_in_thread() + >>> future_response = asyncio.run_coroutine_threadsafe( + ... rb.log_async(my_records, dataset_name), loop + ... ) """ + tags = tags or {} + metadata = metadata or {} + if not name: raise InputValueError("Empty dataset name has been passed as argument.") @@ -223,65 +318,51 @@ def log( "Please, use a valid name for your dataset" ) + if chunk_size > self._MAX_CHUNK_SIZE: + _LOGGER.warning( + """The introduced chunk size is noticeably large, timeout errors may occur. + Consider a chunk size smaller than %s""", + self._MAX_CHUNK_SIZE, + ) + if isinstance(records, Record.__args__): records = [records] - # this transforms a Dataset* to a list of *Record records = list(records) - tags = tags or {} - metadata = metadata or {} - try: record_type = type(records[0]) except IndexError: raise InputValueError("Empty record list has been passed as argument.") - if chunk_size > _MAX_CHUNK_SIZE: - _LOGGER.warning( - """The introduced chunk size is noticeably large, timeout errors may occur. - Consider a chunk size smaller than %s""", - _MAX_CHUNK_SIZE, - ) - if record_type is TextClassificationRecord: bulk_class = TextClassificationBulkData - bulk_records_function = text_classification_api.bulk - to_sdk_model = CreationTextClassificationRecord.from_client - + creation_class = CreationTextClassificationRecord elif record_type is TokenClassificationRecord: bulk_class = TokenClassificationBulkData - bulk_records_function = token_classification_api.bulk - to_sdk_model = CreationTokenClassificationRecord.from_client - + creation_class = CreationTokenClassificationRecord elif record_type is Text2TextRecord: bulk_class = Text2TextBulkData - bulk_records_function = text2text_api.bulk - to_sdk_model = CreationText2TextRecord.from_client - - # Record type is not recognised + creation_class = CreationText2TextRecord else: raise InputValueError( - f"Unknown record type passed as argument for [{','.join(map(str, records[0:5]))}...] " - f"Available values are {Record.__args__}" + f"Unknown record type {record_type}. Available values are {Record.__args__}" ) - processed = 0 - failed = 0 + processed, failed = 0, 0 progress_bar = tqdm(total=len(records), disable=not verbose) for i in range(0, len(records), chunk_size): chunk = records[i : i + chunk_size] - response = bulk_records_function( + response = await async_bulk( client=self._client, name=name, json_body=bulk_class( tags=tags, metadata=metadata, - records=[to_sdk_model(r) for r in chunk], + records=[creation_class.from_client(r) for r in chunk], ), ) - self.check_response_errors(response) processed += response.parsed.processed failed += response.parsed.failed @@ -290,13 +371,16 @@ def log( # TODO: improve logging policy in library if verbose: + _LOGGER.info( + f"Processed {processed} records in dataset {name}. Failed: {failed}" + ) workspace = self.get_workspace() if ( not workspace ): # Just for backward comp. with datasets with no workspaces workspace = "-" print( - f"{processed} records logged to {self._client.base_url + '/datasets/' + workspace + '/' + name}" + f"{processed} records logged to {self._client.base_url}/datasets/{workspace}/{name}" ) # Creating a composite BulkResponse with the total processed and failed @@ -329,7 +413,6 @@ def load( >>> dataset = rb.load(name="example-dataset") """ response = datasets_api.get_dataset(client=self._client, name=name) - self.check_response_errors(response) task = response.parsed.task task_config = { @@ -364,8 +447,6 @@ def load( limit=limit, ) - self.check_response_errors(response) - records = [sdk_record.to_client() for sdk_record in response.parsed] try: records_sorted_by_id = sorted(records, key=lambda x: x.id) @@ -389,12 +470,9 @@ def load( def dataset_metrics(self, name: str) -> List[MetricInfo]: response = datasets_api.get_dataset(self._client, name) - self.check_response_errors(response) - response = metrics_api.get_dataset_metrics( self._client, name=name, task=response.parsed.task ) - self.check_response_errors(response) return response.parsed @@ -413,7 +491,6 @@ def compute_metric( size: Optional[int] = None, ) -> MetricResults: response = datasets_api.get_dataset(self._client, name) - self.check_response_errors(response) metric_ = self.get_metric(name, metric=metric) assert metric_ is not None, f"Metric {metric} not found !!!" @@ -427,14 +504,13 @@ def compute_metric( interval=interval, size=size, ) - self.check_response_errors(response) + return MetricResults(**metric_.dict(), results=response.parsed) def fetch_dataset_labeling_rules(self, dataset: str) -> List[LabelingRule]: response = text_classification_api.fetch_dataset_labeling_rules( self._client, name=dataset ) - self.check_response_errors(response) return [LabelingRule.parse_obj(data) for data in response.parsed] @@ -444,55 +520,9 @@ def rule_metrics_for_dataset( response = text_classification_api.dataset_rule_metrics( self._client, name=dataset, query=rule.query, label=rule.label ) - self.check_response_errors(response) return LabelingRuleMetricsSummary.parse_obj(response.parsed) - @staticmethod - def check_response_errors(response: Response) -> None: - """Checks response status codes and raise corresponding error if found""" - - http_status = response.status_code - response_data = response.parsed - - if http_status == 401: - raise Exception( - "Unauthorized error: invalid credentials. The API answered with a {} code: {}".format( - http_status, response_data - ) - ) - - elif http_status == 403: - raise Exception( - "Forbidden error: you have not been authorised to access this dataset. " - "The API answered with a {} code: {}".format(http_status, response_data) - ) - - elif http_status == 404: - raise Exception( - "Not found error. The API answered with a {} code: {}".format( - http_status, response_data - ) - ) - - elif http_status == 422: - raise Exception( - "Unprocessable entity error: Something is wrong in your records. " - "The API answered with a {} code: {}".format(http_status, response_data) - ) - - elif 400 <= http_status < 500: - raise Exception( - "Request error: API cannot answer. " - "The API answered with a {} code: {}".format(http_status, response_data) - ) - - elif http_status >= 500: - raise Exception( - "Connection error: API is not responding. " - "The API answered with a {} code: {}".format(http_status, response_data) - ) - __ACTIVE_API__: Optional[Api] = None @@ -515,9 +545,17 @@ def api_wrapper(api_method: Callable): """ def decorator(func): - @wraps(api_method) - def wrapped_func(*args, **kwargs): - return func(*args, **kwargs) + if asyncio.iscoroutinefunction(api_method): + + @wraps(api_method) + async def wrapped_func(*args, **kwargs): + return await func(*args, **kwargs) + + else: + + @wraps(api_method) + def wrapped_func(*args, **kwargs): + return func(*args, **kwargs) sign = signature(api_method) wrapped_func.__signature__ = sign.replace( @@ -559,6 +597,11 @@ def log(*args, **kwargs): return active_api().log(*args, **kwargs) +@api_wrapper(Api.log_async) +def log_async(*args, **kwargs): + return active_api().log_async(*args, **kwargs) + + @api_wrapper(Api.load) def load(*args, **kwargs): return active_api().load(*args, **kwargs) diff --git a/src/rubrix/client/rubrix_client.py b/src/rubrix/client/rubrix_client.py index ee4febeddf..70ef6eb312 100644 --- a/src/rubrix/client/rubrix_client.py +++ b/src/rubrix/client/rubrix_client.py @@ -41,20 +41,19 @@ TokenClassificationRecord, ) from rubrix.client.sdk.client import AuthenticatedClient +from rubrix.client.sdk.commons.api import bulk from rubrix.client.sdk.commons.errors import RubrixClientError from rubrix.client.sdk.commons.models import Response from rubrix.client.sdk.datasets.api import copy_dataset, delete_dataset, get_dataset from rubrix.client.sdk.datasets.models import CopyDatasetRequest, TaskType from rubrix.client.sdk.metrics.api import compute_metric, get_dataset_metrics from rubrix.client.sdk.metrics.models import MetricInfo -from rubrix.client.sdk.text2text.api import bulk as text2text_bulk from rubrix.client.sdk.text2text.api import data as text2text_data from rubrix.client.sdk.text2text.models import ( CreationText2TextRecord, Text2TextBulkData, Text2TextQuery, ) -from rubrix.client.sdk.text_classification.api import bulk as text_classification_bulk from rubrix.client.sdk.text_classification.api import data as text_classification_data from rubrix.client.sdk.text_classification.api import ( dataset_rule_metrics, @@ -67,7 +66,6 @@ TextClassificationBulkData, TextClassificationQuery, ) -from rubrix.client.sdk.token_classification.api import bulk as token_classification_bulk from rubrix.client.sdk.token_classification.api import data as token_classification_data from rubrix.client.sdk.token_classification.models import ( CreationTokenClassificationRecord, @@ -194,17 +192,17 @@ def log( # Check record type if record_type is TextClassificationRecord: bulk_class = TextClassificationBulkData - bulk_records_function = text_classification_bulk + bulk_records_function = bulk to_sdk_model = CreationTextClassificationRecord.from_client elif record_type is TokenClassificationRecord: bulk_class = TokenClassificationBulkData - bulk_records_function = token_classification_bulk + bulk_records_function = bulk to_sdk_model = CreationTokenClassificationRecord.from_client elif record_type is Text2TextRecord: bulk_class = Text2TextBulkData - bulk_records_function = text2text_bulk + bulk_records_function = bulk to_sdk_model = CreationText2TextRecord.from_client # Record type is not recognised diff --git a/src/rubrix/client/sdk/commons/api.py b/src/rubrix/client/sdk/commons/api.py index d3074d6473..f48411bb94 100644 --- a/src/rubrix/client/sdk/commons/api.py +++ b/src/rubrix/client/sdk/commons/api.py @@ -30,6 +30,7 @@ import httpx +from rubrix.client.sdk.client import AuthenticatedClient from rubrix.client.sdk.commons.errors_handler import handle_response_error from rubrix.client.sdk.commons.models import ( BulkResponse, @@ -37,6 +38,56 @@ HTTPValidationError, Response, ) +from rubrix.client.sdk.text2text.models import Text2TextBulkData +from rubrix.client.sdk.text_classification.models import TextClassificationBulkData +from rubrix.client.sdk.token_classification.models import TokenClassificationBulkData + +_TASK_TO_ENDPOINT = { + TextClassificationBulkData: "TextClassification", + TokenClassificationBulkData: "TokenClassification", + Text2TextBulkData: "Text2Text", +} + + +def bulk( + client: AuthenticatedClient, + name: str, + json_body: Union[ + TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData + ], +) -> Response[BulkResponse]: + url = f"{client.base_url}/api/datasets/{name}/{_TASK_TO_ENDPOINT[type(json_body)]}:bulk" + + response = httpx.post( + url=url, + headers=client.get_headers(), + cookies=client.get_cookies(), + timeout=client.get_timeout(), + json=json_body.dict(by_alias=True), + ) + + return build_bulk_response(response, name=name, body=json_body) + + +async def async_bulk( + client: AuthenticatedClient, + name: str, + json_body: Union[ + TextClassificationBulkData, TokenClassificationBulkData, Text2TextBulkData + ], +) -> Response[BulkResponse]: + url = f"{client.base_url}/api/datasets/{name}/{_TASK_TO_ENDPOINT[type(json_body)]}:bulk" + + async with httpx.AsyncClient() as async_client: + response = await async_client.post( + url=url, + headers=client.get_headers(), + cookies=client.get_cookies(), + timeout=client.get_timeout(), + json=json_body.dict(by_alias=True), + ) + + return build_bulk_response(response, name=name, body=json_body) def build_bulk_response( diff --git a/src/rubrix/client/sdk/text2text/api.py b/src/rubrix/client/sdk/text2text/api.py index dfae036f89..b013229844 100644 --- a/src/rubrix/client/sdk/text2text/api.py +++ b/src/rubrix/client/sdk/text2text/api.py @@ -17,36 +17,9 @@ import httpx from rubrix.client.sdk.client import AuthenticatedClient -from rubrix.client.sdk.commons.api import build_bulk_response, build_data_response -from rubrix.client.sdk.commons.models import ( - BulkResponse, - ErrorMessage, - HTTPValidationError, - Response, -) -from rubrix.client.sdk.text2text.models import ( - Text2TextBulkData, - Text2TextQuery, - Text2TextRecord, -) - - -def bulk( - client: AuthenticatedClient, - name: str, - json_body: Text2TextBulkData, -) -> Response[Union[BulkResponse, ErrorMessage, HTTPValidationError]]: - url = "{}/api/datasets/{name}/Text2Text:bulk".format(client.base_url, name=name) - - response = httpx.post( - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=client.get_timeout(), - json=json_body.dict(by_alias=True), - ) - - return build_bulk_response(response, name=name, body=json_body) +from rubrix.client.sdk.commons.api import build_data_response +from rubrix.client.sdk.commons.models import ErrorMessage, HTTPValidationError, Response +from rubrix.client.sdk.text2text.models import Text2TextQuery, Text2TextRecord def data( diff --git a/src/rubrix/client/sdk/text_classification/api.py b/src/rubrix/client/sdk/text_classification/api.py index 31a91a2e15..ddae9bc648 100644 --- a/src/rubrix/client/sdk/text_classification/api.py +++ b/src/rubrix/client/sdk/text_classification/api.py @@ -18,46 +18,19 @@ from rubrix.client.sdk.client import AuthenticatedClient from rubrix.client.sdk.commons.api import ( - build_bulk_response, build_data_response, build_list_response, build_typed_response, ) -from rubrix.client.sdk.commons.models import ( - BulkResponse, - ErrorMessage, - HTTPValidationError, - Response, -) +from rubrix.client.sdk.commons.models import ErrorMessage, HTTPValidationError, Response from rubrix.client.sdk.text_classification.models import ( LabelingRule, LabelingRuleMetricsSummary, - TextClassificationBulkData, TextClassificationQuery, TextClassificationRecord, ) -def bulk( - client: AuthenticatedClient, - name: str, - json_body: TextClassificationBulkData, -) -> Response[BulkResponse]: - url = "{}/api/datasets/{name}/TextClassification:bulk".format( - client.base_url, name=name - ) - - response = httpx.post( - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=client.get_timeout(), - json=json_body.dict(by_alias=True), - ) - - return build_bulk_response(response, name=name, body=json_body) - - def data( client: AuthenticatedClient, name: str, diff --git a/src/rubrix/client/sdk/token_classification/api.py b/src/rubrix/client/sdk/token_classification/api.py index 6a6077008d..093d996caf 100644 --- a/src/rubrix/client/sdk/token_classification/api.py +++ b/src/rubrix/client/sdk/token_classification/api.py @@ -18,40 +18,14 @@ import httpx from rubrix.client.sdk.client import AuthenticatedClient -from rubrix.client.sdk.commons.api import build_bulk_response, build_data_response -from rubrix.client.sdk.commons.models import ( - BulkResponse, - ErrorMessage, - HTTPValidationError, - Response, -) +from rubrix.client.sdk.commons.api import build_data_response +from rubrix.client.sdk.commons.models import ErrorMessage, HTTPValidationError, Response from rubrix.client.sdk.token_classification.models import ( - TokenClassificationBulkData, TokenClassificationQuery, TokenClassificationRecord, ) -def bulk( - client: AuthenticatedClient, - name: str, - json_body: TokenClassificationBulkData, -) -> Response[Union[BulkResponse, ErrorMessage, HTTPValidationError]]: - url = "{}/api/datasets/{name}/TokenClassification:bulk".format( - client.base_url, name=name - ) - - response = httpx.post( - url=url, - headers=client.get_headers(), - cookies=client.get_cookies(), - timeout=client.get_timeout(), - json=json_body.dict(by_alias=True), - ) - - return build_bulk_response(response, name=name, body=json_body) - - def data( client: AuthenticatedClient, name: str, diff --git a/src/rubrix/monitoring/_flair.py b/src/rubrix/monitoring/_flair.py index add3ee23fe..4d52b0489a 100644 --- a/src/rubrix/monitoring/_flair.py +++ b/src/rubrix/monitoring/_flair.py @@ -1,16 +1,15 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union -import rubrix from rubrix import TokenClassificationRecord +from rubrix.client.models import BulkResponse from rubrix.monitoring.base import BaseMonitor from rubrix.monitoring.types import MissingType try: - + from flair import __version__ as _flair_version from flair.data import Sentence from flair.models import SequenceTagger - from flair import __version__ as _flair_version except ModuleNotFoundError: Sentence = MissingType SequenceTagger = MissingType @@ -18,26 +17,30 @@ class FlairMonitor(BaseMonitor): - def _log2rubrix(self, data: List[Tuple[Sentence, Dict[str, Any]]]): - records = [ - TokenClassificationRecord( - text=sentence.to_original_text(), - tokens=[token.text for token in sentence.tokens], - metadata=meta, - prediction_agent=self.agent, - event_timestamp=datetime.utcnow(), - prediction=[ - (label.value, label.span.start_pos, label.span.end_pos, label.score) - for label in sentence.get_labels(self.__model__.tag_type) - ], - ) - for sentence, meta in data - ] - - rubrix.log( - records, + def _prepare_log_data( + self, data: List[Tuple[Sentence, Dict[str, Any]]] + ) -> Dict[str, Any]: + return dict( + records=[ + TokenClassificationRecord( + text=sentence.to_original_text(), + tokens=[token.text for token in sentence.tokens], + metadata=meta, + prediction_agent=self.agent, + event_timestamp=datetime.utcnow(), + prediction=[ + ( + label.value, + label.span.start_pos, + label.span.end_pos, + label.score, + ) + for label in sentence.get_labels(self.__model__.tag_type) + ], + ) + for sentence, meta in data + ], name=self.dataset, - verbose=False, tags={**(self.tags or {}), "flair_version": _flair_version}, ) @@ -57,7 +60,7 @@ def predict(self, sentences: Union[List[Sentence], Sentence], *args, **kwargs): if self.is_record_accepted() ] if filtered_data: - self.log_async(filtered_data) + self._log_future = self.log_async(filtered_data) return result diff --git a/src/rubrix/monitoring/_spacy.py b/src/rubrix/monitoring/_spacy.py index 7f1f96e882..7e42f0e959 100644 --- a/src/rubrix/monitoring/_spacy.py +++ b/src/rubrix/monitoring/_spacy.py @@ -1,7 +1,6 @@ from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple -import rubrix as rb from rubrix import TokenClassificationRecord from rubrix.monitoring.base import BaseMonitor from rubrix.monitoring.types import MissingType @@ -45,22 +44,27 @@ def doc2token_classification( event_timestamp=datetime.utcnow(), ) - def _log2rubrix(self, doc: Doc, metadata: Optional[Dict[str, Any]] = None): - record = self.doc2token_classification( - doc, agent=self.__wrapped__.path.name, metadata=metadata - ) - rb.log( - record, + def _prepare_log_data( + self, docs_info: Tuple[Doc, Optional[Dict[str, Any]]] + ) -> Dict[str, Any]: + + return dict( + records=[ + self.doc2token_classification( + doc, agent=self.__wrapped__.path.name, metadata=metadata + ) + for doc, metadata in docs_info + ], name=self.dataset, - tags={k: v for k, v in self.__wrapped__.meta.items() if isinstance(v, str)}, + tags={k: v for k, v in self.__model__.meta.items() if isinstance(v, str)}, metadata=self.__model__.meta, - verbose=False, ) def pipe(self, *args, **kwargs): as_tuples = kwargs.get("as_tuples") results = self.__model__.pipe(*args, **kwargs) + log_info = [] for r in results: metadata = {} if as_tuples: @@ -68,15 +72,17 @@ def pipe(self, *args, **kwargs): else: doc = r if self.is_record_accepted(): - self.log_async(doc, metadata) + log_info.append((doc, metadata)) yield r + self.log_async(log_info) + def __call__(self, *args, **kwargs): metadata = kwargs.pop("metadata", None) doc = self.__wrapped__(*args, **kwargs) try: if self.is_record_accepted(): - self.log_async(doc, metadata) + self.log_async([(doc, metadata)]) finally: return doc diff --git a/src/rubrix/monitoring/_transformers.py b/src/rubrix/monitoring/_transformers.py index 31e8702c47..49bca11cfa 100644 --- a/src/rubrix/monitoring/_transformers.py +++ b/src/rubrix/monitoring/_transformers.py @@ -3,13 +3,13 @@ from pydantic import BaseModel -import rubrix +import rubrix as rb from rubrix import TextClassificationRecord +from rubrix.client.models import BulkResponse from rubrix.monitoring.base import BaseMonitor from rubrix.monitoring.types import MissingType try: - from transformers import ( Pipeline, TextClassificationPipeline, @@ -34,16 +34,19 @@ def fetch_transformers_version(config) -> Optional[str]: config_dict = config.to_dict() return config_dict.get("transformers_version", _transformers_version) - def _log2rubrix( + @property + def model_config(self): + return self.__model__.model.config + + def _prepare_log_data( self, data: List[Tuple[str, Dict[str, Any], List[LabelPrediction]]], multi_label: bool = False, - ): - """Register a list of tuples including inputs and its predictions for text classification task""" - records = [] - config = self.__model__.model.config - agent = config.name_or_path + ) -> Dict[str, Any]: + + agent = self.model_config.name_or_path + records = [] for input_, metadata, predictions in data: record = TextClassificationRecord( text=input_ if isinstance(input_, str) else None, @@ -62,19 +65,20 @@ def _log2rubrix( if multi_label: dataset_name += "_multi" - rubrix.log( - records, + return dict( + records=records, name=dataset_name, tags={ - "name": config.name_or_path, - "transformers_version": self.fetch_transformers_version(config), - "model_type": config.model_type, + "name": self.model_config.name_or_path, + "transformers_version": self.fetch_transformers_version( + self.model_config + ), + "model_type": self.model_config.model_type, "task": self.__model__.task, }, - metadata=config.to_dict(), + metadata=self.model_config.to_dict(), verbose=False, ) - pass class ZeroShotMonitor(HuggingFaceMonitor): diff --git a/src/rubrix/monitoring/base.py b/src/rubrix/monitoring/base.py index d7d6ec1598..dc4c8036f1 100644 --- a/src/rubrix/monitoring/base.py +++ b/src/rubrix/monitoring/base.py @@ -1,19 +1,11 @@ import asyncio import random -from typing import Dict, Optional +import threading +from typing import Any, Dict, Optional import wrapt -from rubrix.monitoring.helpers import start_loop_in_thread - -_LOGGING_LOOP = None - - -def _get_current_loop(): - global _LOGGING_LOOP - if not _LOGGING_LOOP: - _LOGGING_LOOP = start_loop_in_thread() - return _LOGGING_LOOP +import rubrix class ModelNotSupportedError(Exception): @@ -26,13 +18,10 @@ class BaseMonitor(wrapt.ObjectProxy): Attributes: ----------- - dataset: Rubrix dataset name - sample_rate: The portion of the data to store in Rubrix. Default = 0.2 - """ def __init__( @@ -65,14 +54,21 @@ def is_record_accepted(self) -> bool: """Return True if a record should be logged to rubrix""" return random.uniform(0.0, 1.0) <= self.sample_rate - def _log2rubrix(self, *args, **kwargs): + def _prepare_log_data(self, *args, **kwargs) -> Dict[str, Any]: raise NotImplementedError() def log_async(self, *args, **kwargs): - wrapped_func = self._log2rubrix - loop = _get_current_loop() - - async def f(): - return wrapped_func(*args, **kwargs) - - asyncio.run_coroutine_threadsafe(f(), loop) + log_args = self._prepare_log_data(*args, **kwargs) + log_args.pop("verbose", None) + log_args.pop("background", None) + return rubrix.log(**log_args, verbose=False, background=True) + + def _start_event_loop_if_needed(self): + """Recreate loop/thread if needed""" + if self._event_loop is None: + self._event_loop = asyncio.new_event_loop() + if self._event_loop_thread is None or not self._event_loop_thread.is_alive(): + self._thread = threading.Thread( + target=self._event_loop.run_forever, daemon=True + ) + self._thread.start() diff --git a/src/rubrix/monitoring/helpers.py b/src/rubrix/monitoring/helpers.py deleted file mode 100644 index 944ddad167..0000000000 --- a/src/rubrix/monitoring/helpers.py +++ /dev/null @@ -1,14 +0,0 @@ -def start_loop_in_thread(): - """Launches a asyncio loop in a different thread and start it""" - from threading import Thread - import asyncio - - def start_loop(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - new_loop = asyncio.new_event_loop() - t = Thread(target=start_loop, args=(new_loop,), daemon=True) - t.start() - - return new_loop diff --git a/src/rubrix/server/commons/es_wrapper.py b/src/rubrix/server/commons/es_wrapper.py index ed351eb407..f07a1f44eb 100644 --- a/src/rubrix/server/commons/es_wrapper.py +++ b/src/rubrix/server/commons/es_wrapper.py @@ -69,6 +69,9 @@ def get_instance(cls) -> "ElasticsearchWrapper": es_client = OpenSearch( hosts=settings.elasticsearch, verify_certs=settings.elasticsearch_ssl_verify, + # Extra args to es configuration -> TODO: extensible by settings + retry_on_timeout=True, + max_retries=5, ) cls._INSTANCE = cls(es_client) @@ -198,6 +201,7 @@ def create_index( self.__client__.indices.create( index=index, body={"settings": settings or {}, "mappings": mappings or {}}, + ignore=400, ) def create_index_template( @@ -418,7 +422,11 @@ def update_document( """ if partial_update: self.__client__.update( - index=index, id=doc_id, body={"doc": document}, refresh=True + index=index, + id=doc_id, + body={"doc": document}, + refresh=True, + retry_on_conflict=500, # TODO: configurable ) else: self.__client__.index(index=index, id=doc_id, body=document, refresh=True) diff --git a/src/rubrix/utils.py b/src/rubrix/utils.py index 61d5d3945f..269b2362d7 100644 --- a/src/rubrix/utils.py +++ b/src/rubrix/utils.py @@ -12,12 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import importlib import os +import threading import warnings from itertools import chain from types import ModuleType -from typing import Any, Optional +from typing import Any, Optional, Tuple class _LazyRubrixModule(ModuleType): @@ -115,3 +117,16 @@ def _get_module( def __reduce__(self): return self.__class__, (self._name, self.__file__, self._import_structure) + + +def setup_loop_in_thread() -> Tuple[asyncio.AbstractEventLoop, threading.Thread]: + """Sets up a new asyncio event loop in a new thread, and runs it forever. + + Returns: + A tuple containing the event loop and the thread. + """ + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + + return loop, thread diff --git a/tests/client/sdk/commons/api.py b/tests/client/sdk/commons/api.py index 792ca22dfe..a03696fd6f 100644 --- a/tests/client/sdk/commons/api.py +++ b/tests/client/sdk/commons/api.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import httpx import pytest from httpx import Response as HttpxResponse -from rubrix.client.sdk.commons.api import build_bulk_response, build_data_response +from rubrix.client.sdk.commons.api import build_bulk_response, build_data_response, bulk from rubrix.client.sdk.commons.models import ( BulkResponse, ErrorMessage, @@ -26,6 +27,39 @@ from rubrix.client.sdk.text_classification.models import TextClassificationRecord +def test_text2text_bulk(sdk_client, mocked_client, bulk_text2text_data, monkeypatch): + monkeypatch.setattr(httpx, "post", mocked_client.post) + + dataset_name = "test_dataset" + mocked_client.delete(f"/api/datasets/{dataset_name}") + response = bulk(sdk_client, name=dataset_name, json_body=bulk_text2text_data) + + assert response.status_code == 200 + assert isinstance(response.parsed, BulkResponse) + + +def test_textclass_bulk(sdk_client, mocked_client, bulk_textclass_data, monkeypatch): + monkeypatch.setattr(httpx, "post", mocked_client.post) + + dataset_name = "test_dataset" + mocked_client.delete(f"/api/datasets/{dataset_name}") + response = bulk(sdk_client, name=dataset_name, json_body=bulk_textclass_data) + + assert response.status_code == 200 + assert isinstance(response.parsed, BulkResponse) + + +def test_tokenclass_bulk(sdk_client, mocked_client, bulk_tokenclass_data, monkeypatch): + monkeypatch.setattr(httpx, "post", mocked_client.post) + + dataset_name = "test_dataset" + mocked_client.delete(f"/api/datasets/{dataset_name}") + response = bulk(sdk_client, name=dataset_name, json_body=bulk_tokenclass_data) + + assert response.status_code == 200 + assert isinstance(response.parsed, BulkResponse) + + @pytest.mark.parametrize( "status_code, expected", [ diff --git a/tests/client/sdk/conftest.py b/tests/client/sdk/conftest.py index e9a29309ec..e7e8627368 100644 --- a/tests/client/sdk/conftest.py +++ b/tests/client/sdk/conftest.py @@ -12,11 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime import pytest +import rubrix as rb from rubrix._constants import DEFAULT_API_KEY from rubrix.client.sdk.client import AuthenticatedClient +from rubrix.client.sdk.text2text.models import ( + CreationText2TextRecord, + Text2TextBulkData, +) +from rubrix.client.sdk.text_classification.models import ( + CreationTextClassificationRecord, + TextClassificationBulkData, +) +from rubrix.client.sdk.token_classification.models import ( + CreationTokenClassificationRecord, + TokenClassificationBulkData, +) class Helpers: @@ -45,3 +59,81 @@ def helpers(): @pytest.fixture(scope="session") def sdk_client(): return AuthenticatedClient(base_url="http://localhost:6900", token=DEFAULT_API_KEY) + + +@pytest.fixture +def bulk_textclass_data(): + explanation = { + "text": [rb.TokenAttributions(token="test", attributions={"test": 0.5})] + } + records = [ + rb.TextClassificationRecord( + text="test", + prediction=[("test", 0.5)], + prediction_agent="agent", + annotation="test1", + annotation_agent="agent", + multi_label=False, + explanation=explanation, + id=i, + metadata={"mymetadata": "str"}, + event_timestamp=datetime(2020, 1, 1), + status="Validated", + ) + for i in range(3) + ] + + return TextClassificationBulkData( + records=[CreationTextClassificationRecord.from_client(rec) for rec in records], + tags={"Mytag": "tag"}, + metadata={"MyMetadata": 5}, + ) + + +@pytest.fixture +def bulk_text2text_data(): + records = [ + rb.Text2TextRecord( + text="test", + prediction=[("prueba", 0.5), ("intento", 0.5)], + prediction_agent="agent", + annotation="prueba", + annotation_agent="agent", + id=i, + metadata={"mymetadata": "str"}, + event_timestamp=datetime(2020, 1, 1), + status="Validated", + ) + for i in range(3) + ] + + return Text2TextBulkData( + records=[CreationText2TextRecord.from_client(rec) for rec in records], + tags={"Mytag": "tag"}, + metadata={"MyMetadata": 5}, + ) + + +@pytest.fixture +def bulk_tokenclass_data(): + records = [ + rb.TokenClassificationRecord( + text="a raw text", + tokens=["a", "raw", "text"], + prediction=[("test", 2, 5, 0.9)], + prediction_agent="agent", + annotation=[("test", 2, 5)], + annotation_agent="agent", + id=i, + metadata={"mymetadata": "str"}, + event_timestamp=datetime(2020, 1, 1), + status="Validated", + ) + for i in range(3) + ] + + return TokenClassificationBulkData( + records=[CreationTokenClassificationRecord.from_client(rec) for rec in records], + tags={"Mytag": "tag"}, + metadata={"MyMetadata": 5}, + ) diff --git a/tests/client/sdk/text2text/test_api.py b/tests/client/sdk/text2text/test_api.py index 1e0478a22d..402e6d8ef5 100644 --- a/tests/client/sdk/text2text/test_api.py +++ b/tests/client/sdk/text2text/test_api.py @@ -12,58 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime - import httpx import pytest -from rubrix.client.models import Text2TextRecord as ClientText2TextRecord -from rubrix.client.sdk.commons.models import BulkResponse -from rubrix.client.sdk.text2text.api import bulk, data -from rubrix.client.sdk.text2text.models import ( - CreationText2TextRecord, - Text2TextBulkData, - Text2TextRecord, -) - - -@pytest.fixture -def bulk_data(): - records = [ - ClientText2TextRecord( - text="test", - prediction=[("prueba", 0.5), ("intento", 0.5)], - prediction_agent="agent", - annotation="prueba", - annotation_agent="agent", - id=i, - metadata={"mymetadata": "str"}, - event_timestamp=datetime(2020, 1, 1), - status="Validated", - ) - for i in range(3) - ] - - return Text2TextBulkData( - records=[CreationText2TextRecord.from_client(rec) for rec in records], - tags={"Mytag": "tag"}, - metadata={"MyMetadata": 5}, - ) - - -def test_bulk(sdk_client, mocked_client, bulk_data, monkeypatch): - monkeypatch.setattr(httpx, "post", mocked_client.post) - - dataset_name = "test_dataset" - mocked_client.delete(f"/api/datasets/{dataset_name}") - response = bulk(sdk_client, name=dataset_name, json_body=bulk_data) - - assert response.status_code == 200 - assert isinstance(response.parsed, BulkResponse) +from rubrix.client.sdk.text2text.api import data +from rubrix.client.sdk.text2text.models import Text2TextRecord @pytest.mark.parametrize("limit,expected", [(None, 3), (2, 2)]) -def test_data(limit, mocked_client, expected, sdk_client, bulk_data, monkeypatch): +def test_data( + limit, mocked_client, expected, sdk_client, bulk_text2text_data, monkeypatch +): # TODO: Not sure how to test the streaming part of the response here monkeypatch.setattr(httpx, "stream", mocked_client.stream) @@ -71,7 +30,7 @@ def test_data(limit, mocked_client, expected, sdk_client, bulk_data, monkeypatch mocked_client.delete(f"/api/datasets/{dataset_name}") mocked_client.post( f"/api/datasets/{dataset_name}/Text2Text:bulk", - json=bulk_data.dict(by_alias=True), + json=bulk_text2text_data.dict(by_alias=True), ) response = data(sdk_client, name=dataset_name, limit=limit) diff --git a/tests/client/sdk/text_classification/test_api.py b/tests/client/sdk/text_classification/test_api.py index 3311d212b5..9d8e8bb6ac 100644 --- a/tests/client/sdk/text_classification/test_api.py +++ b/tests/client/sdk/text_classification/test_api.py @@ -12,66 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime - import httpx import pytest -from rubrix.client.models import ( - TextClassificationRecord as ClientTextClassificationRecord, -) -from rubrix.client.models import TokenAttributions -from rubrix.client.sdk.commons.models import BulkResponse -from rubrix.client.sdk.text_classification.api import bulk, data -from rubrix.client.sdk.text_classification.models import ( - CreationTextClassificationRecord, - TextClassificationBulkData, - TextClassificationRecord, -) - - -@pytest.fixture -def bulk_data(): - explanation = { - "text": [TokenAttributions(token="test", attributions={"test": 0.5})] - } - records = [ - ClientTextClassificationRecord( - text="test", - prediction=[("test", 0.5)], - prediction_agent="agent", - annotation="test1", - annotation_agent="agent", - multi_label=False, - explanation=explanation, - id=i, - metadata={"mymetadata": "str"}, - event_timestamp=datetime(2020, 1, 1), - status="Validated", - ) - for i in range(3) - ] - - return TextClassificationBulkData( - records=[CreationTextClassificationRecord.from_client(rec) for rec in records], - tags={"Mytag": "tag"}, - metadata={"MyMetadata": 5}, - ) - - -def test_bulk(sdk_client, mocked_client, bulk_data, monkeypatch): - monkeypatch.setattr(httpx, "post", mocked_client.post) - - dataset_name = "test_dataset" - mocked_client.delete(f"/api/datasets/{dataset_name}") - response = bulk(sdk_client, name=dataset_name, json_body=bulk_data) - - assert response.status_code == 200 - assert isinstance(response.parsed, BulkResponse) +from rubrix.client.sdk.text_classification.api import data +from rubrix.client.sdk.text_classification.models import TextClassificationRecord @pytest.mark.parametrize("limit,expected", [(None, 3), (2, 2)]) -def test_data(mocked_client, limit, expected, bulk_data, sdk_client, monkeypatch): +def test_data( + mocked_client, limit, expected, bulk_textclass_data, sdk_client, monkeypatch +): # TODO: Not sure how to test the streaming part of the response here monkeypatch.setattr(httpx, "stream", mocked_client.stream) @@ -79,7 +30,7 @@ def test_data(mocked_client, limit, expected, bulk_data, sdk_client, monkeypatch mocked_client.delete(f"/api/datasets/{dataset_name}") mocked_client.post( f"/api/datasets/{dataset_name}/TextClassification:bulk", - json=bulk_data.dict(by_alias=True), + json=bulk_textclass_data.dict(by_alias=True), ) response = data(sdk_client, name=dataset_name, limit=limit) diff --git a/tests/client/sdk/token_classification/test_api.py b/tests/client/sdk/token_classification/test_api.py index 5665fc8716..e88aa46acb 100644 --- a/tests/client/sdk/token_classification/test_api.py +++ b/tests/client/sdk/token_classification/test_api.py @@ -12,61 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime - import httpx import pytest -from rubrix.client.models import ( - TokenClassificationRecord as ClientTokenClassificationRecord, -) -from rubrix.client.sdk.commons.models import BulkResponse -from rubrix.client.sdk.token_classification.api import bulk, data -from rubrix.client.sdk.token_classification.models import ( - CreationTokenClassificationRecord, - TokenClassificationBulkData, - TokenClassificationRecord, -) - - -@pytest.fixture -def bulk_data(): - records = [ - ClientTokenClassificationRecord( - text="a raw text", - tokens=["a", "raw", "text"], - prediction=[("test", 2, 5, 0.9)], - prediction_agent="agent", - annotation=[("test", 2, 5)], - annotation_agent="agent", - id=i, - metadata={"mymetadata": "str"}, - event_timestamp=datetime(2020, 1, 1), - status="Validated", - ) - for i in range(3) - ] - - return TokenClassificationBulkData( - records=[CreationTokenClassificationRecord.from_client(rec) for rec in records], - tags={"Mytag": "tag"}, - metadata={"MyMetadata": 5}, - ) - - -def test_bulk(sdk_client, mocked_client, bulk_data, monkeypatch): - monkeypatch.setattr(httpx, "post", mocked_client.post) - - dataset_name = "test_dataset" - mocked_client.delete(f"/api/datasets/{dataset_name}") - response = bulk(sdk_client, name=dataset_name, json_body=bulk_data) - - assert response.status_code == 200 - assert isinstance(response.parsed, BulkResponse) +from rubrix.client.sdk.token_classification.api import data +from rubrix.client.sdk.token_classification.models import TokenClassificationRecord @pytest.mark.parametrize("limit,expected", [(None, 3), (2, 2)]) -def test_data(mocked_client, limit, expected, sdk_client, bulk_data, monkeypatch): +def test_data( + mocked_client, limit, expected, sdk_client, bulk_tokenclass_data, monkeypatch +): # TODO: Not sure how to test the streaming part of the response here monkeypatch.setattr(httpx, "stream", mocked_client.stream) @@ -74,7 +30,7 @@ def test_data(mocked_client, limit, expected, sdk_client, bulk_data, monkeypatch mocked_client.delete(f"/api/datasets/{dataset_name}") mocked_client.post( f"/api/datasets/{dataset_name}/TokenClassification:bulk", - json=bulk_data.dict(by_alias=True), + json=bulk_tokenclass_data.dict(by_alias=True), ) response = data(sdk_client, name=dataset_name, limit=limit) diff --git a/tests/client/test_api.py b/tests/client/test_api.py index cafb7b023a..25394b9b9c 100644 --- a/tests/client/test_api.py +++ b/tests/client/test_api.py @@ -307,7 +307,7 @@ def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_ def test_passing_wrong_iterable_data(mocked_client): dataset_name = "test_log_single_records" mocked_client.delete(f"/api/datasets/{dataset_name}") - with pytest.raises(Exception, match="Unknown record type passed"): + with pytest.raises(Exception, match="Unknown record type"): api.log({"a": "010", "b": 100}, name=dataset_name) diff --git a/tests/conftest.py b/tests/conftest.py index e61a8a7f70..f9a0967724 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,15 +10,17 @@ @pytest.fixture def mocked_client(monkeypatch): - client = SecuredClient(TestClient(app, raise_server_exceptions=False)) + with TestClient(app, raise_server_exceptions=False) as _client: + client = SecuredClient(_client) - monkeypatch.setattr(httpx, "post", client.post) - monkeypatch.setattr(httpx, "get", client.get) - monkeypatch.setattr(httpx, "delete", client.delete) - monkeypatch.setattr(httpx, "put", client.put) - monkeypatch.setattr(httpx, "stream", client.stream) + monkeypatch.setattr(httpx, "post", client.post) + monkeypatch.setattr(httpx.AsyncClient, "post", client.post_async) + monkeypatch.setattr(httpx, "get", client.get) + monkeypatch.setattr(httpx, "delete", client.delete) + monkeypatch.setattr(httpx, "put", client.put) + monkeypatch.setattr(httpx, "stream", client.stream) - return client + yield client @pytest.fixture diff --git a/tests/helpers.py b/tests/helpers.py index 3e4b9c387d..0e7f53e0dc 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -42,6 +42,9 @@ def post(self, *args, **kwargs): headers = {**self._header, **request_headers} return self._client.post(*args, headers=headers, **kwargs) + async def post_async(self, *args, **kwargs): + return self.post(*args, **kwargs) + def get(self, *args, **kwargs): request_headers = kwargs.pop("headers", {}) headers = {**self._header, **request_headers} diff --git a/tests/monitoring/helpers.py b/tests/monitoring/helpers.py index 7601da18e2..b502af6e3c 100644 --- a/tests/monitoring/helpers.py +++ b/tests/monitoring/helpers.py @@ -1,5 +1,12 @@ +import rubrix from rubrix.monitoring.base import BaseMonitor def mock_monitor(monitor: BaseMonitor, monkeypatch): - monkeypatch.setattr(monitor, "log_async", monitor._log2rubrix) + def log(*args, **kwargs): + log_args = monitor._prepare_log_data(*args, **kwargs) + log_args.pop("verbose", None) + log_args.pop("background", None) + return rubrix.log(**log_args, background=False) + + monkeypatch.setattr(monitor, "log_async", log)