Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion labelbox/data/annotation_types/metrics/scalar.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Optional, Union
from enum import Enum

from pydantic import confloat
from pydantic import confloat, validator

from .base import ConfidenceValue, BaseMetric

Expand All @@ -16,6 +16,11 @@ class ScalarMetricAggregation(Enum):
SUM = "SUM"


RESERVED_METRIC_NAMES = ('true_positive_count', 'false_positive_count',
'true_negative_count', 'false_negative_count',
'precision', 'recall', 'f1', 'iou')


class ScalarMetric(BaseMetric):
""" Class representing scalar metrics

Expand All @@ -28,6 +33,16 @@ class ScalarMetric(BaseMetric):
value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN

@validator('metric_name')
def validate_metric_name(cls, name: Union[str, None]):
if name is None:
return None
clean_name = name.lower().strip()
if clean_name in RESERVED_METRIC_NAMES:
raise ValueError(f"`{clean_name}` is a reserved metric name. "
"Please provide another value for `metric_name`.")
return name

def dict(self, *args, **kwargs):
res = super().dict(*args, **kwargs)
if res.get('metric_name') is None:
Expand Down
6 changes: 4 additions & 2 deletions labelbox/data/metrics/iou/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def miou_metric(ground_truths: List[Union[ObjectAnnotation,
# If both gt and preds are empty there is no metric
if iou is None:
return []
return [ScalarMetric(metric_name="iou", value=iou)]
return [ScalarMetric(metric_name="custom_iou", value=iou)]


def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation,
Expand Down Expand Up @@ -62,7 +62,9 @@ def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation,
if value is None:
continue
metrics.append(
ScalarMetric(metric_name="iou", feature_name=key, value=value))
ScalarMetric(metric_name="custom_iou",
feature_name=key,
value=value))
return metrics


Expand Down
12 changes: 10 additions & 2 deletions tests/data/annotation_types/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric, ScalarMetric
from labelbox.data.annotation_types.collection import LabelList
from labelbox.data.annotation_types import ScalarMetric, Label, ImageData
from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES


def test_legacy_scalar_metric():
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_legacy_scalar_metric():
])
def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
metric = ScalarMetric(metric_name="iou",
metric = ScalarMetric(metric_name="custom_iou",
value=value,
feature_name=feature_name,
subclass_name=subclass_name,
Expand All @@ -80,7 +81,7 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
'value':
value,
'metric_name':
'iou',
'custom_iou',
**({
'feature_name': feature_name
} if feature_name else {}),
Expand Down Expand Up @@ -192,3 +193,10 @@ def test_invalid_number_of_confidence_scores():
metric_name="too many scores",
value={i / 20.: [0, 1, 2, 3] for i in range(20)})
assert "Number of confidence scores must be greater" in str(exc_info.value)


@pytest.mark.parametrize("metric_name", RESERVED_METRIC_NAMES)
def test_reserved_names(metric_name: str):
with pytest.raises(ValidationError) as exc_info:
ScalarMetric(metric_name=metric_name, value=0.5)
assert 'is a reserved metric name' in exc_info.value.errors()[0]['msg']
6 changes: 3 additions & 3 deletions tests/data/assets/ndjson/custom_scalar_import.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "subclassName" : "sample_subclass", "aggregation" : "SUM"},
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "aggregation" : "SUM"},
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : { "0.1" : 0.1, "0.2" : 0.5}, "metricName" : "iou", "aggregation" : "SUM"}]
[{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "custom_iou", "featureName" : "sample_class", "subclassName" : "sample_subclass", "aggregation" : "SUM"},
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "custom_iou", "featureName" : "sample_class", "aggregation" : "SUM"},
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : { "0.1" : 0.1, "0.2" : 0.5}, "metricName" : "custom_iou", "aggregation" : "SUM"}]
2 changes: 1 addition & 1 deletion tests/data/metrics/iou/feature/test_feature_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def check_iou(pair):
assert math.isclose(result[key], pair.expected[key])

for metric in metrics:
assert metric.metric_name == "iou"
assert metric.metric_name == "custom_iou"

if len(pair.expected):
assert len(one_metrics)
Expand Down