diff --git a/labelbox/__init__.py b/labelbox/__init__.py index 115023bd4..680e44068 100644 --- a/labelbox/__init__.py +++ b/labelbox/__init__.py @@ -14,3 +14,4 @@ from labelbox.schema.asset_metadata import AssetMetadata from labelbox.schema.webhook import Webhook from labelbox.schema.prediction import Prediction, PredictionModel +from labelbox.schema.ontology import Ontology diff --git a/labelbox/orm/model.py b/labelbox/orm/model.py index ee93eea22..15f8d5cee 100644 --- a/labelbox/orm/model.py +++ b/labelbox/orm/model.py @@ -42,6 +42,7 @@ class Type(Enum): Boolean = auto() ID = auto() DateTime = auto() + Json = auto() class EnumType: @@ -85,6 +86,10 @@ def DateTime(*args): def Enum(enum_cls: type, *args): return Field(Field.EnumType(enum_cls), *args) + @staticmethod + def Json(*args): + return Field(Field.Type.Json, *args) + def __init__(self, field_type: Union[Type, EnumType], name, diff --git a/labelbox/schema/__init__.py b/labelbox/schema/__init__.py index eadb49ab8..223920a1b 100644 --- a/labelbox/schema/__init__.py +++ b/labelbox/schema/__init__.py @@ -12,3 +12,4 @@ import labelbox.schema.user import labelbox.schema.webhook import labelbox.schema.prediction +import labelbox.schema.ontology diff --git a/labelbox/schema/ontology.py b/labelbox/schema/ontology.py new file mode 100644 index 000000000..301507884 --- /dev/null +++ b/labelbox/schema/ontology.py @@ -0,0 +1,120 @@ +"""Client side object for interacting with the ontology.""" +import abc +from dataclasses import dataclass + +from typing import Any, Callable, Dict, List, Optional, Union + +from labelbox.orm import query +from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable +from labelbox.orm.model import Entity, Field, Relationship +from labelbox.utils import snake_case, camel_case + + +@dataclass +class OntologyEntity: + required: bool + name: str + + +@dataclass +class Option: + label: str + value: str + feature_schema_id: Optional[str] = None + schema_node_id: Optional[str] = None + + @classmethod + def from_json(cls, json_dict): + _dict = convert_keys(json_dict, snake_case) + return cls(**_dict) + + +@dataclass +class Classification(OntologyEntity): + type: str + instructions: str + options: List[Option] + feature_schema_id: Optional[str] = None + schema_node_id: Optional[str] = None + + @classmethod + def from_json(cls, json_dict): + _dict = convert_keys(json_dict, snake_case) + _dict['options'] = [ + Option.from_json(option) for option in _dict['options'] + ] + return cls(**_dict) + + +@dataclass +class Tool(OntologyEntity): + tool: str + color: str + classifications: List[Classification] + feature_schema_id: Optional[str] = None + schema_node_id: Optional[str] = None + + @classmethod + def from_json(cls, json_dict): + _dict = convert_keys(json_dict, snake_case) + _dict['classifications'] = [ + Classification.from_json(classification) + for classification in _dict['classifications'] + ] + return cls(**_dict) + + +class Ontology(DbObject): + """ A ontology specifies which tools and classifications are available + to a project. + + NOTE: This is read only for now. + + >>> project = client.get_project(name="") + >>> ontology = project.ontology() + >>> ontology.normalized + + """ + + name = Field.String("name") + description = Field.String("description") + updated_at = Field.DateTime("updated_at") + created_at = Field.DateTime("created_at") + normalized = Field.Json("normalized") + object_schema_count = Field.Int("object_schema_count") + classification_schema_count = Field.Int("classification_schema_count") + + projects = Relationship.ToMany("Project", True) + created_by = Relationship.ToOne("User", False, "created_by") + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._tools: Optional[List[Tool]] = None + self._classifications: Optional[List[Classification]] = None + + def tools(self) -> List[Tool]: + if self._tools is None: + self._tools = [ + Tool.from_json(tool) for tool in self.normalized['tools'] + ] + return self._tools # type: ignore + + def classifications(self) -> List[Classification]: + if self._classifications is None: + self._classifications = [ + Classification.from_json(classification) + for classification in self.normalized['classifications'] + ] + return self._classifications # type: ignore + + +def convert_keys(json_dict: Dict[str, Any], + converter: Callable) -> Dict[str, Any]: + if isinstance(json_dict, dict): + return { + converter(key): convert_keys(value, converter) + for key, value in json_dict.items() + } + if isinstance(json_dict, list): + return [convert_keys(ele, converter) for ele in json_dict] + return json_dict diff --git a/labelbox/schema/project.py b/labelbox/schema/project.py index 9ebb0b43f..c83b2db96 100644 --- a/labelbox/schema/project.py +++ b/labelbox/schema/project.py @@ -46,6 +46,7 @@ class Project(DbObject, Updateable, Deletable): active_prediction_model = Relationship.ToOne("PredictionModel", False, "active_prediction_model") predictions = Relationship.ToMany("Prediction", False) + ontology = Relationship.ToOne("Ontology", True) def create_label(self, **kwargs): """ Creates a label on this Project. diff --git a/tests/integration/test_ontology.py b/tests/integration/test_ontology.py new file mode 100644 index 000000000..9f6f7e257 --- /dev/null +++ b/tests/integration/test_ontology.py @@ -0,0 +1,73 @@ +import unittest +from typing import Any, Dict, List, Union + + +def sample_ontology() -> Dict[str, Any]: + return { + "tools": [{ + "required": False, + "name": "Dog", + "color": "#FF0000", + "tool": "rectangle", + "classifications": [] + }], + "classifications": [{ + "required": + True, + "instructions": + "This is a question.", + "name": + "this_is_a_question.", + "type": + "radio", + "options": [{ + "label": "Yes", + "value": "yes" + }, { + "label": "No", + "value": "no" + }] + }] + } + + +def test_create_ontology(client, project) -> None: + """ Tests that the ontology that a project was set up with can be grabbed.""" + frontend = list(client.get_labeling_frontends())[0] + project.setup(frontend, sample_ontology()) + normalized_ontology = project.ontology().normalized + + def _remove_schema_ids( + ontology_part: Union[List, Dict[str, Any]]) -> Dict[str, Any]: + """ Recursively scrub the normalized ontology of any schema information.""" + removals = {'featureSchemaId', 'schemaNodeId'} + + if isinstance(ontology_part, list): + return [_remove_schema_ids(part) for part in ontology_part] + if isinstance(ontology_part, dict): + return { + key: _remove_schema_ids(value) + for key, value in ontology_part.items() + if key not in removals + } + return ontology_part + + removed = _remove_schema_ids(normalized_ontology) + assert removed == sample_ontology() + + ontology = project.ontology() + + tools = ontology.tools() + assert tools + for tool in tools: + assert tool.feature_schema_id + assert tool.schema_node_id + + classifications = ontology.classifications() + assert classifications + for classification in classifications: + assert classification.feature_schema_id + assert classification.schema_node_id + for option in classification.options: + assert option.feature_schema_id + assert option.schema_node_id