From a0fdd8e4eb835d6c6b52f9c3a29d070f4ee77342 Mon Sep 17 00:00:00 2001 From: Francisco Aranda Date: Fri, 25 Mar 2022 13:13:55 +0100 Subject: [PATCH] feat(#1130): cleanup rb namespace by refactoring client API (#1160) --- .gitignore | 3 + pyproject.toml | 1 + src/rubrix/__init__.py | 301 +++------ src/rubrix/client/api.py | 4 +- src/rubrix/client/rubrix_client.py | 16 +- src/rubrix/metrics/commons.py | 8 +- .../metrics/text_classification/metrics.py | 8 +- .../metrics/token_classification/metrics.py | 37 +- src/rubrix/monitoring/asgi.py | 9 +- src/rubrix/server/app.py | 28 + .../security/auth_provider/local/settings.py | 3 +- src/rubrix/utils.py | 117 ++++ tests/client/sdk/users/test_api.py | 2 +- tests/client/test_api.py | 587 ++++++++++++++++++ tests/client/test_asgi.py | 5 +- tests/client/test_rubrix_client.py | 203 +++--- .../test_log_for_text_classification.py | 86 +-- tests/metrics/test_common_metrics.py | 25 +- tests/metrics/test_text_classification.py | 25 +- tests/server/test_app.py | 28 + tests/test_init.py | 295 +-------- tests/test_log.py | 333 ---------- tests/test_utils.py | 67 ++ 23 files changed, 1142 insertions(+), 1049 deletions(-) create mode 100644 src/rubrix/server/app.py create mode 100644 src/rubrix/utils.py create mode 100644 tests/client/test_api.py create mode 100644 tests/server/test_app.py delete mode 100644 tests/test_log.py create mode 100644 tests/test_utils.py diff --git a/.gitignore b/.gitignore index 77bc747bb1..86cd036736 100644 --- a/.gitignore +++ b/.gitignore @@ -135,3 +135,6 @@ package-lock.json # App generated files src/**/server/static/ + +# setuptools_scm generated file +src/rubrix/_version.py diff --git a/pyproject.toml b/pyproject.toml index 34c5e81892..a7d52caa8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ requires = ["setuptools", "wheel", "setuptools_scm[toml]"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] +write_to = "src/rubrix/_version.py" [tool.pytest.ini_options] log_format = "%(asctime)s %(name)s %(levelname)s %(message)s" diff --git a/src/rubrix/__init__.py b/src/rubrix/__init__.py index c967692403..82865de196 100644 --- a/src/rubrix/__init__.py +++ b/src/rubrix/__init__.py @@ -13,240 +13,89 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""This file reflects the user facing API. +If you want to add something here, remember to add it as normal import in the _TYPE_CHECKING section (for IDEs), +as well as in the `_import_structure` dictionary. """ -This module contains the interface to access Rubrix's REST API. -""" - -import os -import re -from logging import getLogger -from typing import Any, Dict, Iterable, List, Optional, Union - -import pandas -import pkg_resources - -from rubrix._constants import DEFAULT_API_KEY -from rubrix.client import RubrixClient -from rubrix.client.datasets import ( - Dataset, - DatasetForText2Text, - DatasetForTextClassification, - DatasetForTokenClassification, - read_datasets, - read_pandas, -) -from rubrix.client.models import ( - BulkResponse, - Record, - Text2TextRecord, - TextClassificationRecord, - TokenAttributions, - TokenClassificationRecord, -) -from rubrix.logging import configure_logging -from rubrix.monitoring.model_monitor import monitor - -try: - __version__ = pkg_resources.get_distribution(__name__).version -except pkg_resources.DistributionNotFound: - # package is not installed - pass - -try: - from rubrix.server.server import app -except ModuleNotFoundError as ex: - module_name = ex.name - - def fallback_app(*args, **kwargs): - raise RuntimeError( - "\n" - f"Cannot start rubrix server. Some dependencies was not found:[{module_name}].\n" - "Please, install missing modules or reinstall rubrix with server extra deps:\n" - "pip install rubrix[server]" - ) - - app = fallback_app - -configure_logging() - -_client: Optional[ - RubrixClient -] = None # Client will be stored here to pass it through functions - - -_LOGGER = getLogger(__name__) - - -def _client_instance() -> RubrixClient: - """Checks module instance client and init if not initialized.""" - - global _client - # Calling a by-default-init if it was not called before - if _client is None: - init() - return _client - - -def init( - api_url: Optional[str] = None, - api_key: Optional[str] = None, - workspace: Optional[str] = None, - timeout: int = 60, -) -> None: - """Init the python client. - - Passing an api_url disables environment variable reading, which will provide - default values. - - Args: - api_url: Address of the REST API. If `None` (default) and the env variable ``RUBRIX_API_URL`` is not set, - it will default to `http://localhost:6900`. - api_key: Authentification key for the REST API. If `None` (default) and the env variable ``RUBRIX_API_KEY`` - is not set, it will default to `rubrix.apikey`. - workspace: The workspace to which records will be logged/loaded. If `None` (default) and the - env variable ``RUBRIX_WORKSPACE`` is not set, it will default to the private user workspace. - timeout: Wait `timeout` seconds for the connection to timeout. Default: 60. - - Examples: - >>> import rubrix as rb - >>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y") - """ - - global _client - - final_api_url = api_url or os.getenv("RUBRIX_API_URL", "http://localhost:6900") - # Checking that the api_url does not end in '/' - final_api_url = re.sub(r"\/$", "", final_api_url) +import sys as _sys +from typing import TYPE_CHECKING as _TYPE_CHECKING - # If an api_url is passed, tokens obtained via environ vars are disabled - final_key = api_key or os.getenv("RUBRIX_API_KEY", DEFAULT_API_KEY) +from rubrix.logging import configure_logging as _configure_logging - workspace = workspace or os.getenv("RUBRIX_WORKSPACE") +from . import _version +from .utils import _LazyRubrixModule - _LOGGER.info(f"Rubrix has been initialized on {final_api_url}") +__version__ = _version.version - _client = RubrixClient( - api_url=final_api_url, - api_key=final_key, - workspace=workspace, - timeout=timeout, +if _TYPE_CHECKING: + from rubrix.client.api import ( + copy, + delete, + get_workspace, + init, + load, + log, + set_workspace, ) - - -def get_workspace() -> str: - """Returns the name of the active workspace for the current client session. - - Returns: - The name of the active workspace as a string. - """ - return _client_instance().active_workspace - - -def set_workspace(ws: str) -> None: - """Sets the active workspace for the current client session. - - Args: - ws: The new workspace - """ - _client_instance().set_workspace(ws) - - -def log( - records: Union[Record, Iterable[Record], Dataset], - name: str, - tags: Optional[Dict[str, str]] = None, - metadata: Optional[Dict[str, Any]] = None, - chunk_size: int = 500, - verbose: bool = True, -) -> BulkResponse: - """Log Records to Rubrix. - - Args: - records: The record or an iterable of records. - name: The dataset name. - tags: A dictionary of tags related to the dataset. - metadata: A dictionary of extra info for the dataset. - chunk_size: The chunk size for a data bulk. - verbose: If True, shows a progress bar and prints out a quick summary at the end. - - Returns: - Summary of the response from the REST API - - Examples: - >>> import rubrix as rb - >>> record = rb.TextClassificationRecord( - ... inputs={"text": "my first rubrix example"}, - ... prediction=[('spam', 0.8), ('ham', 0.2)] - ... ) - >>> response = rb.log(record, name="example-dataset") - """ - # noinspection PyTypeChecker,PydanticTypeChecker - return _client_instance().log( - records=records, - name=name, - tags=tags, - metadata=metadata, - chunk_size=chunk_size, - verbose=verbose, + from rubrix.client.datasets import ( + DatasetForText2Text, + DatasetForTextClassification, + DatasetForTokenClassification, + read_datasets, + read_pandas, ) - - -def copy(dataset: str, name_of_copy: str, workspace: str = None): - """Creates a copy of a dataset including its tags and metadata - - Args: - dataset: Name of the source dataset - name_of_copy: Name of the copied dataset - workspace: If provided, dataset will be copied to that workspace - - Examples: - >>> import rubrix as rb - >>> rb.copy("my_dataset", name_of_copy="new_dataset") - >>> dataframe = rb.load("new_dataset") - """ - _client_instance().copy( - source=dataset, target=name_of_copy, target_workspace=workspace - ) - - -def load( - name: str, - query: Optional[str] = None, - ids: Optional[List[Union[str, int]]] = None, - limit: Optional[int] = None, - as_pandas: bool = True, -) -> Union[pandas.DataFrame, Dataset]: - """Loads a dataset as a pandas DataFrame or a Dataset. - - Args: - name: The dataset name. - query: An ElasticSearch query with the - `query string syntax `_ - ids: If provided, load dataset records with given ids. - limit: The number of records to retrieve. - as_pandas: If True, return a pandas DataFrame. If False, return a Dataset. - - Returns: - The dataset as a pandas Dataframe or a Dataset. - - Examples: - >>> import rubrix as rb - >>> dataframe = rb.load(name="example-dataset") - """ - return _client_instance().load( - name=name, query=query, limit=limit, ids=ids, as_pandas=as_pandas + from rubrix.client.models import ( + Text2TextRecord, + TextClassificationRecord, + TokenAttributions, + TokenClassificationRecord, ) + from rubrix.monitoring.model_monitor import monitor + from rubrix.server.server import app +_import_structure = { + "client.api": [ + "copy", + "delete", + "get_workspace", + "init", + "load", + "log", + "set_workspace", + ], + "client.models": [ + "Text2TextRecord", + "TextClassificationRecord", + "TokenClassificationRecord", + "TokenAttributions", + ], + "client.datasets": [ + "DatasetForText2Text", + "DatasetForTextClassification", + "DatasetForTokenClassification", + "read_datasets", + "read_pandas", + ], + "monitoring.model_monitor": ["monitor"], + "server.app": ["app"], +} + +# can be removed in a future version +_deprecated_import_structure = { + "client.models": ["Record", "BulkResponse"], + "client.datasets": ["Dataset"], + "client.rubrix_client": ["RubrixClient"], + "_constants": ["DEFAULT_API_KEY"], +} + +_sys.modules[__name__] = _LazyRubrixModule( + __name__, + globals()["__file__"], + _import_structure, + deprecated_import_structure=_deprecated_import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, +) -def delete(name: str) -> None: - """Delete a dataset. - - Args: - name: The dataset name. - - Examples: - >>> import rubrix as rb - >>> rb.delete(name="example-dataset") - """ - _client_instance().delete(name=name) +_configure_logging() diff --git a/src/rubrix/client/api.py b/src/rubrix/client/api.py index 348f74caee..5e3dd85959 100644 --- a/src/rubrix/client/api.py +++ b/src/rubrix/client/api.py @@ -84,8 +84,8 @@ def __init__( ): """Init the Python client. - We will automatically init a default client for you when calling other client methods. - The arguments provided here will overwrite your corresponding environment variables. + Passing an api_url disables environment variable reading, which will provide + default values. Args: api_url: Address of the REST API. If `None` (default) and the env variable ``RUBRIX_API_URL`` is not set, diff --git a/src/rubrix/client/rubrix_client.py b/src/rubrix/client/rubrix_client.py index 5ec1e0d55d..ee4febeddf 100644 --- a/src/rubrix/client/rubrix_client.py +++ b/src/rubrix/client/rubrix_client.py @@ -12,11 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""The Rubrix client, used by the rubrix.__init__ module""" +""" +The Rubrix client, used by the rubrix.__init__ module. +DEPRECATED, CAN BE REMOVED IN A FUTURE VERSION. USE THE rubrix.client.api MODULE INSTEAD! +""" import logging import socket +import warnings from typing import Any, Dict, Iterable, List, Optional, Union import pandas @@ -80,7 +83,7 @@ class InputValueError(RubrixClientError): class RubrixClient: - """Class definition for Rubrix Client""" + """DEPRECATED. Class definition for Rubrix Client""" _LOGGER = logging.getLogger(__name__) _WARNED_ABOUT_AS_PANDAS = False @@ -97,7 +100,7 @@ def __init__( workspace: Optional[str] = None, timeout: int = 60, ): - """Client setup function. + """DEPRECATED. Client setup function. Args: api_url: Address from which the API is serving. @@ -105,6 +108,11 @@ def __init__( workspace: Active workspace for this client session. timeout: Seconds to wait before raising a connection timeout. """ + warnings.warn( + f"The 'RubrixClient' class is deprecated and will be removed in a future version! " + f"Use the `rubrix.client.api` module instead. Make sure to adapt your code.", + category=FutureWarning, + ) self._client = AuthenticatedClient( base_url=api_url, token=api_key, timeout=timeout diff --git a/src/rubrix/metrics/commons.py b/src/rubrix/metrics/commons.py index 7c148cb548..12763c6fe6 100644 --- a/src/rubrix/metrics/commons.py +++ b/src/rubrix/metrics/commons.py @@ -1,6 +1,6 @@ from typing import Optional -from rubrix import _client_instance as client +from rubrix.client import api from rubrix.metrics import helpers from rubrix.metrics.models import MetricSummary @@ -23,8 +23,7 @@ def text_length(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot an histogram with results >>> summary.data # returns the raw result data """ - current_client = client() - metric = current_client.compute_metric(name, metric="text_length", query=query) + metric = api.ACTIVE_API.compute_metric(name, metric="text_length", query=query) return MetricSummary.new_summary( data=metric.results, @@ -52,8 +51,7 @@ def records_status(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot an histogram with results >>> summary.data # returns the raw result data """ - current_client = client() - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric="status_distribution", query=query ) diff --git a/src/rubrix/metrics/text_classification/metrics.py b/src/rubrix/metrics/text_classification/metrics.py index 88dc8594a1..c126ab33a5 100644 --- a/src/rubrix/metrics/text_classification/metrics.py +++ b/src/rubrix/metrics/text_classification/metrics.py @@ -1,6 +1,6 @@ from typing import Optional -from rubrix import _client_instance as client +from rubrix.client import api from rubrix.metrics import helpers from rubrix.metrics.models import MetricSummary @@ -23,8 +23,7 @@ def f1(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot a bar chart with results >>> summary.data # returns the raw result data """ - current_client = client() - metric = current_client.compute_metric(name, metric="F1", query=query) + metric = api.ACTIVE_API.compute_metric(name, metric="F1", query=query) return MetricSummary.new_summary( data=metric.results, @@ -50,8 +49,7 @@ def f1_multilabel(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot a bar chart with results >>> summary.data # returns the raw result data """ - current_client = client() - metric = current_client.compute_metric(name, metric="MultiLabelF1", query=query) + metric = api.ACTIVE_API.compute_metric(name, metric="MultiLabelF1", query=query) return MetricSummary.new_summary( data=metric.results, diff --git a/src/rubrix/metrics/token_classification/metrics.py b/src/rubrix/metrics/token_classification/metrics.py index 321dccff00..7d64fc6855 100644 --- a/src/rubrix/metrics/token_classification/metrics.py +++ b/src/rubrix/metrics/token_classification/metrics.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Optional, Union -from rubrix import _client_instance as client +from rubrix.client import api from rubrix.metrics import helpers from rubrix.metrics.models import MetricSummary @@ -26,9 +26,7 @@ def tokens_length( >>> summary.visualize() # will plot a histogram with results >>> summary.data # the raw histogram data with bins of size 5 """ - current_client = client() - - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric="tokens_length", query=query, interval=interval ) @@ -62,9 +60,7 @@ def token_frequency( >>> summary.visualize() # will plot a histogram with results >>> summary.data # the top-50 tokens frequency """ - current_client = client() - - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric="token_frequency", query=query, size=tokens ) @@ -94,9 +90,7 @@ def token_length(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot a histogram with results >>> summary.data # The token length distribution """ - current_client = client() - - metric = current_client.compute_metric(name, metric="token_length", query=query) + metric = api.ACTIVE_API.compute_metric(name, metric="token_length", query=query) return MetricSummary.new_summary( data=metric.results, @@ -125,9 +119,7 @@ def token_capitalness(name: str, query: Optional[str] = None) -> MetricSummary: >>> summary.visualize() # will plot a histogram with results >>> summary.data # The token capitalness distribution """ - current_client = client() - - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric="token_capitalness", query=query ) @@ -196,14 +188,13 @@ def mention_length( >>> summary.visualize() # will plot a histogram chart with results >>> summary.data # the raw histogram data with bins of size 2 """ - current_client = client() level = (level or "token").lower().strip() accepted_levels = ["token", "char"] assert ( level in accepted_levels ), f"Unexpected value for level. Accepted values are {accepted_levels}" - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric=f"{_check_compute_for(compute_for)}_mention_{level}_length", query=query, @@ -245,9 +236,7 @@ def entity_labels( >>> summary.visualize() # will plot a bar chart with results >>> summary.data # The top-20 entity tags """ - current_client = client() - - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric=f"{_check_compute_for(compute_for)}_entity_labels", query=query, @@ -288,8 +277,7 @@ def entity_density( >>> summary = entity_density(name="example-dataset") >>> summary.visualize() """ - current_client = client() - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric=f"{_check_compute_for(compute_for)}_entity_density", query=query, @@ -334,8 +322,7 @@ def entity_capitalness( >>> summary = entity_capitalness(name="example-dataset") >>> summary.visualize() """ - current_client = client() - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric=f"{_check_compute_for(compute_for)}_entity_capitalness", query=query, @@ -384,8 +371,7 @@ def entity_consistency( # TODO: Warning??? threshold = 2 - current_client = client() - metric = current_client.compute_metric( + metric = api.ACTIVE_API.compute_metric( name, metric=f"{_check_compute_for(compute_for)}_entity_consistency", query=query, @@ -431,8 +417,7 @@ def f1(name: str, query: Optional[str] = None) -> MetricSummary: >>> import pandas as pd >>> pd.DataFrame(summary.data.values(), index=summary.data.keys()) """ - current_client = client() - metric = current_client.compute_metric(name, metric="F1", query=query) + metric = api.ACTIVE_API.compute_metric(name, metric="F1", query=query) return MetricSummary.new_summary( data=metric.results, diff --git a/src/rubrix/monitoring/asgi.py b/src/rubrix/monitoring/asgi.py index 43a8ddf0e0..e21a1c9740 100644 --- a/src/rubrix/monitoring/asgi.py +++ b/src/rubrix/monitoring/asgi.py @@ -21,13 +21,18 @@ from queue import Queue from typing import Any, Callable, Dict, List, Optional, Tuple -import rubrix -from rubrix import Record, TextClassificationRecord, TokenClassificationRecord from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import JSONResponse, Response, StreamingResponse from starlette.types import Message, Receive +import rubrix +from rubrix.client.models import ( + Record, + TextClassificationRecord, + TokenClassificationRecord, +) + _logger = logging.getLogger(__name__) _spaces_regex = re.compile(r"\s+") diff --git a/src/rubrix/server/app.py b/src/rubrix/server/app.py new file mode 100644 index 0000000000..7fa456f677 --- /dev/null +++ b/src/rubrix/server/app.py @@ -0,0 +1,28 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from rubrix.server.server import app +except ModuleNotFoundError as ex: + _module_name = ex.name + + def fallback_app(*args, **kwargs): + raise RuntimeError( + "\n" + f"Cannot start rubrix server. Some dependencies were not found:[{_module_name}].\n" + "Please, install missing modules or reinstall rubrix with server extra deps:\n" + "pip install rubrix[server]" + ) + + app = fallback_app diff --git a/src/rubrix/server/security/auth_provider/local/settings.py b/src/rubrix/server/security/auth_provider/local/settings.py index 5a5c663b51..d3eed93483 100644 --- a/src/rubrix/server/security/auth_provider/local/settings.py +++ b/src/rubrix/server/security/auth_provider/local/settings.py @@ -14,7 +14,8 @@ # limitations under the License. from pydantic import BaseSettings -from rubrix import DEFAULT_API_KEY + +from rubrix._constants import DEFAULT_API_KEY class Settings(BaseSettings): diff --git a/src/rubrix/utils.py b/src/rubrix/utils.py new file mode 100644 index 0000000000..61d5d3945f --- /dev/null +++ b/src/rubrix/utils.py @@ -0,0 +1,117 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +import warnings +from itertools import chain +from types import ModuleType +from typing import Any, Optional + + +class _LazyRubrixModule(ModuleType): + """Module class that surfaces all objects but only performs associated imports when the objects are requested. + + Shamelessly copied and adapted from the Hugging Face transformers implementation. + """ + + def __init__( + self, + name, + module_file, + import_structure, + deprecated_import_structure=None, + module_spec=None, + extra_objects=None, + ): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list( + chain(*import_structure.values()) + ) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # deprecated stuff + deprecated_import_structure = deprecated_import_structure or {} + self._deprecated_modules = set(deprecated_import_structure.keys()) + self._deprecated_class_to_module = {} + for key, values in deprecated_import_structure.items(): + for value in values: + self._deprecated_class_to_module[value] = key + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + elif name in self._deprecated_modules: + value = self._get_module(name, deprecated=True) + elif name in self._deprecated_class_to_module.keys(): + module = self._get_module( + self._deprecated_class_to_module[name], deprecated=True, class_name=name + ) + value = getattr(module, name) + else: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module( + self, + module_name: str, + deprecated: bool = False, + class_name: Optional[str] = None, + ): + if deprecated: + warnings.warn( + f"Importing '{class_name or module_name}' from the rubrix namespace (that is " + f"`rubrix.{class_name or module_name}`) is deprecated and will not work in a future version. " + f"Make sure you update your code accordingly.", + category=FutureWarning, + ) + + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise RuntimeError( + f"Failed to import {self.__name__}.{module_name} because of the following error " + f"(look up to see its traceback):\n{e}" + ) from e + + def __reduce__(self): + return self.__class__, (self._name, self.__file__, self._import_structure) diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py index d8bb7fa4d5..eabc332dfc 100644 --- a/tests/client/sdk/users/test_api.py +++ b/tests/client/sdk/users/test_api.py @@ -1,7 +1,7 @@ import httpx import pytest -from rubrix import DEFAULT_API_KEY +from rubrix._constants import DEFAULT_API_KEY from rubrix.client.sdk.client import AuthenticatedClient from rubrix.client.sdk.commons.errors import UnauthorizedApiError from rubrix.client.sdk.users.api import whoami diff --git a/tests/client/test_api.py b/tests/client/test_api.py new file mode 100644 index 0000000000..661edd3e32 --- /dev/null +++ b/tests/client/test_api.py @@ -0,0 +1,587 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +from time import sleep +from typing import Iterable + +import datasets +import httpx +import pandas +import pandas as pd +import pytest + +import rubrix as rb +from rubrix.client import api +from rubrix.client.sdk.client import AuthenticatedClient +from rubrix.client.sdk.commons.errors import ( + AlreadyExistsApiError, + ForbiddenApiError, + GenericApiError, + NotFoundApiError, + UnauthorizedApiError, + ValidationApiError, +) +from rubrix.server.tasks.text_classification import TextClassificationSearchResults +from tests.server.test_api import create_some_data_for_text_classification + + +@pytest.fixture +def mock_response_200(monkeypatch): + """Creating of mock_get method from the class, and monkeypatch application. + + It will return a 200 status code, emulating the correct login. + """ + + def mock_get(url, *args, **kwargs): + if "/api/me" in url: + return httpx.Response(status_code=200, json={"username": "booohh"}) + return httpx.Response(status_code=200) + + monkeypatch.setattr(httpx, "get", mock_get) + + +@pytest.fixture +def mock_response_500(monkeypatch): + """Creating of mock_get method from the class, and monkeypatch application. + + It will return a 500 status code, emulating an invalid state of the API error. + """ + + def mock_get(*args, **kwargs): + return httpx.Response(status_code=500) + + monkeypatch.setattr(httpx, "get", mock_get) + + +@pytest.fixture +def mock_response_token_401(monkeypatch): + """Creating of mock_get method from the class, and monkeypatch application. + + It will return a 401 status code, emulating an invalid credentials error when using tokens to log in. + Iterable structure to be able to pass the first 200 status code check + """ + response_200 = httpx.Response(status_code=200) + response_401 = httpx.Response(status_code=401) + + def mock_get(*args, **kwargs): + if kwargs["url"] == "fake_url/api/me": + return response_401 + elif kwargs["url"] == "fake_url/api/docs/spec.json": + return response_200 + + monkeypatch.setattr(httpx, "get", mock_get) + + +def test_init_correct(mock_response_200): + """Testing correct default initalization + + It checks if the _client created is a RubrixClient object. + """ + + api.init() + assert api.ACTIVE_API._client == AuthenticatedClient( + base_url="http://localhost:6900", token="rubrix.apikey", timeout=60.0 + ) + assert api.ACTIVE_API._user == api.User(username="booohh") + + api.init(api_url="mock_url", api_key="mock_key", workspace="mock_ws", timeout=42) + assert api.ACTIVE_API._client == AuthenticatedClient( + base_url="mock_url", + token="mock_key", + timeout=42, + headers={"X-Rubrix-Workspace": "mock_ws"}, + ) + + +def test_init_incorrect(mock_response_500): + """Testing incorrect default initalization + + It checks an Exception is raised with the correct message. + """ + + with pytest.raises( + Exception, + match="Rubrix server returned an error with http status: 500\nError details: \[\{'response': None\}\]", + ): + api.init() + + +def test_init_token_auth_fail(mock_response_token_401): + """Testing initalization with failed authentication + + It checks an Exception is raised with the correct message. + """ + with pytest.raises(UnauthorizedApiError): + api.init(api_url="fake_url", api_key="422") + + +def test_init_evironment_url(mock_response_200, monkeypatch): + """Testing initalization with api_url provided via environment variable + + It checks the url in the environment variable gets passed to client. + """ + monkeypatch.setenv("RUBRIX_API_URL", "mock_url") + monkeypatch.setenv("RUBRIX_API_KEY", "mock_key") + monkeypatch.setenv("RUBRIX_WORKSPACE", "mock_workspace") + api.init() + + assert api.ACTIVE_API._client == AuthenticatedClient( + base_url="mock_url", + token="mock_key", + timeout=60, + headers={"X-Rubrix-Workspace": "mock_workspace"}, + ) + + +def test_trailing_slash(mock_response_200): + """Testing initalization with provided api_url via environment variable and argument + + It checks the trailing slash is removed in all cases + """ + api.init(api_url="http://mock.com/") + assert api.ACTIVE_API._client.base_url == "http://mock.com" + + +def test_log_something(monkeypatch, mocked_client): + dataset_name = "test-dataset" + mocked_client.delete(f"/api/datasets/{dataset_name}") + + response = api.log( + name=dataset_name, + records=rb.TextClassificationRecord(inputs={"text": "This is a test"}), + ) + + assert response.processed == 1 + assert response.failed == 0 + + response = mocked_client.post( + f"/api/datasets/{dataset_name}/TextClassification:search" + ) + assert response.status_code == 200, response.json() + + results = TextClassificationSearchResults.parse_obj(response.json()) + assert results.total == 1 + assert len(results.records) == 1 + assert results.records[0].inputs["text"] == "This is a test" + + +def test_load_limits(mocked_client): + dataset = "test_load_limits" + api_ds_prefix = f"/api/datasets/{dataset}" + mocked_client.delete(api_ds_prefix) + + create_some_data_for_text_classification(mocked_client, dataset, 50) + + limit_data_to = 10 + ds = api.load(name=dataset, limit=limit_data_to) + assert isinstance(ds, pandas.DataFrame) + assert len(ds) == limit_data_to + + ds = api.load(name=dataset, limit=limit_data_to) + assert isinstance(ds, pandas.DataFrame) + assert len(ds) == limit_data_to + + +def test_log_records_with_too_long_text(mocked_client): + dataset_name = "test_log_records_with_too_long_text" + mocked_client.delete(f"/api/datasets/{dataset_name}") + item = rb.TextClassificationRecord( + inputs={"text": "This is a toooooo long text\n" * 10000} + ) + + api.log([item], name=dataset_name) + + +def test_not_found_response(mocked_client): + + with pytest.raises(NotFoundApiError): + api.load(name="not-found") + + +def test_log_without_name(mocked_client): + with pytest.raises( + api.InputValueError, match="Empty project name has been passed as argument." + ): + api.log( + rb.TextClassificationRecord( + inputs={"text": "This is a single record. Only this. No more."} + ), + name=None, + ) + + +def test_log_passing_empty_records_list(mocked_client): + + with pytest.raises( + api.InputValueError, match="Empty record list has been passed as argument." + ): + api.log(records=[], name="ds") + + +@pytest.mark.parametrize( + "status,error_type", + [ + (401, UnauthorizedApiError), + (403, ForbiddenApiError), + (404, NotFoundApiError), + (422, ValidationApiError), + (500, GenericApiError), + ], +) +def test_delete_with_errors(mocked_client, monkeypatch, status, error_type): + def send_mock_response_with_http_status(status: int): + def inner(*args, **kwargs): + return httpx.Response( + status_code=status, + json={"detail": {"code": "error:code", "params": {"message": "Mock"}}}, + ) + + return inner + + with pytest.raises(error_type): + monkeypatch.setattr( + httpx, "delete", send_mock_response_with_http_status(status) + ) + api.delete("dataset") + + +@pytest.mark.parametrize( + "records, dataset_class", + [ + ("singlelabel_textclassification_records", rb.DatasetForTextClassification), + ("multilabel_textclassification_records", rb.DatasetForTextClassification), + ("tokenclassification_records", rb.DatasetForTokenClassification), + ("text2text_records", rb.DatasetForText2Text), + ], +) +def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_class): + dataset_names = [ + f"test_general_log_load_{dataset_class.__name__.lower()}_" + input_type + for input_type in ["single", "list", "dataset"] + ] + for name in dataset_names: + mocked_client.delete(f"/api/datasets/{name}") + + records = request.getfixturevalue(records) + + # log single records + api.log(records[0], name=dataset_names[0]) + dataset = api.load(dataset_names[0], as_pandas=False) + records[0].metrics = dataset[0].metrics + assert dataset[0] == records[0] + + # log list of records + api.log(records, name=dataset_names[1]) + dataset = api.load(dataset_names[1], as_pandas=False) + # check if returned records can be converted to other formats + assert isinstance(dataset.to_datasets(), datasets.Dataset) + assert isinstance(dataset.to_pandas(), pd.DataFrame) + assert len(dataset) == len(records) + for record, expected in zip(dataset, records): + expected.metrics = record.metrics + assert record == expected + + # log dataset + api.log(dataset_class(records), name=dataset_names[2]) + dataset = api.load(dataset_names[2], as_pandas=False) + assert len(dataset) == len(records) + for record, expected in zip(dataset, records): + record.metrics = expected.metrics + assert record == expected + + +def test_passing_wrong_iterable_data(mocked_client): + dataset_name = "test_log_single_records" + mocked_client.delete(f"/api/datasets/{dataset_name}") + with pytest.raises(Exception, match="Unknown record type passed"): + api.log({"a": "010", "b": 100}, name=dataset_name) + + +def test_log_with_generator(mocked_client, monkeypatch): + dataset_name = "test_log_with_generator" + mocked_client.delete(f"/api/datasets/{dataset_name}") + + def generator(items: int = 10) -> Iterable[rb.TextClassificationRecord]: + for i in range(0, items): + yield rb.TextClassificationRecord(id=i, inputs={"text": "The text data"}) + + api.log(generator(), name=dataset_name) + + +def test_create_ds_with_wrong_name(mocked_client): + dataset_name = "Test Create_ds_with_wrong_name" + + with pytest.raises(ValidationApiError): + api.log( + rb.TextClassificationRecord( + inputs={"text": "The text data"}, + ), + name=dataset_name, + ) + + +def test_delete_dataset(mocked_client): + dataset_name = "test_delete_dataset" + mocked_client.delete(f"/api/datasets/{dataset_name}") + + api.log( + rb.TextClassificationRecord( + id=0, + inputs={"text": "The text data"}, + annotation_agent="test", + annotation=["T"], + ), + name=dataset_name, + ) + api.load(name=dataset_name) + api.delete(name=dataset_name) + sleep(1) + with pytest.raises(NotFoundApiError): + api.load(name=dataset_name) + + +def test_dataset_copy(mocked_client): + dataset = "test_dataset_copy" + dataset_copy = "new_dataset" + new_workspace = "new-workspace" + + mocked_client.delete(f"/api/datasets/{dataset}") + mocked_client.delete(f"/api/datasets/{dataset_copy}") + mocked_client.delete(f"/api/datasets/{dataset_copy}?workspace={new_workspace}") + + api.log( + rb.TextClassificationRecord( + id=0, + inputs="This is the record input", + annotation_agent="test", + annotation=["T"], + ), + name=dataset, + ) + api.copy(dataset, name_of_copy=dataset_copy) + df = api.load(name=dataset) + df_copy = api.load(name=dataset_copy) + + assert df.equals(df_copy) + + with pytest.raises(AlreadyExistsApiError): + api.copy(dataset, name_of_copy=dataset_copy) + + api.copy(dataset, name_of_copy=dataset_copy, workspace=new_workspace) + + try: + api.set_workspace(new_workspace) + df_copy = api.load(dataset_copy) + assert df.equals(df_copy) + + with pytest.raises(AlreadyExistsApiError): + api.copy(dataset_copy, name_of_copy=dataset_copy, workspace=new_workspace) + finally: + api.init() # reset workspace + + +def test_update_record(mocked_client): + dataset = "test_update_record" + mocked_client.delete(f"/api/datasets/{dataset}") + + expected_inputs = ["This is a text"] + record = rb.TextClassificationRecord( + id=0, + inputs=expected_inputs, + annotation_agent="test", + annotation=["T"], + ) + api.log( + record, + name=dataset, + ) + + df = api.load(name=dataset) + records = df.to_dict(orient="records") + assert len(records) == 1 + assert records[0]["annotation"] == "T" + # This record will replace the old one + record = rb.TextClassificationRecord( + id=0, + inputs=expected_inputs, + ) + + api.log( + record, + name=dataset, + ) + + df = api.load(name=dataset) + records = df.to_dict(orient="records") + assert len(records) == 1 + assert records[0]["annotation"] is None + assert records[0]["annotation_agent"] is None + + +def test_text_classifier_with_inputs_list(mocked_client): + dataset = "test_text_classifier_with_inputs_list" + mocked_client.delete(f"/api/datasets/{dataset}") + + expected_inputs = ["A", "List", "of", "values"] + api.log( + rb.TextClassificationRecord( + id=0, + inputs=expected_inputs, + annotation_agent="test", + annotation=["T"], + ), + name=dataset, + ) + + df = api.load(name=dataset) + records = df.to_dict(orient="records") + assert len(records) == 1 + assert records[0]["inputs"]["text"] == expected_inputs + + +def test_load_with_ids_list(mocked_client): + dataset = "test_load_with_ids_list" + mocked_client.delete(f"/api/datasets/{dataset}") + + expected_data = 100 + create_some_data_for_text_classification(mocked_client, dataset, n=expected_data) + ds = api.load(name=dataset, ids=[3, 5]) + assert len(ds) == 2 + + +def test_load_with_query(mocked_client): + dataset = "test_load_with_query" + mocked_client.delete(f"/api/datasets/{dataset}") + sleep(1) + + expected_data = 4 + create_some_data_for_text_classification(mocked_client, dataset, n=expected_data) + ds = api.load(name=dataset, query="id:1") + assert len(ds) == 1 + assert ds.id.iloc[0] == 1 + + +@pytest.mark.parametrize("as_pandas", [True, False]) +def test_load_as_pandas(mocked_client, as_pandas): + dataset = "test_sorted_load" + mocked_client.delete(f"/api/datasets/{dataset}") + sleep(1) + + expected_data = 3 + create_some_data_for_text_classification(mocked_client, dataset, n=expected_data) + + # Check that the default value is True + if as_pandas: + records = api.load(name=dataset) + assert isinstance(records, pandas.DataFrame) + assert list(records.id) == [0, 1, 2, 3] + else: + records = api.load(name=dataset, as_pandas=False) + assert isinstance(records, rb.DatasetForTextClassification) + assert isinstance(records[0], rb.TextClassificationRecord) + assert [record.id for record in records] == [0, 1, 2, 3] + + +def test_token_classification_spans(mocked_client): + dataset = "test_token_classification_with_consecutive_spans" + texto = "Esto es una prueba" + item = api.TokenClassificationRecord( + text=texto, + tokens=texto.split(), + prediction=[("test", 1, 2)], # Inicio y fin son consecutivos + prediction_agent="test", + ) + with pytest.raises( + Exception, match=r"Defined offset \[s\] is a misaligned entity mention" + ): + api.log(item, name=dataset) + + item.prediction = [("test", 0, 6)] + with pytest.raises( + Exception, match=r"Defined offset \[Esto e\] is a misaligned entity mention" + ): + api.log(item, name=dataset) + + item.prediction = [("test", 0, 4)] + api.log(item, name=dataset) + + +def test_load_text2text(mocked_client): + records = [ + rb.Text2TextRecord( + text="test text", + prediction=["test prediction"], + annotation="test annotation", + prediction_agent="test_model", + annotation_agent="test_annotator", + id=i, + metadata={"metadata": "test"}, + status="Default", + event_timestamp=datetime.datetime(2000, 1, 1), + ) + for i in range(0, 2) + ] + + dataset = "test_load_text2text" + api.delete(dataset) + api.log(records, name=dataset) + + df = api.load(name=dataset) + assert len(df) == 2 + + +def test_client_workspace(mocked_client): + try: + ws = api.get_workspace() + assert ws == "rubrix" + + api.set_workspace("other-workspace") + assert api.get_workspace() == "other-workspace" + + with pytest.raises(Exception, match="Must provide a workspace"): + api.set_workspace(None) + + # Mocking user + api.ACTIVE_API._user.workspaces = ["a", "b"] + + with pytest.raises(Exception, match="Wrong provided workspace c"): + api.set_workspace("c") + + api.set_workspace("rubrix") + assert api.get_workspace() == "rubrix" + finally: + api.init() # reset workspace + + +def test_load_sort(mocked_client): + records = [ + rb.TextClassificationRecord( + inputs="test text", + id=i, + ) + for i in ["1str", 1, 2, 11, "2str", "11str"] + ] + + dataset = "test_load_sort" + api.delete(dataset) + api.log(records, name=dataset) + + # check sorting policies + df = api.load(name=dataset) + assert list(df.id) == [1, 11, "11str", "1str", 2, "2str"] + df = api.load(name=dataset, ids=[1, 2, 11]) + assert list(df.id) == [1, 2, 11] + df = api.load(name=dataset, ids=["1str", "2str", "11str"]) + assert list(df.id) == ["11str", "1str", "2str"] diff --git a/tests/client/test_asgi.py b/tests/client/test_asgi.py index 40c3ae09d5..e88ae90169 100644 --- a/tests/client/test_asgi.py +++ b/tests/client/test_asgi.py @@ -22,7 +22,6 @@ from starlette.testclient import TestClient import rubrix -from rubrix import TextClassificationRecord, TokenClassificationRecord from rubrix.monitoring.asgi import RubrixLogHTTPMiddleware, token_classification_mapper @@ -59,7 +58,7 @@ def __call__(self, records, name: str, **kwargs): self.was_called = True assert name == expected_dataset_name assert len(records) == 2 - assert isinstance(records[0], TextClassificationRecord) + assert isinstance(records[0], rubrix.TextClassificationRecord) mock_log = MockLog() monkeypatch.setattr(rubrix, "log", mock_log) @@ -113,7 +112,7 @@ def __call__(self, records, name: str, **kwargs): self.was_called = True assert name == expected_dataset_name assert len(records) == 2 - assert isinstance(records[0], TokenClassificationRecord) + assert isinstance(records[0], rubrix.TokenClassificationRecord) mock_log = MockLog() monkeypatch.setattr(rubrix, "log", mock_log) diff --git a/tests/client/test_rubrix_client.py b/tests/client/test_rubrix_client.py index 872ab7de91..f9b2236700 100644 --- a/tests/client/test_rubrix_client.py +++ b/tests/client/test_rubrix_client.py @@ -31,7 +31,9 @@ Text2TextRecord, TextClassificationRecord, ) -from rubrix.client.rubrix_client import InputValueError +from rubrix._constants import DEFAULT_API_KEY +from rubrix.client import rubrix_client +from rubrix.client.api import InputValueError from rubrix.client.sdk.commons.errors import ( AlreadyExistsApiError, ForbiddenApiError, @@ -44,11 +46,18 @@ from tests.server.test_api import create_some_data_for_text_classification -def test_log_something(monkeypatch, mocked_client): +@pytest.fixture +def rb_client(mocked_client): + return rubrix_client.RubrixClient( + api_url="http://localhost:6900", api_key=DEFAULT_API_KEY + ) + + +def test_log_something(monkeypatch, mocked_client, rb_client): dataset_name = "test-dataset" mocked_client.delete(f"/api/datasets/{dataset_name}") - response = rubrix.log( + response = rb_client.log( name=dataset_name, records=rubrix.TextClassificationRecord(inputs={"text": "This is a test"}), ) @@ -67,7 +76,7 @@ def test_log_something(monkeypatch, mocked_client): assert results.records[0].inputs["text"] == "This is a test" -def test_load_limits(mocked_client): +def test_load_limits(mocked_client, rb_client): dataset = "test_load_limits" api_ds_prefix = f"/api/datasets/{dataset}" mocked_client.delete(api_ds_prefix) @@ -75,36 +84,37 @@ def test_load_limits(mocked_client): create_some_data_for_text_classification(mocked_client, dataset, 50) limit_data_to = 10 - ds = rubrix.load(name=dataset, limit=limit_data_to) + ds = rb_client.load(name=dataset, limit=limit_data_to) assert isinstance(ds, pandas.DataFrame) assert len(ds) == limit_data_to - ds = rubrix.load(name=dataset, limit=limit_data_to) + ds = rb_client.load(name=dataset, limit=limit_data_to) assert isinstance(ds, pandas.DataFrame) assert len(ds) == limit_data_to -def test_log_records_with_too_long_text(mocked_client): +def test_log_records_with_too_long_text(mocked_client, rb_client): dataset_name = "test_log_records_with_too_long_text" mocked_client.delete(f"/api/datasets/{dataset_name}") item = TextClassificationRecord( inputs={"text": "This is a toooooo long text\n" * 10000} ) - rubrix.log([item], name=dataset_name) + rb_client.log([item], name=dataset_name) -def test_not_found_response(mocked_client): +def test_not_found_response(rb_client): with pytest.raises(NotFoundApiError): - rubrix.load(name="not-found") + rb_client.load(name="not-found") -def test_log_without_name(mocked_client): +def test_log_without_name(rb_client): with pytest.raises( - InputValueError, match="Empty project name has been passed as argument." + rubrix_client.InputValueError, + match="Empty project name has been passed as argument.", ): - rubrix.log( + rb_client.log( TextClassificationRecord( inputs={"text": "This is a single record. Only this. No more."} ), @@ -112,12 +122,13 @@ def test_log_without_name(mocked_client): ) -def test_log_passing_empty_records_list(mocked_client): +def test_log_passing_empty_records_list(rb_client): with pytest.raises( - InputValueError, match="Empty record list has been passed as argument." + rubrix_client.InputValueError, + match="Empty record list has been passed as argument.", ): - rubrix.log(records=[], name="ds") + rb_client.log(records=[], name="ds") @pytest.mark.parametrize( @@ -130,7 +141,7 @@ def test_log_passing_empty_records_list(mocked_client): (500, GenericApiError), ], ) -def test_delete_with_errors(mocked_client, monkeypatch, status, error_type): +def test_delete_with_errors(rb_client, monkeypatch, status, error_type): def send_mock_response_with_http_status(status: int): def inner(*args, **kwargs): return httpx.Response( @@ -144,7 +155,7 @@ def inner(*args, **kwargs): monkeypatch.setattr( httpx, "delete", send_mock_response_with_http_status(status) ) - rubrix.delete("dataset") + rb_client.delete("dataset") @pytest.mark.parametrize( @@ -156,7 +167,9 @@ def inner(*args, **kwargs): ("text2text_records", DatasetForText2Text), ], ) -def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_class): +def test_general_log_load( + mocked_client, monkeypatch, request, records, dataset_class, rb_client +): dataset_names = [ f"test_general_log_load_{dataset_class.__name__.lower()}_" + input_type for input_type in ["single", "list", "dataset"] @@ -167,14 +180,14 @@ def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_ records = request.getfixturevalue(records) # log single records - rubrix.log(records[0], name=dataset_names[0]) - dataset = rubrix.load(dataset_names[0], as_pandas=False) + rb_client.log(records[0], name=dataset_names[0]) + dataset = rb_client.load(dataset_names[0], as_pandas=False) records[0].metrics = dataset[0].metrics assert dataset[0] == records[0] # log list of records - rubrix.log(records, name=dataset_names[1]) - dataset = rubrix.load(dataset_names[1], as_pandas=False) + rb_client.log(records, name=dataset_names[1]) + dataset = rb_client.load(dataset_names[1], as_pandas=False) # check if returned records can be converted to other formats assert isinstance(dataset.to_datasets(), datasets.Dataset) assert isinstance(dataset.to_pandas(), pd.DataFrame) @@ -184,22 +197,22 @@ def test_general_log_load(mocked_client, monkeypatch, request, records, dataset_ assert record == expected # log dataset - rubrix.log(dataset_class(records), name=dataset_names[2]) - dataset = rubrix.load(dataset_names[2], as_pandas=False) + rb_client.log(dataset_class(records), name=dataset_names[2]) + dataset = rb_client.load(dataset_names[2], as_pandas=False) assert len(dataset) == len(records) for record, expected in zip(dataset, records): record.metrics = expected.metrics assert record == expected -def test_passing_wrong_iterable_data(mocked_client): +def test_passing_wrong_iterable_data(mocked_client, rb_client): dataset_name = "test_log_single_records" mocked_client.delete(f"/api/datasets/{dataset_name}") with pytest.raises(Exception, match="Unknown record type passed"): - rubrix.log({"a": "010", "b": 100}, name=dataset_name) + rb_client.log({"a": "010", "b": 100}, name=dataset_name) -def test_log_with_generator(mocked_client, monkeypatch): +def test_log_with_generator(mocked_client, monkeypatch, rb_client): dataset_name = "test_log_with_generator" mocked_client.delete(f"/api/datasets/{dataset_name}") @@ -207,14 +220,14 @@ def generator(items: int = 10) -> Iterable[TextClassificationRecord]: for i in range(0, items): yield TextClassificationRecord(id=i, inputs={"text": "The text data"}) - rubrix.log(generator(), name=dataset_name) + rb_client.log(generator(), name=dataset_name) -def test_create_ds_with_wrong_name(mocked_client): +def test_create_ds_with_wrong_name(rb_client): dataset_name = "Test Create_ds_with_wrong_name" with pytest.raises(ValidationApiError): - rubrix.log( + rb_client.log( TextClassificationRecord( inputs={"text": "The text data"}, ), @@ -222,11 +235,11 @@ def test_create_ds_with_wrong_name(mocked_client): ) -def test_delete_dataset(mocked_client): +def test_delete_dataset(mocked_client, rb_client): dataset_name = "test_delete_dataset" mocked_client.delete(f"/api/datasets/{dataset_name}") - rubrix.log( + rb_client.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, @@ -235,14 +248,14 @@ def test_delete_dataset(mocked_client): ), name=dataset_name, ) - rubrix.load(name=dataset_name) - rubrix.delete(name=dataset_name) + rb_client.load(name=dataset_name) + rb_client.delete(name=dataset_name) sleep(1) with pytest.raises(NotFoundApiError): - rubrix.load(name=dataset_name) + rb_client.load(name=dataset_name) -def test_dataset_copy(mocked_client): +def test_dataset_copy(mocked_client, rb_client): dataset = "test_dataset_copy" dataset_copy = "new_dataset" new_workspace = "new-workspace" @@ -251,7 +264,7 @@ def test_dataset_copy(mocked_client): mocked_client.delete(f"/api/datasets/{dataset_copy}") mocked_client.delete(f"/api/datasets/{dataset_copy}?workspace={new_workspace}") - rubrix.log( + rb_client.log( TextClassificationRecord( id=0, inputs="This is the record input", @@ -260,31 +273,28 @@ def test_dataset_copy(mocked_client): ), name=dataset, ) - rubrix.copy(dataset, name_of_copy=dataset_copy) - df = rubrix.load(name=dataset) - df_copy = rubrix.load(name=dataset_copy) + rb_client.copy(dataset, target=dataset_copy) + df = rb_client.load(name=dataset) + df_copy = rb_client.load(name=dataset_copy) assert df.equals(df_copy) with pytest.raises(AlreadyExistsApiError): - rubrix.copy(dataset, name_of_copy=dataset_copy) + rb_client.copy(dataset, target=dataset_copy) - rubrix.copy(dataset, name_of_copy=dataset_copy, workspace=new_workspace) + rb_client.copy(dataset, target=dataset_copy, target_workspace=new_workspace) - try: - rubrix.set_workspace(new_workspace) - df_copy = rubrix.load(dataset_copy) - assert df.equals(df_copy) + rb_client.set_workspace(new_workspace) + df_copy = rb_client.load(dataset_copy) + assert df.equals(df_copy) - with pytest.raises(AlreadyExistsApiError): - rubrix.copy( - dataset_copy, name_of_copy=dataset_copy, workspace=new_workspace - ) - finally: - rubrix.init() # reset workspace + with pytest.raises(AlreadyExistsApiError): + rb_client.copy( + dataset_copy, target=dataset_copy, target_workspace=new_workspace + ) -def test_update_record(mocked_client): +def test_update_record(mocked_client, rb_client): dataset = "test_update_record" mocked_client.delete(f"/api/datasets/{dataset}") @@ -295,12 +305,12 @@ def test_update_record(mocked_client): annotation_agent="test", annotation=["T"], ) - rubrix.log( + rb_client.log( record, name=dataset, ) - df = rubrix.load(name=dataset) + df = rb_client.load(name=dataset) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["annotation"] == "T" @@ -310,24 +320,24 @@ def test_update_record(mocked_client): inputs=expected_inputs, ) - rubrix.log( + rb_client.log( record, name=dataset, ) - df = rubrix.load(name=dataset) + df = rb_client.load(name=dataset) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["annotation"] is None assert records[0]["annotation_agent"] is None -def test_text_classifier_with_inputs_list(mocked_client): +def test_text_classifier_with_inputs_list(mocked_client, rb_client): dataset = "test_text_classifier_with_inputs_list" mocked_client.delete(f"/api/datasets/{dataset}") expected_inputs = ["A", "List", "of", "values"] - rubrix.log( + rb_client.log( TextClassificationRecord( id=0, inputs=expected_inputs, @@ -337,36 +347,36 @@ def test_text_classifier_with_inputs_list(mocked_client): name=dataset, ) - df = rubrix.load(name=dataset) + df = rb_client.load(name=dataset) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["inputs"]["text"] == expected_inputs -def test_load_with_ids_list(mocked_client): +def test_load_with_ids_list(mocked_client, rb_client): dataset = "test_load_with_ids_list" mocked_client.delete(f"/api/datasets/{dataset}") expected_data = 100 create_some_data_for_text_classification(mocked_client, dataset, n=expected_data) - ds = rubrix.load(name=dataset, ids=[3, 5]) + ds = rb_client.load(name=dataset, ids=[3, 5]) assert len(ds) == 2 -def test_load_with_query(mocked_client): +def test_load_with_query(mocked_client, rb_client): dataset = "test_load_with_query" mocked_client.delete(f"/api/datasets/{dataset}") sleep(1) expected_data = 4 create_some_data_for_text_classification(mocked_client, dataset, n=expected_data) - ds = rubrix.load(name=dataset, query="id:1") + ds = rb_client.load(name=dataset, query="id:1") assert len(ds) == 1 assert ds.id.iloc[0] == 1 @pytest.mark.parametrize("as_pandas", [True, False]) -def test_load_as_pandas(mocked_client, as_pandas): +def test_load_as_pandas(mocked_client, as_pandas, rb_client): dataset = "test_sorted_load" mocked_client.delete(f"/api/datasets/{dataset}") sleep(1) @@ -376,17 +386,17 @@ def test_load_as_pandas(mocked_client, as_pandas): # Check that the default value is True if as_pandas: - records = rubrix.load(name=dataset) + records = rb_client.load(name=dataset) assert isinstance(records, pandas.DataFrame) assert list(records.id) == [0, 1, 2, 3] else: - records = rubrix.load(name=dataset, as_pandas=False) + records = rb_client.load(name=dataset, as_pandas=False) assert isinstance(records, DatasetForTextClassification) assert isinstance(records[0], TextClassificationRecord) assert [record.id for record in records] == [0, 1, 2, 3] -def test_token_classification_spans(mocked_client): +def test_token_classification_spans(rb_client): dataset = "test_token_classification_with_consecutive_spans" texto = "Esto es una prueba" item = rubrix.TokenClassificationRecord( @@ -398,19 +408,19 @@ def test_token_classification_spans(mocked_client): with pytest.raises( Exception, match=r"Defined offset \[s\] is a misaligned entity mention" ): - rubrix.log(item, name=dataset) + rb_client.log(item, name=dataset) item.prediction = [("test", 0, 6)] with pytest.raises( Exception, match=r"Defined offset \[Esto e\] is a misaligned entity mention" ): - rubrix.log(item, name=dataset) + rb_client.log(item, name=dataset) item.prediction = [("test", 0, 4)] - rubrix.log(item, name=dataset) + rb_client.log(item, name=dataset) -def test_load_text2text(mocked_client): +def test_load_text2text(rb_client): records = [ Text2TextRecord( text="test text", @@ -427,37 +437,34 @@ def test_load_text2text(mocked_client): ] dataset = "test_load_text2text" - rubrix.delete(dataset) - rubrix.log(records, name=dataset) + rb_client.delete(dataset) + rb_client.log(records, name=dataset) - df = rubrix.load(name=dataset) + df = rb_client.load(name=dataset) assert len(df) == 2 -def test_client_workspace(mocked_client): - try: - ws = rubrix.get_workspace() - assert ws == "rubrix" +def test_client_workspace(rb_client): + ws = rb_client.active_workspace + assert ws == "rubrix" - rubrix.set_workspace("other-workspace") - assert rubrix.get_workspace() == "other-workspace" + rb_client.set_workspace("other-workspace") + assert rb_client.active_workspace == "other-workspace" - with pytest.raises(Exception, match="Must provide a workspace"): - rubrix.set_workspace(None) + with pytest.raises(Exception, match="Must provide a workspace"): + rb_client.set_workspace(None) - # Mocking user - rubrix._client_instance().__current_user__.workspaces = ["a", "b"] + # Mocking user + rb_client.__current_user__.workspaces = ["a", "b"] - with pytest.raises(Exception, match="Wrong provided workspace c"): - rubrix.set_workspace("c") + with pytest.raises(Exception, match="Wrong provided workspace c"): + rb_client.set_workspace("c") - rubrix.set_workspace("rubrix") - assert rubrix.get_workspace() == "rubrix" - finally: - rubrix.init() # reset workspace + rb_client.set_workspace("rubrix") + assert rb_client.active_workspace == "rubrix" -def test_load_sort(mocked_client): +def test_load_sort(rb_client): records = [ TextClassificationRecord( inputs="test text", @@ -467,13 +474,13 @@ def test_load_sort(mocked_client): ] dataset = "test_load_sort" - rubrix.delete(dataset) - rubrix.log(records, name=dataset) + rb_client.delete(dataset) + rb_client.log(records, name=dataset) # check sorting policies - df = rubrix.load(name=dataset) + df = rb_client.load(name=dataset) assert list(df.id) == [1, 11, "11str", "1str", 2, "2str"] - df = rubrix.load(name=dataset, ids=[1, 2, 11]) + df = rb_client.load(name=dataset, ids=[1, 2, 11]) assert list(df.id) == [1, 2, 11] - df = rubrix.load(name=dataset, ids=["1str", "2str", "11str"]) + df = rb_client.load(name=dataset, ids=["1str", "2str", "11str"]) assert list(df.id) == ["11str", "1str", "2str"] diff --git a/tests/functional_tests/test_log_for_text_classification.py b/tests/functional_tests/test_log_for_text_classification.py index f09a57c40b..c4adddb99d 100644 --- a/tests/functional_tests/test_log_for_text_classification.py +++ b/tests/functional_tests/test_log_for_text_classification.py @@ -1,7 +1,6 @@ import pytest -import rubrix -from rubrix import TextClassificationRecord, TokenClassificationRecord +import rubrix as rb from rubrix.client.sdk.commons.errors import BadRequestApiError, ValidationApiError from rubrix.server.commons.settings import settings @@ -10,14 +9,14 @@ def test_log_records_with_multi_and_single_label_task(mocked_client): dataset = "test_log_records_with_multi_and_single_label_task" expected_inputs = ["This is a text"] - rubrix.delete(dataset) + rb.delete(dataset) records = [ - TextClassificationRecord( + rb.TextClassificationRecord( id=0, inputs=expected_inputs, multi_label=False, ), - TextClassificationRecord( + rb.TextClassificationRecord( id=1, inputs=expected_inputs, multi_label=True, @@ -25,29 +24,30 @@ def test_log_records_with_multi_and_single_label_task(mocked_client): ] with pytest.raises(ValidationApiError): - rubrix.log( + rb.log( records, name=dataset, ) - rubrix.log(records[0], name=dataset) + rb.log(records[0], name=dataset) with pytest.raises(Exception): - rubrix.log(records[1], name=dataset) + rb.log(records[1], name=dataset) def test_delete_and_create_for_different_task(mocked_client): dataset = "test_delete_and_create_for_different_task" text = "This is a text" - rubrix.delete(dataset) - rubrix.log(TextClassificationRecord(id=0, inputs=text), name=dataset) - rubrix.load(dataset) + rb.delete(dataset) + rb.log(rb.TextClassificationRecord(id=0, inputs=text), name=dataset) + rb.load(dataset) - rubrix.delete(dataset) - rubrix.log( - TokenClassificationRecord(id=0, text=text, tokens=text.split(" ")), name=dataset + rb.delete(dataset) + rb.log( + rb.TokenClassificationRecord(id=0, text=text, tokens=text.split(" ")), + name=dataset, ) - rubrix.load(dataset) + rb.load(dataset) def test_search_keywords(mocked_client): @@ -55,12 +55,12 @@ def test_search_keywords(mocked_client): from datasets import load_dataset dataset_ds = load_dataset("Recognai/sentiment-banking", split="train") - dataset_rb = rubrix.read_datasets(dataset_ds, task="TextClassification") + dataset_rb = rb.read_datasets(dataset_ds, task="TextClassification") - rubrix.delete(dataset) - rubrix.log(name=dataset, records=dataset_rb) + rb.delete(dataset) + rb.log(name=dataset, records=dataset_rb) - df = rubrix.load(dataset, query="lim*") + df = rb.load(dataset, query="lim*") assert not df.empty assert "search_keywords" in df.columns top_keywords = set( @@ -78,16 +78,22 @@ def test_search_keywords(mocked_client): def test_log_records_with_empty_metadata_list(mocked_client): dataset = "test_log_records_with_empty_metadata_list" - rubrix.delete(dataset) + rb.delete(dataset) expected_records = [ - TextClassificationRecord(inputs="The input text", metadata={"emptyList": []}), - TextClassificationRecord(inputs="The input text", metadata={"emptyTuple": ()}), - TextClassificationRecord(inputs="The input text", metadata={"emptyDict": {}}), - TextClassificationRecord(inputs="The input text", metadata={"none": None}), + rb.TextClassificationRecord( + inputs="The input text", metadata={"emptyList": []} + ), + rb.TextClassificationRecord( + inputs="The input text", metadata={"emptyTuple": ()} + ), + rb.TextClassificationRecord( + inputs="The input text", metadata={"emptyDict": {}} + ), + rb.TextClassificationRecord(inputs="The input text", metadata={"none": None}), ] - rubrix.log(expected_records, name=dataset) + rb.log(expected_records, name=dataset) - df = rubrix.load(dataset) + df = rb.load(dataset) assert len(df) == len(expected_records) for meta in df.metadata.values.tolist(): @@ -97,46 +103,46 @@ def test_log_records_with_empty_metadata_list(mocked_client): def test_logging_with_metadata_limits_exceeded(mocked_client): dataset = "test_logging_with_metadata_limits_exceeded" - rubrix.delete(dataset) - expected_record = TextClassificationRecord( + rb.delete(dataset) + expected_record = rb.TextClassificationRecord( inputs="The input text", metadata={k: k for k in range(0, settings.metadata_fields_limit + 1)}, ) with pytest.raises(BadRequestApiError): - rubrix.log(expected_record, name=dataset) + rb.log(expected_record, name=dataset) expected_record.metadata = {k: k for k in range(0, settings.metadata_fields_limit)} - rubrix.log(expected_record, name=dataset) + rb.log(expected_record, name=dataset) expected_record.metadata["new_key"] = "value" with pytest.raises(BadRequestApiError): - rubrix.log(expected_record, name=dataset) + rb.log(expected_record, name=dataset) def test_log_with_other_task(mocked_client): dataset = "test_log_with_other_task" - rubrix.delete(dataset) - record = TextClassificationRecord( + rb.delete(dataset) + record = rb.TextClassificationRecord( inputs="The input text", ) - rubrix.log(record, name=dataset) + rb.log(record, name=dataset) with pytest.raises(BadRequestApiError): - rubrix.log( - TokenClassificationRecord(text="The text", tokens=["The", "text"]), + rb.log( + rb.TokenClassificationRecord(text="The text", tokens=["The", "text"]), name=dataset, ) def test_dynamics_metadata(mocked_client): dataset = "test_dynamics_metadata" - rubrix.log( - TextClassificationRecord(inputs="This is a text", metadata={"a": "value"}), + rb.log( + rb.TextClassificationRecord(inputs="This is a text", metadata={"a": "value"}), name=dataset, ) - rubrix.log( - TextClassificationRecord(inputs="Another text", metadata={"b": "value"}), + rb.log( + rb.TextClassificationRecord(inputs="Another text", metadata={"b": "value"}), name=dataset, ) diff --git a/tests/metrics/test_common_metrics.py b/tests/metrics/test_common_metrics.py index 65b4da18c7..fee0336f4f 100644 --- a/tests/metrics/test_common_metrics.py +++ b/tests/metrics/test_common_metrics.py @@ -1,11 +1,24 @@ -import httpx +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import rubrix as rb +from rubrix.metrics.commons import records_status, text_length def test_status_distribution(mocked_client): dataset = "test_status_distribution" - import rubrix as rb - rb.delete(dataset) rb.log( @@ -27,8 +40,6 @@ def test_status_distribution(mocked_client): name=dataset, ) - from rubrix.metrics.commons import records_status - results = records_status(dataset) assert results assert results.data == {"Default": 1, "Validated": 1} @@ -38,8 +49,6 @@ def test_status_distribution(mocked_client): def test_text_length(mocked_client): dataset = "test_text_length" - import rubrix as rb - rb.delete(dataset) rb.log( @@ -67,8 +76,6 @@ def test_text_length(mocked_client): name=dataset, ) - from rubrix.metrics.commons import text_length - results = text_length(dataset) assert results assert results.data == { diff --git a/tests/metrics/test_text_classification.py b/tests/metrics/test_text_classification.py index 0f19602160..f8330ff4b7 100644 --- a/tests/metrics/test_text_classification.py +++ b/tests/metrics/test_text_classification.py @@ -1,8 +1,24 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import rubrix as rb +from rubrix.metrics.text_classification import f1, f1_multilabel + + def test_metrics_for_text_classification(mocked_client): dataset = "test_metrics_for_text_classification" - import rubrix as rb - rb.log( [ rb.TextClassificationRecord( @@ -21,8 +37,6 @@ def test_metrics_for_text_classification(mocked_client): name=dataset, ) - from rubrix.metrics.text_classification import f1, f1_multilabel - results = f1(dataset) assert results assert results.data == { @@ -62,7 +76,6 @@ def test_metrics_for_text_classification(mocked_client): def test_f1_without_results(mocked_client): dataset = "test_f1_without_results" - import rubrix as rb rb.log( [ @@ -78,8 +91,6 @@ def test_f1_without_results(mocked_client): name=dataset, ) - from rubrix.metrics.text_classification import f1 - results = f1(dataset) assert results assert results.data == {} diff --git a/tests/server/test_app.py b/tests/server/test_app.py new file mode 100644 index 0000000000..67f54c1102 --- /dev/null +++ b/tests/server/test_app.py @@ -0,0 +1,28 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from importlib import reload + +import pytest + +from rubrix.server import app + + +def test_fallback_app(monkeypatch): + monkeypatch.setitem(sys.modules, "rubrix.server.server", None) + reload(app) + + with pytest.raises(RuntimeError, match="Cannot start rubrix server"): + app.app() diff --git a/tests/test_init.py b/tests/test_init.py index 7403a7d45a..3113533e7d 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -13,295 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Rubrix Client Init Testing File""" +import logging +import sys -import os +from rubrix.logging import LoguruLoggerHandler +from rubrix.utils import _LazyRubrixModule -import httpx -import pytest -import rubrix -from rubrix.client import RubrixClient -from rubrix.client.sdk.client import AuthenticatedClient -from rubrix.client.sdk.commons.errors import GenericApiError, UnauthorizedApiError +def test_lazy_module(): + assert isinstance(sys.modules["rubrix"], _LazyRubrixModule) -@pytest.fixture -def mock_response_200(monkeypatch): - """Creating of mock_get method from the class, and monkeypatch application. - - It will return a 200 status code, emulating the correct login. - - Parameters - ---------- - monkeypatch - Mockup function - """ - - def mock_get(url, *args, **kwargs): - if "/api/me" in url: - return httpx.Response(status_code=200, json={"username": "booohh"}) - return httpx.Response(status_code=200) - - monkeypatch.setattr( - httpx, "get", mock_get - ) # apply the monkeypatch for requests.get to mock_get - - -@pytest.fixture -def mock_response_500(monkeypatch): - """Creating of mock_get method from the class, and monkeypatch application. - - It will return a 500 status code, emulating an invalid state of the API error. - - Parameters - ---------- - monkeypatch - Mockup function - - """ - - def mock_get(*args, **kwargs): - return httpx.Response(status_code=500) - - monkeypatch.setattr( - httpx, "get", mock_get - ) # apply the monkeypatch for requests.get to mock_get - - -@pytest.fixture -def mock_response_token_401(monkeypatch): - """Creating of mock_get method from the class, and monkeypatch application. - - It will return a 401 status code, emulating an invalid credentials error when using tokens to log in. - Iterable stucture to be able to pass the first 200 status code check - - Parameters - ---------- - monkeypatch - Mockup function - - """ - response_200 = httpx.Response(status_code=200) - response_401 = httpx.Response(status_code=401) - - def mock_get(*args, **kwargs): - if kwargs["url"] == "fake_url/api/me": - return response_401 - elif kwargs["url"] == "fake_url/api/docs/spec.json": - return response_200 - - monkeypatch.setattr( - httpx, "get", mock_get - ) # apply the monkeypatch for requests.get to mock_get - - -@pytest.fixture -def api_url_env_var(): - """Sets an api_url via environment variable""" - - os.environ["RUBRIX_API_URL"] = "http://fakeurl.com" - - -@pytest.fixture -def api_url_env_var_trailing_slash(): - """Sets an api_url via environment variable. The url has trailing slash""" - - os.environ["RUBRIX_API_URL"] = "http://fakeurl.com/" - - -@pytest.fixture -def token_env_var(): - """Sets an api_url via environment variable""" - - os.environ["RUBRIX_API_KEY"] = "622" - - -def test_init_correct(mock_response_200): - """Testing correct default initalization - - It checks if the _client created is a RubrixClient object. - - Parameters - ---------- - mock_response_200 - Mocked correct http response - """ - - rubrix.init() - - assert isinstance(rubrix._client, RubrixClient) - - -def test_init_incorrect(mock_response_500): - """Testing incorrect default initalization - - It checks an Exception is raised with the correct message. - - Parameters - ---------- - mock_response_500 - Mocked incorrect http response - """ - - rubrix._client = None # assert empty client - with pytest.raises(GenericApiError): - rubrix.init() - - -def test_init_token_correct(mock_response_200): - """Testing correct token initalization - - It checks if the _client created is a RubrixClient object. - - Parameters - ---------- - mock_response_200 - Mocked correct http response - """ - rubrix._client = None # assert empty client - rubrix.init(api_key="fjkjdf333") - - assert isinstance(rubrix._client, RubrixClient) - - -def test_init_token_incorrect(mock_response_500): - """Testing incorrect token initalization - - It checks an Exception is raised with the correct message. - - Parameters - ---------- - mock_response_500 - Mocked correct http response - """ - rubrix._client = None # assert empty client - with pytest.raises(GenericApiError): - rubrix.init(api_key="422") - - -def test_init_token_auth_fail(mock_response_token_401): - """Testing initalization with failed authentication - - It checks an Exception is raised with the correct message. - - Parameters - ---------- - mock_response_401 - Mocked correct http response - """ - rubrix._client = None # assert empty client - with pytest.raises(UnauthorizedApiError): - rubrix.init(api_url="fake_url", api_key="422") - - -def test_init_evironment_url(api_url_env_var, mock_response_200): - """Testing initalization with api_url provided via environment variable - - It checks the url in the environment variable gets passed to client. - - Parameters - ---------- - api_url_env_var - Fixture to set the fake url in the env variable - mock_response_200 - Mocked correct http response - """ - rubrix._client = None # assert empty client - - rubrix.init() - - assert isinstance(rubrix._client, RubrixClient) - assert isinstance(rubrix._client._client, AuthenticatedClient) - assert rubrix._client._client.base_url == "http://fakeurl.com" - - -def test_init_evironment_url_token(api_url_env_var, token_env_var, mock_response_200): - """Testing initalization with api_url and tokenprovided via environment variable - - It checks the url and token in the environment variable gets passed to client. - - Parameters - ---------- - api_url_env_var - Fixture to set the fake url in the env variable - token_env_var - Fixture to set the fake token in the env variable - mock_response_200 - Mocked correct http response - """ - rubrix._client = None # assert empty client - - rubrix.init() - - assert isinstance(rubrix._client, RubrixClient) - assert isinstance(rubrix._client._client, AuthenticatedClient) - assert rubrix._client._client.base_url == "http://fakeurl.com" - assert rubrix._client._client.token == str(622) - - -def test_init_evironment_no_url_token(token_env_var, mock_response_200): - """Testing initalization with token provided via environment variable and api_url via args - - It checks a non-secured Client is created - - Parameters - ---------- - token_env_var - Fixture to set the fake token in the env variable - mock_response_200 - Mocked correct http response - """ - rubrix._client = None # assert empty client - - rubrix.init(api_url="http://anotherfakeurl.com") - - assert isinstance(rubrix._client, RubrixClient) - assert isinstance(rubrix._client._client, AuthenticatedClient) - assert rubrix._client._client.base_url == "http://anotherfakeurl.com" - - -def test_trailing_slash(api_url_env_var_trailing_slash, mock_response_200): - """Testing initalization with provided api_url via environment variable and argument - - It checks the trailing slash is removed in all cases - - Parameters - ---------- - api_url_env_var - Fixture to set the fake url in the env variable, with trailing slash - mock_response_200 - Mocked correct http response - """ - - rubrix._client = None # assert empty client - - # Environment variable case - rubrix.init(api_url="http://anotherfakeurl.com/") - assert rubrix._client._client.base_url == "http://anotherfakeurl.com" - - rubrix._client = None # assert empty client - - # Argument case - rubrix.init() - assert rubrix._client._client.base_url == "http://fakeurl.com" - - -def test_default_init(mock_response_200): - rubrix._client = None - - if "RUBRIX_API_URL" in os.environ: - del os.environ["RUBRIX_API_URL"] - - if "RUBRIX_API_KEY" in os.environ: - del os.environ["RUBRIX_API_KEY"] - - rubrix.init() - - assert isinstance(rubrix._client._client, AuthenticatedClient) - assert rubrix._client._client.base_url == "http://localhost:6900" - - expected_token = "blablabla" - rubrix.init(api_key=expected_token) - assert isinstance(rubrix._client._client, AuthenticatedClient) - assert rubrix._client._client.token == expected_token +def test_configure_logging_call(): + assert isinstance(logging.getLogger("rubrix").handlers[0], LoguruLoggerHandler) diff --git a/tests/test_log.py b/tests/test_log.py deleted file mode 100644 index bd1dcfd0b4..0000000000 --- a/tests/test_log.py +++ /dev/null @@ -1,333 +0,0 @@ -# coding=utf-8 -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Rubrix Log Test Unit - -This pytest modules aims to test the correct state to the log function. -Interaction with the client will be mocked, as this test are independent from the API, -which could or could not be mounted. -""" - -import logging -from typing import cast - -import httpx -import pytest - -import rubrix -from rubrix import ( - BulkResponse, - Text2TextRecord, - TextClassificationRecord, - TokenClassificationRecord, -) -from rubrix.client.sdk.commons.models import Response - - -@pytest.fixture -def mock_response_200(monkeypatch): - """Mock_get method from the class, and monkeypatch application. - - It will return a 200 status code in the init function, emulating the correct login. - - Parameters - ---------- - monkeypatch - Mockup function - """ - - def mock_get(url, *args, **kwargs): - if "/api/me" in url: - return httpx.Response(200, json={"username": "rubrix"}) - return httpx.Response(200) - - monkeypatch.setattr( - httpx, "get", mock_get - ) # apply the monkeypatch for requests.get to mock_get - - -@pytest.fixture -def mock_response_text(monkeypatch): - """Mock log response for TextClassification records""" - - def mock_get(*args, **kwargs): - return Response( - status_code=200, - content=b"Everything's fine", - headers={ - "date": "Tue, 09 Mar 2021 10:18:23 GMT", - "server": "uvicorn", - "content-length": "43", - "content-type": "application/json", - }, - parsed=BulkResponse(dataset="test", processed=500, failed=0), - ) - - monkeypatch.setattr( - "rubrix.client.rubrix_client.text_classification_bulk", mock_get - ) # apply the monkeypatch for requests.get to mock_get - - -@pytest.fixture -def mock_response_token(monkeypatch): - """Mock log response for TokenClassification records""" - - def mock_get(*args, **kwargs): - return Response( - status_code=200, - content=b"Everything's fine", - headers={ - "date": "Tue, 09 Mar 2021 10:18:23 GMT", - "server": "uvicorn", - "content-length": "43", - "content-type": "application/json", - }, - parsed=BulkResponse(dataset="test", processed=500, failed=0), - ) - - monkeypatch.setattr( - "rubrix.client.rubrix_client.token_classification_bulk", mock_get - ) # apply the monkeypatch for requests.get to mock_get - - -@pytest.fixture -def mock_response_text2text(monkeypatch): - """Mock log response for Text2Text records""" - - _response = BulkResponse(dataset="test", processed=500, failed=0) - - def mock_get(*args, **kwargs): - return Response( - status_code=200, - content=b"Everything's fine", - headers={ - "date": "Tue, 09 Mar 2021 10:18:23 GMT", - "server": "uvicorn", - "content-length": "43", - "content-type": "application/json", - }, - parsed=_response, - ) - - monkeypatch.setattr( - "rubrix.client.rubrix_client.text2text_bulk", - mock_get, - ) # apply the monkeypatch for requests.get to mock_get - - -def test_text_classification(mock_response_text, mocked_client): - """Testing text classification with log function - - It checks a Response is generated. - - Parameters - ---------- - mock_response_text - Mocked response for the text_classification bulk API call - """ - records = [ - TextClassificationRecord( - inputs={"review_body": "increible test"}, - prediction=[("test", 0.9), ("test2", 0.1)], - annotation="test", - metadata={"product_category": "test de pytest"}, - id="test", - ) - ] - - assert ( - rubrix.log( - name="test", - records=records, - tags={"type": "sentiment classifier", "lang": "spanish"}, - ) - == BulkResponse(dataset="test", processed=500, failed=0) - ) - - -def test_token_classification(mock_response_token): - """Testing token classification with log function - - It checks a Response is generated. - - Parameters - ---------- - mock_response_token - Mocked response for the token_classification bulk API call - """ - records = [ - TokenClassificationRecord( - text="Super test", - tokens=["Super", "test"], - prediction=[("test", 6, 10)], - annotation=[("test", 6, 10)], - prediction_agent="spacy", - annotation_agent="recognai", - metadata={"model": "spacy_es_core_news_sm"}, - id=1, - ) - ] - - assert ( - rubrix.log( - name="test", - records=records[0], - tags={"type": "sentiment classifier", "lang": "spanish"}, - ) - == BulkResponse(dataset="test", processed=500, failed=0) - ) - - -def test_text2text(mock_response_text2text): - """Testing text2text with log function - - It checks a Response is generated. - - Parameters - ---------- - mock_response_text2text - Mocked response for the text2text bulk API call - """ - records = [ - Text2TextRecord( - text="Super test", - prediction=[("test", 0.5)], - annotation="test", - prediction_agent="spacy", - annotation_agent="recognai", - metadata={"model": "spacy_es_core_news_sm"}, - id=1, - ) - ] - - assert ( - rubrix.log( - name="test", - records=records[0], - tags={"type": "text2text", "lang": "spanish"}, - ) - == BulkResponse(dataset="test", processed=500, failed=0) - ) - - -def test_no_name(mock_response_200): - """Testing classification with no input name - - It checks an Exception is raised, with the corresponding message. - - Parameters - ---------- - mock_response_200 - Mocked correct http response, emulating API init - """ - - with pytest.raises( - Exception, match="Empty project name has been passed as argument." - ): - assert rubrix.log(name="", records=cast(TextClassificationRecord, None)) - - -def test_empty_records(mock_response_200): - """Testing classification with empty record list - - It checks an Exception is raised, with the corresponding message. - - Parameters - ---------- - mock_response_200 - Mocked correct http response, emulating API init - """ - - with pytest.raises( - Exception, match="Empty record list has been passed as argument." - ): - rubrix.log(name="test", records=[]) - - -def test_unknow_record_type(mock_response_200): - """Testing classification with unknown record type - - It checks an Exception is raised, with the corresponding message. - - Parameters - ---------- - mock_response_200 - Mocked correct http response, emulating API init - """ - - with pytest.raises(Exception, match="Unknown record type passed as argument."): - rubrix.log(name="test", records=["12"]) - - -@pytest.fixture -def mock_wrong_bulk_response(monkeypatch): - def mock(*args, **kwargs): - return Response( - status_code=500, - headers={}, - content=b"", - parsed={"error": "the error message "}, - ) - - monkeypatch.setattr("rubrix.client.rubrix_client.text_classification_bulk", mock) - - -def test_wrong_response(mock_response_200, mock_wrong_bulk_response): - rubrix._client = None - with pytest.raises( - Exception, - match="Connection error: API is not responding. The API answered with", - ): - rubrix.log( - name="dataset", - records=[TextClassificationRecord(inputs={"text": "The textual info"})], - tags={"env": "Test"}, - ) - - -@pytest.mark.skip -def test_info_message(mock_response_200, mock_response_text, caplog): - """Testing initialization info message - - Parameters - ---------- - mock_response_200 - Mocked correct http response, emulating API init - mock_response_text - Mocked response given by the sync method, emulating the log of data - caplog - Captures the logging output - """ - - rubrix._client = None # Force client initialization - caplog.set_level(logging.INFO) - - records = [ - TextClassificationRecord( - inputs={"review_body": "increible test"}, - prediction=[("test", 0.9), ("test2", 0.1)], - annotation="test", - metadata={"product_category": "test de pytest"}, - id="test", - ) - ] - - rubrix.log( - name="test", - records=records, - tags={"type": "sentiment classifier", "lang": "spanish"}, - ) - - assert "Rubrix has been initialized on http://localhost:6900" in caplog.text diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..2410797e44 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,67 @@ +# coding=utf-8 +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from rubrix.utils import _LazyRubrixModule + + +def test_lazy_rubrix_module(monkeypatch): + def mock_import_module(name, package): + return name + + monkeypatch.setattr("importlib.import_module", mock_import_module) + + lazy_module = _LazyRubrixModule( + name="rb_mock", + module_file="rb_mock_file", + import_structure={"mock_module": ["title"]}, + extra_objects={"string": str}, + deprecated_import_structure={"dep_mock_module": ["upper"]}, + ) + assert all(attr in dir(lazy_module) for attr in ["mock_module", "title"]) + assert lazy_module.mock_module == ".mock_module" + assert lazy_module.title() == ".mock_module".title() + assert lazy_module.string == str + + with pytest.warns( + FutureWarning, match="Importing 'dep_mock_module' from the rubrix namespace" + ): + assert lazy_module.dep_mock_module == ".dep_mock_module" + + with pytest.warns( + FutureWarning, match="Importing 'upper' from the rubrix namespace" + ): + assert lazy_module.upper() == ".dep_mock_module".upper() + + with pytest.raises(AttributeError): + lazy_module.not_available_mock + + assert lazy_module.__reduce__() + + +def test_lazy_rubrix_module_import_error(monkeypatch): + def mock_import_module(*args, **kwargs): + raise Exception + + monkeypatch.setattr("importlib.import_module", mock_import_module) + + lazy_module = _LazyRubrixModule( + name="rb_mock", + module_file=__file__, + import_structure={"mock_module": ["title"]}, + ) + + with pytest.raises(RuntimeError, match="Failed to import rb_mock.mock_module"): + lazy_module.mock_module