diff --git a/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue b/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue index 4ba78129a8..08ac5a15cf 100644 --- a/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue +++ b/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue @@ -68,12 +68,12 @@ export default { maxVisibleLabels() { return DatasetViewSettings.MAX_VISIBLE_LABELS; }, - query() { - return this.dataset.query.text; - }, labels() { return this.dataset.labels.map((l) => ({ class: l, selected: false })); }, + query() { + return this.dataset.query.text; + }, sortedLabels() { return this.labels.slice().sort((a, b) => (a.score > b.score ? -1 : 1)); }, diff --git a/frontend/models/TextClassification.js b/frontend/models/TextClassification.js index aa6570a4df..fb1edb58e9 100644 --- a/frontend/models/TextClassification.js +++ b/frontend/models/TextClassification.js @@ -272,16 +272,28 @@ class TextClassificationDataset extends ObservationDataset { get labels() { const { labels } = (this.metadata || {})[USER_DATA_METADATA_KEY] || {}; const aggregations = this.globalResults.aggregations; + const label2str = (label) => label.class; + const recordsLabels = this.results.records.flatMap((record) => { + return [] + .concat( + record.annotation ? record.annotation.labels.map(label2str) : [] + ) + .concat( + record.prediction ? record.prediction.labels.map(label2str) : [] + ); + }); const uniqueLabels = [ ...new Set( (labels || []) .filter((l) => l && l.trim()) .concat(this._labels || []) + .concat(recordsLabels) .concat(Object.keys(aggregations.annotated_as)) .concat(Object.keys(aggregations.predicted_as)) ), ]; + uniqueLabels.sort(); return uniqueLabels; } diff --git a/src/rubrix/server/tasks/text_classification/metrics.py b/src/rubrix/server/tasks/text_classification/metrics.py index 8ec35ccf14..15b5707ac4 100644 --- a/src/rubrix/server/tasks/text_classification/metrics.py +++ b/src/rubrix/server/tasks/text_classification/metrics.py @@ -1,13 +1,12 @@ -from typing import Any, ClassVar, Dict, Iterable, List, Optional +from typing import Any, ClassVar, Dict, Iterable, List from pydantic import Field from sklearn.metrics import precision_recall_fscore_support from sklearn.preprocessing import MultiLabelBinarizer -from rubrix.server.tasks.commons.metrics import CommonTasksMetrics +from rubrix.server.tasks.commons.metrics import CommonTasksMetrics, GenericRecord from rubrix.server.tasks.commons.metrics.model.base import ( BaseMetric, - BaseTaskMetrics, PythonMetric, TermsAggregation, ) @@ -87,39 +86,22 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any: } -class DatasetLabels(TermsAggregation): +class DatasetLabels(PythonMetric): id: str = Field("dataset_labels", const=True) name: str = Field("The dataset labels", const=True) - fixed_size: int = Field(1500, const=True) - script: Dict[str, Any] = Field( - { - "lang": "painless", - "source": """ - def all_labels = []; - def predicted = doc.containsKey('prediction.labels.class_label.keyword') - ? doc['prediction.labels.class_label.keyword'] : null; - def annotated = doc.containsKey('annotation.labels.class_label.keyword') - ? doc['annotation.labels.class_label.keyword'] : null; - - if (predicted != null && predicted.size() > 0) { - for (int i = 0; i < predicted.length; i++) { - all_labels.add(predicted[i]) - } - } - - if (annotated != null && annotated.size() > 0) { - for (int i = 0; i < annotated.length; i++) { - all_labels.add(annotated[i]) - } - } - return all_labels; - """, - }, - const=True, - ) - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - return {"labels": [k for k in (aggregation_result or {}).keys()]} + + def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]: + ds_labels = set() + for record in records: + if record.annotation: + ds_labels.update( + [label.class_label for label in record.annotation.labels] + ) + if record.prediction: + ds_labels.update( + [label.class_label for label in record.prediction.labels] + ) + return {"labels": ds_labels or []} class TextClassificationMetrics(BaseTaskMetrics[TextClassificationRecord]): diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py index 627852e0bb..77887c7ef2 100644 --- a/tests/server/metrics/test_api.py +++ b/tests/server/metrics/test_api.py @@ -206,3 +206,57 @@ def test_dataset_metrics(): json={}, ) assert response.status_code == 200, f"{metric}: {response.json()}" + + +def test_dataset_labels_for_text_classification(mocked_client): + records = [ + TextClassificationRecord.parse_obj(data) + for data in [ + { + "id": 0, + "inputs": {"text": "Some test data"}, + "prediction": {"agent": "test", "labels": [{"class": "A"}]}, + }, + { + "id": 1, + "inputs": {"text": "Some test data"}, + "annotation": {"agent": "test", "labels": [{"class": "B"}]}, + }, + { + "id": 2, + "inputs": {"text": "Some test data"}, + "prediction": { + "agent": "test", + "labels": [ + {"class": "A", "score": 0.5}, + { + "class": "D", + "score": 0.5, + }, + ], + }, + "annotation": {"agent": "test", "labels": [{"class": "E"}]}, + }, + ] + ] + request = TextClassificationBulkData(records=records) + dataset = "test_dataset_labels_for_text_classification" + + assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200 + + assert ( + mocked_client.post( + f"/api/datasets/{dataset}/TextClassification:bulk", + json=request.dict(by_alias=True), + ).status_code + == 200 + ) + + response = mocked_client.post( + f"/api/datasets/TextClassification/{dataset}/metrics/dataset_labels:summary", + json={}, + ) + assert response.status_code == 200 + response = response.json() + labels = response["labels"] + assert sorted(labels) == ["A", "B", "D", "E"]