From 28e97c6f0599c0487b2c03ee8334fea430d572df Mon Sep 17 00:00:00 2001 From: leiyre Date: Thu, 10 Mar 2022 15:29:12 +0100 Subject: [PATCH] fix(#1238): show prediction labels when annotating rule (#1239) This PR shows prediction labels when annotating rule in Weak Supervision Closes #1238 * fix(metrics): compute dataset labels as python metric * test: fix tests * fix: compute dataset label properly * fix show all labels for empty query view * empty query * refactor: revert comp. changes and compute labels in model * chore: lint fix fix: remove unused imports fix: include missing imports test: fix tests Co-authored-by: Francisco Aranda (cherry picked from commit 6f19e405356dda26c6d9d57bff833f6a339ec38d) --- .../labeling-rules/RuleEmptyQuery.vue | 6 +- frontend/models/TextClassification.js | 12 ++++ .../tasks/text_classification/metrics.py | 47 +++++----------- tests/server/metrics/test_api.py | 56 +++++++++++++++++++ 4 files changed, 86 insertions(+), 35 deletions(-) diff --git a/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue b/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue index 4ba78129a8..08ac5a15cf 100644 --- a/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue +++ b/frontend/components/text-classifier/labeling-rules/RuleEmptyQuery.vue @@ -68,12 +68,12 @@ export default { maxVisibleLabels() { return DatasetViewSettings.MAX_VISIBLE_LABELS; }, - query() { - return this.dataset.query.text; - }, labels() { return this.dataset.labels.map((l) => ({ class: l, selected: false })); }, + query() { + return this.dataset.query.text; + }, sortedLabels() { return this.labels.slice().sort((a, b) => (a.score > b.score ? -1 : 1)); }, diff --git a/frontend/models/TextClassification.js b/frontend/models/TextClassification.js index aa6570a4df..fb1edb58e9 100644 --- a/frontend/models/TextClassification.js +++ b/frontend/models/TextClassification.js @@ -272,16 +272,28 @@ class TextClassificationDataset extends ObservationDataset { get labels() { const { labels } = (this.metadata || {})[USER_DATA_METADATA_KEY] || {}; const aggregations = this.globalResults.aggregations; + const label2str = (label) => label.class; + const recordsLabels = this.results.records.flatMap((record) => { + return [] + .concat( + record.annotation ? record.annotation.labels.map(label2str) : [] + ) + .concat( + record.prediction ? record.prediction.labels.map(label2str) : [] + ); + }); const uniqueLabels = [ ...new Set( (labels || []) .filter((l) => l && l.trim()) .concat(this._labels || []) + .concat(recordsLabels) .concat(Object.keys(aggregations.annotated_as)) .concat(Object.keys(aggregations.predicted_as)) ), ]; + uniqueLabels.sort(); return uniqueLabels; } diff --git a/src/rubrix/server/tasks/text_classification/metrics.py b/src/rubrix/server/tasks/text_classification/metrics.py index 8ec35ccf14..d1e802e808 100644 --- a/src/rubrix/server/tasks/text_classification/metrics.py +++ b/src/rubrix/server/tasks/text_classification/metrics.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Iterable, List, Optional +from typing import Any, ClassVar, Dict, Iterable, List from pydantic import Field from sklearn.metrics import precision_recall_fscore_support @@ -87,39 +87,22 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any: } -class DatasetLabels(TermsAggregation): +class DatasetLabels(PythonMetric): id: str = Field("dataset_labels", const=True) name: str = Field("The dataset labels", const=True) - fixed_size: int = Field(1500, const=True) - script: Dict[str, Any] = Field( - { - "lang": "painless", - "source": """ - def all_labels = []; - def predicted = doc.containsKey('prediction.labels.class_label.keyword') - ? doc['prediction.labels.class_label.keyword'] : null; - def annotated = doc.containsKey('annotation.labels.class_label.keyword') - ? doc['annotation.labels.class_label.keyword'] : null; - - if (predicted != null && predicted.size() > 0) { - for (int i = 0; i < predicted.length; i++) { - all_labels.add(predicted[i]) - } - } - - if (annotated != null && annotated.size() > 0) { - for (int i = 0; i < annotated.length; i++) { - all_labels.add(annotated[i]) - } - } - return all_labels; - """, - }, - const=True, - ) - - def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]: - return {"labels": [k for k in (aggregation_result or {}).keys()]} + + def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]: + ds_labels = set() + for record in records: + if record.annotation: + ds_labels.update( + [label.class_label for label in record.annotation.labels] + ) + if record.prediction: + ds_labels.update( + [label.class_label for label in record.prediction.labels] + ) + return {"labels": ds_labels or []} class TextClassificationMetrics(BaseTaskMetrics[TextClassificationRecord]): diff --git a/tests/server/metrics/test_api.py b/tests/server/metrics/test_api.py index 627852e0bb..ec3d17ba69 100644 --- a/tests/server/metrics/test_api.py +++ b/tests/server/metrics/test_api.py @@ -206,3 +206,59 @@ def test_dataset_metrics(): json={}, ) assert response.status_code == 200, f"{metric}: {response.json()}" + + +def test_dataset_labels_for_text_classification(): + dataset = "test_dataset_labels_for_text_classification" + assert client.delete(f"/api/datasets/{dataset}").status_code == 200 + records = [ + TextClassificationRecord.parse_obj(data) + for data in [ + { + "id": 0, + "inputs": {"text": "Some test data"}, + "prediction": {"agent": "test", "labels": [{"class": "A"}]}, + }, + { + "id": 1, + "inputs": {"text": "Some test data"}, + "annotation": {"agent": "test", "labels": [{"class": "B"}]}, + }, + { + "id": 2, + "inputs": {"text": "Some test data"}, + "prediction": { + "agent": "test", + "labels": [ + {"class": "A", "score": 0.5}, + { + "class": "D", + "score": 0.5, + }, + ], + }, + "annotation": {"agent": "test", "labels": [{"class": "E"}]}, + }, + ] + ] + request = TextClassificationBulkData(records=records) + dataset = "test_dataset_labels_for_text_classification" + + assert client.delete(f"/api/datasets/{dataset}").status_code == 200 + + assert ( + client.post( + f"/api/datasets/{dataset}/TextClassification:bulk", + json=request.dict(by_alias=True), + ).status_code + == 200 + ) + + response = client.post( + f"/api/datasets/TextClassification/{dataset}/metrics/dataset_labels:summary", + json={}, + ) + assert response.status_code == 200 + response = response.json() + labels = response["labels"] + assert sorted(labels) == ["A", "B", "D", "E"]