Skip to content

Commit

Permalink
refactor(label): use "AttrsMixin" in label module
Browse files Browse the repository at this point in the history
PR Closed: #635
  • Loading branch information
zhen.chen authored and AChenQ committed Jun 2, 2021
1 parent 1dff66e commit c435598
Show file tree
Hide file tree
Showing 22 changed files with 329 additions and 534 deletions.
3 changes: 2 additions & 1 deletion tensorbay/label/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
"""Label related classes."""

from .attributes import AttributeInfo, Items
from .basic import Label, LabelType
from .basic import LabelType
from .catalog import Catalog, Subcatalogs
from .label import Label
from .label_box import Box2DSubcatalog, Box3DSubcatalog, LabeledBox2D, LabeledBox3D
from .label_classification import Classification, ClassificationSubcatalog
from .label_keypoints import Keypoints2DSubcatalog, LabeledKeypoints2D
Expand Down
20 changes: 5 additions & 15 deletions tensorbay/label/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union

from ..utility import EqMixin, NameMixin, ReprMixin, common_loads
from ..utility import EqMixin, NameMixin, ReprMixin, attr, attr_base, camel, common_loads

_AvailaleType = Union[list, bool, int, float, str, None]
_SingleArgType = Union[str, None, Type[_AvailaleType]]
Expand Down Expand Up @@ -306,6 +306,9 @@ class AttributeInfo(NameMixin, Items):
_repr_attrs = ("name", "parent_categories") + Items._repr_attrs
_repr_maxlevel = 2

_attrs_base: Items = attr_base(key=None)
parent_categories: List[str] = attr(is_dynamic=True, key=camel)

