Skip to content

Commit

Permalink
feat(#950): include search keywords as part of record results (#1201)
Browse files Browse the repository at this point in the history
* chore: include search_keywords in client records

* chore: signatures

* feat: include search_records as part of client records

* fix: add highlight on dataset scan

* test: add missing tests

* test: estabilize tests

* Apply suggestions from code review

Co-authored-by: David Fidalgo <david@recogn.ai>

* test: try to fix push to hf hub

Co-authored-by: David Fidalgo <david@recogn.ai>

(cherry picked from commit 0678043)
  • Loading branch information
frascuchon committed Mar 30, 2022
1 parent 1313bab commit 364ba1b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/rubrix/server/tasks/storage/service.py
Expand Up @@ -37,7 +37,7 @@ def store_records(
self,
dataset: BaseDatasetDB,
records: List[Record],
record_type: Type[BaseRecord],
record_type: Type[Record],
) -> int:
"""Store a set of records"""
self._compute_record_metrics(dataset, records)
Expand Down
11 changes: 6 additions & 5 deletions tests/client/test_dataset.py
Expand Up @@ -282,23 +282,24 @@ def test_to_from_pandas(self, records, request):
reason="You need a HF Hub access token to test the push_to_hub feature",
)
@pytest.mark.parametrize(
"records",
"name",
[
"singlelabel_textclassification_records",
"multilabel_textclassification_records",
],
)
def test_push_to_hub(self, request, records):
records = request.getfixturevalue(records)
def test_push_to_hub(self, request, name: str):
records = request.getfixturevalue(name)
dataset_name = f"rubrix/_test_text_classification_records-{name}"
dataset_rb = rb.DatasetForTextClassification(records)
dataset_rb.to_datasets().push_to_hub(
"rubrix/_test_text_classification_records",
dataset_name,
token=_HF_HUB_ACCESS_TOKEN,
private=True,
)
sleep(1)
dataset_ds = datasets.load_dataset(
"rubrix/_test_text_classification_records",
dataset_name,
use_auth_token=_HF_HUB_ACCESS_TOKEN,
split="train",
)
Expand Down
Expand Up @@ -452,7 +452,7 @@ def test_search_keywords(mocked_client):
dataset = "test_search_keywords"
from datasets import load_dataset

dataset_ds = load_dataset("rubrix/gutenberg_spacy-ner", split="train")
dataset_ds = load_dataset("rubrix/gutenberg_spacy-ner_sm", split="train")
dataset_rb = rubrix.read_datasets(dataset_ds, task="TokenClassification")

rubrix.delete(dataset)
Expand Down

0 comments on commit 364ba1b

Please sign in to comment.