Skip to content

Commit

Permalink
fix(#1238): show prediction labels when annotating rule (#1239)
Browse files Browse the repository at this point in the history
This PR shows prediction labels when annotating rule in Weak Supervision

Closes #1238

* fix(metrics): compute dataset labels as python metric

* test: fix tests

* fix: compute dataset label properly

* fix show all labels for empty query view

* empty query

* refactor: revert comp. changes and compute labels in model

* chore: lint fix

fix: remove unused imports

fix: include missing imports

test: fix tests

Co-authored-by: Francisco Aranda <francisco@recogn.ai>
(cherry picked from commit 6f19e40)
  • Loading branch information
leiyre authored and frascuchon committed Mar 11, 2022
1 parent 87894cd commit 28e97c6
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 35 deletions.
Expand Up @@ -68,12 +68,12 @@ export default {
maxVisibleLabels() {
return DatasetViewSettings.MAX_VISIBLE_LABELS;
},
query() {
return this.dataset.query.text;
},
labels() {
return this.dataset.labels.map((l) => ({ class: l, selected: false }));
},
query() {
return this.dataset.query.text;
},
sortedLabels() {
return this.labels.slice().sort((a, b) => (a.score > b.score ? -1 : 1));
},
Expand Down
12 changes: 12 additions & 0 deletions frontend/models/TextClassification.js
Expand Up @@ -272,16 +272,28 @@ class TextClassificationDataset extends ObservationDataset {
get labels() {
const { labels } = (this.metadata || {})[USER_DATA_METADATA_KEY] || {};
const aggregations = this.globalResults.aggregations;
const label2str = (label) => label.class;
const recordsLabels = this.results.records.flatMap((record) => {
return []
.concat(
record.annotation ? record.annotation.labels.map(label2str) : []
)
.concat(
record.prediction ? record.prediction.labels.map(label2str) : []
);
});

const uniqueLabels = [
...new Set(
(labels || [])
.filter((l) => l && l.trim())
.concat(this._labels || [])
.concat(recordsLabels)
.concat(Object.keys(aggregations.annotated_as))
.concat(Object.keys(aggregations.predicted_as))
),
];

uniqueLabels.sort();
return uniqueLabels;
}
Expand Down
47 changes: 15 additions & 32 deletions src/rubrix/server/tasks/text_classification/metrics.py
@@ -1,4 +1,4 @@
from typing import Any, ClassVar, Dict, Iterable, List, Optional
from typing import Any, ClassVar, Dict, Iterable, List

from pydantic import Field
from sklearn.metrics import precision_recall_fscore_support
Expand Down Expand Up @@ -87,39 +87,22 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any:
}


class DatasetLabels(TermsAggregation):
class DatasetLabels(PythonMetric):
id: str = Field("dataset_labels", const=True)
name: str = Field("The dataset labels", const=True)
fixed_size: int = Field(1500, const=True)
script: Dict[str, Any] = Field(
{
"lang": "painless",
"source": """
def all_labels = [];
def predicted = doc.containsKey('prediction.labels.class_label.keyword')
? doc['prediction.labels.class_label.keyword'] : null;
def annotated = doc.containsKey('annotation.labels.class_label.keyword')
? doc['annotation.labels.class_label.keyword'] : null;
if (predicted != null && predicted.size() > 0) {
for (int i = 0; i < predicted.length; i++) {
all_labels.add(predicted[i])
}
}
if (annotated != null && annotated.size() > 0) {
for (int i = 0; i < annotated.length; i++) {
all_labels.add(annotated[i])
}
}
return all_labels;
""",
},
const=True,
)

def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
return {"labels": [k for k in (aggregation_result or {}).keys()]}

def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]:
ds_labels = set()
for record in records:
if record.annotation:
ds_labels.update(
[label.class_label for label in record.annotation.labels]
)
if record.prediction:
ds_labels.update(
[label.class_label for label in record.prediction.labels]
)
return {"labels": ds_labels or []}


class TextClassificationMetrics(BaseTaskMetrics[TextClassificationRecord]):
Expand Down
56 changes: 56 additions & 0 deletions tests/server/metrics/test_api.py
Expand Up @@ -206,3 +206,59 @@ def test_dataset_metrics():
json={},
)
assert response.status_code == 200, f"{metric}: {response.json()}"


def test_dataset_labels_for_text_classification():
dataset = "test_dataset_labels_for_text_classification"
assert client.delete(f"/api/datasets/{dataset}").status_code == 200
records = [
TextClassificationRecord.parse_obj(data)
for data in [
{
"id": 0,
"inputs": {"text": "Some test data"},
"prediction": {"agent": "test", "labels": [{"class": "A"}]},
},
{
"id": 1,
"inputs": {"text": "Some test data"},
"annotation": {"agent": "test", "labels": [{"class": "B"}]},
},
{
"id": 2,
"inputs": {"text": "Some test data"},
"prediction": {
"agent": "test",
"labels": [
{"class": "A", "score": 0.5},
{
"class": "D",
"score": 0.5,
},
],
},
"annotation": {"agent": "test", "labels": [{"class": "E"}]},
},
]
]
request = TextClassificationBulkData(records=records)
dataset = "test_dataset_labels_for_text_classification"

assert client.delete(f"/api/datasets/{dataset}").status_code == 200

assert (
client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
json=request.dict(by_alias=True),
).status_code
== 200
)

response = client.post(
f"/api/datasets/TextClassification/{dataset}/metrics/dataset_labels:summary",
json={},
)
assert response.status_code == 200
response = response.json()
labels = response["labels"]
assert sorted(labels) == ["A", "B", "D", "E"]

0 comments on commit 28e97c6

Please sign in to comment.