def __init__(
self,
name: str,
Expand All @@ -329,13 +332,6 @@ def __init__(
else:
self.parent_categories = list(parent_categories)

def _loads(self, contents: Dict[str, Any]) -> None:
NameMixin._loads(self, contents)
Items._loads(self, contents)

if "parentCategories" in contents:
self.parent_categories = contents["parentCategories"]

@classmethod
def loads(cls: Type[_T], contents: Dict[str, Any]) -> _T:
"""Load an AttributeInfo from a dict containing the attribute information.
Expand Down Expand Up @@ -414,10 +410,4 @@ def dumps(self) -> Dict[str, Any]:
}
"""
contents: Dict[str, Any] = NameMixin._dumps(self)
contents.update(Items.dumps(self))

if hasattr(self, "parent_categories"):
contents["parentCategories"] = self.parent_categories

return contents
return self._dumps()
187 changes: 11 additions & 176 deletions tensorbay/label/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,21 @@
# Copyright 2021 Graviti. Licensed under MIT License.
#

"""LabelType, SubcatalogBase, Label.
"""LabelType, SubcatalogBase.
:class:`LabelType` is an enumeration type
which includes all the supported label types within :class:`Label`.
:class:`Subcatalogbase` is the base class for different types of subcatalogs,
which defines the basic concept of Subcatalog.
A :class:`~.tensorbay.dataset.data.Data` instance contains one or several types of labels,
all of which are stored in :attr:`~tensorbay.dataset.data.Data.label`.
A subcatalog class extends :class:`SubcatalogBase` and needed :class:`SubcatalogMixin` classes.
Different label types correspond to different label classes classes.
.. table:: label classes
:widths: auto
============================================================= ===================================
label classes explaination
============================================================= ===================================
:class:`~tensorbay.label.label_classification.Classification` classification type of label
:class:`~tensorbay.label.label_box.LabeledBox2D` 2D bounding box type of label
:class:`~tensorbay.label.label_box.LabeledBox3D` 3D bounding box type of label
:class:`~tensorbay.label.label_polygon.LabeledPolygon2D` 2D polygon type of label
:class:`~tensorbay.label.label_polyline.LabeledPolyline2D` 2D polyline type of label
:class:`~tensorbay.label.label_keypoints.LabeledKeypoints2D` 2D keypoints type of label
:class:`~tensorbay.label.label_sentence.LabeledSentence` transcripted sentence type of label
============================================================= ===================================
"""

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

from ..utility import EqMixin, ReprMixin, ReprType, TypeEnum, TypeMixin, common_loads
from .supports import SubcatalogMixin
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

if TYPE_CHECKING:
from .label_box import LabeledBox2D, LabeledBox3D
from .label_classification import Classification
from .label_keypoints import LabeledKeypoints2D
from .label_polygon import LabeledPolygon2D
from .label_polyline import LabeledPolyline2D
from .label_sentence import LabeledSentence
from ..utility import AttrsMixin, ReprMixin, ReprType, TypeEnum, TypeMixin, attr, common_loads


class LabelType(TypeEnum):
Expand Down Expand Up @@ -91,7 +62,7 @@ def subcatalog_type(self) -> Type[Any]:
return self.__subcatalog_registry__[self]


class SubcatalogBase(TypeMixin[LabelType], ReprMixin, EqMixin):
class SubcatalogBase(TypeMixin[LabelType], ReprMixin, AttrsMixin):
"""This is the base class for different types of subcatalogs.
It defines the basic concept of Subcatalog, which is the collection of the labels information.
Expand Down Expand Up @@ -120,22 +91,11 @@ class SubcatalogBase(TypeMixin[LabelType], ReprMixin, EqMixin):
"attributes",
"lexicon",
)

_supports: Tuple[Type[SubcatalogMixin], ...]
description: str = attr(default="")

def __init__(self, description: str = "") -> None:
self.description = description

def __init_subclass__(cls) -> None:
cls._supports = tuple(
filter(lambda class_: issubclass(class_, SubcatalogMixin), cls.__bases__)
)

def _loads(self, contents: Dict[str, Any]) -> None:
self.description = contents.get("description", "")
for support in self._supports:
support._loads(self, contents) # pylint: disable=protected-access

@classmethod
def loads(cls: Type[_T], contents: Dict[str, Any]) -> _T:
"""Loads a subcatalog from a dict containing the information of the subcatalog.
Expand All @@ -156,16 +116,10 @@ def dumps(self) -> Dict[str, Any]:
A dict containing all the information of the subcatalog.
"""
contents: Dict[str, Any] = {}
if self.description:
contents["description"] = self.description
return self._dumps()

for support in self._supports:
contents.update(support._dumps(self)) # pylint: disable=protected-access
return contents


class _LabelBase(TypeMixin[LabelType], ReprMixin, EqMixin):
class _LabelBase(AttrsMixin, TypeMixin[LabelType], ReprMixin):
"""This class defines the basic concept of label.
:class:`_LabelBase` is the most basic label level in the TensorBay dataset structure,
Expand All @@ -192,9 +146,9 @@ class _LabelBase(TypeMixin[LabelType], ReprMixin, EqMixin):

_AttributeType = Dict[str, Union[str, int, float, bool, List[Union[str, int, float, bool]]]]

category: str
attributes: _AttributeType
instance: str
category: str = attr(is_dynamic=True)
attributes: _AttributeType = attr(is_dynamic=True)
instance: str = attr(is_dynamic=True)

def __init__(
self,
Expand All @@ -209,11 +163,6 @@ def __init__(
if instance:
self.instance = instance

def _loads(self, contents: Dict[str, Any]) -> None:
for attribute_name in self._label_attrs:
if attribute_name in contents:
setattr(self, attribute_name, contents[attribute_name])

def dumps(self) -> Dict[str, Any]:
"""Dumps the label into a dict.
Expand All @@ -222,118 +171,4 @@ def dumps(self) -> Dict[str, Any]:
See dict format details in ``dumps()`` of different label classes .
"""
contents: Dict[str, Any] = {}
for attribute_name in self._label_attrs:
attribute_value = getattr(self, attribute_name, None)
if attribute_value:
contents[attribute_name] = attribute_value
return contents


class Label(ReprMixin, EqMixin):
"""This class defines :attr:`~tensorbay.dataset.data.Data.label`.
It contains growing types of labels referring to different tasks.
Examples:
>>> from tensorbay.label import Classification
>>> label = Label()
>>> label.classification = Classification("example_category", {"example_attribute1": "a"})
>>> label
Label(
(classification): Classification(
(category): 'example_category',
(attributes): {...}
)
)
"""

_T = TypeVar("_T", bound="Label")

_repr_type = ReprType.INSTANCE
_repr_attrs = tuple(label_type.value for label_type in LabelType)
_repr_maxlevel = 2

classification: "Classification"
box2d: List["LabeledBox2D"]
box3d: List["LabeledBox3D"]
polygon2d: List["LabeledPolygon2D"]
polyline2d: List["LabeledPolyline2D"]
keypoints2d: List["LabeledKeypoints2D"]
sentence: List["LabeledSentence"]

def __bool__(self) -> bool:
for label_type in LabelType:
if hasattr(self, label_type.value):
return True
return False

def _loads(self, contents: Dict[str, Any]) -> None:
for key, labels in contents.items():
if key not in LabelType.__members__:
continue

label_type = LabelType[key]
if label_type == LabelType.CLASSIFICATION:
setattr(self, label_type.value, label_type.type.loads(labels))
else:
setattr(
self,
label_type.value,
[label_type.type.loads(label) for label in labels],
)

@classmethod
def loads(cls: Type[_T], contents: Dict[str, Any]) -> _T:
"""Loads data from a dict containing the labels information.
Arguments:
contents: A dict containing the labels information.
Returns:
A :class:`Label` instance containing labels information from the given dict.
Examples:
>>> contents = {
... "CLASSIFICATION": {
... "category": "example_category",
... "attributes": {"example_attribute1": "a"}
... }
... }
>>> Label.loads(contents)
Label(
(classification): Classification(
(category): 'example_category',
(attributes): {...}
)
)
"""
return common_loads(cls, contents)

def dumps(self) -> Dict[str, Any]:
"""Dumps all labels into a dict.
Returns:
Dumped labels dict.
Examples:
>>> from tensorbay.label import Classification
>>> label = Label()
>>> label.classification = Classification("category1", {"attribute1": "a"})
>>> label.dumps()
{'CLASSIFICATION': {'category': 'category1', 'attributes': {'attribute1': 'a'}}}
"""
contents: Dict[str, Any] = {}
for label_type in LabelType:
labels = getattr(self, label_type.value, None)
if labels is None:
continue
if label_type == LabelType.CLASSIFICATION:
contents[label_type.name] = labels.dumps()
else:
contents[label_type.name] = [label.dumps() for label in labels]

return contents
return self._dumps()
30 changes: 10 additions & 20 deletions tensorbay/label/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from typing import Any, Dict, Type, TypeVar, Union

from ..label import LabelType
from ..utility import EqMixin, ReprMixin, ReprType, common_loads
from ..utility import AttrsMixin, ReprMixin, ReprType, attr, common_loads, upper
from .label_box import Box2DSubcatalog, Box3DSubcatalog
from .label_classification import ClassificationSubcatalog
from .label_keypoints import Keypoints2DSubcatalog
Expand All @@ -51,7 +51,7 @@
]


class Catalog(ReprMixin, EqMixin):
class Catalog(ReprMixin, AttrsMixin):
"""This class defines the concept of catalog.
:class:`Catalog` is used to describe the types of labels
Expand Down Expand Up @@ -87,25 +87,20 @@ class Catalog(ReprMixin, EqMixin):
_repr_attrs = tuple(label_type.value for label_type in LabelType)
_repr_maxlevel = 2

classification: ClassificationSubcatalog
box2d: Box2DSubcatalog
box3d: Box3DSubcatalog
polygon2d: Polygon2DSubcatalog
polyline2d: Polyline2DSubcatalog
keypoints2d: Keypoints2DSubcatalog
sentence: SentenceSubcatalog
classification: ClassificationSubcatalog = attr(is_dynamic=True, key=upper)
box2d: Box2DSubcatalog = attr(is_dynamic=True, key=upper)
box3d: Box3DSubcatalog = attr(is_dynamic=True, key=upper)
polygon2d: Polygon2DSubcatalog = attr(is_dynamic=True, key=upper)
polyline2d: Polyline2DSubcatalog = attr(is_dynamic=True, key=upper)
keypoints2d: Keypoints2DSubcatalog = attr(is_dynamic=True, key=upper)
sentence: SentenceSubcatalog = attr(is_dynamic=True, key=upper)

def __bool__(self) -> bool:
for label_type in LabelType:
if hasattr(self, label_type.value):
return True
return False

def _loads(self, contents: Dict[str, Any]) -> None:
for type_name, subcatalog in contents.items():
label_type = LabelType[type_name]
setattr(self, label_type.value, label_type.subcatalog_type.loads(subcatalog))

@classmethod
def loads(cls: Type[_T], contents: Dict[str, Any]) -> _T:
"""Load a Catalog from a dict containing the catalog information.
Expand Down Expand Up @@ -159,9 +154,4 @@ def dumps(self) -> Dict[str, Any]:
{'CLASSIFICATION': {'categories': [{'name': 'example'}]}}
"""
contents: Dict[str, Any] = {}
for label_type in LabelType:
subcatalog = getattr(self, label_type.value, None)
if subcatalog:
contents[label_type.name] = subcatalog.dumps()
return contents
return self._dumps()

0 comments on commit c435598

Please sign in to comment.