Skip to content

Commit

Permalink
test: update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Jul 27, 2022
1 parent 2ffec9d commit 83a5965
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 17 deletions.
5 changes: 5 additions & 0 deletions tests/client/sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def remove_description(self, schema: dict):
def remove_pattern(self, schema: dict):
return self.remove_key(schema, key="pattern")

def are_compatible_api_schemas(self, client_schema: dict, server_schema: dict):
return (
client_schema["properties"].items() <= server_schema["properties"].items()
)


@pytest.fixture(scope="session")
def helpers():
Expand Down
4 changes: 1 addition & 3 deletions tests/client/sdk/text2text/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def test_query_schema(helpers):
client_schema = Text2TextQuery.schema()
server_schema = ServerText2TextQuery.schema()

assert helpers.remove_description(client_schema) == helpers.remove_description(
server_schema
)
assert helpers.are_compatible_api_schemas(client_schema, server_schema)


@pytest.mark.parametrize(
Expand Down
4 changes: 1 addition & 3 deletions tests/client/sdk/text_classification/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def test_query_schema(helpers):
client_schema = TextClassificationQuery.schema()
server_schema = ServerTextClassificationQuery.schema()

assert helpers.remove_description(client_schema) == helpers.remove_description(
server_schema
)
assert helpers.are_compatible_api_schemas(client_schema, server_schema)


def test_labeling_rule_schema(helpers):
Expand Down
4 changes: 1 addition & 3 deletions tests/client/sdk/token_classification/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def test_query_schema(helpers):
client_schema = TokenClassificationQuery.schema()
server_schema = ServerTokenClassificationQuery.schema()

assert helpers.remove_description(client_schema) == helpers.remove_description(
server_schema
)
assert helpers.are_compatible_api_schemas(client_schema, server_schema)


@pytest.mark.parametrize(
Expand Down
13 changes: 8 additions & 5 deletions tests/functional_tests/search/test_search_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@


@pytest.fixture
def es_wrapper():
def backend():
return ElasticsearchBackend.get_instance()


@pytest.fixture
def dao(es_wrapper: ElasticsearchBackend):
return DatasetRecordsDAO.get_instance(es=es_wrapper)
def dao(backend: ElasticsearchBackend):
return DatasetRecordsDAO.get_instance(es=backend)


@pytest.fixture
Expand All @@ -38,7 +38,9 @@ def service(dao: DatasetRecordsDAO, metrics: MetricsService):

def test_query_builder_with_query_range(backend: ElasticsearchBackend):
es_query = backend.query_builder(
"ds", query=TextClassificationQuery(score=ScoreRange(range_from=10))
"ds",
schema=None,
query=TextClassificationQuery(score=ScoreRange(range_from=10)),
)
assert es_query == {
"bool": {
Expand All @@ -53,7 +55,7 @@ def test_query_builder_with_query_range(backend: ElasticsearchBackend):
}


def test_query_builder_with_nested(mocked_client, backend: ElasticsearchBackend):
def test_query_builder_with_nested(mocked_client, dao, backend: ElasticsearchBackend):
dataset = Dataset(
name="test_query_builder_with_nested",
owner=rubrix.get_workspace(),
Expand All @@ -71,6 +73,7 @@ def test_query_builder_with_nested(mocked_client, backend: ElasticsearchBackend)

es_query = backend.query_builder(
dataset=dataset,
schema=dao.get_dataset_schema(dataset),
query=TokenClassificationQuery(
advanced_query_dsl=True,
query_text="metrics.predicted.mentions:(label:NAME AND score:[* TO 0.1])",
Expand Down
2 changes: 1 addition & 1 deletion tests/labeling/text_classification/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_rule_metrics_without_annotated(
)

metrics = rule.metrics(log_dataset_without_annotations)
assert metrics == expected_metrics
assert expected_metrics == metrics


def delete_rule_silently(client, dataset: str, rule: Rule):
Expand Down
2 changes: 1 addition & 1 deletion tests/server/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

import os

from rubrix.server.apis.v0.models.commons.model import TaskStatus
from rubrix.server.apis.v0.models.text_classification import (
TaskType,
TextClassificationBulkData,
TextClassificationRecord,
)
from rubrix.server.commons.models import TaskStatus


def create_some_data_for_text_classification(client, name: str, n: int):
Expand Down
2 changes: 1 addition & 1 deletion tests/server/text_classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from pydantic import ValidationError

from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.apis.v0.models.commons.model import TaskStatus
from rubrix.server.apis.v0.models.text_classification import (
ClassPrediction,
PredictionStatus,
TextClassificationAnnotation,
TextClassificationQuery,
TextClassificationRecord,
)
from rubrix.server.commons.models import TaskStatus
from rubrix.server.elasticseach.search.query_builder import EsQueryBuilder


Expand Down

0 comments on commit 83a5965

Please sign in to comment.