Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(#1238): show prediction labels when annotating rule #1239

Merged
merged 9 commits into from Mar 10, 2022
Expand Up @@ -68,12 +68,12 @@ export default {
maxVisibleLabels() {
return DatasetViewSettings.MAX_VISIBLE_LABELS;
},
query() {
return this.dataset.query.text;
},
labels() {
return this.dataset.labels.map((l) => ({ class: l, selected: false }));
},
query() {
return this.dataset.query.text;
},
sortedLabels() {
return this.labels.slice().sort((a, b) => (a.score > b.score ? -1 : 1));
},
Expand Down
12 changes: 12 additions & 0 deletions frontend/models/TextClassification.js
Expand Up @@ -272,16 +272,28 @@ class TextClassificationDataset extends ObservationDataset {
get labels() {
const { labels } = (this.metadata || {})[USER_DATA_METADATA_KEY] || {};
const aggregations = this.globalResults.aggregations;
const label2str = (label) => label.class;
const recordsLabels = this.results.records.flatMap((record) => {
return []
.concat(
record.annotation ? record.annotation.labels.map(label2str) : []
)
.concat(
record.prediction ? record.prediction.labels.map(label2str) : []
);
});

const uniqueLabels = [
...new Set(
(labels || [])
.filter((l) => l && l.trim())
.concat(this._labels || [])
.concat(recordsLabels)
.concat(Object.keys(aggregations.annotated_as))
.concat(Object.keys(aggregations.predicted_as))
),
];

uniqueLabels.sort();
return uniqueLabels;
}
Expand Down
50 changes: 16 additions & 34 deletions src/rubrix/server/tasks/text_classification/metrics.py
@@ -1,13 +1,12 @@
from typing import Any, ClassVar, Dict, Iterable, List, Optional
from typing import Any, ClassVar, Dict, Iterable, List

from pydantic import Field
from sklearn.metrics import precision_recall_fscore_support
from sklearn.preprocessing import MultiLabelBinarizer

from rubrix.server.tasks.commons.metrics import CommonTasksMetrics
from rubrix.server.tasks.commons.metrics import CommonTasksMetrics, GenericRecord
from rubrix.server.tasks.commons.metrics.model.base import (
BaseMetric,
BaseTaskMetrics,
PythonMetric,
TermsAggregation,
)
Expand Down Expand Up @@ -87,39 +86,22 @@ def apply(self, records: Iterable[TextClassificationRecord]) -> Any:
}


class DatasetLabels(TermsAggregation):
class DatasetLabels(PythonMetric):
id: str = Field("dataset_labels", const=True)
name: str = Field("The dataset labels", const=True)
fixed_size: int = Field(1500, const=True)
script: Dict[str, Any] = Field(
{
"lang": "painless",
"source": """
def all_labels = [];
def predicted = doc.containsKey('prediction.labels.class_label.keyword')
? doc['prediction.labels.class_label.keyword'] : null;
def annotated = doc.containsKey('annotation.labels.class_label.keyword')
? doc['annotation.labels.class_label.keyword'] : null;

if (predicted != null && predicted.size() > 0) {
for (int i = 0; i < predicted.length; i++) {
all_labels.add(predicted[i])
}
}

if (annotated != null && annotated.size() > 0) {
for (int i = 0; i < annotated.length; i++) {
all_labels.add(annotated[i])
}
}
return all_labels;
""",
},
const=True,
)

def aggregation_result(self, aggregation_result: Dict[str, Any]) -> Dict[str, Any]:
return {"labels": [k for k in (aggregation_result or {}).keys()]}

def apply(self, records: Iterable[TextClassificationRecord]) -> Dict[str, Any]:
ds_labels = set()
for record in records:
if record.annotation:
ds_labels.update(
[label.class_label for label in record.annotation.labels]
)
if record.prediction:
ds_labels.update(
[label.class_label for label in record.prediction.labels]
)
return {"labels": ds_labels or []}


class TextClassificationMetrics(CommonTasksMetrics[TextClassificationRecord]):
Expand Down
54 changes: 54 additions & 0 deletions tests/server/metrics/test_api.py
Expand Up @@ -209,3 +209,57 @@ def test_dataset_metrics(mocked_client):
json={},
)
assert response.status_code == 200, f"{metric}: {response.json()}"


def test_dataset_labels_for_text_classification(mocked_client):
records = [
TextClassificationRecord.parse_obj(data)
for data in [
{
"id": 0,
"inputs": {"text": "Some test data"},
"prediction": {"agent": "test", "labels": [{"class": "A"}]},
},
{
"id": 1,
"inputs": {"text": "Some test data"},
"annotation": {"agent": "test", "labels": [{"class": "B"}]},
},
{
"id": 2,
"inputs": {"text": "Some test data"},
"prediction": {
"agent": "test",
"labels": [
{"class": "A", "score": 0.5},
{
"class": "D",
"score": 0.5,
},
],
},
"annotation": {"agent": "test", "labels": [{"class": "E"}]},
},
]
]
request = TextClassificationBulkData(records=records)
dataset = "test_dataset_labels_for_text_classification"

assert mocked_client.delete(f"/api/datasets/{dataset}").status_code == 200

assert (
mocked_client.post(
f"/api/datasets/{dataset}/TextClassification:bulk",
json=request.dict(by_alias=True),
).status_code
== 200
)

response = mocked_client.post(
f"/api/datasets/TextClassification/{dataset}/metrics/dataset_labels:summary",
json={},
)
assert response.status_code == 200
response = response.json()
labels = response["labels"]
assert sorted(labels) == ["A", "B", "D", "E"]