Skip to content

Commit

Permalink
feat(#735): add warning when agent but no prediction/annotation is pr…
Browse files Browse the repository at this point in the history
…ovided (#987)

* feat: add warning when only agent is provided

* refactor: avoid global import

* docs: remove not important members from docs

* test: add test

* refactor: use root_validators instead of init

(cherry picked from commit 974ecb2)
  • Loading branch information
David Fidalgo authored and frascuchon committed Jan 31, 2022
1 parent afb6611 commit 99baf63
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/reference/python/python_client.rst
Expand Up @@ -19,3 +19,4 @@ Models

.. automodule:: rubrix.client.models
:members:
:exclude-members: BaseRecord, BulkResponse
11 changes: 9 additions & 2 deletions src/rubrix/__init__.py
Expand Up @@ -21,14 +21,21 @@
import logging
import os
import re
from typing import Iterable
from typing import Any, Dict, Iterable, List, Optional, Union

import pandas
import pkg_resources

from rubrix._constants import DEFAULT_API_KEY
from rubrix.client import RubrixClient
from rubrix.client.models import *
from rubrix.client.models import (
BulkResponse,
Record,
Text2TextRecord,
TextClassificationRecord,
TokenAttributions,
TokenClassificationRecord,
)
from rubrix.monitoring.model_monitor import monitor

try:
Expand Down
20 changes: 6 additions & 14 deletions src/rubrix/server/tasks/token_classification/service/service.py
Expand Up @@ -17,27 +17,22 @@

from fastapi import Depends

from rubrix import MAX_KEYWORD_LENGTH
from rubrix.server.commons.es_helpers import (
aggregations,
sort_by2elasticsearch,
)
from rubrix._constants import MAX_KEYWORD_LENGTH
from rubrix.server.commons.es_helpers import aggregations, sort_by2elasticsearch
from rubrix.server.datasets.model import Dataset
from rubrix.server.tasks.commons import (
BulkResponse,
EsRecordDataFieldNames,
SortableField,
)
from rubrix.server.tasks.commons.dao import (
extends_index_properties,
)
from rubrix.server.tasks.commons.dao import extends_index_properties
from rubrix.server.tasks.commons.dao.dao import DatasetRecordsDAO, dataset_records_dao
from rubrix.server.tasks.commons.dao.model import RecordSearch
from rubrix.server.tasks.commons.metrics.service import MetricsService
from rubrix.server.tasks.token_classification.api.model import (
CreationTokenClassificationRecord,
MENTIONS_ES_FIELD_NAME,
PREDICTED_MENTIONS_ES_FIELD_NAME,
CreationTokenClassificationRecord,
TokenClassificationAggregations,
TokenClassificationQuery,
TokenClassificationRecord,
Expand Down Expand Up @@ -177,7 +172,7 @@ def search(
),
size=size,
record_from=record_from,
exclude_fields=["metrics"] if exclude_metrics else None
exclude_fields=["metrics"] if exclude_metrics else None,
)
return TokenClassificationSearchResults(
total=results.total,
Expand Down Expand Up @@ -239,8 +234,5 @@ def token_classification_service(
"""
global _instance
if not _instance:
_instance = TokenClassificationService(
dao=dao,
metrics=metrics
)
_instance = TokenClassificationService(dao=dao, metrics=metrics)
return _instance

0 comments on commit 99baf63

Please sign in to comment.